Understanding DeepMind’s PonderNet: A Thinkable Network for MNIST

This article explains DeepMind’s PonderNet framework, which lets any neural network allocate computation adaptively, demonstrates its implementation with PyTorch Lightning on the MNIST dataset, details the underlying theory, loss functions, training procedure, and evaluates its pondering behavior on rotated digit experiments.

Code DAO
Code DAO
Code DAO
Understanding DeepMind’s PonderNet: A Thinkable Network for MNIST

Motivation

The author argues that most modern neural networks lack the ability to "think"—i.e., to adjust their computational budget based on task difficulty. PonderNet is introduced as a framework that allocates more resources to harder inputs, enabling networks to ponder longer when needed.

PonderNet Framework

Intuition

Unlike a traditional model that processes an input once, PonderNet can process the same input multiple times, deciding at each step whether to stop or continue. The decision is modeled as a Bernoulli coin flip with probability λ.

Formal Definition

At each step the network receives the original input x and the current hidden state h, produces a prediction y, a halting probability λ, and an updated hidden state. The stopping probability λ is sampled from a Bernoulli distribution; if the coin lands heads the network halts and outputs y, otherwise it proceeds to the next step.

Meaning of λ

λ represents the conditional probability of stopping at the current step given that no earlier step stopped. The unconditional stopping probability at step p is the product of the previous non‑stopping probabilities and λ.

Thinking Steps

During inference a maximum number of pondering steps can be set; the final step forces λ=1 to guarantee termination. In training the loss function incorporates a regularization term that penalizes excessive steps.

Training PonderNet

The total loss L is the sum of a reconstruction loss and a regularization loss, weighted by a hyper‑parameter β. The reconstruction loss is the expected cross‑entropy over all steps, weighted by the halting probabilities. The regularization loss is a KL‑divergence between the learned halting distribution and a geometric prior with parameter λₚ, encouraging the expected number of steps to be close to 1/λₚ.

class ReconstructionLoss(nn.Module):
    def __init__(self, loss_func: nn.Module):
        super().__init__()
        self.loss_func = loss_func
    def forward(self, p, y_pred, y):
        total_loss = p.new_tensor(0.)
        for n in range(p.shape[0]):
            loss = (p[n] * self.loss_func(y_pred[n], y)).mean()
            total_loss = total_loss + loss
        return total_loss
class RegularizationLoss(nn.Module):
    def __init__(self, lambda_p: float, max_steps: int = 1000, device=None):
        super().__init__()
        p_g = torch.zeros((max_steps,), device=device)
        not_halted = 1.
        for k in range(max_steps):
            p_g[k] = not_halted * lambda_p
            not_halted = not_halted * (1 - lambda_p)
        self.p_g = nn.Parameter(p_g, requires_grad=False)
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
    def forward(self, p):
        p = p.transpose(0, 1)
        p_g = self.p_g[None, :p.shape[1]].expand_as(p)
        return self.kl_div(p.log(), p_g)

Implementation Details

Data Module

class MNIST_DataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./', train_transform=None, test_transform=None, batch_size=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.default_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    def setup(self, stage=None):
        if stage in [None, 'fit', 'validate']:
            mnist_train = MNIST(self.data_dir, train=True, transform=self.train_transform or self.default_transform)
            self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        if stage == 'test' or stage is None:
            if self.test_transform is None or isinstance(self.test_transform, transforms.Compose):
                self.mnist_test = MNIST(self.data_dir, train=False, transform=self.test_transform or self.default_transform)
            else:
                self.mnist_test = [MNIST(self.data_dir, train=False, transform=t) for t in self.test_transform]
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)
    def test_dataloader(self):
        if isinstance(self.mnist_test, MNIST):
            return DataLoader(self.mnist_test, batch_size=self.batch_size)
        return [DataLoader(ds, batch_size=self.batch_size) for ds in self.mnist_test]

CNN and MLP Modules

class CNN(nn.Module):
    def __init__(self, n_input=28, n_output=50, kernel_size=5):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=kernel_size)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=kernel_size)
        self.conv2_drop = nn.Dropout2d()
        self.lin_size = floor((floor((n_input - (kernel_size - 1)) / 2) - (kernel_size - 1)) / 2)
        self.fc1 = nn.Linear(self.lin_size ** 2 * 20, n_output)
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = torch.flatten(x, 1)
        return F.relu(self.fc1(x))
