Boosting Video Generation Inference: Full Graph Compilation with torch.compile

This article examines the challenges of optimizing video generation model inference, moving from operator-level tweaks to full-graph compilation using torch.compile, and details systematic strategies to eliminate Graph Breaks, handle dynamic shapes, KV-Cache indexing, and Python-side caches, achieving a 47.6% speedup on a 14B model without accuracy loss.

Bilibili Tech
Bilibili Tech
Bilibili Tech
Boosting Video Generation Inference: Full Graph Compilation with torch.compile

Introduction: From Operator-Level to Graph-Level Optimization

Video generation model inference is a multi‑layered engineering challenge. Early optimizations focus on individual operators (e.g., convolution, attention) to improve floating‑point performance, but as operator efficiency nears hardware limits, graph‑level optimizations—scheduling, memory reuse, and control‑flow overhead reduction—become essential for further gains.

This article concentrates on the inference execution pipeline itself, exploring how torch.compile can be used to fully compile the Self‑Forcing inference flow, thereby reducing Python interpretation and scheduling overhead and laying the groundwork for deeper graph‑level optimizations.

Self‑Forcing Inference Characteristics and Compilation Challenges

Self‑Forcing transforms full‑frame parallel diffusion models (e.g., Wan2.1) into a causal‑attention architecture for training and inference. Unlike traditional bidirectional diffusion, which processes all video frames simultaneously, Self‑Forcing generates video autoregressively in small blocks (typically three latent frames) and reuses historical context via a KV Cache.

From a compilation perspective, Self‑Forcing is both "compiler‑friendly" (highly structured computation) and "compiler‑unfriendly" (numerous Python‑side dependencies). The main obstacles are:

Python control flow that depends on tensor values.

Calls such as .item() and .tolist() that move tensors back to the host.

Dynamic indexing and slicing of the KV Cache.

Mixed‑in Python‑level caching and debugging logic.

These patterns trigger Graph Breaks during torch.compile, causing the compiler to emit fragmented sub‑graphs and limiting performance gains.

Full‑Graph Compilation Strategy

We adopt a progressive optimization workflow: first wrap critical modules with torch.compile to gauge potential benefits, then systematically identify and eliminate Graph Breaks, ultimately achieving full‑graph compilation.

In the attention module ( CausalWanAttentionBlock) we apply the following decorator: @torch.compile(dynamic=True, fullgraph=True) dynamic=True allows the compiler to represent runtime‑determined dimensions symbolically, avoiding extra recompilations when input shapes vary (e.g., different token counts per GPU or KV Cache lengths).

fullgraph=True forces the entire function into a single FX graph; any untraced operation aborts compilation, exposing all Graph Break points for remediation.

Analyzing and Removing Graph Breaks

Control Flow and Scalar Extraction

Many control‑flow decisions rely on tensor‑to‑scalar conversions such as .item(). Example:

frame_seqlen = math.prod(grid_sizes[0][1:]).item()
local_end_index = kv_cache["local_end_index"].item() + current_end

Under fullgraph=True this raises:

Unsupported Tensor.item() call with capture_scalar_outputs=False

Although capture_scalar_outputs=True could silence the error, it adds host‑device synchronization overhead. Instead we rewrite the code to keep all calculations on the GPU:

frame_seqlen = torch.prod(grid_sizes[0][1:])
current_start_frame = current_start // frame_seqlen

We systematically remove the following scalar‑related patterns:

.item()
int(tensor)
.tolist()

Other similar tensor‑to‑Python conversions

Keeping data as tensors preserves static analysis opportunities for the compiler.

Data Dependency and Dynamic Shape

Dynamic shape inference failures also cause Graph Breaks. The original RoPE implementation iterates over grid_sizes.tolist():

for f, h, w in grid_sizes.tolist():
    seq_len = f * h * w
    x_i = x[i, :seq_len]

This triggers a guard‑failure error:

Could not guard on data-dependent expression u0*u1*u2 < 0 (unhinted: u0*u1*u2 < 0). (Size‑like symbols: none)

In inference we treat grid_sizes as a configuration‑time constant (video frame count, latent resolution, patch size). We therefore specialize the graph for a small set of discrete configurations, a practice common in CUDA‑Graph bucketing, eliminating the need for runtime list conversion.

KV Cache Dynamic Indexing

KV Cache reads and writes originally use tensor indices computed at runtime, which torch.compile cannot currently handle:

# local_start_index and local_end_index are Tensors
x = kv_cache["k"][:, local_start_index:local_end_index]
kv_cache["k"][:, local_start_index:local_end_index] = roped_key

Because the slice bounds are dynamic, compilation fails in fullgraph mode. By analyzing the causal attention algorithm we find that the read start is always zero and the write position can be expressed as a static offset derived from current_start and the block token count. We replace the dynamic slices with static indices and implement the write via a custom tile‑language kernel, fully removing the KV‑Cache‑induced Graph Breaks.

Host Calls and Python‑Side Caches

Python‑level caches (e.g., LRU dictionaries for RoPE cos/sin values) and debugging calls such as time.time() also break the graph. We pre‑compute the cosine/sine tensors during model initialization and store them as contiguous GPU tensors, accessing them via tensor indexing only. All debugging statements are stripped from the production inference path.

Experimental Results

After eliminating Graph Breaks and enabling full‑graph compilation, we benchmarked a 14‑billion‑parameter Self‑Forcing model generating 5‑second, 480p video clips. The full‑graph optimization alone delivered a 47.6% speedup, reducing end‑to‑end inference time from 8.86 s to 6.00 s, with no observable degradation in visual quality.

These findings demonstrate that full‑graph compilation not only provides automatic acceleration but also imposes valuable constraints on data dependencies and control flow, exposing hidden complexity and paving the way for further operator fusion and system‑level optimizations.

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

AIVideo generationInference accelerationgraph optimizationtorch.compileself-forcing
Bilibili Tech
Written by

Bilibili Tech

Provides introductions and tutorials on Bilibili-related technologies.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.