Also known as: Fully Sharded Data Parallel, ZeRO-3, FSDP2
TL;DR
FSDP shards parameters, gradients, and optimizer state across data-parallel ranks instead of replicating them. Each rank holds only 1/N of the weights at rest and gathers full layers on the fly during forward and backward.
FSDP — Fully Sharded Data Parallel — is the natural extension of data parallelism for models that no longer fit on one GPU. Instead of replicating the full model on every rank, FSDP shards the parameters, gradients, and optimizer state across all data-parallel ranks. Each rank holds only of the model’s weights at rest, all-gathers the full parameters of a layer right before it needs them, runs forward or backward on that layer, and then drops the gathered weights again. The result is ZeRO-3 style memory savings — close to x reduction in resident model state — with the programming model of plain data parallelism.
FSDP is PyTorch’s native answer to DeepSpeed ZeRO and is the default for training any model that is “bigger than one GPU but smaller than one node.” For larger jobs, it composes with tensor and pipeline parallelism rather than replacing them.
The sharded lifecycle
Conceptually FSDP wraps each transformer block (or any user-chosen unit) and runs three collectives around it:
All-gather parameters before the block’s forward — every rank now has the full layer weights.
Forward pass with full weights, then drop the gathered shards from the non-owning ranks.
Same dance on the backward pass: all-gather, compute gradients, then reduce-scatter the gradients so each rank ends up with only the gradient slice it owns.
Optimizer step is local: each rank updates the parameters it owns using its local gradient slice and its local optimizer state.
The reduce-scatter is the FSDP analog of DDP’s gradient allreduce, just split across the sharding axis. The two extra all-gathers are the price of not replicating parameters.
Take a 70B-parameter model in mixed precision with Adam. Resident state per replica:
Parameters in bf16: 140 GB
Gradients in bf16: 140 GB
Optimizer state (Adam: m, v in fp32): 560 GB
Master weights in fp32 for stability: 280 GB
Total: ~1.12 TB per replica. No single GPU has that. With plain DDP across 16 H100s (80 GB each), every rank still needs 1.12 TB locally — DDP does not help.
With FSDP across the same 16 ranks, each rank holds 1/16 of every tensor: ~70 GB resident. Add transient gathered-parameter buffers (one layer’s worth in bf16, ~3-5 GB for typical transformer blocks) and the activation memory budget, and a 70B model trains comfortably on a single 16-GPU node. This is the entire point of FSDP.
Sharding strategies
PyTorch FSDP exposes a small set of sharding choices:
FULL_SHARD (ZeRO-3 equivalent). Parameters, gradients, optimizer state all sharded. Maximum memory saving, maximum comms.
SHARD_GRAD_OP (ZeRO-2 equivalent). Parameters replicated, gradients and optimizer state sharded. Less comms, less memory savings — useful when parameters fit but optimizer state does not.
HYBRID_SHARD. Full-shard within a node group, replicated across node groups. The standard pattern for cross-node frontier training: keep the heavy all-gathers on fast intra-node NVLink, do a normal allreduce across the slower inter-node fabric.
NO_SHARD. Equivalent to DDP. Mostly used for debugging.
The sharding unit also matters. FSDP wraps modules — typically each transformer block is its own FSDP unit. Wrapping too coarsely (the whole model as one unit) means you all-gather every parameter at once and lose the memory savings during forward / backward. Wrapping too finely (every linear layer separately) burns time on tiny collectives.
What composes well
The full performance-engineering stack on top of FSDP
Mixed-precision training — bf16 / fp16 compute with fp32 master weights, integrated cleanly via the FSDP MixedPrecision policy.
Gradient checkpointing — the standard activation-memory reducer; orthogonal to FSDP and almost always used together at scale.
CPU offloading — push optimizer state or even parameters to CPU RAM when GPU memory is the binding constraint and step time is not.
Tensor / pipeline parallelism — for jobs where activations dominate, FSDP across the data axis composes with TP / PP across the model axis.
Where FSDP gets you in trouble
The other operational quirk is checkpointing. Naive state_dict() collects the full unsharded model on rank 0 — fine for a 7B model, fatal for a 70B one. Production FSDP code uses sharded checkpointing (each rank writes only its slice) and reconstructs the full state at load time, with torch.distributed.checkpoint as the canonical path.
FSDP is the boring-but-correct answer for the meaty middle of the model-size distribution. The day a model stops fitting on one GPU is the day you reach for it; the day activations stop fitting is the day you start composing it with everything else.
Go further
How does FSDP relate to ZeRO-1 / 2 / 3?
ZeRO-1 shards optimizer state only. ZeRO-2 shards optimizer state and gradients. ZeRO-3 shards parameters, gradients, and optimizer state — which is what PyTorch FSDP implements. The naming differs, the technique is the same; ZeRO came first (DeepSpeed, Microsoft) and FSDP is PyTorch's native re-implementation.
Roughly 1.3-1.5x more communication per step. DDP does one gradient allreduce; FSDP adds an all-gather for parameters before forward, an all-gather before backward, and a reduce-scatter for gradients. With overlap and bucketing, end-to-end throughput drop is usually 5-15% on intra-node NVLink. Cross-node, the gap widens and hybrid-shard becomes important.
When should you skip FSDP and reach for tensor parallelism?
When activations dominate memory rather than parameters — long-context training, MoE models with huge embedding tables, very wide FFN layers. FSDP shards model state but leaves activations replicated within a rank. Tensor parallelism shards activations along the feature axis. Production frontier-scale training combines both: TP within a node, FSDP across nodes.