Also known as: activation checkpointing, rematerialization, remat, checkpoint_segments
TL;DR
Trade compute for memory by recomputing forward activations during the backward pass instead of storing them. Roughly 5x memory savings on activations at a cost of ~30% slower training.
Gradient checkpointing — also called activation checkpointing or rematerialization — is the standard trick for cutting activation memory by recomputing it on the fly during the backward pass. In a normal forward pass, every intermediate activation is stored so that backpropagation can use it to compute gradients. With checkpointing, you store only a sparse subset (the “checkpoints”), discard the rest, and redo the forward computation on each segment when its turn comes during backward. The result is roughly 5x lower activation memory at the cost of running the forward pass approximately twice — a typical training-throughput hit of 25-35%.
The trade is almost always worth it at scale. Frontier-scale training runs, long-context fine-tunes, and any job whose activation memory exceeds its weight memory turn checkpointing on by default.
Why activation memory dominates
A transformer’s resident memory during a forward pass is roughly:
Parameters — fixed by model size.
Optimizer state — fixed by model size (and 4-8x parameter bytes for Adam).
Activations — .
Activations scale with all three of batch, sequence, and depth, while parameters do not scale with batch or sequence. At long context the activation term dominates: for a 7B model at 128K context with batch 1, activations are bigger than the model itself. FSDP shards parameters and optimizer state but does not shard activations within a rank — that is exactly the gap checkpointing fills.
How the recomputation works
A normal training step does:
Forward. Run input through every layer, store all intermediate activations.
Backward. Walk the layers in reverse. Each layer’s gradient computation reads the stored activation it needs.
Checkpointing splits the model into segments (typically one transformer block per segment). The modified step:
Forward. Run input through every segment. At the boundary of each segment, store the segment input (the checkpoint) and discard every internal activation. This shrinks resident activation memory by the segment depth.
Backward. Walk segments in reverse. For each segment, re-run the forward pass on it using the stored input, regenerating all internal activations. Then run the backward pass on that segment as normal, freeing the regenerated activations afterward.
The forward pass is effectively done twice — once during the original forward, once again during backward, segment by segment. That is where the ~30% throughput hit comes from. The memory win comes from never holding all internal activations at once: at any moment during backward, only one segment’s activations are live.
A common pattern: checkpoint every transformer block. With blocks, peak activation memory drops from to if you checkpoint roughly evenly-spaced blocks (Chen et al., 2016), or to per block plus the checkpoint inputs if you checkpoint every block. In practice “every block” is the easiest to reason about and is the default in PyTorch (torch.utils.checkpoint) and in FSDP’s checkpoint wrapper.
What you actually save
A back-of-envelope for a 70B-parameter dense transformer at 8K context, bf16, batch 4, no checkpointing:
Parameters + grads + Adam state: ~1.1 TB (sharded across DP ranks via FSDP).
Activations: ~120 GB resident on each rank during forward.
With checkpoint-every-block:
Activations during forward: ~10 GB.
Activations during backward: one block’s worth at a time, ~3-5 GB peak.
That is the difference between “training fits” and “training OOMs on H100” at this size. The throughput cost is real — recomputing the forward pass once more is roughly 1.3x the original step time — but the alternative is reducing batch size, which costs more throughput than the recompute does.
How it composes
Gradient checkpointing stacks cleanly on top of every other major performance trick. Mixed precision halves the bf16 activations that get checkpointed in the first place. FlashAttention already recomputes attention stats internally, so the attention sub-step inside a checkpointed block is “free” recompute on top of the FFN recompute. FSDP shards model state but leaves activations replicated within a rank — so checkpointing is the natural pair to FSDP for activation savings.
When to leave it off
Gradient checkpointing’s 30% throughput tax is unwelcome when activation memory is not the bottleneck. Small models, short sequences, and inference do not benefit. For inference specifically, you don’t run a backward pass at all — there is nothing to checkpoint. The decision is purely a training-time, memory-vs-compute trade.
Activation checkpointing was the unloved performance trick of the 2017-2020 era and became a frontier-training default the moment context lengths started growing. It is also the easiest knob to flip when a job OOMs at 99% memory: in PyTorch it is one wrapper around the model. The cost is predictable, the savings are large, and the composition story with the rest of the stack is clean.
Go further
What's the optimal checkpointing pattern for a transformer?
Checkpoint each transformer block. That gives a roughly sqrt(N) memory / compute trade-off — Chen et al. (2016) showed this is approximately optimal under the assumption of a uniform compute graph, which transformer stacks closely match. Finer-grained checkpointing (every layer norm, every activation) saves more memory but the recompute overhead jumps superlinearly.
FlashAttention already does its own activation recomputation internally — it never materializes the N x N attention matrix and recomputes softmax stats during backward. So when you checkpoint a transformer block that uses FlashAttention, you get attention's recompute for free; only the FFN and norm activations actually get rematerialized. This is one reason the modern (FlashAttention + checkpointing) stack works so well at long context.
Not for small models that fit comfortably. The 30% throughput hit is real and unnecessary if activation memory is not the bottleneck. The decision rule: enable it when you are activation-memory bound (long context, large batch, large model) and disable it when you are weight-memory bound (FSDP fixes that one) or pure compute bound (where you would rather spend the cycles on more steps).