Why Does Recompute Crash Distributed Training? A Deep Dive into Checkpoint Issues and Fixes
When training large‑batch deep learning models, developers often use recompute to trade computation for memory, but in dynamic graph frameworks this can trigger synchronization errors in distributed data parallel training; the article explains the underlying DDP mechanics, illustrates the error, and offers a practical no_sync workaround with code examples.
Background
Large batch sizes are essential for contrastive learning and other tasks, yet GPU memory limits batch size. Practitioners therefore use the recompute (or checkpoint) technique to reduce memory consumption by discarding intermediate activations during the forward pass and recomputing them during back‑propagation.
Basic Training Process
A typical training step consists of three stages:
Forward computation : operators process inputs to produce outputs, ultimately yielding the loss.
Backward computation : gradients are calculated via the chain rule.
Optimization : gradients update model parameters.
Recompute Mechanics
Recompute splits the network into segments. During the forward pass, only a small set of variables (checkpoints) are kept; all other intermediates are discarded. In the backward pass, the forward segment is recomputed to restore the needed activations before gradient calculation. This incurs an extra forward pass but saves memory.
Checkpoint Definition
Checkpoints are the layers or tensors retained across recompute. In both static and dynamic graph frameworks (e.g., PyTorch, TensorFlow, Paddle), users must explicitly specify checkpoints. In Paddle’s fleet API this is done via
dist_strategy.recompute_configs = {"checkpoints": model.checkpoints}.
Problem in Dynamic Graph Distributed Training
When recompute is enabled in a dynamic‑graph distributed setting (e.g., Paddle), an error similar to the one shown below appears, and the same issue also occurs in PyTorch:
Error happened, when parameter[385][xxxxx@GRAD] has been ready before. Please set fine_unused_parameters=True to traverse backward graph in each step to prepare reduce in advance.
Understanding the Error
Distributed Data Parallel (DDP) works by broadcasting model parameters from rank 0, creating a reducer that groups gradients into buckets, and synchronizing them during back‑propagation. When recompute creates multiple checkpoints, the same parameter can be bound to several gradient hooks. Once one hook marks the gradient as ready, the others attempt to reduce it again, causing the reported error.
Solution: Use no_sync() Context
Wrap the forward‑backward pass of each DDP replica with model.no_sync() to suppress automatic gradient synchronization. After the local gradients are computed, manually invoke a fused all‑reduce before the optimizer step. This pattern works in both Paddle (≥ 2.2) and PyTorch.
# required: distributed
import numpy
import paddle
import paddle.distributed as dist
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.linear = paddle.nn.Linear(2, 2)
def forward(self, inputs):
inputs = cus_tanh.apply(inputs)
return self.linear(inputs)
if __name__ == '__main__':
dist.init_parallel_env()
model = SimpleNet()
model = paddle.DataParallel(model)
opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
for step in range(10):
x_data = numpy.random.randn(2, 2).astype(numpy.float32)
x = paddle.to_tensor(x_data)
x.stop_gradient = False
# step 1: skip gradient sync
with model.no_sync():
y_pred = model(x)
loss = y_pred.mean()
loss.backward()
# step 2: manual all‑reduce then optimizer step
fused_allreduce_gradients(list(model.parameters()), None)
opt.step()
opt.clear_grad()This approach avoids the premature ready‑state of gradients but introduces extra synchronization latency because reduction is performed after the entire local backward pass.
References
Key references include PyTorch’s torch.utils.checkpoint, Paddle’s recompute documentation, TensorFlow’s recompute API, and several GitHub issues discussing the interaction between recompute and DDP.
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
