Understanding and Implementing LoRA (Low‑Rank Adaptation) for Model Training with PyTorch
This article explains the principle of LoRA (Low‑Rank Adaptation) for large language models, demonstrates how to decompose weight updates into low‑rank matrices, and provides a complete PyTorch implementation that fine‑tunes a small VGG‑19 network on a custom goldfish dataset.
1. Introduction
In the AIGC field the term “LoRA” (Low‑Rank Adaptation of Large Language Models) frequently appears as an efficient training method for large models. Because large models have massive parameters and require months of training, LoRA offers a resource‑saving alternative.
2. Model Training Basics
Standard gradient‑descent training consists of four steps: forward propagation to compute loss, backward propagation to compute gradients, parameter update using the gradients, and repeating until loss is sufficiently low. For a linear model with parameters W, input x, and output y, the pseudo‑code is:
# 4. Repeat steps 1‑3
for i in range(10000):
# 1. Forward pass
L = MSE(W @ x, y)
# 2. Backward pass
dW = gradient(L, W)
# 3. Update parameters
W -= lr * dW3. Introducing LoRA
Instead of updating the full weight matrix W, LoRA assumes the update matrix R is low‑rank and factorises it into A (size m×r) and B (size r×n) with r << m,n. This reduces the number of trainable parameters dramatically while preserving the expressive power of the original update.
When LoRA is applied, the prediction becomes y = W·x + (A·B)·x, which incurs only a slight overhead during inference.
4. Practical Implementation
4.1 Dataset Preparation
The VGG‑19 model pre‑trained on ImageNet is used as the base model. A tiny dataset containing five images of goldfish is placed under data/goldfish. A custom LoraDataset class loads the images and creates one‑hot labels.
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
class LoraDataset(data.Dataset):
def __init__(self, data_path="datas"):
categories = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"]
self.files = []
self.labels = []
for dir in os.listdir(data_path):
dirname = os.path.join(data_path, dir)
for file in os.listdir(dirname):
self.files.append(os.path.join(dirname, file))
self.labels.append(categories.index(dir))
def __getitem__(self, item):
image = Image.open(self.files[item]).convert("RGB")
label = torch.zeros(1000, dtype=torch.float64)
label[self.labels[item]] = 1.
return transform(image), label
def __len__(self):
return len(self.files)4.2 Defining the LoRA Layer
The LoRA adaptation is encapsulated in a PyTorch nn.Module that holds the two low‑rank matrices A and B.
class Lora(nn.Module):
def __init__(self, m, n, rank=10):
super().__init__()
self.m = m
self.A = nn.Parameter(torch.randn(m, rank))
self.B = nn.Parameter(torch.zeros(rank, n))
def forward(self, inputs):
inputs = inputs.view(-1, self.m)
return torch.mm(torch.mm(inputs, self.A), self.B)4.3 Training
The base VGG‑19 weights are frozen, and the LoRA layer is trained on the goldfish data using Adam optimizer and cross‑entropy loss.
# Load base model and LoRA
vgg19 = models.vgg19(models.VGG19_Weights.IMAGENET1K_V1)
for params in vgg19.parameters():
params.requires_grad = False
vgg19.eval()
lora = Lora(224*224*3, 1000)
# Data loader
lora_loader = data.DataLoader(LoraDataset(), batch_size=batch_size, shuffle=True)
# Optimizer and loss
optimizer = optim.Adam(lora.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
# Training loop
for epoch in range(epochs):
for image, label in lora_loader:
pred = vgg19(image) + lora(image)
loss = loss_fn(pred, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"loss: {loss.item()}")4.4 Testing
After training, the model is evaluated on the same loader; the predicted category is printed and the LoRA weights are saved.
# Test
for image, _ in lora_loader:
pred = vgg19(image) + lora(image)
idx = torch.argmax(pred, dim=1).item()
category = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"][idx]
print(category)
torch.save(lora.state_dict(), 'lora.pth')The output shows the model correctly predicts “goldfish” for all test images, demonstrating that the tiny LoRA module (≈5 MB) can adapt the large VGG‑19 backbone effectively.
5. Conclusion
LoRA provides an efficient way to fine‑tune large models by training only low‑rank adapters. This tutorial applied LoRA to a small classification network to illustrate the concept, though further experiments with larger datasets and models are needed to evaluate accuracy and efficiency trade‑offs.
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
