Building a Satellite Image Classifier with PyTorch ResNet34

This article walks through creating a satellite image classification pipeline using PyTorch and a pretrained ResNet34 model, covering dataset preparation, project structure, data loading, model definition, training, validation, loss/accuracy plotting, and inference on new images with detailed code examples and results.

Code DAO
Code DAO
Code DAO
Building a Satellite Image Classifier with PyTorch ResNet34

This guide demonstrates how to build a satellite image classification project using PyTorch and the pretrained ResNet34 architecture.

Dataset

The Kaggle satellite image dataset contains roughly 5,600 RGB images divided into four classes: cloudy (1,500 images), desert (1,131 images), green_area (1,500 images) and water (1,500 images). The directory layout matches the class names.

Sample satellite images for each class
Sample satellite images for each class

Project Structure

├── input
│   ├── data
│   │   ├── cloudy
│   │   ├── desert
│   │   ├── green_area
│   │   └── water
│   └── test_data
│       ├── cloudy.jpeg
│       ├── desert.jpeg
│       ├── green_area.jpeg
│       └── water.jpeg
├── outputs
│   ├── accuracy.png
│   ├── loss.png
│   ├── model.pth
│   └── …
├── datasets.py
├── model.py
├── train.py
└── utils.py

Data Preparation (datasets.py)

The script uses torchvision.datasets.ImageFolder with a 20% validation split and a batch size of 64 (adjustable to 32 or 16 if GPU memory is limited). Training transforms include resizing to 224×224, random horizontal/vertical flips, Gaussian blur, random rotation, conversion to tensor, and ImageNet normalization. Validation transforms only resize, convert to tensor, and normalize.

train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.GaussianBlur(kernel_size=(5,9), sigma=(0.1,5)),
    transforms.RandomRotation(degrees=(30,70)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

valid_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

Model Definition (model.py)

A helper function build_model loads the pretrained ResNet34 weights, optionally freezes hidden layers, and replaces the final fully‑connected layer with a nn.Linear(512, num_classes) head.

def build_model(pretrained=True, fine_tune=True, num_classes=1):
    if pretrained:
        print('[INFO]: Loading pre-trained weights')
    else:
        print('[INFO]: Not loading pre-trained weights')
    model = models.resnet34(pretrained=pretrained)
    if fine_tune:
        print('[INFO]: Fine-tuning all layers...')
        for p in model.parameters():
            p.requires_grad = True
    else:
        print('[INFO]: Freezing hidden layers...')
        for p in model.parameters():
            p.requires_grad = False
    model.fc = nn.Linear(512, num_classes)
    return model

Utility Functions (utils.py)

Two simple utilities save the model checkpoint and plot loss/accuracy curves.

def save_model(epoch, model, optimizer, criterion):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': criterion
    }, 'outputs/model.pth')

def save_plots(train_acc, valid_acc, train_loss, valid_loss):
    # code that creates and saves accuracy.png and loss.png
    ...

Training Script (train.py)

The script parses --epochs, builds the model (pretrained, frozen layers), moves it to CUDA if available, creates an Adam optimizer (lr=0.001) and a cross‑entropy loss, then runs a standard training loop. After each epoch it prints class‑wise validation accuracy, saves the checkpoint, and updates the loss/accuracy plots.

# training loop (simplified)
for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(model, train_loader, optimizer, criterion)
    valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader, criterion, dataset.classes)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
    save_model(epoch, model, optimizer, criterion)
    save_plots(train_acc, valid_acc, train_loss, valid_loss)
print('TRAINING COMPLETE')

Sample console output shows the model reaching ~99% training accuracy and ~95% validation accuracy after 100 epochs, with per‑class accuracies printed after each epoch.

Training/validation accuracy curves
Training/validation accuracy curves
Training/validation loss curves
Training/validation loss curves

Inference Script (inference.py)

The inference script runs on CPU, loads the saved model.pth, applies the same ImageNet normalization, and predicts the class of a user‑provided image. It draws the ground‑truth and predicted labels on the original image and saves the result.

# load model and weights
model = build_model(pretrained=False, fine_tune=False, num_classes=4).to(device)
checkpoint = torch.load('outputs/model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# preprocess input image
image = cv2.imread(args['input'])
gt_class = args['input'].split('/')[-2]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = transform(image)
image = torch.unsqueeze(image, 0)

# forward pass
with torch.no_grad():
    outputs = model(image.to(device))
    _, pred_idx = torch.topk(outputs, 1)
    pred_class = labels[int(pred_idx)]

# annotate and save
cv2.putText(orig_image, f"GT: {gt_class}", (10,25), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2, cv2.LINE_AA)
cv2.putText(orig_image, f"Pred: {pred_class}", (10,55), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2, cv2.LINE_AA)
cv2.imwrite(f"outputs/{gt_class}.png", orig_image)

Running the script on the four test images yields correct predictions for cloudy , desert , and water , while the green_area image is mistakenly classified as cloudy , reflecting the class‑wise accuracy observed during validation.

Inference result for a cloudy image
Inference result for a cloudy image
Inference result for a desert image
Inference result for a desert image
Inference result for a green_area image (misclassified)
Inference result for a green_area image (misclassified)
Inference result for a water image
Inference result for a water image

Conclusion

The article presents a complete, reproducible workflow for satellite image classification using a pretrained ResNet34 model in PyTorch. It covers dataset handling, model customization, training with data augmentation, evaluation, visualisation of metrics, and deployment‑ready inference on new images.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

Image ClassificationPythonDeep LearningPyTorchSatellite ImageryResNet34
Code DAO
Written by

Code DAO

We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.