Understanding DeepMind’s PonderNet: A Thinkable Network for MNIST
This article explains DeepMind’s PonderNet framework, which lets any neural network allocate computation adaptively, demonstrates its implementation with PyTorch Lightning on the MNIST dataset, details the underlying theory, loss functions, training procedure, and evaluates its pondering behavior on rotated digit experiments.
Motivation
The author argues that most modern neural networks lack the ability to "think"—i.e., to adjust their computational budget based on task difficulty. PonderNet is introduced as a framework that allocates more resources to harder inputs, enabling networks to ponder longer when needed.
PonderNet Framework
Intuition
Unlike a traditional model that processes an input once, PonderNet can process the same input multiple times, deciding at each step whether to stop or continue. The decision is modeled as a Bernoulli coin flip with probability λ.
Formal Definition
At each step the network receives the original input x and the current hidden state h, produces a prediction y, a halting probability λ, and an updated hidden state. The stopping probability λ is sampled from a Bernoulli distribution; if the coin lands heads the network halts and outputs y, otherwise it proceeds to the next step.
Meaning of λ
λ represents the conditional probability of stopping at the current step given that no earlier step stopped. The unconditional stopping probability at step p is the product of the previous non‑stopping probabilities and λ.
Thinking Steps
During inference a maximum number of pondering steps can be set; the final step forces λ=1 to guarantee termination. In training the loss function incorporates a regularization term that penalizes excessive steps.
Training PonderNet
The total loss L is the sum of a reconstruction loss and a regularization loss, weighted by a hyper‑parameter β. The reconstruction loss is the expected cross‑entropy over all steps, weighted by the halting probabilities. The regularization loss is a KL‑divergence between the learned halting distribution and a geometric prior with parameter λₚ, encouraging the expected number of steps to be close to 1/λₚ.
class ReconstructionLoss(nn.Module):
def __init__(self, loss_func: nn.Module):
super().__init__()
self.loss_func = loss_func
def forward(self, p, y_pred, y):
total_loss = p.new_tensor(0.)
for n in range(p.shape[0]):
loss = (p[n] * self.loss_func(y_pred[n], y)).mean()
total_loss = total_loss + loss
return total_loss class RegularizationLoss(nn.Module):
def __init__(self, lambda_p: float, max_steps: int = 1000, device=None):
super().__init__()
p_g = torch.zeros((max_steps,), device=device)
not_halted = 1.
for k in range(max_steps):
p_g[k] = not_halted * lambda_p
not_halted = not_halted * (1 - lambda_p)
self.p_g = nn.Parameter(p_g, requires_grad=False)
self.kl_div = nn.KLDivLoss(reduction='batchmean')
def forward(self, p):
p = p.transpose(0, 1)
p_g = self.p_g[None, :p.shape[1]].expand_as(p)
return self.kl_div(p.log(), p_g)Implementation Details
Data Module
class MNIST_DataModule(pl.LightningDataModule):
def __init__(self, data_dir='./', train_transform=None, test_transform=None, batch_size=256):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.train_transform = train_transform
self.test_transform = test_transform
self.default_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage in [None, 'fit', 'validate']:
mnist_train = MNIST(self.data_dir, train=True, transform=self.train_transform or self.default_transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test' or stage is None:
if self.test_transform is None or isinstance(self.test_transform, transforms.Compose):
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.test_transform or self.default_transform)
else:
self.mnist_test = [MNIST(self.data_dir, train=False, transform=t) for t in self.test_transform]
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
if isinstance(self.mnist_test, MNIST):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
return [DataLoader(ds, batch_size=self.batch_size) for ds in self.mnist_test]CNN and MLP Modules
class CNN(nn.Module):
def __init__(self, n_input=28, n_output=50, kernel_size=5):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=kernel_size)
self.conv2 = nn.Conv2d(10, 20, kernel_size=kernel_size)
self.conv2_drop = nn.Dropout2d()
self.lin_size = floor((floor((n_input - (kernel_size - 1)) / 2) - (kernel_size - 1)) / 2)
self.fc1 = nn.Linear(self.lin_size ** 2 * 20, n_output)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = torch.flatten(x, 1)
return F.relu(self.fc1(x)) class MLP(nn.Module):
def __init__(self, n_input, n_hidden, n_output):
super(MLP, self).__init__()
self.i2h = nn.Linear(n_input, n_hidden)
self.h2o = nn.Linear(n_hidden, n_output)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = F.relu(self.i2h(x))
x = self.dropout(x)
return F.relu(self.h2o(x))PonderMNIST Module
class PonderMNIST(pl.LightningModule):
def __init__(self, n_hidden, n_hidden_lin, n_hidden_cnn, kernel_size, max_steps, lambda_p, beta, lr):
super().__init__()
self.n_classes = 10
self.max_steps = max_steps
self.lambda_p = lambda_p
self.beta = beta
self.n_hidden = n_hidden
self.lr = lr
self.cnn = CNN(n_input=28, kernel_size=kernel_size, n_output=n_hidden_cnn)
self.mlp = MLP(n_input=n_hidden_cnn + n_hidden, n_hidden=n_hidden_lin, n_output=n_hidden)
self.outpt_layer = nn.Linear(n_hidden, self.n_classes)
self.lambda_layer = nn.Linear(n_hidden, 1)
self.loss_rec = ReconstructionLoss(nn.CrossEntropyLoss())
self.loss_reg = RegularizationLoss(self.lambda_p, max_steps=self.max_steps, device=self.device)
self.accuracy = torchmetrics.Accuracy()
self.save_hyperparameters()
def forward(self, x):
batch_size = x.shape[0]
h = x.new_zeros((batch_size, self.n_hidden))
embedding = self.cnn(x)
h = self.mlp(torch.cat([embedding, h], 1))
p_list, y_list = [], []
un_halted_prob = h.new_ones((batch_size,))
halting_step = h.new_zeros((batch_size,), dtype=torch.long)
for n in range(1, self.max_steps + 1):
lambda_n = h.new_ones(batch_size) if n == self.max_steps else torch.sigmoid(self.lambda_layer(h)).squeeze()
y_n = self.outpt_layer(h)
p_n = un_halted_prob * lambda_n
p_list.append(p_n)
y_list.append(y_n)
halting_step = torch.maximum(
n * (halting_step == 0) * torch.bernoulli(lambda_n).to(torch.long),
halting_step)
un_halted_prob = un_halted_prob * (1 - lambda_n)
embedding = self.cnn(x)
h = self.mlp(torch.cat([embedding, h], 1))
if not self.training and (halting_step > 0).sum() == batch_size:
break
return torch.stack(y_list), torch.stack(p_list), halting_step
def _get_loss_and_metrics(self, batch):
data, target = batch
y, p, halted_step = self(data)
if torch.any(p == 0) and self.training:
valid = torch.all(p != 0, dim=0)
p = p[:, valid]
y = y[:, valid]
halted_step = halted_step[valid]
target = target[valid]
loss_rec = self.loss_rec(p, y, target)
loss_reg = self.loss_reg(p)
loss = Loss(loss_rec, loss_reg, self.beta)
halted_index = (halted_step - 1).unsqueeze(0).unsqueeze(2).repeat(1, 1, self.n_classes)
logits = y.gather(dim=0, index=halted_index).squeeze()
preds = torch.argmax(logits, dim=1)
acc = self.accuracy(preds, target)
steps = (halted_step * 1.0).mean()
return loss, preds, acc, steps
def training_step(self, batch, batch_idx):
loss, _, acc, steps = self._get_loss_and_metrics(batch)
self.log('train/steps', steps)
self.log('train/accuracy', acc)
self.log('train/total_loss', loss.get_total_loss())
self.log('train/reconstruction_loss', loss.get_rec_loss())
self.log('train/regularization_loss', loss.get_reg_loss())
return loss.get_total_loss()
def validation_step(self, batch, batch_idx):
loss, preds, acc, steps = self._get_loss_and_metrics(batch)
self.log('val/steps', steps)
self.log('val/accuracy', acc)
self.log('val/total_loss', loss.get_total_loss())
self.log('val/reconstruction_loss', loss.get_rec_loss())
self.log('val/regularization_loss', loss.get_reg_loss())
return preds
def test_step(self, batch, batch_idx, dataset_idx=0):
_, _, acc, steps = self._get_loss_and_metrics(batch)
self.log(f'test_{dataset_idx}/steps', steps)
self.log(f'test_{dataset_idx}/accuracy', acc)
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.lr)
return {"optimizer": optimizer,
"lr_scheduler": {"scheduler": ReduceLROnPlateau(optimizer, mode='max', verbose=True),
"monitor": 'val/accuracy',
"interval": 'epoch',
"frequency": 1}}
def configure_callbacks(self):
early = EarlyStopping(monitor='val/accuracy', mode='max', patience=4)
ckpt = ModelCheckpoint(monitor='val/accuracy', mode='max')
return [early, ckpt]Experiments
Interpolation Experiment
Using the default MNIST split, the trainer logs loss, accuracy, and average pondering steps. The reported average number of steps is close to 1/λₚ (≈5), confirming that the regularization term controls the expected computation budget.
Extrapolation Experiment
The model is trained on digits rotated by 22.5°, then evaluated on rotations of 22.5°, 45°, 67.5°, and 90°. Accuracy drops as rotation increases, while the average number of pondering steps rises, suggesting the network allocates more computation to harder inputs.
Conclusion
PonderNet provides a mathematically grounded way for neural networks to adaptively allocate computation. The implementation on MNIST demonstrates that the model can learn to ponder longer on more difficult inputs, and the regularization term effectively controls the expected number of steps.
References
[1] A. Banino, J. Balaguer, C. Blundell, PonderNet: Learning to Ponder , 2021, arXiv:2107.05407.
[2] A. Graves, Adaptive Computation Time for Recurrent Neural Networks , 2017, arXiv:1603.08983.
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Code DAO
We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
