Transfer Learning with ShuffleNetV2 for Flower Classification

This article walks through building a PyTorch ShuffleNetV2 model, preparing the Kaggle Flowers dataset, training with transfer learning on a GPU, visualizing loss and accuracy, and performing inference on five test images, achieving nearly 90% validation accuracy after 95 epochs.

Code DAO
Code DAO
Code DAO
Transfer Learning with ShuffleNetV2 for Flower Classification

Introduction

The ShuffleNetV2 architecture, introduced in the paper ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design , provides an efficient convolutional neural network that runs well on ARM devices such as the Qualcomm Snapdragon 810. This guide demonstrates how to apply transfer learning with a pre‑trained ShuffleNetV2 model to recognize five flower classes from the Kaggle Flowers Recognition dataset.

Dataset

The dataset contains RGB images of five flower types: daisy, dandelion, rose, sunflower, and tulip, totaling 4,317 images. The directory structure is:

├── input
│   ├── flowers
│   │   ├── daisy
│   │   ├── dandelion
│   │   ├── rose
│   │   ├── sunflower
│   │   └── tulip
│   └── test_data
│       ├── daisy.jpg
│       ├── dandelion.jpg
│       ├── rose.jpg
│       ├── sunflower.jpg
│       └── tulip.jpg
├── outputs
│   ├── accuracy.png
│   ├── loss.png
│   └── model.pth
├── datasets.py
├── inference.py
├── model.py
├── train.py
└── utils.py

Images are 3‑channel RGB files.

Utility Functions (utils.py)

Two helper functions save the trained model and plot training curves.

import torch
import matplotlib.pyplot as plt

def save_model(epochs, model, optimizer, criterion):
    """Save the trained model to disk."""
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': criterion,
    }, 'outputs/model.pth')

def save_plots(train_acc, valid_acc, train_loss, valid_loss):
    """Save loss and accuracy plots to disk."""
    plt.figure(figsize=(10, 7))
    plt.plot(train_acc, color='green', linestyle='-', label='train accuracy')
    plt.plot(valid_acc, color='blue', linestyle='-', label='validation accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig('outputs/accuracy.png')

    plt.figure(figsize=(10, 7))
    plt.plot(train_loss, color='orange', linestyle='-', label='train loss')
    plt.plot(valid_loss, color='red', linestyle='-', label='validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('outputs/loss.png')

Dataset Preparation (datasets.py)

The script creates a training/validation split (80/20) with a batch size of 64 (reduced to 32 or 16 on OOM). ImageNet normalization is applied because the pre‑trained model was trained on ImageNet.

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

valid_split = 0.2
batch_size = 64
root_dir = 'input/flowers'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(root_dir, transform=transform)

dataset_size = len(dataset)
valid_size = int(valid_split * dataset_size)
train_size = dataset_size - valid_size
train_data, valid_data = torch.utils.data.random_split(dataset, [train_size, valid_size])

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=4)

Model Definition (model.py)

ShuffleNetV2 is loaded from torchvision.models. The final fully‑connected layer is replaced with a linear layer of 5 outputs (one per flower class). The function accepts pretrained and fine_tune flags; in this tutorial pretrained=True and fine_tune=False, so hidden layers are frozen.

import torchvision.models as models
import torch.nn as nn

def build_model(pretrained=True, fine_tune=True):
    if pretrained:
        print('[INFO]: Loading pre-trained weights')
    else:
        print('[INFO]: Not loading pre-trained weights')
    model = models.shufflenet_v2_x1_0(pretrained=pretrained)
    if fine_tune:
        print('[INFO]: Fine-tuning all layers...')
        for p in model.parameters():
            p.requires_grad = True
    else:
        print('[INFO]: Freezing hidden layers...')
        for p in model.parameters():
            p.requires_grad = False
    model.fc = nn.Linear(1024, 5)  # five flower classes
    return model

Training Script (train.py)

The script parses --epochs, builds the model, defines the Adam optimizer (learning rate 0.001), and uses cross‑entropy loss. It prints the total and trainable parameter counts, then runs a training loop that calls train() and validate() each epoch, storing loss and accuracy for later plotting.

import torch, argparse
from torch import nn, optim
from model import build_model
from utils import save_model, save_plots
from datasets import train_loader, valid_loader
from tqdm.auto import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', type=int, default=20,
                    help='number of epochs to train our network for')
args = vars(parser.parse_args())

lr = 0.001
epochs = args['epochs']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Computation device: {device}')

model = build_model(pretrained=True, fine_tune=False).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_params} total parameters.")
print(f"{trainable_params} training parameters.")

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

train_loss, valid_loss = [], []
train_acc, valid_acc = [], []

for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(model, train_loader, optimizer, criterion)
    valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader, criterion)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
    print('-' * 50)

save_model(epochs, model, optimizer, criterion)
save_plots(train_acc, valid_acc, train_loss, valid_loss)
print('TRAINING COMPLETE')

The train() function performs a forward pass, computes loss, updates accuracy counters, back‑propagates, and steps the optimizer. The validate() function runs in torch.no_grad() mode and returns loss and accuracy without updating weights.

Training Results

After 95 epochs the model reaches approximately 90% validation accuracy with a validation loss around 0.31. Accuracy and loss curves are saved as outputs/accuracy.png and outputs/loss.png (see images below).

Training accuracy curve after 95 epochs
Training accuracy curve after 95 epochs
Training and validation loss curves after 95 epochs
Training and validation loss curves after 95 epochs

Inference Script (inference.py)

The script loads the custom‑trained weights (CPU inference), defines the same ImageNet normalization transforms, reads an input image, preprocesses it, runs the model, and prints the ground‑truth and predicted class. It also draws the labels on the image and saves the result.

import torch, cv2, argparse
from torchvision import transforms
from model import build_model

parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', default='input/test_data/daisy.jpg',
                    help='path to the input image')
args = vars(parser.parse_args())

device = 'cpu'
labels = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']

model = build_model(pretrained=False, fine_tune=False).to(device)
print('[INFO]: Loading custom-trained weights...')
checkpoint = torch.load('outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

image = cv2.imread(args['input'])
gt_class = args['input'].split('/')[-1].split('.')[0]
orig_image = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image)
image = torch.unsqueeze(image, 0).to(device)

with torch.no_grad():
    outputs = model(image)
    _, pred_idx = torch.max(outputs, 1)
    pred_class = labels[int(pred_idx)]

cv2.putText(orig_image, f"GT: {gt_class}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2, cv2.LINE_AA)
cv2.putText(orig_image, f"Pred: {pred_class}", (10, 55), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, cv2.LINE_AA)
print(f"GT: {gt_class}, pred: {pred_class}")
cv2.imshow('Result', orig_image)
cv2.waitKey(0)
cv2.imwrite(f"outputs/{gt_class}.png", orig_image)

Inference Results

Running the script on each of the five test images yields correct predictions for all classes, as shown in the screenshots below.

Correct prediction on daisy image
Correct prediction on daisy image
Correct prediction on dandelion image
Correct prediction on dandelion image
Correct prediction on rose image
Correct prediction on rose image
Correct prediction on sunflower image
Correct prediction on sunflower image
Correct prediction on tulip image
Correct prediction on tulip image

Conclusion

The guide demonstrates a complete workflow for transfer learning with ShuffleNetV2: dataset preparation, model adaptation, training on GPU, performance visualization, and CPU inference on individual images. The resulting model achieves high accuracy on the five‑class flower dataset.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

CNNPyTorchimage recognitiontransfer learningShuffleNetV2flower classification
Code DAO
Written by

Code DAO

We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.