Image Mosaic Removal Using Autoencoder and UNet in PyTorch
This article explains the principle behind using deep‑learning autoencoders and UNet architectures to reconstruct mosaicked images, provides a complete PyTorch implementation with dataset preparation, network definition, training, and inference, and demonstrates the restored results.
1. Introduction
People often think mosaic (pixelation) cannot be reversed, but deep learning makes it possible; this article explains the principle and demonstrates a PyTorch implementation that learns to reconstruct images from mosaicked inputs.
2. Principle
2.1 Autoencoder
Autoencoders are self‑supervised models that compress an image into a latent vector and reconstruct it, using a simple encoder‑decoder architecture.
2.2 Using Autoencoder for Mosaic Removal
By training the network with mosaicked images as input and original images as target, the model learns to generate the missing content.
2.3 Limitations of Plain Autoencoders
Standard autoencoders lose fine details; adding a Feature Pyramid Network (FPN) leads to a UNet architecture that preserves more spatial information.
2.4 UNet Architecture
UNet adds skip connections between encoder and decoder layers, improving reconstruction quality.
2.4.1 Convolution Block
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, inputs):
return self.model(inputs)2.4.2 Downsampling Block
class ConvDown(nn.Module):
def __init__(self, channels):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(channels, channels, 3, 2, 1),
nn.BatchNorm2d(channels),
nn.ReLU()
)
def forward(self, inputs):
return self.model(inputs)2.4.3 Upsampling Block
class ConvUp(nn.Module):
def __init__(self, channels):
super().__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(channels, channels // 2, 2, 2),
nn.BatchNorm2d(channels // 2),
nn.ReLU()
)
def forward(self, inputs):
return self.model(inputs)3. Full Implementation
3.1 Dataset
class ReConstructionDataset(data.Dataset):
def __init__(self, data_dir=r"G:/datasets/lbxx", image_size=64):
self.image_size = image_size
self.trans = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
])
self.image_paths = []
for root, dirs, files in os.walk(data_dir):
for file in files:
self.image_paths.append(os.path.join(root, file))
def __getitem__(self, item):
image = Image.open(self.image_paths[item])
return self.trans(self.create_blur(image)), self.trans(image)
def __len__(self):
return len(self.image_paths)
@staticmethod
def create_blur(image, return_mask=False, box_size=200):
mask = Image.new('L', image.size, 255)
draw = ImageDraw.Draw(mask)
upper_left_corner = (random.randint(0, image.size[0] - box_size),
random.randint(0, image.size[1] - box_size))
lower_right_corner = (upper_left_corner[0] + box_size,
upper_left_corner[1] + box_size)
draw.rectangle([lower_right_corner, upper_left_corner], fill=0)
masked_image = Image.composite(image, image.filter(ImageFilter.GaussianBlur(15)), mask)
return (masked_image, mask) if return_mask else masked_image3.2 Network Construction
class UNetEncoder(nn.Module):
def __init__(self):
super().__init__()
self.blk0 = ConvBlock(3, 64)
self.down0 = ConvDown(64)
self.blk1 = ConvBlock(64, 128)
self.down1 = ConvDown(128)
self.blk2 = ConvBlock(128, 256)
self.down2 = ConvDown(256)
self.blk3 = ConvBlock(256, 512)
self.down3 = ConvDown(512)
self.blk4 = ConvBlock(512, 1024)
def forward(self, inputs):
f0 = self.blk0(inputs)
d0 = self.down0(f0)
f1 = self.blk1(d0)
d1 = self.down1(f1)
f2 = self.blk2(d1)
d2 = self.down2(f2)
f3 = self.blk3(d2)
d3 = self.down3(f3)
f4 = self.blk4(d3)
return f0, f1, f2, f3, f4 class UNetDecoder(nn.Module):
def __init__(self):
super().__init__()
self.up3 = ConvUp(1024)
self.blk3 = ConvBlock(1024, 512)
self.up2 = ConvUp(512)
self.blk2 = ConvBlock(512, 256)
self.up1 = ConvUp(256)
self.blk1 = ConvBlock(256, 128)
self.up0 = ConvUp(128)
self.blk0 = ConvBlock(128, 64)
self.last_conv = nn.Conv2d(64, 3, 3, 1, 1)
def forward(self, inputs):
f0, f1, f2, f3, f4 = inputs
u3 = self.up3(f4)
df2 = self.blk3(torch.concat((f3, u3), dim=1))
u2 = self.up2(df2)
df1 = self.blk2(torch.concat((f2, u2), dim=1))
u1 = self.up1(df1)
df0 = self.blk1(torch.concat((f1, u1), dim=1))
u0 = self.up0(df0)
f = self.blk0(torch.concat((f0, u0), dim=1))
return torch.tanh(self.last_conv(f)) class ReConstructionNetwork(nn.Module):
def __init__(self):
super().__init__()
self.encoder = UNetEncoder()
self.decoder = UNetDecoder()
def forward(self, inputs):
fs = self.encoder(inputs)
return self.decoder(fs)3.3 Training
device = "cuda" if torch.cuda.is_available() else "cpu"
def train(model, dataloader, optimizer, criterion, epochs):
model = model.to(device)
for epoch in range(epochs):
for iter, (masked_images, images) in enumerate(dataloader):
masked_images, images = masked_images.to(device), images.to(device)
outputs = model(masked_images)
loss = criterion(outputs, images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (iter + 1) % 100 == 1:
print("epoch: %s, iter: %s, loss: %s" % (epoch + 1, iter + 1, loss.item()))
with torch.no_grad():
outputs = make_grid(outputs)
img = outputs.cpu().numpy().transpose(1, 2, 0)
plt.imshow(img)
plt.show()
torch.save(model.state_dict(), '../outputs/reconstruction.pth')
if __name__ == '__main__':
dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64)
unet = ReConstructionNetwork()
optimizer = optim.Adam(unet.parameters(), lr=0.0002)
criterion = nn.MSELoss()
train(unet, dataloader, optimizer, criterion, 20)3.4 Inference
dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64, shuffle=True)
unet = ReConstructionNetwork().to(device)
unet.load_state_dict(torch.load('../outputs/reconstruction.pth'))
for masked_images, images in dataloader:
masked_images, images = masked_images.to(device), images.to(device)
with torch.no_grad():
outputs = unet(masked_images)
outputs = torch.concatenate([images, masked_images, outputs], dim=-1)
outputs = make_grid(outputs)
img = outputs.cpu().numpy().transpose(1, 2, 0)
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
Image.fromarray(img).show()The results show that the model can effectively restore mosaicked regions, and the same pipeline can be adapted for other image restoration tasks such as aged or ink‑stained photos.
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.