Artificial Intelligence 13 min read

Practical Implementation of Vision Transformer (ViT) for Image Classification in PyTorch

This article walks readers through building, training, and evaluating a Vision Transformer (ViT) model for a five‑class flower classification task, providing detailed code snippets, model architecture explanations, training script adjustments, and experimental results that highlight the importance of pre‑trained weights.

Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Practical Implementation of Vision Transformer (ViT) for Image Classification in PyTorch

The article begins by introducing the Vision Transformer (ViT) model and its application to a five‑class flower classification problem, referencing previous posts that covered the theoretical background.

It then presents the ViT‑Base model specifications in a table, showing parameters such as patch size (16×16), number of layers (12), hidden size (768), MLP size (3072), number of heads (12), and total parameters (86M).

Next, the author explains the step‑by‑step construction of the ViT model in PyTorch, starting with the PatchEmbed class that converts images into patch embeddings using a convolutional layer, followed by the addition of a learnable class token and positional embeddings.

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

The code then shows how to create a dummy input tensor to test the model:

if __name__ == '__main__':
    input = torch.ones(1, 3, 224, 224)  # batch_size=1, image size 224x224
    print(input.shape)
    model = vit_base_patch16_224_in21k()
    output = model(input)
    print(output.shape)

Subsequent sections detail the encoder block, including LayerNorm, Multi‑Head Attention, DropPath, and MLP components, with their respective implementations ( Block , Attention , and Mlp classes). The article emphasizes that the tensor shape remains (1, 197, 768) throughout the encoder.

After the encoder, the class token is extracted, optionally passed through a pre_logits layer (a linear layer followed by Tanh when a representation size is defined), and finally through a classification head:

self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
 x = self.head(x)

The training script modifications are then described: updating the data path argument to point to the flower dataset and setting the path to the appropriate pre‑trained weight file (e.g., ./vit_base_patch16_224_in21k.pth ). The script also saves model checkpoints to a weights directory.

Experimental results are presented for three scenarios: training without pre‑trained weights for 10 and 50 epochs, and training with pre‑trained weights for 10 epochs. The results demonstrate that ViT heavily relies on pre‑training, achieving up to 97.1% accuracy with pre‑trained weights compared to much lower accuracy without them.

Finally, a prediction example shows the probability distribution for a tulip image, illustrating the model's inference capability.

image classificationDeep LearningPyTorchpretrained modelsVision TransformerViT
Rare Earth Juejin Tech Community
Written by

Rare Earth Juejin Tech Community

Juejin, a tech community that helps developers grow.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.