Decomposing PointGAN: Teaching a Machine to Generate a Single Point

This article walks through building and analyzing a minimal GAN—PointGAN—that learns to output the single value 1, covering the linear generator, a two‑layer discriminator, training loops, loss visualizations, instability diagnostics, and practical fixes such as loss easing, weighted examples, weight decay, and noisy generator parameters.

Code DAO
Code DAO
Code DAO
Decomposing PointGAN: Teaching a Machine to Generate a Single Point

Introduction

The goal is to construct a very simple GAN whose task is to generate a single numeric point (the value 1). The generator is a single‑neuron linear model, and the discriminator must learn a bell‑shaped probability function.

Discriminator shape
Discriminator shape

1. Baseline Training

1.1 Common Terms

The generator maps a random input x in [-1, 1] to an output y via a linear function: y = w * x + b Solving for the target output y = 1 yields w = 0 and b = 1, a vertical line that always outputs 1 regardless of x:

w = 0
b = 1

The discriminator must output the probability that its input is real. A simple two‑hidden‑node MLP is used:

l1 = torch.nn.Linear(1, 2)
l2 = torch.nn.Linear(2, 1)

x = l1(x)
x = torch.tanh(x)
x = l2(x)
x = torch.sigmoid(x)

Running a forward pass on a range of inputs visualizes both functions:

Generator and discriminator functions
Generator and discriminator functions

1.2 GAN Training

Training alternates between updating the generator and the discriminator. Random inputs are drawn from torch.rand([BATCH_SIZE, 1]) * 2 - 1. The generator output and real label (all ones) are fed to the discriminator, and binary cross‑entropy loss is computed.

generated = generator(random_input)
gen_prediction = discriminator(generated)
gen_target = torch.ones(BATCH_SIZE, 1)
gen_loss = x_entropy(gen_prediction, gen_target)

real_prediction = discriminator(torch.ones_like(random_input))
real_target = torch.ones(BATCH_SIZE, 1)

dis_output = torch.cat([gen_prediction, real_prediction])
dis_target = torch.cat([torch.zeros(BATCH_SIZE, 1), torch.ones(BATCH_SIZE, 1)])
dis_loss = x_entropy(dis_output, dis_target)

During early epochs the discriminator outputs an average of ~0.48 for fake samples and ~0.56 for real samples, prompting both networks to adjust.

Initial generator and discriminator
Initial generator and discriminator

1.3 Baseline Code

Key model definitions in PyTorch:

import torch

class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.l1 = torch.nn.Linear(1, 1)
    def forward(self, x):
        return self.l1(x)

class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.l1 = torch.nn.Linear(1, 2)
        self.l2 = torch.nn.Linear(2, 1)
    def forward(self, x):
        x = self.l1(x)
        x = torch.tanh(x)
        x = self.l2(x)
        return torch.sigmoid(x)

2. Baseline Issues

2.1 Generator Works, Discriminator Fails

Sometimes the generator learns to output values close to 1, but the discriminator does not become sufficiently steep, causing the network to stop learning because the inputs appear already realistic.

2.2 Loss Function Not Effective

Experiments show that a better‑shaped discriminator can actually produce a higher loss because its predictions for near‑real values are only slightly different from the ideal 0.5 baseline.

Loss values for different discriminators
Loss values for different discriminators

2.3 Experimental Adjustments

To soften the binary cross‑entropy targets, the labels are moved from {0, 1} to {0.1, 0.9}:

dis_target = torch.cat([
    torch.zeros(bs, 1) + easing,
    torch.ones(bs, 1) - easing
])

Training with eased targets yields smoother loss curves but does not fully resolve instability.

3. Remedies

3.1 Weight Decay

Adding L2 regularization to the discriminator prevents it from learning an overly sharp function that stalls generator updates.

dis_opt_obj = torch.optim.Adam(
    dis.parameters(),
    dis_lr,
    weight_decay=WEIGHT_DECAY
)

Experiments with decay values 0.1, 1e‑3, and 1e‑5 show progressively smoother discriminator curves.

Training statistics with weight decay
Training statistics with weight decay

3.2 Noisy Generator Parameters

