Why Does Recompute Crash Distributed Training? A Deep Dive into Checkpoint Issues and Fixes

When training large‑batch deep learning models, developers often use recompute to trade computation for memory, but in dynamic graph frameworks this can trigger synchronization errors in distributed data parallel training; the article explains the underlying DDP mechanics, illustrates the error, and offers a practical no_sync workaround with code examples.

Baidu Geek Talk
Baidu Geek Talk
Baidu Geek Talk
Why Does Recompute Crash Distributed Training? A Deep Dive into Checkpoint Issues and Fixes

Background

Large batch sizes are essential for contrastive learning and other tasks, yet GPU memory limits batch size. Practitioners therefore use the recompute (or checkpoint) technique to reduce memory consumption by discarding intermediate activations during the forward pass and recomputing them during back‑propagation.

Basic Training Process

A typical training step consists of three stages:

Forward computation : operators process inputs to produce outputs, ultimately yielding the loss.

Backward computation : gradients are calculated via the chain rule.

Optimization : gradients update model parameters.

Recompute Mechanics

Recompute splits the network into segments. During the forward pass, only a small set of variables (checkpoints) are kept; all other intermediates are discarded. In the backward pass, the forward segment is recomputed to restore the needed activations before gradient calculation. This incurs an extra forward pass but saves memory.

Checkpoint Definition

Checkpoints are the layers or tensors retained across recompute. In both static and dynamic graph frameworks (e.g., PyTorch, TensorFlow, Paddle), users must explicitly specify checkpoints. In Paddle’s fleet API this is done via

dist_strategy.recompute_configs = {"checkpoints": model.checkpoints}

.

Problem in Dynamic Graph Distributed Training

When recompute is enabled in a dynamic‑graph distributed setting (e.g., Paddle), an error similar to the one shown below appears, and the same issue also occurs in PyTorch:

Error happened, when parameter[385][xxxxx@GRAD] has been ready before. Please set fine_unused_parameters=True to traverse backward graph in each step to prepare reduce in advance.

Understanding the Error

Distributed Data Parallel (DDP) works by broadcasting model parameters from rank 0, creating a reducer that groups gradients into buckets, and synchronizing them during back‑propagation. When recompute creates multiple checkpoints, the same parameter can be bound to several gradient hooks. Once one hook marks the gradient as ready, the others attempt to reduce it again, causing the reported error.

Solution: Use no_sync() Context

Wrap the forward‑backward pass of each DDP replica with model.no_sync() to suppress automatic gradient synchronization. After the local gradients are computed, manually invoke a fused all‑reduce before the optimizer step. This pattern works in both Paddle (≥ 2.2) and PyTorch.

# required: distributed
import numpy
import paddle
import paddle.distributed as dist
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

class cus_tanh(PyLayer):
    @staticmethod
    def forward(ctx, x):
        y = paddle.tanh(x)
        ctx.save_for_backward(y)
        return y
    @staticmethod
    def backward(ctx, dy):
        y, = ctx.saved_tensor()
        grad = dy * (1 - paddle.square(y))
        return grad

class SimpleNet(paddle.nn.Layer):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear = paddle.nn.Linear(2, 2)
    def forward(self, inputs):
        inputs = cus_tanh.apply(inputs)
        return self.linear(inputs)

if __name__ == '__main__':
    dist.init_parallel_env()
    model = SimpleNet()
    model = paddle.DataParallel(model)
    opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
    for step in range(10):
        x_data = numpy.random.randn(2, 2).astype(numpy.float32)
        x = paddle.to_tensor(x_data)
        x.stop_gradient = False
        # step 1: skip gradient sync
        with model.no_sync():
            y_pred = model(x)
            loss = y_pred.mean()
            loss.backward()
        # step 2: manual all‑reduce then optimizer step
        fused_allreduce_gradients(list(model.parameters()), None)
        opt.step()
        opt.clear_grad()

This approach avoids the premature ready‑state of gradients but introduces extra synchronization latency because reduction is performed after the entire local backward pass.

References

Key references include PyTorch’s torch.utils.checkpoint, Paddle’s recompute documentation, TensorFlow’s recompute API, and several GitHub issues discussing the interaction between recompute and DDP.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

PyTorchDistributed TrainingCheckpointpaddlegradient synchronizationrecompute
Baidu Geek Talk
Written by

Baidu Geek Talk

Follow us to discover more Baidu tech insights.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.