Decomposing PointGAN: Teaching a Machine to Generate a Single Point
This article walks through building and analyzing a minimal GAN—PointGAN—that learns to output the single value 1, covering the linear generator, a two‑layer discriminator, training loops, loss visualizations, instability diagnostics, and practical fixes such as loss easing, weighted examples, weight decay, and noisy generator parameters.
Introduction
The goal is to construct a very simple GAN whose task is to generate a single numeric point (the value 1). The generator is a single‑neuron linear model, and the discriminator must learn a bell‑shaped probability function.
1. Baseline Training
1.1 Common Terms
The generator maps a random input x in [-1, 1] to an output y via a linear function: y = w * x + b Solving for the target output y = 1 yields w = 0 and b = 1, a vertical line that always outputs 1 regardless of x:
w = 0
b = 1The discriminator must output the probability that its input is real. A simple two‑hidden‑node MLP is used:
l1 = torch.nn.Linear(1, 2)
l2 = torch.nn.Linear(2, 1)
x = l1(x)
x = torch.tanh(x)
x = l2(x)
x = torch.sigmoid(x)Running a forward pass on a range of inputs visualizes both functions:
1.2 GAN Training
Training alternates between updating the generator and the discriminator. Random inputs are drawn from torch.rand([BATCH_SIZE, 1]) * 2 - 1. The generator output and real label (all ones) are fed to the discriminator, and binary cross‑entropy loss is computed.
generated = generator(random_input)
gen_prediction = discriminator(generated)
gen_target = torch.ones(BATCH_SIZE, 1)
gen_loss = x_entropy(gen_prediction, gen_target)
real_prediction = discriminator(torch.ones_like(random_input))
real_target = torch.ones(BATCH_SIZE, 1)
dis_output = torch.cat([gen_prediction, real_prediction])
dis_target = torch.cat([torch.zeros(BATCH_SIZE, 1), torch.ones(BATCH_SIZE, 1)])
dis_loss = x_entropy(dis_output, dis_target)During early epochs the discriminator outputs an average of ~0.48 for fake samples and ~0.56 for real samples, prompting both networks to adjust.
1.3 Baseline Code
Key model definitions in PyTorch:
import torch
class Generator(torch.nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.l1 = torch.nn.Linear(1, 1)
def forward(self, x):
return self.l1(x)
class Discriminator(torch.nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.l1 = torch.nn.Linear(1, 2)
self.l2 = torch.nn.Linear(2, 1)
def forward(self, x):
x = self.l1(x)
x = torch.tanh(x)
x = self.l2(x)
return torch.sigmoid(x)2. Baseline Issues
2.1 Generator Works, Discriminator Fails
Sometimes the generator learns to output values close to 1, but the discriminator does not become sufficiently steep, causing the network to stop learning because the inputs appear already realistic.
2.2 Loss Function Not Effective
Experiments show that a better‑shaped discriminator can actually produce a higher loss because its predictions for near‑real values are only slightly different from the ideal 0.5 baseline.
2.3 Experimental Adjustments
To soften the binary cross‑entropy targets, the labels are moved from {0, 1} to {0.1, 0.9}:
dis_target = torch.cat([
torch.zeros(bs, 1) + easing,
torch.ones(bs, 1) - easing
])Training with eased targets yields smoother loss curves but does not fully resolve instability.
3. Remedies
3.1 Weight Decay
Adding L2 regularization to the discriminator prevents it from learning an overly sharp function that stalls generator updates.
dis_opt_obj = torch.optim.Adam(
dis.parameters(),
dis_lr,
weight_decay=WEIGHT_DECAY
)Experiments with decay values 0.1, 1e‑3, and 1e‑5 show progressively smoother discriminator curves.
3.2 Noisy Generator Parameters
During discriminator updates, a cloned generator with added Gaussian noise (scale 0.5) produces slightly invalid samples, forcing the discriminator to keep learning:
NOISE_SCALE = 0.5
gen_clone = Generator()
gen_clone.l1.weight.data = generator.l1.weight.data + torch.randn(gen.l1.weight.data.shape) * NOISE_SCALE
gen_clone.l1.bias.data = generator.l1.bias.data + torch.randn(gen.l1.bias.data.shape) * NOISE_SCALE
generated = gen_clone(random_input)
real_values = torch.ones_like(generated)
gen_prediction = discriminator(generated)
real_prediction = discriminator(real_values)
dis_target = torch.cat([torch.zeros(BATCH_SIZE, 1), torch.ones(BATCH_SIZE, 1)])
dis_output = torch.cat([gen_prediction, real_prediction])
dis_loss = x_entropy(dis_output, dis_target)Because gradients from multiple noisy passes are accumulated before the optimizer step, training becomes more stable.
3.3 Increased Noise Scale
Raising NOISE_SCALE to 1.5 improves statistics but some runs still fail; combining higher noise with weight decay yields the most consistent results.
4. Full Updated Training Script
The final script integrates the noisy‑generator trick, weight decay, and logging via Weights & Biases. It alternates updates every epoch, aggregates gradients over an internal batch of noisy clones, and saves animated visualizations of the evolving functions.
import os, torch, numpy as np
import matplotlib.pyplot as plt, matplotlib.animation as animation
import wandb as wnb
from models import Generator, Discriminator
BATCH_SIZE = 5
INTERNAL_BATCH = 30
NOISE_SCALE = 1.5
model_dir = 'models_noisy_bswd'
anim_dir = 'output_noisy_bswd'
def train(dis, gen, dis_opt, gen_opt, name):
x_entropy = torch.nn.BCELoss()
dis_outputs, gen_outputs = [], []
src = torch.tensor(np.linspace(-3, 5, 100)).float().reshape([-1, 1])
for epoch in range(15001):
dis_opt.zero_grad(); gen_opt.zero_grad()
if epoch % 2 == 0:
input = torch.rand([BATCH_SIZE, 1]) * 2 - 1
output = gen(input)
pred = dis(output)
gen_target = torch.ones(BATCH_SIZE, 1)
gen_loss = x_entropy(pred, gen_target)
gen_loss.backward(); gen_opt.step()
else:
for _ in range(INTERNAL_BATCH):
gen_clone = Generator()
gen_clone.l1.weight.data = gen.l1.weight.data + torch.randn_like(gen.l1.weight.data) * NOISE_SCALE
gen_clone.l1.bias.data = gen.l1.bias.data + torch.randn_like(gen.l1.bias.data) * NOISE_SCALE
input = torch.rand([BATCH_SIZE, 1]) * 2 - 1
out = gen_clone(input)
pred = dis(out)
dis_output = torch.cat([pred, dis(torch.ones_like(input))])
dis_target = torch.cat([torch.zeros(BATCH_SIZE, 1), torch.ones(BATCH_SIZE, 1)])
dis_loss = x_entropy(dis_output, dis_target)
dis_loss.backward()
dis_opt.step()
if epoch % 500 == 1:
wnb.log({
'dis_loss': float(dis_loss),
'gen_loss': float(gen_loss),
'bias': float(gen.l1.bias.data),
'weight': float(gen.l1.weight.data)
})
if epoch % 250 == 1:
gen_outputs.append(gen(src).cpu().detach().reshape(-1).numpy())
dis_outputs.append(dis(src).cpu().detach().reshape(-1).numpy())
# animation creation omitted for brevityRunning the script with learning‑rate pairs (0.01, 0.001) and seeds {22, 16, 48, 19, 73} reproduces the visualizations shown throughout the article.
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Code DAO
We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
