EGBAD: Efficient GAN‑Based Anomaly Detection – Theory and Practical Implementation
This article introduces the EGBAD model, an efficient GAN‑based anomaly detection method that replaces AnoGAN's costly latent variable search with an encoder, provides detailed PyTorch code for data loading, model construction, training, and inference, and compares its testing speed with AnoGAN.
EGBAD: Efficient GAN‑Based Anomaly Detection – Theory and Practical Implementation
Introduction
In the previous post I presented AnoGAN, a GAN used for anomaly detection, and pointed out its major drawback: the need to iteratively search for the latent variable z during testing, which is time‑consuming. This article introduces EGBAD (Efficient GAN‑Based Anomaly Detection), which incorporates an encoder to obtain z directly, dramatically reducing inference time.
EGBAD Principle
EGBAD stands for EFFICIENT GAN‑BASED ANOMALY DETECTION . Like AnoGAN, it consists of a generator, a discriminator, and an additional encoder. During training the encoder learns to map an input image to a latent vector. At test time the encoder provides the latent vector instantly, so the generator can produce a reconstruction without any iterative optimisation.
The training pipeline mirrors AnoGAN: a DCGAN is trained on normal data, and the encoder is trained jointly to predict the latent code. In the testing phase the encoder replaces the costly optimisation loop.
Code Implementation
Data Loading
# Import packages
import numpy as np
import pandas as pd
"""mnist data loading"""
## Load training set (60000, 785)
train = pd.read_csv("./data/mnist_train.csv", dtype=np.float32)
## Load test set (10000, 785)
test = pd.read_csv("./data/mnist_test.csv", dtype=np.float32)
# Select digits 7 and 8 for training (first 400 samples)
train = train.query("label in [7.0, 8.0]").head(400)
# Select digits 2, 7, 8 for testing (first 600 samples)
test = test.query("label in [2.0, 7.0, 8.0]").head(600)
# Remove label column and reshape to (N, 28, 28)
train = train.iloc[:, 1:].values.astype('float32')
test = test.iloc[:, 1:].values.astype('float32')
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)Model Construction
Generator
"""Define generator network"""
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1,
activation=nn.ReLU(inplace=True), bn=True):
seq = []
seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding)]
if bn:
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(20, 64*8, stride=1, padding=0)]
seq += [CBA(64*8, 64*4)]
seq += [CBA(64*4, 64*2)]
seq += [CBA(64*2, 64)]
seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]
self.generator_network = nn.Sequential(*seq)
def forward(self, z):
return self.generator_network(z)Encoder
"""Define encoder network"""
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1,
activation=nn.LeakyReLU(0.1, inplace=True)):
seq = []
seq += [nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)]
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(1, 64)]
seq += [CBA(64, 64*2)]
seq += [CBA(64*2, 64*4)]
seq += [CBA(64*4, 64*8)]
seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
self.feature_network = nn.Sequential(*seq)
self.embedding_network = nn.Linear(512, 20)
def forward(self, x):
feature = self.feature_network(x).view(-1, 512)
return self.embedding_network(feature)Discriminator
"""Define discriminator network"""
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1,
activation=nn.LeakyReLU(0.1, inplace=True)):
seq = []
seq += [nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)]
seq += [nn.BatchNorm2d(out_channel)]
seq += [activation]
return nn.Sequential(*seq)
seq = []
seq += [CBA(1, 64)]
seq += [CBA(64, 64*2)]
seq += [CBA(64*2, 64*4)]
seq += [CBA(64*4, 64*8)]
seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
self.feature_network = nn.Sequential(*seq)
self.latent_network = nn.Sequential(
nn.Linear(20, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.1, inplace=True)
)
self.critic_network = nn.Linear(1024, 1)
def forward(self, x, z):
feature = self.feature_network(x).view(x.size(0), -1)
latent = self.latent_network(z)
out = self.critic_network(torch.cat([feature, latent], dim=1))
return out, featureTraining Procedure
The training loop follows three steps for each batch: (1) generate fake images from random latent vectors, (2) encode real images to obtain latent vectors, (3) feed both real and fake pairs to the discriminator and update the discriminator, generator, and encoder respectively. The loss functions are binary cross‑entropy for the discriminator and generator, and a custom loss for the encoder.
# Example training loop (simplified)
for epoch in range(args.epochs):
for images in train_loader:
images = images.to(device)
# Labels
label_real = torch.full((images.size(0),), 1.0).to(device)
label_fake = torch.full((images.size(0),), 0.0).to(device)
# Random latent vector
z = torch.randn(images.size(0), 20, 1, 1).to(device)
fake_images = G(z)
# Encode real images
z_real = E(images)
# Discriminator forward
d_out_real, _ = D(images, z_real)
d_out_fake, _ = D(fake_images, z.view(images.size(0), 20))
# Compute losses and back‑propagate for D, G, and E
...Anomaly Scoring
The anomaly score combines a residual loss (pixel‑wise difference) and a discrimination loss (feature‑wise difference) with a weighting of 0.9 and 0.1 respectively.
def anomaly_score(input_image, fake_image, z_real, D):
residual_loss = torch.sum(torch.abs(input_image - fake_image), (1,2,3))
_, real_feature = D(input_image, z_real)
_, fake_feature = D(fake_image, z_real)
discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), 1)
return 0.9 * residual_loss + 0.1 * discrimination_lossResults and Discussion
Experimental results show that EGBAD reduces testing time by up to four orders of magnitude compared with AnoGAN, but it can suffer from mode collapse and unstable training, a common issue for GANs. Visual inspection of generated samples reveals that the generator sometimes fails to reconstruct the correct digit, indicating limited generation quality.
Future work will explore techniques to improve GAN training stability.
Reference Links
EFFICIENT GAN‑BASED ANOMALY DETECTION GAN anomaly detection with PyTorch
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.