Transfer Learning with ShuffleNetV2 for Flower Classification
This article walks through building a PyTorch ShuffleNetV2 model, preparing the Kaggle Flowers dataset, training with transfer learning on a GPU, visualizing loss and accuracy, and performing inference on five test images, achieving nearly 90% validation accuracy after 95 epochs.
Introduction
The ShuffleNetV2 architecture, introduced in the paper ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design , provides an efficient convolutional neural network that runs well on ARM devices such as the Qualcomm Snapdragon 810. This guide demonstrates how to apply transfer learning with a pre‑trained ShuffleNetV2 model to recognize five flower classes from the Kaggle Flowers Recognition dataset.
Dataset
The dataset contains RGB images of five flower types: daisy, dandelion, rose, sunflower, and tulip, totaling 4,317 images. The directory structure is:
├── input
│ ├── flowers
│ │ ├── daisy
│ │ ├── dandelion
│ │ ├── rose
│ │ ├── sunflower
│ │ └── tulip
│ └── test_data
│ ├── daisy.jpg
│ ├── dandelion.jpg
│ ├── rose.jpg
│ ├── sunflower.jpg
│ └── tulip.jpg
├── outputs
│ ├── accuracy.png
│ ├── loss.png
│ └── model.pth
├── datasets.py
├── inference.py
├── model.py
├── train.py
└── utils.pyImages are 3‑channel RGB files.
Utility Functions (utils.py)
Two helper functions save the trained model and plot training curves.
import torch
import matplotlib.pyplot as plt
def save_model(epochs, model, optimizer, criterion):
"""Save the trained model to disk."""
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': criterion,
}, 'outputs/model.pth')
def save_plots(train_acc, valid_acc, train_loss, valid_loss):
"""Save loss and accuracy plots to disk."""
plt.figure(figsize=(10, 7))
plt.plot(train_acc, color='green', linestyle='-', label='train accuracy')
plt.plot(valid_acc, color='blue', linestyle='-', label='validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('outputs/accuracy.png')
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', linestyle='-', label='train loss')
plt.plot(valid_loss, color='red', linestyle='-', label='validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('outputs/loss.png')Dataset Preparation (datasets.py)
The script creates a training/validation split (80/20) with a batch size of 64 (reduced to 32 or 16 on OOM). ImageNet normalization is applied because the pre‑trained model was trained on ImageNet.
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
valid_split = 0.2
batch_size = 64
root_dir = 'input/flowers'
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(root_dir, transform=transform)
dataset_size = len(dataset)
valid_size = int(valid_split * dataset_size)
train_size = dataset_size - valid_size
train_data, valid_data = torch.utils.data.random_split(dataset, [train_size, valid_size])
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=4)Model Definition (model.py)
ShuffleNetV2 is loaded from torchvision.models. The final fully‑connected layer is replaced with a linear layer of 5 outputs (one per flower class). The function accepts pretrained and fine_tune flags; in this tutorial pretrained=True and fine_tune=False, so hidden layers are frozen.
import torchvision.models as models
import torch.nn as nn
def build_model(pretrained=True, fine_tune=True):
if pretrained:
print('[INFO]: Loading pre-trained weights')
else:
print('[INFO]: Not loading pre-trained weights')
model = models.shufflenet_v2_x1_0(pretrained=pretrained)
if fine_tune:
print('[INFO]: Fine-tuning all layers...')
for p in model.parameters():
p.requires_grad = True
else:
print('[INFO]: Freezing hidden layers...')
for p in model.parameters():
p.requires_grad = False
model.fc = nn.Linear(1024, 5) # five flower classes
return modelTraining Script (train.py)
The script parses --epochs, builds the model, defines the Adam optimizer (learning rate 0.001), and uses cross‑entropy loss. It prints the total and trainable parameter counts, then runs a training loop that calls train() and validate() each epoch, storing loss and accuracy for later plotting.
import torch, argparse
from torch import nn, optim
from model import build_model
from utils import save_model, save_plots
from datasets import train_loader, valid_loader
from tqdm.auto import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', type=int, default=20,
help='number of epochs to train our network for')
args = vars(parser.parse_args())
lr = 0.001
epochs = args['epochs']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Computation device: {device}')
model = build_model(pretrained=True, fine_tune=False).to(device)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_params} total parameters.")
print(f"{trainable_params} training parameters.")
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
for epoch in range(epochs):
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
train_epoch_loss, train_epoch_acc = train(model, train_loader, optimizer, criterion)
valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader, criterion)
train_loss.append(train_epoch_loss)
valid_loss.append(valid_epoch_loss)
train_acc.append(train_epoch_acc)
valid_acc.append(valid_epoch_acc)
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
print('-' * 50)
save_model(epochs, model, optimizer, criterion)
save_plots(train_acc, valid_acc, train_loss, valid_loss)
print('TRAINING COMPLETE')The train() function performs a forward pass, computes loss, updates accuracy counters, back‑propagates, and steps the optimizer. The validate() function runs in torch.no_grad() mode and returns loss and accuracy without updating weights.
Training Results
After 95 epochs the model reaches approximately 90% validation accuracy with a validation loss around 0.31. Accuracy and loss curves are saved as outputs/accuracy.png and outputs/loss.png (see images below).
Inference Script (inference.py)
The script loads the custom‑trained weights (CPU inference), defines the same ImageNet normalization transforms, reads an input image, preprocesses it, runs the model, and prints the ground‑truth and predicted class. It also draws the labels on the image and saves the result.
import torch, cv2, argparse
from torchvision import transforms
from model import build_model
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', default='input/test_data/daisy.jpg',
help='path to the input image')
args = vars(parser.parse_args())
device = 'cpu'
labels = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
model = build_model(pretrained=False, fine_tune=False).to(device)
print('[INFO]: Loading custom-trained weights...')
checkpoint = torch.load('outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = cv2.imread(args['input'])
gt_class = args['input'].split('/')[-1].split('.')[0]
orig_image = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image)
image = torch.unsqueeze(image, 0).to(device)
with torch.no_grad():
outputs = model(image)
_, pred_idx = torch.max(outputs, 1)
pred_class = labels[int(pred_idx)]
cv2.putText(orig_image, f"GT: {gt_class}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2, cv2.LINE_AA)
cv2.putText(orig_image, f"Pred: {pred_class}", (10, 55), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, cv2.LINE_AA)
print(f"GT: {gt_class}, pred: {pred_class}")
cv2.imshow('Result', orig_image)
cv2.waitKey(0)
cv2.imwrite(f"outputs/{gt_class}.png", orig_image)Inference Results
Running the script on each of the five test images yields correct predictions for all classes, as shown in the screenshots below.
Conclusion
The guide demonstrates a complete workflow for transfer learning with ShuffleNetV2: dataset preparation, model adaptation, training on GPU, performance visualization, and CPU inference on individual images. The resulting model achieves high accuracy on the five‑class flower dataset.
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Code DAO
We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
