Also known as: backprop, reverse-mode autodiff, automatic differentiation
TL;DR
Backpropagation is the chain-rule application that computes the gradient of the loss with respect to every parameter in a neural network. A forward pass produces predictions.
Backpropagation is the algorithm that answers one question: given a neural network’s loss , what is for every parameter ? It’s the chain rule from calculus, applied recursively along a computation graph, and it’s the engine that makes gradient descent tractable. Without backprop, modern deep learning doesn’t exist.
The full training step is two passes:
Forward pass. Run the input through the network, layer by layer, producing predictions and a scalar loss. Cache every intermediate activation.
Backward pass. Starting from at the output, walk the graph in reverse. At each operation, multiply the upstream gradient by the operation’s local Jacobian. End up with for every parameter.
Apply the optimizer , step the parameters, repeat.
The chain rule, mechanically
If and , then . Stack that pattern for hundreds of layers and you get backprop. The key trick is reverse mode: by computing gradients from output to input rather than input to output, you compute one gradient (the scalar loss) with respect to all parameters in a single pass. The cost of backprop is roughly 2-3x the cost of the forward pass — independent of parameter count.
That’s why training a 70B model is feasible at all. Forward-mode autodiff would scale with the number of inputs (parameters); reverse-mode scales with the number of outputs (one).
What an autodiff framework does
PyTorch, JAX, and TensorFlow all implement reverse-mode autodiff. During the forward pass, the framework records every tensor operation into a directed acyclic graph — the “computation graph” or “tape.” When you call .backward() on the loss, the framework walks the graph in reverse, applying each operation’s hand-coded local gradient rule.
You write the forward pass; the framework derives the backward by inverting the graph. This is the single largest productivity win of modern ML frameworks over the pre-2015 era of hand-derived gradients.
To compute for parameter in layer , the chain rule needs every intermediate activation between layer and the output. So the backward pass requires the full set of activations cached from the forward pass.
For a transformer with layers, batch size , sequence length , and hidden size , the activation memory is — and at long context this dominates the parameter memory itself. Hence the proliferation of memory-saving tricks:
Activation checkpointing — recompute activations during the backward pass instead of storing them. Trades compute for memory.
Mixed precision — store activations in float16 / bfloat16 instead of float32.
ZeRO and FSDP — shard activations and gradients across GPUs.
These are all responses to the fact that backprop is memory-hungry by design.
Where it can go wrong
Three failure modes haunt backprop in deep networks:
Vanishing gradients — gradients shrink toward zero as they flow backward through many layers, especially with saturating activations (sigmoid, tanh). Residual connections and ReLU-family activations were the standard fix.
Exploding gradients — gradients grow without bound, often in RNNs and very deep transformers. The standard defense is gradient clipping .
NaN propagation — a single inf or nan anywhere in the backward pass corrupts the entire gradient. Modern frameworks include detection hooks for this.
Go further
Why is backprop O(forward pass), not O(parameters)?
Reverse-mode automatic differentiation amortizes the chain rule by computing one gradient for one scalar output (the loss) with respect to many inputs (the parameters). You walk the graph backward exactly once, reusing every intermediate. Total cost is roughly 2-3x the forward pass — independent of how many parameters there are.
It records every operation during the forward pass into a directed acyclic graph (the 'tape'), then walks that graph in reverse during .backward(), applying each operation's known local derivative. PyTorch, JAX, and TensorFlow are all variations on this pattern, differing mainly in when the graph is built (eager vs traced).
Why does memory blow up during training but not inference?
Backprop needs every intermediate activation from the forward pass to compute gradients. For a transformer, that's O(layers × batch × seq_len × hidden) of stored tensors per training step — often the dominant memory cost. Inference can free activations layer-by-layer because there's no backward pass.