Fine-Tuning LLMs on TPU with Tunix: A Step‑by‑Step QLoRA Guide

This article introduces Google’s Tunix library for JAX‑based LLM post‑training, explains its core features such as supervised fine‑tuning, reinforcement learning and knowledge distillation, and provides detailed installation steps and a complete TPU‑accelerated QLoRA fine‑tuning workflow on the Gemma 2B model, including code snippets and inference testing.

Data Party THU
Data Party THU
Data Party THU
Fine-Tuning LLMs on TPU with Tunix: A Step‑by‑Step QLoRA Guide

Introduction

The JAX ecosystem has added strong support for large‑language‑model (LLM) post‑training, offering parallelism, TPU acceleration, and composable APIs. Google released Tunix , a JAX‑native library that provides supervised fine‑tuning, reinforcement‑learning (RL) fine‑tuning, and knowledge‑distillation pipelines for LLMs.

Core capabilities

Supervised fine‑tuning : Full‑parameter fine‑tuning and parameter‑efficient methods such as LoRA and Q‑LoRA.

Reinforcement learning : Implements PPO, GRPO, token‑level GSPO, and DPO for preference alignment.

Knowledge distillation : Logit‑based distribution matching, attention‑based transfer, and cross‑architecture feature pooling.

The library is modular, allowing components to be combined freely, and supports distributed training strategies including data parallel (DP), fully‑sharded data parallel (FSDP), and tensor parallelism with TPU‑specific optimizations.

Installation

Three installation methods are provided:

From PyPI (recommended): pip install "tunix[prod]" Directly from the GitHub main branch: pip install git+https://github.com/google/tunix Development mode from source:

git clone https://github.com/google/tunix.git
cd tunix
pip install -e ".[dev]"

TPU‑accelerated QLoRA fine‑tuning example (Gemma 2B)

Environment setup

pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets
pip install -q git+https://github.com/google/tunix
pip install -q git+https://github.com/google/qwix
# Upgrade Flax to the latest version
pip uninstall -q -y flax
pip install -q git+https://github.com/google/flax.git

Data preparation

from tunix.examples.data import translation_dataset
train_ds, validation_ds = translation_dataset.create_datasets(
    dataset_name="mtnt/en-fr",
    global_batch_size=16,
    max_target_length=256,
    num_train_epochs=3,
    tokenizer=tokenizer,
)

The example uses the MTNT English‑French parallel corpus with a global batch size of 16 and a maximum target length of 256 tokens.

Model and tokenizer initialization

from flax import nnx
from tunix.models.gemma import model as gemma_lib, params as params_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib

base_model = gemma_lib.Transformer.from_params(
    params_lib.load_and_format_params(kaggle_ckpt_path, "2b"),
    version="2b",
)

tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")

Attach QLoRA adapter

import qwix
lora_provider = qwix.LoraProvider(
    module_path=".*(q_einsum|kv_einsum|proj)",
    rank=16,
    alpha=2.0,
    weight_qtype="nf4",  # enable QLoRA quantization
)

lora_model = qwix.apply_lora_to_model(base_model, lora_provider)

The rank is set to 16, alpha to 2.0, and the weight quantization format to nf4.

Training

from tunix.sft import peft_trainer, utils
import optax

trainer = peft_trainer.PeftTrainer(
    lora_model,
    optimizer=optax.adamw(1e-3),
    config=peft_trainer.TrainingConfig(max_steps=100),
)
trainer.train(train_ds, validation_ds)

AdamW with a learning rate of 1e‑3 is used; the run performs 100 training steps as a quick sanity check.

Inference test

from tunix.generate import sampler as sampler_lib
sampler = sampler_lib.Sampler(
    transformer=lora_model,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=256,
        num_layers=base_model.num_layers,
        num_kv_heads=base_model.num_kv_heads,
        head_dim=base_model.head_dim,
    ),
)

input_batch = [
    "Translate this into French:
Hello, my name is Morgane.",
    "Translate this into French:
This dish is delicious!",
    "Translate this into French:
I am a student.",
    "Translate this into French:
How's the weather today?",
]

out_data = sampler(
    input_strings=input_batch,
    max_generation_steps=20,
)
for input_string, out_string in zip(input_batch, out_data.text):
    print("----------------------")
    print(f"Prompt:
{input_string}")
    print(f"Output:
{out_string}")

If QLoRA is used, replace lora_model with the merged qlora_model to reduce inference latency.

Conclusion

After 100 steps the model can generate translations, though quality is limited; longer training improves accuracy while keeping memory and speed reasonable. Tunix’s TPU‑first design, modular API, LoRA/QLoRA support, and comprehensive distributed‑training strategies make it a valuable tool for researchers adapting LLMs. Future releases are expected to broaden model and algorithm support.

Repository: https://github.com/google/tunix

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

AILLMFine-tuningQLoRAJAXTPUTunix
Data Party THU
Written by

Data Party THU

Official platform of Tsinghua Big Data Research Center, sharing the team's latest research, teaching updates, and big data news.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.