Artificial Intelligence 13 min read

Image Mosaic Removal Using Autoencoder and UNet in PyTorch

This article explains the principle behind using deep‑learning autoencoders and UNet architectures to reconstruct mosaicked images, provides a complete PyTorch implementation with dataset preparation, network definition, training, and inference, and demonstrates the restored results.

Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Image Mosaic Removal Using Autoencoder and UNet in PyTorch

1. Introduction

People often think mosaic (pixelation) cannot be reversed, but deep learning makes it possible; this article explains the principle and demonstrates a PyTorch implementation that learns to reconstruct images from mosaicked inputs.

2. Principle

2.1 Autoencoder

Autoencoders are self‑supervised models that compress an image into a latent vector and reconstruct it, using a simple encoder‑decoder architecture.

2.2 Using Autoencoder for Mosaic Removal

By training the network with mosaicked images as input and original images as target, the model learns to generate the missing content.

2.3 Limitations of Plain Autoencoders

Standard autoencoders lose fine details; adding a Feature Pyramid Network (FPN) leads to a UNet architecture that preserves more spatial information.

2.4 UNet Architecture

UNet adds skip connections between encoder and decoder layers, improving reconstruction quality.

2.4.1 Convolution Block

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.model(inputs)

2.4.2 Downsampling Block

class ConvDown(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 2, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.model(inputs)

2.4.3 Upsampling Block

class ConvUp(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(channels, channels // 2, 2, 2),
            nn.BatchNorm2d(channels // 2),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.model(inputs)

3. Full Implementation

3.1 Dataset

class ReConstructionDataset(data.Dataset):
    def __init__(self, data_dir=r"G:/datasets/lbxx", image_size=64):
        self.image_size = image_size
        self.trans = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])
        self.image_paths = []
        for root, dirs, files in os.walk(data_dir):
            for file in files:
                self.image_paths.append(os.path.join(root, file))

    def __getitem__(self, item):
        image = Image.open(self.image_paths[item])
        return self.trans(self.create_blur(image)), self.trans(image)

    def __len__(self):
        return len(self.image_paths)

    @staticmethod
    def create_blur(image, return_mask=False, box_size=200):
        mask = Image.new('L', image.size, 255)
        draw = ImageDraw.Draw(mask)
        upper_left_corner = (random.randint(0, image.size[0] - box_size),
                             random.randint(0, image.size[1] - box_size))
        lower_right_corner = (upper_left_corner[0] + box_size,
                              upper_left_corner[1] + box_size)
        draw.rectangle([lower_right_corner, upper_left_corner], fill=0)
        masked_image = Image.composite(image, image.filter(ImageFilter.GaussianBlur(15)), mask)
        return (masked_image, mask) if return_mask else masked_image

3.2 Network Construction

class UNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.blk0 = ConvBlock(3, 64)
        self.down0 = ConvDown(64)
        self.blk1 = ConvBlock(64, 128)
        self.down1 = ConvDown(128)
        self.blk2 = ConvBlock(128, 256)
        self.down2 = ConvDown(256)
        self.blk3 = ConvBlock(256, 512)
        self.down3 = ConvDown(512)
        self.blk4 = ConvBlock(512, 1024)

    def forward(self, inputs):
        f0 = self.blk0(inputs)
        d0 = self.down0(f0)
        f1 = self.blk1(d0)
        d1 = self.down1(f1)
        f2 = self.blk2(d1)
        d2 = self.down2(f2)
        f3 = self.blk3(d2)
        d3 = self.down3(f3)
        f4 = self.blk4(d3)
        return f0, f1, f2, f3, f4
class UNetDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.up3 = ConvUp(1024)
        self.blk3 = ConvBlock(1024, 512)
        self.up2 = ConvUp(512)
        self.blk2 = ConvBlock(512, 256)
        self.up1 = ConvUp(256)
        self.blk1 = ConvBlock(256, 128)
        self.up0 = ConvUp(128)
        self.blk0 = ConvBlock(128, 64)
        self.last_conv = nn.Conv2d(64, 3, 3, 1, 1)

    def forward(self, inputs):
        f0, f1, f2, f3, f4 = inputs
        u3 = self.up3(f4)
        df2 = self.blk3(torch.concat((f3, u3), dim=1))
        u2 = self.up2(df2)
        df1 = self.blk2(torch.concat((f2, u2), dim=1))
        u1 = self.up1(df1)
        df0 = self.blk1(torch.concat((f1, u1), dim=1))
        u0 = self.up0(df0)
        f = self.blk0(torch.concat((f0, u0), dim=1))
        return torch.tanh(self.last_conv(f))
class ReConstructionNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = UNetEncoder()
        self.decoder = UNetDecoder()

    def forward(self, inputs):
        fs = self.encoder(inputs)
        return self.decoder(fs)

3.3 Training

device = "cuda" if torch.cuda.is_available() else "cpu"

def train(model, dataloader, optimizer, criterion, epochs):
    model = model.to(device)
    for epoch in range(epochs):
        for iter, (masked_images, images) in enumerate(dataloader):
            masked_images, images = masked_images.to(device), images.to(device)
            outputs = model(masked_images)
            loss = criterion(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (iter + 1) % 100 == 1:
                print("epoch: %s, iter: %s, loss: %s" % (epoch + 1, iter + 1, loss.item()))
                with torch.no_grad():
                    outputs = make_grid(outputs)
                    img = outputs.cpu().numpy().transpose(1, 2, 0)
                    plt.imshow(img)
                    plt.show()
        torch.save(model.state_dict(), '../outputs/reconstruction.pth')

if __name__ == '__main__':
    dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64)
    unet = ReConstructionNetwork()
    optimizer = optim.Adam(unet.parameters(), lr=0.0002)
    criterion = nn.MSELoss()
    train(unet, dataloader, optimizer, criterion, 20)

3.4 Inference

dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64, shuffle=True)
unet = ReConstructionNetwork().to(device)
unet.load_state_dict(torch.load('../outputs/reconstruction.pth'))
for masked_images, images in dataloader:
    masked_images, images = masked_images.to(device), images.to(device)
    with torch.no_grad():
        outputs = unet(masked_images)
        outputs = torch.concatenate([images, masked_images, outputs], dim=-1)
        outputs = make_grid(outputs)
        img = outputs.cpu().numpy().transpose(1, 2, 0)
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        Image.fromarray(img).show()

The results show that the model can effectively restore mosaicked regions, and the same pipeline can be adapted for other image restoration tasks such as aged or ink‑stained photos.

PyTorchautoencoderUNetMosaic Removal
Rare Earth Juejin Tech Community
Written by

Rare Earth Juejin Tech Community

Juejin, a tech community that helps developers grow.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.