Image Mosaic Removal Using Autoencoder and UNet in PyTorch
This article explains how to use autoencoders and a UNet architecture implemented in PyTorch to remove mosaic blocks from images, detailing the underlying principles, dataset preparation, network components, training procedure, and sample results.
Introduction
People often think that removing mosaic (pixelation) from an image is impossible, but deep learning makes it feasible. Mosaic destroys original pixel information, so restoration is actually an estimation rather than true recovery.
Principle
2.1 Autoencoder
Autoencoders are self‑supervised models that compress an image into a latent vector with an encoder and reconstruct it with a decoder. The reconstruction loss (MSE) guides training.
2.2 Autoencoder for Mosaic Removal
To remove mosaic, the network is trained with mosaic‑corrupted images as input and the original images as targets, so the decoder learns to generate plausible content for the masked regions.
2.3 Limitations of Plain Autoencoders
Standard autoencoders lose fine details, leading to blurry outputs. Adding a Feature Pyramid Network (FPN) yields a UNet‑style architecture that preserves more spatial information.
2.4 UNet Architecture
UNet combines encoder features with decoder up‑sampling via concatenation, improving detail retention. Key blocks include ConvBlock, ConvDown, and ConvUp, implemented as follows:
ConvBlock
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, inputs):
return self.model(inputs)ConvDown
class ConvDown(nn.Module):
def __init__(self, channels):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(channels, channels, 3, 2, 1),
nn.BatchNorm2d(channels),
nn.ReLU()
)
def forward(self, inputs):
return self.model(inputs)ConvUp
class ConvUp(nn.Module):
def __init__(self, channels):
super().__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(channels, channels // 2, 2, 2),
nn.BatchNorm2d(channels // 2),
nn.ReLU()
)
def forward(self, inputs):
return self.model(inputs)Full Implementation
Dataset
class ReConstructionDataset(data.Dataset):
def __init__(self, data_dir=r"G:/datasets/lbxx", image_size=64):
self.image_size = image_size
self.trans = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
])
self.image_paths = []
for root, dirs, files in os.walk(data_dir):
for file in files:
self.image_paths.append(os.path.join(root, file))
def __getitem__(self, item):
image = Image.open(self.image_paths[item])
return self.trans(self.create_blur(image)), self.trans(image)
def __len__(self):
return len(self.image_paths)
@staticmethod
def create_blur(image, return_mask=False, box_size=200):
mask = Image.new('L', image.size, 255)
draw = ImageDraw.Draw(mask)
upper_left_corner = (random.randint(0, image.size[0] - box_size),
random.randint(0, image.size[1] - box_size))
lower_right_corner = (upper_left_corner[0] + box_size,
upper_left_corner[1] + box_size)
draw.rectangle([lower_right_corner, upper_left_corner], fill=0)
masked_image = Image.composite(image, image.filter(ImageFilter.GaussianBlur(15)), mask)
return (masked_image, mask) if return_mask else masked_imageNetwork Construction
class UNetEncoder(nn.Module):
def __init__(self):
super().__init__()
self.blk0 = ConvBlock(3, 64)
self.down0 = ConvDown(64)
self.blk1 = ConvBlock(64, 128)
self.down1 = ConvDown(128)
self.blk2 = ConvBlock(128, 256)
self.down2 = ConvDown(256)
self.blk3 = ConvBlock(256, 512)
self.down3 = ConvDown(512)
self.blk4 = ConvBlock(512, 1024)
def forward(self, inputs):
f0 = self.blk0(inputs)
d0 = self.down0(f0)
f1 = self.blk1(d0)
d1 = self.down1(f1)
f2 = self.blk2(d1)
d2 = self.down2(f2)
f3 = self.blk3(d2)
d3 = self.down3(f3)
f4 = self.blk4(d3)
return f0, f1, f2, f3, f4 class UNetDecoder(nn.Module):
def __init__(self):
super().__init__()
self.up3 = ConvUp(1024)
self.blk3 = ConvBlock(1024, 512)
self.up2 = ConvUp(512)
self.blk2 = ConvBlock(512, 256)
self.up1 = ConvUp(256)
self.blk1 = ConvBlock(256, 128)
self.up0 = ConvUp(128)
self.blk0 = ConvBlock(128, 64)
self.last_conv = nn.Conv2d(64, 3, 3, 1, 1)
def forward(self, inputs):
f0, f1, f2, f3, f4 = inputs
u3 = self.up3(f4)
df2 = self.blk3(torch.concat((f3, u3), dim=1))
u2 = self.up2(df2)
df1 = self.blk2(torch.concat((f2, u2), dim=1))
u1 = self.up1(df1)
df0 = self.blk1(torch.concat((f1, u1), dim=1))
u0 = self.up0(df0)
f = self.blk0(torch.concat((f0, u0), dim=1))
return torch.tanh(self.last_conv(f)) class ReConstructionNetwork(nn.Module):
def __init__(self):
super().__init__()
self.encoder = UNetEncoder()
self.decoder = UNetDecoder()
def forward(self, inputs):
fs = self.encoder(inputs)
return self.decoder(fs)Training
device = "cuda" if torch.cuda.is_available() else "cpu"
def train(model, dataloader, optimizer, criterion, epochs):
model = model.to(device)
for epoch in range(epochs):
for iter, (masked_images, images) in enumerate(dataloader):
masked_images, images = masked_images.to(device), images.to(device)
outputs = model(masked_images)
loss = criterion(outputs, images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (iter + 1) % 100 == 1:
print("epoch: %s, iter: %s, loss: %s" % (epoch + 1, iter + 1, loss.item()))
torch.save(model.state_dict(), '../outputs/reconstruction.pth')
if __name__ == '__main__':
dataloader = data.DataLoader(ReConstructionDataset(r"G:\\datasets\\lbxx"), 64)
unet = ReConstructionNetwork()
optimizer = optim.Adam(unet.parameters(), lr=0.0002)
criterion = nn.MSELoss()
train(unet, dataloader, optimizer, criterion, 20)Inference
dataloader = data.DataLoader(ReConstructionDataset(r"G:\\datasets\\lbxx"), 64, shuffle=True)
unet = ReConstructionNetwork().to(device)
unet.load_state_dict(torch.load('../outputs/reconstruction.pth'))
for masked_images, images in dataloader:
masked_images, images = masked_images.to(device), images.to(device)
with torch.no_grad():
outputs = unet(masked_images)
outputs = torch.concatenate([images, masked_images, outputs], dim=-1)
outputs = make_grid(outputs)
img = outputs.cpu().numpy().transpose(1, 2, 0)
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
Image.fromarray(img).show()Results show the original image, the mosaic‑corrupted version, and the network’s reconstruction, demonstrating that the approach can also be applied to aged or ink‑stained photos.
For more robust performance, refer to the CodeFormer project.
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.