class MLP(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(MLP, self).__init__()
        self.i2h = nn.Linear(n_input, n_hidden)
        self.h2o = nn.Linear(n_hidden, n_output)
        self.dropout = nn.Dropout(0.2)
    def forward(self, x):
        x = F.relu(self.i2h(x))
        x = self.dropout(x)
        return F.relu(self.h2o(x))

PonderMNIST Module

class PonderMNIST(pl.LightningModule):
    def __init__(self, n_hidden, n_hidden_lin, n_hidden_cnn, kernel_size, max_steps, lambda_p, beta, lr):
        super().__init__()
        self.n_classes = 10
        self.max_steps = max_steps
        self.lambda_p = lambda_p
        self.beta = beta
        self.n_hidden = n_hidden
        self.lr = lr
        self.cnn = CNN(n_input=28, kernel_size=kernel_size, n_output=n_hidden_cnn)
        self.mlp = MLP(n_input=n_hidden_cnn + n_hidden, n_hidden=n_hidden_lin, n_output=n_hidden)
        self.outpt_layer = nn.Linear(n_hidden, self.n_classes)
        self.lambda_layer = nn.Linear(n_hidden, 1)
        self.loss_rec = ReconstructionLoss(nn.CrossEntropyLoss())
        self.loss_reg = RegularizationLoss(self.lambda_p, max_steps=self.max_steps, device=self.device)
        self.accuracy = torchmetrics.Accuracy()
        self.save_hyperparameters()
    def forward(self, x):
        batch_size = x.shape[0]
        h = x.new_zeros((batch_size, self.n_hidden))
        embedding = self.cnn(x)
        h = self.mlp(torch.cat([embedding, h], 1))
        p_list, y_list = [], []
        un_halted_prob = h.new_ones((batch_size,))
        halting_step = h.new_zeros((batch_size,), dtype=torch.long)
        for n in range(1, self.max_steps + 1):
            lambda_n = h.new_ones(batch_size) if n == self.max_steps else torch.sigmoid(self.lambda_layer(h)).squeeze()
            y_n = self.outpt_layer(h)
            p_n = un_halted_prob * lambda_n
            p_list.append(p_n)
            y_list.append(y_n)
            halting_step = torch.maximum(
                n * (halting_step == 0) * torch.bernoulli(lambda_n).to(torch.long),
                halting_step)
            un_halted_prob = un_halted_prob * (1 - lambda_n)
            embedding = self.cnn(x)
            h = self.mlp(torch.cat([embedding, h], 1))
            if not self.training and (halting_step > 0).sum() == batch_size:
                break
        return torch.stack(y_list), torch.stack(p_list), halting_step
    def _get_loss_and_metrics(self, batch):
        data, target = batch
        y, p, halted_step = self(data)
        if torch.any(p == 0) and self.training:
            valid = torch.all(p != 0, dim=0)
            p = p[:, valid]
            y = y[:, valid]
            halted_step = halted_step[valid]
            target = target[valid]
        loss_rec = self.loss_rec(p, y, target)
        loss_reg = self.loss_reg(p)
        loss = Loss(loss_rec, loss_reg, self.beta)
        halted_index = (halted_step - 1).unsqueeze(0).unsqueeze(2).repeat(1, 1, self.n_classes)
        logits = y.gather(dim=0, index=halted_index).squeeze()
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, target)
        steps = (halted_step * 1.0).mean()
        return loss, preds, acc, steps
    def training_step(self, batch, batch_idx):
        loss, _, acc, steps = self._get_loss_and_metrics(batch)
        self.log('train/steps', steps)
        self.log('train/accuracy', acc)
        self.log('train/total_loss', loss.get_total_loss())
        self.log('train/reconstruction_loss', loss.get_rec_loss())
        self.log('train/regularization_loss', loss.get_reg_loss())
        return loss.get_total_loss()
    def validation_step(self, batch, batch_idx):
        loss, preds, acc, steps = self._get_loss_and_metrics(batch)
        self.log('val/steps', steps)
        self.log('val/accuracy', acc)
        self.log('val/total_loss', loss.get_total_loss())
        self.log('val/reconstruction_loss', loss.get_rec_loss())
        self.log('val/regularization_loss', loss.get_reg_loss())
        return preds
    def test_step(self, batch, batch_idx, dataset_idx=0):
        _, _, acc, steps = self._get_loss_and_metrics(batch)
        self.log(f'test_{dataset_idx}/steps', steps)
        self.log(f'test_{dataset_idx}/accuracy', acc)
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.lr)
        return {"optimizer": optimizer,
                "lr_scheduler": {"scheduler": ReduceLROnPlateau(optimizer, mode='max', verbose=True),
                                 "monitor": 'val/accuracy',
                                 "interval": 'epoch',
                                 "frequency": 1}}
    def configure_callbacks(self):
        early = EarlyStopping(monitor='val/accuracy', mode='max', patience=4)
        ckpt = ModelCheckpoint(monitor='val/accuracy', mode='max')
        return [early, ckpt]

Experiments

Interpolation Experiment

Using the default MNIST split, the trainer logs loss, accuracy, and average pondering steps. The reported average number of steps is close to 1/λₚ (≈5), confirming that the regularization term controls the expected computation budget.

Extrapolation Experiment

The model is trained on digits rotated by 22.5°, then evaluated on rotations of 22.5°, 45°, 67.5°, and 90°. Accuracy drops as rotation increases, while the average number of pondering steps rises, suggesting the network allocates more computation to harder inputs.

Conclusion

PonderNet provides a mathematically grounded way for neural networks to adaptively allocate computation. The implementation on MNIST demonstrates that the model can learn to ponder longer on more difficult inputs, and the regularization term effectively controls the expected number of steps.

References

[1] A. Banino, J. Balaguer, C. Blundell, PonderNet: Learning to Ponder , 2021, arXiv:2107.05407.

[2] A. Graves, Adaptive Computation Time for Recurrent Neural Networks , 2017, arXiv:1603.08983.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

Deep Learningneural networksMNISTPyTorch Lightningadaptive computationPonderNet
Code DAO
Written by

Code DAO

We deliver AI algorithm tutorials and the latest news, curated by a team of researchers from Peking University, Shanghai Jiao Tong University, Central South University, and leading AI companies such as Huawei, Kuaishou, and SenseTime. Join us in the AI alchemy—making life better!

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.