Inference Graph Compilation

Also known as: graph compilation, torch.compile, TensorRT-LLM, JIT compilation, AOT compilation

TL;DR

Capture a model's computation as a static graph, optimize it (operator fusion, constant folding, attention specialization, kernel selection), and emit a compiled artifact that runs without Python overhead. torch.compile, TensorRT-LLM, ONNX Runtime.

Inference graph compilation is the practice of converting a model from eager-mode Python — where every operation is dispatched, type-checked, and launched as the interpreter walks the call stack — into a static computation graph that can be aggressively optimized and executed without Python in the loop. The compiler sees the whole graph at once: it can fuse adjacent operators, fold constants, replace generic implementations with specialized kernels, choose tile sizes for the actual input shapes, and emit a single compiled artifact. The result, on a well-structured transformer, is typically 1.5-3x faster than the same model running in eager PyTorch, with the gain rising as batch size shrinks and Python overhead becomes a larger share of step time.

torch.compile, TensorRT-LLM, ONNX Runtime, JAX jit, and vLLM’s CUDA-graph backend are all instances of the same pattern. The differences are in how aggressive they are, what hardware they target, and how much of the model they can actually capture without falling back to eager.

What the compiler is doing

Three families of optimization, applied in roughly this order:

  1. Graph capture. Trace the model under representative inputs to get a static graph (or a small set of dynamic-shape graphs). PyTorch uses TorchDynamo + Inductor; JAX uses tracing under jax.jit; ONNX Runtime imports a frozen ONNX file. The capture step is where most fragile-frontier-model-meets-compiler bugs surface.
  2. Graph-level rewrites. across adjacent ops, constant folding (compute anything whose inputs are known at compile time once and bake it into the graph), dead-code elimination, common-subexpression elimination, layout transforms (NHWC vs NCHW), precision casts.
  3. Kernel specialization. Replace generic ops with hardware-specific implementations. The standard moves: attention -> ; matmul -> cuBLAS / cuBLASLt with shape-tuned tile sizes; layernorm -> fused norm-residual; softmax -> fused-with-attention.

A compiled transformer block typically collapses from ~30 individual eager-mode kernel launches to ~5-8 fused kernels per layer.

Eager-mode Python is too dynamic. A transformer block written for research often contains: shape-dependent control flow (if seq_len > thresh), Python-level conditionals on dtype, custom autograd functions that the tracer cannot see through, and dictionary-keyed buffer accesses. Each of these is a potential graph-break.

torch.compile handles graph breaks by switching back to eager mode for the offending span and resuming compilation after — workable, but every break costs the optimization opportunities across that boundary. A model with 50 graph-breaks per forward pass compiles to something only marginally faster than eager.

The standard pre-compilation hygiene pass: replace Python control flow with torch.cond / torch.where, eliminate dtype branching, prefer nn.Module containers over Python dicts, avoid .item() calls inside the forward. Production-ready model code looks subtly different from research code for exactly this reason.

TensorRT-LLM and ONNX Runtime are stricter — they generally require the model to be fully exported to a static graph (or a small set of graphs over discrete dynamic shapes) before they can compile at all. The compile step is more painful but the ceiling is higher.

Why it composes with everything else

The compilation row in the production stack
  • is the compiler’s primary value-add. The reason a compiled transformer is faster is mostly that elementwise ops get folded into matmul epilogues.
  • is selected automatically when the compiler detects the attention pattern. You write the math; the compiler picks the kernel.
  • handling specializes per shape — prefill (large QK matmul, full softmax) vs decode (single-token Q, paged-attention KV) compile to different kernels.
  • uses CUDA Graphs to capture and replay decode steps without Python overhead — a coarser-grained version of the same idea.
  • typically requires the draft and target both compiled into the serving graph for the latency math to work out.

When eager mode wins

The other failure mode is correctness drift. Compilers reorder floating-point ops, switch to lower-precision intermediate accumulators in fused kernels, and sometimes pick numerically aggressive math (FastMath, TF32 instead of fp32 internals). Outputs are not bit-equal to eager mode. For most production cases the differences are within numerical noise; for retrieval models, embedding spaces, and anything with downstream calibrated thresholds, run a compiled-vs-eager equivalence test before you ship. zerank-style pairwise reranking models, in particular, are sensitive to small score perturbations and warrant explicit validation.

What it adds up to

In 2026, the production default for LLM serving on NVIDIA is some compiled stack — TensorRT-LLM, vLLM with CUDA Graphs, or torch.compile. Eager-mode inference is essentially never the right answer at production scale; the only place it survives is during development and in workloads with extremely variable shapes that defeat compilation. Together with , mixed precision, and FlashAttention, graph compilation is what closes the gap between “the math runs” and “the math runs at 80% of theoretical hardware peak.”

Go further

torch.compile vs TensorRT-LLM vs ONNX Runtime — when does each win?

torch.compile: easiest, broadest model coverage, ~1.3-2x over eager, but inference is not its primary target. TensorRT-LLM: best raw single-stream latency on NVIDIA GPUs (often 2-3x over eager) at the cost of a complex build process and NVIDIA lock-in. ONNX Runtime: cross-platform (CPU, GPU, mobile, edge), strong for non-LLM workloads and hybrid deployments, weaker on the latest LLM tricks. For LLM serving on NVIDIA hardware, TensorRT-LLM or vLLM-with-CUDA-graphs are the production choices.

What does 'attention specialization' mean?

Compilers detect attention patterns in the graph and swap them for hand-optimized kernels: FlashAttention for standard self-attention, FlashDecoding for the decode phase, paged-attention variants for KV-cache layouts. The same Python softmax(QK^T)V source compiles to wildly different kernels depending on shape, mask, and whether you are in prefill or decode. Done well, the compiler picks the right kernel automatically.

Why does compilation often hurt small batch sizes?

Compilation pays off when the per-step Python overhead is non-trivial relative to the GPU compute. At batch size 1 with short context, the GPU compute is so fast (microseconds) that Python launch overhead and CUDA kernel launch overhead together can dominate, and the compiled version's removal of Python overhead matters most. Counterintuitively, compilation often shines more at small batch — eager mode's overhead is unhidden there. The reason it sometimes does not help small batch is that compilation can specialize on shape and the overhead-of-recompilation per shape change kills it for highly variable workloads.

ZeroEntropy
The best AI teams build with ZeroEntropy models
Follow us on
GitHubTwitterSlackLinkedInDiscord