Mixed-Precision Training

Also known as: AMP, automatic mixed precision, bf16 training, fp16 training

TL;DR

Train with bf16 or fp16 activations and weights instead of fp32, while keeping master weights and optimizer accumulations in fp32 for numerical stability.

MIXED-PRECISION TRAINING · FP32 MASTER · FP16 COMPUTEHalf the bytes, all of the convergence.ONE TRAINING STEP · CLOSED LOOPcast ↓ fp16FP32 → FP16logits, lossFP16× S appliedFP16grads, scaledFP16cast ↑ fp32FP16 → FP32update θFP32FP32master weightsθ_fp32FP16forward passmatmuls in fp16FP16loss × Sscaled to clear fp16 floorFP16backward pass∇θ in fp16FP16 → FP32unscale ∇θ ÷ Scast up to fp32FP32optimizer stepAdam moments in fp32WHY THE × S EXISTS · FP16 UNDERFLOW10⁻⁸10⁻⁷10⁻⁶10⁻⁵10⁻⁴10⁻³10⁻²10⁻¹10GRADIENT MAGNITUDEFP16 FLOOR ≈ 6e-5rounds to 0BEFORE ∇θAFTER ∇θ × SS = 216 shifts every grad above the floorfp32 owns precision·fp16 owns throughput·×S bridges the gap

Mixed-precision training is the standard practice of running a neural-network training loop with most tensors in 16-bit floating point — bf16 today, fp16 historically — while keeping a small set of numerically sensitive tensors in fp32. It buys two things: roughly 2x lower memory pressure (half-precision tensors are half the bytes) and 2-4x higher compute throughput on hardware whose tensor cores are tuned for low-precision matmul (every NVIDIA GPU since Volta, every TPU, every modern AMD MI-class accelerator). Mixed precision is on by default in essentially every serious training run since 2019; full-fp32 training is now an exotic choice.

What stays in fp32

Three things, every time:

  1. Master weights. A canonical fp32 copy of the parameters, owned by the . Forward / backward read a bf16 cast; the optimizer step updates fp32 and then re-casts to bf16. This avoids the cumulative weight drift that would otherwise come from doing every parameter update in low precision.
  2. Optimizer accumulators. Adam’s and , RMSProp’s running second moment, etc. These are sums of squared gradients summed over thousands of steps; the precision floor of bf16 is too coarse for that.
  3. Loss and softmax internals. A bf16 cross-entropy loss is fine if you upcast the logits before the log-softmax; an in-place bf16 softmax silently underflows on the small probabilities. Most production frameworks handle this via autocast policies.

Everything else — activations, the gradient flowing back through them, parameter casts used during forward / backward, intermediate matmul results — lives in bf16. That is where the memory savings and the tensor-core speedup come from.

Why bf16 won

fp16 has a small dynamic range (~6e-5 to ~6.5e4). Realistic gradient distributions have a long lower tail — gradients of 1e-7 are common during training and round to zero in fp16, killing those weight updates. The fix in the fp16 era was loss scaling: multiply the loss by a large constant (typically 2^16) before backward, run the backward pass in fp16 with much-larger gradient values, then divide the gradients back down before the optimizer step.

This works but is fiddly. Dynamic loss scaling — adapt the scale factor based on whether gradients overflow or underflow — became standard, but every framework has a story about a training run silently NaN-ing because the loss-scaler did not adapt fast enough during a learning-rate ramp.

bf16 has the same 8-bit exponent as fp32, so its dynamic range is identical to fp32 (~1e-38 to ~3.4e38). Gradients essentially never underflow or overflow in bf16. No loss scaling. This single property is why bf16 displaced fp16 the moment hardware support landed (Ampere, 2020). The cost is bf16’s coarser mantissa — 7 bits vs fp16’s 10 — but training is much more sensitive to range than to precision, and modern recipes show no quality regression from bf16 vs fp32.

The practical default in 2026 is: bf16 if you are on Ampere or newer, fp16 + dynamic loss scaling only if you are on Volta / Turing where bf16 is not natively supported. Apple Silicon, Hopper, Blackwell, and TPUs all do bf16 first-class.

What it composes with

The mixed-precision row in the training-stack matrix
  • — bf16 parameters and gradients shrink the FSDP all-gather and reduce-scatter volume by 2x; the resident sharded state is also halved.
  • — bf16 activations roughly halve the recompute cost of stored activations and the memory of those that are kept.
  • — gradient allreduce volume halves; on bandwidth-bound cross-node DP this is a free 2x win.
  • — fp32 moments inside the optimizer, bf16 weights presented to the model. Standard since the original mixed-precision paper.

Where it goes wrong

A second category of issue: numerical equivalence with an fp32 baseline. Mixed precision is not bit-equivalent to fp32 training — runs differ in the trailing few digits of every metric. For most production models this is invisible noise. For ablation studies trying to detect 0.1% quality differences, run a fixed-seed fp32 control to confirm the gap is real.

Mixed precision is one of those rare changes that gives a 2x speedup, halves the memory bill, and (on bf16) introduces almost no new failure modes. It is the cheapest performance win in the training stack and the one every serious project is already using.

Go further

Why bf16 over fp16?

Both are 16-bit, but they spend their bits differently. fp16 has a 5-bit exponent (range up to ~65k, underflows below ~6e-5) and 10 bits of mantissa. bf16 has the same 8-bit exponent as fp32 (range up to ~3.4e38) but only 7 bits of mantissa. For training, range matters more than precision: gradients regularly span many orders of magnitude, and bf16's fp32-equivalent range eliminates the underflow / overflow problems that fp16 needs loss scaling to manage.

What does the master-weights pattern actually do?

Forward, backward, and most matmuls run in bf16 / fp16. The optimizer keeps a separate copy of the parameters in fp32 plus its state (Adam's m, v) in fp32. The optimizer step computes the update in fp32 and then casts the parameter back to bf16 for the next forward. Without this, accumulated update noise from low-precision steps drifts the weights enough to hurt convergence.

Does mixed precision actually help inference?

Yes for compute and memory, but the inference story has moved past bf16. Modern serving uses int8 or fp8 for weights and activations on supported hardware, with int4 weight-only quantization for memory-bound regimes. Mixed precision is the training-time default; quantization is its inference-time successor.

ZeroEntropy
The best AI teams build with ZeroEntropy models
Follow us on
GitHubTwitterSlackLinkedInDiscord