During discriminator updates, a cloned generator with added Gaussian noise (scale 0.5) produces slightly invalid samples, forcing the discriminator to keep learning:

NOISE_SCALE = 0.5

gen_clone = Generator()
gen_clone.l1.weight.data = generator.l1.weight.data + torch.randn(gen.l1.weight.data.shape) * NOISE_SCALE
gen_clone.l1.bias.data   = generator.l1.bias.data   + torch.randn(gen.l1.bias.data.shape) * NOISE_SCALE

generated = gen_clone(random_input)
real_values = torch.ones_like(generated)

gen_prediction = discriminator(generated)
real_prediction = discriminator(real_values)

dis_target = torch.cat([torch.zeros(BATCH_SIZE, 1), torch.ones(BATCH_SIZE, 1)])
dis_output = torch.cat([gen_prediction, real_prediction])

dis_loss = x_entropy(dis_output, dis_target)

Because gradients from multiple noisy passes are accumulated before the optimizer step, training becomes more stable.

Generator and discriminator evolution with noise
Generator and discriminator evolution with noise

3.3 Increased Noise Scale

Raising NOISE_SCALE to 1.5 improves statistics but some runs still fail; combining higher noise with weight decay yields the most consistent results.

Training with higher noise
Training with higher noise

4. Full Updated Training Script

The final script integrates the noisy‑generator trick, weight decay, and logging via Weights & Biases. It alternates updates every epoch, aggregates gradients over an internal batch of noisy clones, and saves animated visualizations of the evolving functions.

import os, torch, numpy as np
import matplotlib.pyplot as plt, matplotlib.animation as animation
import wandb as wnb
from models import Generator, Discriminator

BATCH_SIZE = 5
INTERNAL_BATCH = 30
NOISE_SCALE = 1.5
model_dir = 'models_noisy_bswd'
anim_dir = 'output_noisy_bswd'

def train(dis, gen, dis_opt, gen_opt, name):
    x_entropy = torch.nn.BCELoss()
    dis_outputs, gen_outputs = [], []
    src = torch.tensor(np.linspace(-3, 5, 100)).float().reshape([-1, 1])
    for epoch in range(15001):
        dis_opt.zero_grad(); gen_opt.zero_grad()
        if epoch % 2 == 0:
            input = torch.rand([BATCH_SIZE, 1]) * 2 - 1
            output = gen(input)
            pred = dis(output)
            gen_target = torch.ones(BATCH_SIZE, 1)
            gen_loss = x_entropy(pred, gen_target)
            gen_loss.backward(); gen_opt.step()
        else:
            for _ in range(INTERNAL_BATCH):
                gen_clone = Generator()
                gen_clone.l1.weight.data = gen.l1.weight.data + torch.randn_like(gen.l1.weight.data) * NOISE_SCALE
                gen_clone.l1.bias.data   = gen.l1.bias.data   + torch.randn_like(gen.l1.bias.data) * NOISE_SCALE
                input = torch.rand([BATCH_SIZE, 1]) * 2 - 1
                out = gen_clone(input)
                pred = dis(out)
                dis_output = torch.cat([pred, dis(torch.ones_like(input))])
                dis_target = torch.cat([torch.zeros(BATCH_SIZE, 1), torch.ones(BATCH_SIZE, 1)])
                dis_loss = x_entropy(dis_output, dis_target)
                dis_loss.backward()
            dis_opt.step()
        if epoch % 500 == 1:
            wnb.log({
                'dis_loss': float(dis_loss),
                'gen_loss': float(gen_loss),
                'bias': float(gen.l1.bias.data),
                'weight': float(gen.l1.weight.data)
            })
        if epoch % 250 == 1:
            gen_outputs.append(gen(src).cpu().detach().reshape(-1).numpy())
            dis_outputs.append(dis(src).cpu().detach().reshape(-1).numpy())
    # animation creation omitted for brevity

Running the script with learning‑rate pairs (0.01, 0.001) and seeds {22, 16, 48, 19, 73} reproduces the visualizations shown throughout the article.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

GANPyTorchGeneratorDiscriminatorWeight DecayNoise InjectionTraining Visualization
Code DAO
Written by

Code DAO

We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.