Artificial Intelligence 13 min read

Image Mosaic Removal Using Autoencoder and UNet in PyTorch

This article explains how to use autoencoders and a UNet architecture implemented in PyTorch to remove mosaic blocks from images, detailing the underlying principles, dataset preparation, network components, training procedure, and sample results.

Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Image Mosaic Removal Using Autoencoder and UNet in PyTorch

Introduction

People often think that removing mosaic (pixelation) from an image is impossible, but deep learning makes it feasible. Mosaic destroys original pixel information, so restoration is actually an estimation rather than true recovery.

Principle

2.1 Autoencoder

Autoencoders are self‑supervised models that compress an image into a latent vector with an encoder and reconstruct it with a decoder. The reconstruction loss (MSE) guides training.

2.2 Autoencoder for Mosaic Removal

To remove mosaic, the network is trained with mosaic‑corrupted images as input and the original images as targets, so the decoder learns to generate plausible content for the masked regions.

2.3 Limitations of Plain Autoencoders

Standard autoencoders lose fine details, leading to blurry outputs. Adding a Feature Pyramid Network (FPN) yields a UNet‑style architecture that preserves more spatial information.

2.4 UNet Architecture

UNet combines encoder features with decoder up‑sampling via concatenation, improving detail retention. Key blocks include ConvBlock, ConvDown, and ConvUp, implemented as follows:

ConvBlock

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self, inputs):
        return self.model(inputs)

ConvDown

class ConvDown(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 2, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
    def forward(self, inputs):
        return self.model(inputs)

ConvUp

class ConvUp(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(channels, channels // 2, 2, 2),
            nn.BatchNorm2d(channels // 2),
            nn.ReLU()
        )
    def forward(self, inputs):
        return self.model(inputs)

Full Implementation

Dataset

class ReConstructionDataset(data.Dataset):
    def __init__(self, data_dir=r"G:/datasets/lbxx", image_size=64):
        self.image_size = image_size
        self.trans = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])
        self.image_paths = []
        for root, dirs, files in os.walk(data_dir):
            for file in files:
                self.image_paths.append(os.path.join(root, file))
    def __getitem__(self, item):
        image = Image.open(self.image_paths[item])
        return self.trans(self.create_blur(image)), self.trans(image)
    def __len__(self):
        return len(self.image_paths)
    @staticmethod
    def create_blur(image, return_mask=False, box_size=200):
        mask = Image.new('L', image.size, 255)
        draw = ImageDraw.Draw(mask)
        upper_left_corner = (random.randint(0, image.size[0] - box_size),
                             random.randint(0, image.size[1] - box_size))
        lower_right_corner = (upper_left_corner[0] + box_size,
                               upper_left_corner[1] + box_size)
        draw.rectangle([lower_right_corner, upper_left_corner], fill=0)
        masked_image = Image.composite(image, image.filter(ImageFilter.GaussianBlur(15)), mask)
        return (masked_image, mask) if return_mask else masked_image

Network Construction

class UNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.blk0 = ConvBlock(3, 64)
        self.down0 = ConvDown(64)
        self.blk1 = ConvBlock(64, 128)
        self.down1 = ConvDown(128)
        self.blk2 = ConvBlock(128, 256)
        self.down2 = ConvDown(256)
        self.blk3 = ConvBlock(256, 512)
        self.down3 = ConvDown(512)
        self.blk4 = ConvBlock(512, 1024)
    def forward(self, inputs):
        f0 = self.blk0(inputs)
        d0 = self.down0(f0)
        f1 = self.blk1(d0)
        d1 = self.down1(f1)
        f2 = self.blk2(d1)
        d2 = self.down2(f2)
        f3 = self.blk3(d2)
        d3 = self.down3(f3)
        f4 = self.blk4(d3)
        return f0, f1, f2, f3, f4
class UNetDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.up3 = ConvUp(1024)
        self.blk3 = ConvBlock(1024, 512)
        self.up2 = ConvUp(512)
        self.blk2 = ConvBlock(512, 256)
        self.up1 = ConvUp(256)
        self.blk1 = ConvBlock(256, 128)
        self.up0 = ConvUp(128)
        self.blk0 = ConvBlock(128, 64)
        self.last_conv = nn.Conv2d(64, 3, 3, 1, 1)
    def forward(self, inputs):
        f0, f1, f2, f3, f4 = inputs
        u3 = self.up3(f4)
        df2 = self.blk3(torch.concat((f3, u3), dim=1))
        u2 = self.up2(df2)
        df1 = self.blk2(torch.concat((f2, u2), dim=1))
        u1 = self.up1(df1)
        df0 = self.blk1(torch.concat((f1, u1), dim=1))
        u0 = self.up0(df0)
        f = self.blk0(torch.concat((f0, u0), dim=1))
        return torch.tanh(self.last_conv(f))
class ReConstructionNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = UNetEncoder()
        self.decoder = UNetDecoder()
    def forward(self, inputs):
        fs = self.encoder(inputs)
        return self.decoder(fs)

Training

device = "cuda" if torch.cuda.is_available() else "cpu"
def train(model, dataloader, optimizer, criterion, epochs):
    model = model.to(device)
    for epoch in range(epochs):
        for iter, (masked_images, images) in enumerate(dataloader):
            masked_images, images = masked_images.to(device), images.to(device)
            outputs = model(masked_images)
            loss = criterion(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (iter + 1) % 100 == 1:
                print("epoch: %s, iter: %s, loss: %s" % (epoch + 1, iter + 1, loss.item()))
        torch.save(model.state_dict(), '../outputs/reconstruction.pth')
if __name__ == '__main__':
    dataloader = data.DataLoader(ReConstructionDataset(r"G:\\datasets\\lbxx"), 64)
    unet = ReConstructionNetwork()
    optimizer = optim.Adam(unet.parameters(), lr=0.0002)
    criterion = nn.MSELoss()
    train(unet, dataloader, optimizer, criterion, 20)

Inference

dataloader = data.DataLoader(ReConstructionDataset(r"G:\\datasets\\lbxx"), 64, shuffle=True)
unet = ReConstructionNetwork().to(device)
unet.load_state_dict(torch.load('../outputs/reconstruction.pth'))
for masked_images, images in dataloader:
    masked_images, images = masked_images.to(device), images.to(device)
    with torch.no_grad():
        outputs = unet(masked_images)
        outputs = torch.concatenate([images, masked_images, outputs], dim=-1)
        outputs = make_grid(outputs)
        img = outputs.cpu().numpy().transpose(1, 2, 0)
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        Image.fromarray(img).show()

Results show the original image, the mosaic‑corrupted version, and the network’s reconstruction, demonstrating that the approach can also be applied to aged or ink‑stained photos.

For more robust performance, refer to the CodeFormer project.

deep learningimage processingPyTorchautoencoderUNetMosaic Removal
Rare Earth Juejin Tech Community
Written by

Rare Earth Juejin Tech Community

Juejin, a tech community that helps developers grow.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.