Artificial Intelligence 18 min read

EGBAD: Efficient GAN‑Based Anomaly Detection – Theory and Practical Implementation

This article introduces the EGBAD model, an efficient GAN‑based anomaly detection method that replaces AnoGAN's costly latent variable search with an encoder, provides detailed PyTorch code for data loading, model construction, training, and inference, and compares its testing speed with AnoGAN.

Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
Rare Earth Juejin Tech Community
EGBAD: Efficient GAN‑Based Anomaly Detection – Theory and Practical Implementation

EGBAD: Efficient GAN‑Based Anomaly Detection – Theory and Practical Implementation

Introduction

In the previous post I presented AnoGAN, a GAN used for anomaly detection, and pointed out its major drawback: the need to iteratively search for the latent variable z during testing, which is time‑consuming. This article introduces EGBAD (Efficient GAN‑Based Anomaly Detection), which incorporates an encoder to obtain z directly, dramatically reducing inference time.

EGBAD Principle

EGBAD stands for EFFICIENT GAN‑BASED ANOMALY DETECTION . Like AnoGAN, it consists of a generator, a discriminator, and an additional encoder. During training the encoder learns to map an input image to a latent vector. At test time the encoder provides the latent vector instantly, so the generator can produce a reconstruction without any iterative optimisation.

The training pipeline mirrors AnoGAN: a DCGAN is trained on normal data, and the encoder is trained jointly to predict the latent code. In the testing phase the encoder replaces the costly optimisation loop.

Code Implementation

Data Loading

# Import packages
import numpy as np
import pandas as pd

"""mnist data loading"""
## Load training set (60000, 785)
train = pd.read_csv("./data/mnist_train.csv", dtype=np.float32)
## Load test set (10000, 785)
test = pd.read_csv("./data/mnist_test.csv", dtype=np.float32)

# Select digits 7 and 8 for training (first 400 samples)
train = train.query("label in [7.0, 8.0]").head(400)
# Select digits 2, 7, 8 for testing (first 600 samples)
test = test.query("label in [2.0, 7.0, 8.0]").head(600)

# Remove label column and reshape to (N, 28, 28)
train = train.iloc[:, 1:].values.astype('float32')
test = test.iloc[:, 1:].values.astype('float32')
train = train.reshape(train.shape[0], 28, 28)
test = test.reshape(test.shape[0], 28, 28)

Model Construction

Generator

"""Define generator network"""
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1,
                activation=nn.ReLU(inplace=True), bn=True):
            seq = []
            seq += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding)]
            if bn:
                seq += [nn.BatchNorm2d(out_channel)]
            seq += [activation]
            return nn.Sequential(*seq)
        seq = []
        seq += [CBA(20, 64*8, stride=1, padding=0)]
        seq += [CBA(64*8, 64*4)]
        seq += [CBA(64*4, 64*2)]
        seq += [CBA(64*2, 64)]
        seq += [CBA(64, 1, activation=nn.Tanh(), bn=False)]
        self.generator_network = nn.Sequential(*seq)
    def forward(self, z):
        return self.generator_network(z)

Encoder

"""Define encoder network"""
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1,
                activation=nn.LeakyReLU(0.1, inplace=True)):
            seq = []
            seq += [nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)]
            seq += [nn.BatchNorm2d(out_channel)]
            seq += [activation]
            return nn.Sequential(*seq)
        seq = []
        seq += [CBA(1, 64)]
        seq += [CBA(64, 64*2)]
        seq += [CBA(64*2, 64*4)]
        seq += [CBA(64*4, 64*8)]
        seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
        self.feature_network = nn.Sequential(*seq)
        self.embedding_network = nn.Linear(512, 20)
    def forward(self, x):
        feature = self.feature_network(x).view(-1, 512)
        return self.embedding_network(feature)

Discriminator

"""Define discriminator network"""
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        def CBA(in_channel, out_channel, kernel_size=4, stride=2, padding=1,
                activation=nn.LeakyReLU(0.1, inplace=True)):
            seq = []
            seq += [nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)]
            seq += [nn.BatchNorm2d(out_channel)]
            seq += [activation]
            return nn.Sequential(*seq)
        seq = []
        seq += [CBA(1, 64)]
        seq += [CBA(64, 64*2)]
        seq += [CBA(64*2, 64*4)]
        seq += [CBA(64*4, 64*8)]
        seq += [nn.Conv2d(64*8, 512, kernel_size=4, stride=1)]
        self.feature_network = nn.Sequential(*seq)
        self.latent_network = nn.Sequential(
            nn.Linear(20, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self.critic_network = nn.Linear(1024, 1)
    def forward(self, x, z):
        feature = self.feature_network(x).view(x.size(0), -1)
        latent = self.latent_network(z)
        out = self.critic_network(torch.cat([feature, latent], dim=1))
        return out, feature

Training Procedure

The training loop follows three steps for each batch: (1) generate fake images from random latent vectors, (2) encode real images to obtain latent vectors, (3) feed both real and fake pairs to the discriminator and update the discriminator, generator, and encoder respectively. The loss functions are binary cross‑entropy for the discriminator and generator, and a custom loss for the encoder.

# Example training loop (simplified)
for epoch in range(args.epochs):
    for images in train_loader:
        images = images.to(device)
        # Labels
        label_real = torch.full((images.size(0),), 1.0).to(device)
        label_fake = torch.full((images.size(0),), 0.0).to(device)
        # Random latent vector
        z = torch.randn(images.size(0), 20, 1, 1).to(device)
        fake_images = G(z)
        # Encode real images
        z_real = E(images)
        # Discriminator forward
        d_out_real, _ = D(images, z_real)
        d_out_fake, _ = D(fake_images, z.view(images.size(0), 20))
        # Compute losses and back‑propagate for D, G, and E
        ...

Anomaly Scoring

The anomaly score combines a residual loss (pixel‑wise difference) and a discrimination loss (feature‑wise difference) with a weighting of 0.9 and 0.1 respectively.

def anomaly_score(input_image, fake_image, z_real, D):
    residual_loss = torch.sum(torch.abs(input_image - fake_image), (1,2,3))
    _, real_feature = D(input_image, z_real)
    _, fake_feature = D(fake_image, z_real)
    discrimination_loss = torch.sum(torch.abs(real_feature - fake_feature), 1)
    return 0.9 * residual_loss + 0.1 * discrimination_loss

Results and Discussion

Experimental results show that EGBAD reduces testing time by up to four orders of magnitude compared with AnoGAN, but it can suffer from mode collapse and unstable training, a common issue for GANs. Visual inspection of generated samples reveals that the generator sometimes fails to reconstruct the correct digit, indicating limited generation quality.

Future work will explore techniques to improve GAN training stability.

Reference Links

EFFICIENT GAN‑BASED ANOMALY DETECTION GAN anomaly detection with PyTorch
GANAnomaly DetectionPyTorchGeneratorDiscriminatorEGBADEncoder
Rare Earth Juejin Tech Community
Written by

Rare Earth Juejin Tech Community

Juejin, a tech community that helps developers grow.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.