diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md index dc44fbcd17..31620ec7d7 100644 --- a/examples/diffusers/sparsity/README.md +++ b/examples/diffusers/sparsity/README.md @@ -1,4 +1,4 @@ -# Skip-Softmax Sparse Attention for Diffusion Models +# Sparse Attention for Diffusion Models > [!WARNING] > **Third-Party License Notice — LTX-2** @@ -13,11 +13,32 @@ > fine-tuned weights produced from LTX-2 (including quantized, distilled, or sparsified > checkpoints) remain subject to the LTX Community License Agreement, not Apache 2.0. -Skip-softmax sparse attention (BLASST, ) skips KV -tiles whose attention scores are negligible during the FlashAttention computation, -reducing FLOPs without retraining. +Two sparse-attention methods are supported under +`modelopt.torch.sparsity.attention_sparsity` (`mtsa`): + +| Method | When to use | Calibration | +|--------|-------------|-------------| +| **Skip-Softmax** (BLASST) | Drop low-impact KV tiles inside FlashAttention. Works on any transformer with bidirectional attention. | Optional (exponential model) | +| **VSA** (Video Sparse Attention) | Block-level two-branch attention tuned for video models with long 3D token sequences. | None (fixed `top_k_ratio`) | + +Switching between methods is a CLI/config change — the pipelines, APIs, +and plugins are shared. + +| Model | Script | Methods | +|-------|--------|---------| +| Wan 2.2 5B / 14B | `wan22_sparse_attn.py` | `--method skip_softmax` (default), `--method vsa` | +| LTX-2 | `ltx2_vsa.py` | VSA only (LTX-2 uses a custom attention module; skip-softmax backend in progress) | + +--- + +## Skip-Softmax Sparse Attention + +Skip-softmax (BLASST, ) skips KV tiles whose attention +scores are negligible during the FlashAttention computation, reducing FLOPs without +retraining. + +Two threshold modes are supported: -Two modes are supported: - **Fixed raw threshold** — pass a log2-space threshold directly to the Triton kernel. No calibration needed. Good for quick testing and sweeps. - **Calibrated threshold** — an exponential model @@ -26,43 +47,38 @@ Two modes are supported: without recalibration. Log-space fitting (`fit_logspace=True`) is recommended for diffusion models where scale_factors span many orders of magnitude. -## Supported Models - -| Model | Script | Notes | -|-------|--------|-------| -| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only | -| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) | -| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend | - -## Quick Start +### Quick Start (Wan 2.2) ```bash # Fixed raw threshold (no calibration, fast) -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --raw-threshold -0.7 \ --prompt "A cat playing piano" --output out.mp4 # With calibration -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --calibrate --target-sparsity 0.5 \ --prompt "A cat playing piano" --output out.mp4 # Dense baseline (no sparsity, for comparison) -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --baseline \ --prompt "A cat playing piano" --output baseline.mp4 # Report runtime sparsity (per-layer tile skip ratios) -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --raw-threshold -0.7 --report-avg-sparsity \ --prompt "A cat playing piano" --output out.mp4 ``` -## Threshold Modes +`--method skip_softmax` is the default, so it doesn't need to be passed +explicitly when using skip-softmax flags. + +### Threshold Modes | Mode | How threshold reaches the kernel | Use case | |------|----------------------------------|----------| @@ -70,7 +86,163 @@ python wan22_skip_softmax.py \ | **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation | | **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated | -## Known Issues +### Known Issues + +- **14B dual transformer calibration**: Transformers are calibrated sequentially — + transformer_2's calibration runs while transformer_1 is already sparsified, + introducing asymmetric calibration conditions. +- **Minimum achievable sparsity**: Even the strictest threshold may yield 30–40% + sparsity on diffusion models (many tiles are inherently negligible). Targets + below this floor cause extrapolation; an inference-time warning is emitted. + +--- + +## Video Sparse Attention (VSA) + +VSA is a two-branch sparse attention architecture tailored for video diffusion +models: + +1. **Compression branch** — averages tokens within a 3D block (default `4,4,4` = + 64 tokens) and computes coarse-grained block-level attention for global context. +2. **Sparse branch** — selects the top-K most important blocks by the compression + branch's attention scores and computes fine-grained attention only on those. +3. **Gate blend** — `output = compression * gate_compress + sparse`. On models + without a learned `gate_compress` (Wan 2.2, and LTX-2 until fine-tuned), VSA + passes a zero tensor so `output = 0 * compression + sparse = sparse`. This + makes VSA at `top_k_ratio=1.0` (keep all blocks) mathematically equivalent to + dense attention, modulo `bfloat16` kernel rounding (~10⁻⁵ per call on a 75k + token sequence). + +VSA is **calibration-free** — sparsity is controlled by a fixed `top_k_ratio` +(`0.5` keeps 50% of blocks, `0.3` keeps 30%). + +### Quick Start + +```bash +# Wan 2.2 — VSA with default 50% top-K ratio (video_shape auto-derived) +python wan22_sparse_attn.py --method vsa \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --top-k-ratio 0.5 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# Wan 2.2 — aggressive 30% top-K (70% sparsity), keep first/last 2 layers dense +python wan22_sparse_attn.py --method vsa \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --top-k-ratio 0.3 --skip-first-last 2 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# Wan 2.2 — 720p+ / 81+ frames can OOM during VAE decode since VSA reserves +# ~15 GB of GPU memory for its tile buffers. Enable VAE tiling to recover. +python wan22_sparse_attn.py --method vsa \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --top-k-ratio 0.5 --enable-vae-tiling \ + --num-frames 81 --height 720 --width 1280 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# LTX-2 — VSA +python ltx2_vsa.py \ + --checkpoint /path/to/ltx2.safetensors \ + --text-encoder-path /path/to/gemma \ + --top-k-ratio 0.5 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# LTX-2 — baseline (no VSA) +python ltx2_vsa.py \ + --checkpoint /path/to/ltx2.safetensors \ + --text-encoder-path /path/to/gemma \ + --no-vsa --output baseline.mp4 +``` + +### Requirements + +- `fastvideo_kernel` at runtime (the Triton VSA kernel). Install with + `pip install fastvideo_kernel`. VSA imports this lazily, so the modelopt + sparsity API loads without it, but a VSA forward will raise a clear + `ImportError` if missing. +- For LTX-2 only: `ltx_core`, `ltx_trainer`, `ltx_pipelines` (see LICENSE + notice above). + +### Programmatic API + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT + +# Apply with the pre-built default (50% top-K, self-attention only) +transformer = mtsa.sparsify(transformer, VSA_DEFAULT) + +# Or with a custom config +config = { + "sparse_cfg": { + "*.attn1*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), # 3D tile (T, H, W); 64 tokens per block + "top_k_ratio": 0.3, # 70% sparsity + "enable": True, + # "video_shape": (T, H, W), # optional; auto-derived by the plugin + }, + "*.attn2*": {"enable": False}, # skip cross-attention + "default": {"enable": False}, + }, +} +transformer = mtsa.sparsify(transformer, config) +``` + +### Configuration Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `block_size_3d` | `(4, 4, 4)` | Video tile dims (T, H, W) — default creates 64-token blocks | +| `top_k_ratio` | `0.5` | Fraction of blocks kept in the sparse branch (0 < ratio ≤ 1). `1.0` keeps all blocks = degenerate dense mode | +| `video_shape` | `None` | Post-patchify video shape (T, H, W). Plugins auto-derive it — set explicitly only to override. | +| `enable` | `True` | Per-layer toggle | + +### How VSA Routes Through the Sparse-Attention API + +- **Wan 2.2** uses diffusers' `WanAttention` whose processor calls + `F.scaled_dot_product_attention` — VSA's SDPA patch in + `SparseAttentionModule._forward_with_vsa_sdpa_patch` intercepts that call and + replaces the computation with the Triton VSA kernel. The Wan 2.2 plugin + registers a forward pre-hook that reads `hidden_states.shape = (B, C, T, H, W)` + and sets `video_shape = (T // p_t, H // p_h, W // p_w)` on each VSA method + instance before the transformer runs. +- **LTX-2** uses its native `LTXSelfAttention` whose forward signature is + `(x, context, pe, k_pe)` and does **not** call `F.scaled_dot_product_attention`. + The LTX-2 plugin installs a `_LTX2SparseAttention` wrapper that computes + Q/K/V (with LTX-2's RMSNorm and `ltx_core` RoPE), an optional trainable + `gate_compress` (zero-init), and then calls `VSA.forward_attention` directly. + A forward pre-hook on the root `LTXModel` extracts `video_shape` from + `Modality.positions`. +- Cross-attention is detected via Q/K sequence-length mismatch and falls + through to the original attention path (no behaviour change). + +### Verifying the Setup on Wan 2.2 + +A good sanity check is to compare `top_k_ratio=1.0` to the dense baseline — +since VSA without a learned gate becomes pure sparse attention and a full +mask is mathematically equivalent to dense, the two outputs should be close. +On a Wan 2.2 14B run at 720×1280 / 81 frames / 40 steps we measured: + +| Comparison | First-frame PSNR | +|---|---| +| baseline vs baseline w/ VAE tiling | 40.5 dB | +| baseline vs VSA `top_k_ratio=1.0` | 23.9 dB | +| baseline vs VSA `top_k_ratio=0.5` | 13.1 dB | + +The ~24 dB degrade at `top_k=1.0` is error accumulation (6400 attention +calls × bf16 rounding through the denoising loop) — a single-call PSNR vs +dense SDPA is 50 dB on random inputs. + +### Known Limits -- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions. -- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted. +- **Peak memory on 720p+**: VSA's tile buffers reserve ~15 GB of GPU memory + on top of the transformer, which can OOM the one-shot VAE decode at 720p / + 81 frames. Pass `--enable-vae-tiling` (or set + `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`) to recover. +- **Token count ≥ 16 tiles (≈1024 tokens)**: VSA's setup overhead dominates for + tiny sequences. For LTX-2, use ≥121 frames at ≥512×768 for meaningful speedups. +- **Mixing with skip-softmax**: VSA patches SDPA globally per module, while + skip-softmax needs `attn_implementation="eager"`. `conversion.py` rejects + configs that mix the two — run them separately. +- **Training**: `to_gate_compress` is zero-initialised and trainable, but no + training loop is wired up yet. This example covers inference only. diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_sparse_attn.py similarity index 57% rename from examples/diffusers/sparsity/wan22_skip_softmax.py rename to examples/diffusers/sparsity/wan22_sparse_attn.py index e335451e2b..65a5618fb0 100644 --- a/examples/diffusers/sparsity/wan22_skip_softmax.py +++ b/examples/diffusers/sparsity/wan22_sparse_attn.py @@ -13,40 +13,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Wan 2.2 inference with skip-softmax sparse attention. +"""Wan 2.2 inference with sparse attention (skip-softmax or VSA). -This example applies skip-softmax sparse attention to the Wan 2.2 video -generation model (text-to-video). Four modes are supported: +Two sparse-attention methods are supported via ``--method``: -1. **Baseline** — pass ``--baseline`` for dense inference (default diffusers backend). -2. **Triton baseline** — pass ``--triton-baseline`` for dense Triton FA kernel - (no skip-softmax, same kernel as sparse runs for apples-to-apples comparison). -3. **Fixed raw threshold** — pass ``--raw-threshold`` to supply a log2-space - threshold directly to the Triton kernel. No calibration data is needed. -4. **Calibrated threshold** — pass ``--calibrate`` to run exponential-model - calibration (``scale_factor = a * exp(b * target_sparsity)``). +- **skip_softmax** (default, BLASST) — drops KV tiles whose attention + scores are negligible during the FlashAttention computation. Reduces FLOPs + without retraining. Supports ``--raw-threshold`` (log2-space, no + calibration) and ``--calibrate`` (exponential model fitted once, target + sparsity tunable at runtime). +- **vsa** (Video Sparse Attention) — two-branch (compression + sparse) + block-level attention tuned for video models. Calibration-free — sparsity + is controlled by a fixed ``top_k_ratio``. The Wan 2.2 plugin auto-derives + ``video_shape`` from each forward's ``hidden_states``. -During calibration, ``triton_skip_softmax`` with the Triton calibration kernel -collects sparsity statistics across multiple threshold trials. The fitted -exponential model then allows runtime control of the target sparsity ratio -without recalibration. +Run modes (method-agnostic): + +- ``--baseline`` — dense inference, no sparsity (default diffusers backend). +- ``--triton-baseline`` — dense Triton FA kernel, no skip-softmax + (apples-to-apples comparison with skip-softmax runs; skip_softmax only). The Wan 2.2 5B model has 40 transformer blocks with self-attention (attn1) -and cross-attention (attn2). Only self-attention is sparsified. +and cross-attention (attn2); the 14B model has two transformers. Only +self-attention is sparsified — cross-attention is always left dense. Usage:: - # Baseline (dense, no sparsity) - python wan22_skip_softmax.py --baseline --prompt "A cat playing piano" \\ - --output baseline.mp4 - - # Fixed raw threshold (no calibration needed) - python wan22_skip_softmax.py --raw-threshold -5.0 --report-avg-sparsity \\ + # Skip-softmax with fixed raw threshold (default method, no calibration) + python wan22_sparse_attn.py --raw-threshold -5.0 --report-avg-sparsity \\ --prompt "A cat playing piano" --output out.mp4 - # With calibration - python wan22_skip_softmax.py --calibrate --target-sparsity 0.25 \\ + # Skip-softmax with calibration + python wan22_sparse_attn.py --calibrate --target-sparsity 0.25 \\ --report-avg-sparsity --prompt "A cat playing piano" --output out.mp4 + + # VSA with 50% top-K (50% sparsity) + python wan22_sparse_attn.py --method vsa --top-k-ratio 0.5 \\ + --prompt "A cat playing piano" --output vsa.mp4 + + # VSA with aggressive 30% top-K (70% sparsity), keep first/last 2 layers dense + python wan22_sparse_attn.py --method vsa --top-k-ratio 0.3 \\ + --skip-first-last 2 --report-avg-sparsity \\ + --prompt "A cat playing piano" --output vsa.mp4 + + # Dense baseline (any method) + python wan22_sparse_attn.py --baseline --prompt "A cat playing piano" \\ + --output baseline.mp4 """ import argparse @@ -73,7 +85,7 @@ ) # fmt: on -# Default threshold trials for calibration +# Default threshold trials for calibration (skip_softmax only) DEFAULT_THRESHOLD_TRIALS = [ 1e-12, 1e-10, @@ -102,7 +114,7 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Wan 2.2 video generation with skip-softmax sparse attention" + description=("Wan 2.2 video generation with sparse attention (skip-softmax or VSA)") ) parser.add_argument( "--prompt", @@ -137,7 +149,15 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--seed", type=int, default=42, help="Random seed") - # Sparse attention options + # ---- Method selection ---- + parser.add_argument( + "--method", + choices=["skip_softmax", "vsa"], + default="skip_softmax", + help="Sparse attention method (default: skip_softmax)", + ) + + # ---- Run-mode flags (method-agnostic) ---- parser.add_argument( "--baseline", action="store_true", @@ -147,15 +167,7 @@ def parse_args() -> argparse.Namespace: "--triton-baseline", action="store_true", help="Run dense inference with Triton FA kernel (no skip-softmax, " - "apples-to-apples comparison with sparse runs)", - ) - parser.add_argument( - "--raw-threshold", - type=float, - default=None, - help="Raw skip_threshold_log2 value passed directly to the Triton kernel. " - "Negative values (e.g., -5.0 means tile must be within 5 units of running max). " - "Bypasses calibration and lambda conversion. Typical range: -1 to -30.", + "apples-to-apples comparison with sparse runs). skip_softmax only.", ) parser.add_argument( "--skip-first-last", @@ -166,40 +178,94 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--report-avg-sparsity", action="store_true", - help="Report per-layer and overall average tile sparsity after generation", + help="[skip_softmax] Report per-layer and overall average tile sparsity " + "measured via the Triton kernel's atomic counters. " + "No-op for VSA (sparsity is deterministic from --top-k-ratio).", + ) + parser.add_argument( + "--enable-vae-tiling", + action="store_true", + help="Enable VAE tiling in the Wan pipeline to reduce peak memory during " + "decode. Recommended at 720p+ when VSA is active, since VSA leaves less " + "GPU memory free than the dense baseline.", ) - # Calibration options + # ---- Skip-softmax options ---- + parser.add_argument( + "--raw-threshold", + type=float, + default=None, + help="[skip_softmax] Raw skip_threshold_log2 value passed directly to the Triton kernel. " + "Negative values (e.g., -5.0 means tile must be within 5 units of running max). " + "Bypasses calibration and lambda conversion. Typical range: -1 to -30.", + ) parser.add_argument( "--calibrate", action="store_true", - help="Calibrate threshold via exponential model (recommended)", + help="[skip_softmax] Calibrate threshold via exponential model (recommended)", ) parser.add_argument( "--target-sparsity", type=float, default=0.5, - help="Target sparsity ratio for calibration (0.0-1.0)", + help="[skip_softmax] Target sparsity ratio for calibration (0.0-1.0)", ) parser.add_argument( "--calib-steps", type=int, default=40, - help="Inference steps for calibration", + help="[skip_softmax] Inference steps for calibration", ) parser.add_argument( "--calib-frames", type=int, default=151, - help="Number of frames for calibration", + help="[skip_softmax] Number of frames for calibration", ) parser.add_argument( "--calib-size", type=int, default=4, - help="Number of calibration prompts from OpenVid-1M dataset", + help="[skip_softmax] Number of calibration prompts from OpenVid-1M dataset", + ) + + # ---- VSA options ---- + parser.add_argument( + "--top-k-ratio", + type=float, + default=0.5, + help="[vsa] Ratio of blocks kept in the sparse branch (0 < ratio ≤ 1). " + "Lower = more sparsity. 0.5 → 50%% sparsity, 0.3 → 70%%.", + ) + parser.add_argument( + "--block-size", + type=str, + default="4,4,4", + help="[vsa] VSA 3D block size as 'T,H,W' (default 4,4,4 → 64-token blocks)", + ) + parser.add_argument( + "--video-shape", + type=str, + default=None, + help="[vsa] Override post-patchify video shape as 'T,H,W'. " + "If unset, the Wan 2.2 plugin derives it automatically from hidden_states.", ) - return parser.parse_args() + + args = parser.parse_args() + + # Cross-method validation + if args.triton_baseline and args.method != "skip_softmax": + parser.error("--triton-baseline is only valid with --method skip_softmax") + + return args + + +def _parse_int_triple(spec: str) -> tuple[int, int, int]: + """Parse 'T,H,W' into a triple of positive ints.""" + parts = [int(p.strip()) for p in spec.split(",")] + if len(parts) != 3 or any(p <= 0 for p in parts): + raise ValueError(f"expected 3 positive integers T,H,W — got {spec!r}") + return (parts[0], parts[1], parts[2]) def build_pipeline(model_path: str) -> WanPipeline: @@ -210,8 +276,8 @@ def build_pipeline(model_path: str) -> WanPipeline: return pipe -def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: - """Build sparse attention config from CLI args. +def build_skip_softmax_config(args: argparse.Namespace, num_blocks: int) -> dict: + """Build a skip-softmax config from CLI args. Two modes: - **Raw threshold**: ``--raw-threshold`` sets ``skip_softmax_raw_threshold`` @@ -257,6 +323,37 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: return config +def build_vsa_config(args: argparse.Namespace, num_blocks: int) -> dict: + """Build a VSA sparse-attention config from CLI args. + + Applies VSA to self-attention (``attn1``) only — cross-attention + (``attn2``) is disabled because VSA's 3D-tile structure does not apply + to text KV. Optionally keeps the first/last N transformer layers dense. + """ + block_size = _parse_int_triple(args.block_size) + + attn_cfg: dict = { + "method": "vsa", + "block_size_3d": block_size, + "top_k_ratio": args.top_k_ratio, + "enable": True, + } + if args.video_shape is not None: + attn_cfg["video_shape"] = _parse_int_triple(args.video_shape) + + sparse_cfg: dict = { + "*.attn1*": attn_cfg, # Self-attention only + "*.attn2*": {"enable": False}, # Text cross-attention + "default": {"enable": False}, + } + + for i in range(args.skip_first_last): + sparse_cfg[f"*blocks.{i}.attn*"] = {"enable": False} + sparse_cfg[f"*blocks.{num_blocks - 1 - i}.attn*"] = {"enable": False} + + return {"sparse_cfg": sparse_cfg} + + def load_calib_prompts(calib_size: int) -> list[str]: """Load calibration prompts from OpenVid-1M dataset.""" dataset = load_dataset("nkp37/OpenVid-1M", split="train") @@ -277,7 +374,7 @@ def build_calibration_forward_loop( guidance_scale_2: float | None = 3.0, negative_prompt: str = "", ): - """Build a forward loop for exponential model calibration. + """Build a forward loop for exponential model calibration (skip_softmax). Uses prompts from OpenVid-1M dataset (same as quantization examples). Each prompt is run individually (batch_size=1). @@ -305,7 +402,12 @@ def forward_loop(model): def enable_sparsity_measurement(model: torch.nn.Module) -> None: - """Enable runtime sparsity measurement on all sparse attention modules.""" + """Enable runtime sparsity measurement on skip-softmax modules. + + Only applies to methods that expose ``enable_measure_sparsity`` (i.e. + the Triton skip-softmax kernel). VSA reports stats via its stats manager + instead — see ``print_vsa_runtime_stats``. + """ for _name, module in model.named_modules(): if isinstance(module, SparseAttentionModule) and module.is_enabled: method = module._sparse_method_instance @@ -315,7 +417,7 @@ def enable_sparsity_measurement(model: torch.nn.Module) -> None: def print_sparsity_summary(model: torch.nn.Module) -> None: - """Print per-module sparsity statistics including runtime kernel counters.""" + """Print per-module sparsity configuration (method-agnostic).""" enabled, disabled = [], [] for name, module in model.named_modules(): if isinstance(module, SparseAttentionModule): @@ -330,8 +432,8 @@ def print_sparsity_summary(model: torch.nn.Module) -> None: print(f" {name}: {info}") -def print_runtime_sparsity(model: torch.nn.Module) -> None: - """Print runtime tile sparsity measured via kernel atomic counters.""" +def print_skip_softmax_runtime_sparsity(model: torch.nn.Module) -> None: + """Print per-layer tile sparsity measured via the Triton kernel's atomic counters.""" total_all = 0 skipped_all = 0 per_module: list[tuple[str, int, int]] = [] @@ -378,6 +480,65 @@ def _get_num_blocks(transformer: torch.nn.Module) -> int: return max_idx + 1 +def _apply_skip_softmax( + pipe: WanPipeline, + transformers: list[tuple[str, torch.nn.Module]], + args: argparse.Namespace, + is_14b: bool, +): + """Sparsify every transformer with skip-softmax. + + Returns the calibration ``forward_loop`` (or None) so the caller can + free memory after calibration completes. + """ + forward_loop = None + if args.triton_baseline: + print("Triton baseline: dense Triton FA kernel (no skip-softmax)") + elif args.raw_threshold is not None: + print(f"Skip-softmax: fixed raw threshold {args.raw_threshold} (no calibration)") + if args.calibrate: + print("Warning: --calibrate is ignored when --raw-threshold is set") + elif args.calibrate: + forward_loop = build_calibration_forward_loop( + pipe, + calib_size=args.calib_size, + num_steps=args.calib_steps, + num_frames=args.calib_frames, + height=args.height, + width=args.width, + seed=args.seed, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_2 if is_14b else None, + negative_prompt=args.negative_prompt, + ) + else: + print( + "Warning: skip_softmax requested without --raw-threshold or --calibrate; " + "falling back to static skip_softmax_threshold=0.1" + ) + + for name, transformer in transformers: + num_blocks = _get_num_blocks(transformer) + label = "Triton backend" if args.triton_baseline else "skip-softmax" + print(f"Applying {label} to {name} ({num_blocks} blocks)...") + config = build_skip_softmax_config(args, num_blocks=num_blocks) + mtsa.sparsify(transformer, config, forward_loop=forward_loop) + + return forward_loop + + +def _apply_vsa( + transformers: list[tuple[str, torch.nn.Module]], + args: argparse.Namespace, +): + """Sparsify every transformer with VSA. No calibration needed.""" + for name, transformer in transformers: + num_blocks = _get_num_blocks(transformer) + print(f"Applying VSA to {name} ({num_blocks} blocks)...") + config = build_vsa_config(args, num_blocks=num_blocks) + mtsa.sparsify(transformer, config) + + def main() -> None: args = parse_args() @@ -385,9 +546,16 @@ def main() -> None: print(f"Loading Wan 2.2 from {args.model_path}...") pipe = build_pipeline(args.model_path) + if args.enable_vae_tiling: + # VAE tiling decodes latents in tiles instead of one shot — essential at + # 720p+ when VSA is active (VSA's tile buffers leave ~15 GB less free GPU + # memory vs. dense, which can OOM the one-shot VAE decode). + pipe.vae.enable_tiling() + print("Enabled VAE tiling (reduces peak memory during decode)") + # ---- Collect transformers ---- # Wan 2.2 5B has one transformer; 14B has two (transformer + transformer_2) - transformers = [] + transformers: list[tuple[str, torch.nn.Module]] = [] if pipe.transformer is not None: transformers.append(("transformer", pipe.transformer)) if getattr(pipe, "transformer_2", None) is not None: @@ -395,57 +563,24 @@ def main() -> None: is_14b = len(transformers) > 1 # ---- Sparsify (unless baseline) ---- + forward_loop = None if args.baseline: print("Baseline mode: running dense inference (default diffusers backend)") - elif args.triton_baseline: - print("Triton baseline: dense Triton FA kernel (no skip-softmax)") - for name, transformer in transformers: - num_blocks = _get_num_blocks(transformer) - print(f"Applying Triton backend to {name} ({num_blocks} blocks)...") - config = build_sparse_config(args, num_blocks=num_blocks) - mtsa.sparsify(transformer, config, forward_loop=None) - else: - # Build calibration forward loop if needed - forward_loop = None - if args.raw_threshold is not None: - print(f"Using fixed raw threshold: {args.raw_threshold} (skipping calibration)") - if args.calibrate: - print("Warning: --calibrate is ignored when --raw-threshold is set") - elif args.calibrate: - forward_loop = build_calibration_forward_loop( - pipe, - calib_size=args.calib_size, - num_steps=args.calib_steps, - num_frames=args.calib_frames, - height=args.height, - width=args.width, - seed=args.seed, - guidance_scale=args.guidance_scale, - guidance_scale_2=args.guidance_scale_2 if is_14b else None, - negative_prompt=args.negative_prompt, - ) - else: - print( - "Warning: neither --baseline, --raw-threshold, nor --calibrate specified; " - "using default static threshold" - ) - - for name, transformer in transformers: - num_blocks = _get_num_blocks(transformer) - print(f"Applying skip-softmax to {name} ({num_blocks} blocks)...") - config = build_sparse_config(args, num_blocks=num_blocks) - mtsa.sparsify(transformer, config, forward_loop=forward_loop) + elif args.method == "skip_softmax": + forward_loop = _apply_skip_softmax(pipe, transformers, args, is_14b) + elif args.method == "vsa": + _apply_vsa(transformers, args) # ---- Free calibration memory before inference ---- - if not args.baseline and not args.triton_baseline and forward_loop is not None: + if forward_loop is not None: gc.collect() torch.cuda.empty_cache() print("Cleared CUDA cache after calibration") # ---- Generate (optional) ---- if args.prompt: - # Enable runtime sparsity measurement before generation - if args.report_avg_sparsity and not args.baseline: + # Enable runtime sparsity measurement before generation (skip_softmax only) + if args.report_avg_sparsity and not args.baseline and args.method == "skip_softmax": for _name, transformer in transformers: enable_sparsity_measurement(transformer) @@ -477,8 +612,11 @@ def main() -> None: for name, transformer in transformers: print(f"\n{name}:") print_sparsity_summary(transformer) - if args.report_avg_sparsity: - print_runtime_sparsity(transformer) + # Runtime sparsity is meaningful only for skip-softmax (data-dependent). + # VSA sparsity is deterministic from top_k_ratio — the per-module + # summary above already reports it via get_threshold_info(). + if args.report_avg_sparsity and args.method == "skip_softmax": + print_skip_softmax_runtime_sparsity(transformer) if __name__ == "__main__": diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py index 66acfb510c..c6bf33b63b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py @@ -286,9 +286,21 @@ def forward_attention( query_tiled = self._tile_tensor(query, metadata) key_tiled = self._tile_tensor(key, metadata) value_tiled = self._tile_tensor(value, metadata) - gate_tiled = ( - self._tile_tensor(gate_compress, metadata) if gate_compress is not None else None - ) + if gate_compress is not None: + gate_tiled = self._tile_tensor(gate_compress, metadata) + else: + # The fastvideo kernel's default behaviour when + # ``compress_attn_weight is None`` is ``out_c + out_s`` — i.e. it + # *adds* the compression branch at full strength on top of the + # sparse branch. For models without a learned ``gate_compress`` + # (e.g. Wan 2.2), this doubles the attention signal and corrupts + # the output. The intended "no gate" semantics is + # ``gate_compress = 0`` → ``out = 0 * out_c + out_s = out_s``, + # which (a) matches an untrained LTX-2 whose ``to_gate_compress`` + # is zero-initialised, and (b) makes VSA at ``top_k_ratio=1.0`` + # reduce to dense attention (since ``out_s`` with all blocks + # selected is mathematically equivalent to dense SDPA). + gate_tiled = torch.zeros((), dtype=query_tiled.dtype, device=query_tiled.device) # ========== TRITON VSA KERNEL ========== # Kernel operates on tiled tensors in [batch, heads, padded_seq, dim] format diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py index 434fc18214..9c99e42c07 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -15,19 +15,32 @@ """Plugins for sparse attention integration with various frameworks.""" -# List of model plugins that are called during conversion -# Each plugin is a callable that takes (model) and performs validation/setup -CUSTOM_MODEL_PLUGINS: list = [] +from modelopt.torch.utils import import_plugin + +# Set of model plugins called during conversion. A set (rather than a list) +# keeps re-imports idempotent — the same callback inserted multiple times +# stays registered once. Matches the convention used by quantization and peft. +CUSTOM_MODEL_PLUGINS: set = set() def register_custom_model_plugins_on_the_fly(model): - """Applies all registered custom model plugins.""" + """Apply every registered custom model plugin to ``model``.""" for callback in CUSTOM_MODEL_PLUGINS: callback(model) +# Built-in plugins from . import huggingface # noqa: E402 +# Model-specific plugins for VSA. Guarded by ``import_plugin`` so the +# module-level imports stay soft — a missing dependency in one plugin must +# not break the core sparse-attention API. +with import_plugin("ltx2"): + from . import ltx2 + +with import_plugin("wan22"): + from . import wan22 + __all__ = [ "CUSTOM_MODEL_PLUGINS", "register_custom_model_plugins_on_the_fly", diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index d26b73f0b4..f4e43a40de 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -132,4 +132,4 @@ def _is_supported_model(model: nn.Module) -> bool: # Register plugins -CUSTOM_MODEL_PLUGINS.append(register_sparse_attention_on_the_fly) +CUSTOM_MODEL_PLUGINS.add(register_sparse_attention_on_the_fly) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py new file mode 100644 index 0000000000..27bea131ae --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py @@ -0,0 +1,413 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugin for LTX-2 video diffusion models with VSA support. + +LTX-2 uses a native ``LTXSelfAttention`` module whose forward signature is +``(x, context, pe, k_pe)`` and which does not call +``F.scaled_dot_product_attention``. VSA's default SDPA patching in +``SparseAttentionModule`` therefore has no effect on it, so this plugin +installs a model-specific wrapper that: + +1. Projects Q/K/V from ``x`` (and ``context`` for self-attention: ``context = x``) +2. Applies LTX-2's ``q_norm`` / ``k_norm`` RMSNorms and RoPE via ``ltx_core`` +3. Computes an optional ``gate_compress`` from a trainable zero-initialised + projection (used by VSA's compression branch, trained later) +4. Calls ``VSA.forward_attention()`` directly, bypassing SDPA +5. Applies the original module's ``to_out`` projection + +A forward pre-hook on the root ``LTXModel`` extracts the ``(T, H, W)`` +shape from ``Modality.positions`` (same source FastVideo uses) and stores it +on the model, so the wrapper can read it per-step without module-level global +state. +""" + +import logging +import weakref + +import torch +import torch.nn as nn + +from modelopt.torch.utils.logging import warn_rank_0 + +from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from . import CUSTOM_MODEL_PLUGINS + +logger = logging.getLogger(__name__) + +_LTX2_LICENSE_WARNING = ( + "LTX-2 packages (ltx-core, ltx-pipelines, ltx-trainer) are provided by " + "Lightricks and are NOT covered by the Apache 2.0 license governing NVIDIA " + "Model Optimizer. You MUST comply with the LTX Community License Agreement " + "when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative " + "models or fine-tuned weights from LTX-2 (including quantized or distilled " + "checkpoints) remain subject to the LTX Community License Agreement, not " + "Apache 2.0. See: https://github.com/Lightricks/LTX-2/blob/main/LICENSE" +) + + +def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: + """Forward pre-hook on LTXModel to extract ``dit_seq_shape`` from Modality.positions. + + Mirrors FastVideo's ``VideoSparseAttentionMetadataBuilder.build()`` which + computes ``dit_seq_shape = raw_latent_shape // patch_size``. Here we + derive the same shape by counting unique position values per dimension in + ``Modality.positions``, which is available at the LTXModel entry point + (before ``TransformerArgsPreprocessor`` converts it to RoPE embeddings). + + The result is stored on the model instance as ``module._vsa_video_shape`` + so ``_LTX2SparseAttention._resolve_video_shape()`` can read it via its + weak reference to the root model. Using an instance attribute (not a + global) makes this safe for concurrent models. + """ + # LTXModel.forward(self, video: Modality | None, audio, perturbations) + video = args[0] if len(args) > 0 else None + if video is None or not hasattr(video, "positions") or video.positions is None: + return + + positions = video.positions # (B, 3, T) or (B, 3, T, 2) + + try: + if positions.ndim == 4: + # (B, 3, T, 2) -- take start coordinates + pos_per_dim = positions[0, :, :, 0] # (3, T) + elif positions.ndim == 3: + # (B, 3, T) + pos_per_dim = positions[0] # (3, T) + else: + return + + t_dim = pos_per_dim[0].unique().numel() + h_dim = pos_per_dim[1].unique().numel() + w_dim = pos_per_dim[2].unique().numel() + seq_len = positions.shape[2] + + if t_dim * h_dim * w_dim == seq_len: + module._vsa_video_shape = (t_dim, h_dim, w_dim) + logger.debug( + f"Extracted dit_seq_shape={module._vsa_video_shape} from " + f"Modality.positions (seq_len={seq_len})" + ) + else: + logger.debug( + f"Position-derived shape {(t_dim, h_dim, w_dim)} product " + f"({t_dim * h_dim * w_dim}) != seq_len ({seq_len}), skipping" + ) + except Exception: + logger.debug("Failed to extract video_shape from Modality.positions", exc_info=True) + + +def _is_ltx2_model(model: nn.Module) -> bool: + """Check if model is an LTX-2 model. + + Uses ``LTXModel`` / ``LTXSelfAttention`` class names to avoid false + positives from other DiTs (e.g., LongCat) that share similar attribute + patterns. + """ + if type(model).__name__ == "LTXModel": + return True + return any(type(m).__name__ == "LTXSelfAttention" for m in model.modules()) + + +def _is_ltx2_attention_module(module: nn.Module, name: str = "") -> bool: + """Check if a module is an LTX-2 Attention module by class name or structure. + + Primary: class name is ``LTXSelfAttention``. Fallback: has ``to_q/k/v``, + ``q_norm``, ``k_norm``, and ``rope_type`` (unique to LTX-2 among DiTs we + support). + """ + class_name = type(module).__name__ + if class_name == "LTXSelfAttention": + return True + return ( + hasattr(module, "to_q") + and hasattr(module, "to_k") + and hasattr(module, "to_v") + and hasattr(module, "q_norm") + and hasattr(module, "k_norm") + and hasattr(module, "rope_type") + ) + + +class _LTX2SparseAttention(SparseAttentionModule): + """Sparse-attention wrapper for LTX-2 ``LTXSelfAttention`` modules. + + Handles LTX-2 specifics (native forward args, RMSNorm, RoPE, trainable + ``gate_compress``) and delegates the actual attention computation to + ``VSA.forward_attention``. Falls back to the original module forward + for cross-attention / incompatible sequence lengths / missing video + shape, matching how the core SDPA patch falls through to original SDPA. + """ + + def _setup(self): + super()._setup() + + # Add trainable gate_compress projection if not already present. + # Zero-init so its initial contribution is 0 — matches VSA's behaviour + # when gate_compress is None but leaves room for fine-tuning. + if not hasattr(self, "to_gate_compress"): + to_q = self.to_q + in_features = to_q.in_features + out_features = to_q.out_features + + self.to_gate_compress = nn.Linear(in_features, out_features, bias=True) + nn.init.zeros_(self.to_gate_compress.weight) + nn.init.zeros_(self.to_gate_compress.bias) + + self.to_gate_compress = self.to_gate_compress.to( + device=to_q.weight.device, + dtype=to_q.weight.dtype, + ) + + def _compute_qkv( + self, + x: torch.Tensor, + context: torch.Tensor | None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute Q/K/V with LTX-2 norms and RoPE. + + Inputs are ``[batch, seq, hidden_dim]``; output tensors share the same + layout and are reshaped to heads later in ``forward``. + """ + context = context if context is not None else x + + query = self.to_q(x) + key = self.to_k(context) + value = self.to_v(context) + + if hasattr(self, "q_norm"): + query = self.q_norm(query) + if hasattr(self, "k_norm"): + key = self.k_norm(key) + + if pe is not None and hasattr(self, "rope_type"): + try: + from ltx_core.model.transformer.rope import apply_rotary_emb + except ModuleNotFoundError: + raise ModuleNotFoundError( + "LTX-2 VSA plugin requires the 'ltx_core' package for RoPE " + "support. The plugin registered successfully, but 'ltx_core' " + "is needed at runtime. Install with: pip install ltx-core" + ) from None + + query = apply_rotary_emb(query, pe, self.rope_type) + key = apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type) + + return query, key, value + + @staticmethod + def _reshape_for_vsa(tensor: torch.Tensor, num_heads: int) -> torch.Tensor: + """``[batch, seq, hidden]`` → ``[batch, heads, seq, head_dim]``.""" + batch, seq_len, hidden_dim = tensor.shape + head_dim = hidden_dim // num_heads + return tensor.view(batch, seq_len, num_heads, head_dim).transpose(1, 2) + + @staticmethod + def _reshape_from_vsa(tensor: torch.Tensor) -> torch.Tensor: + """``[batch, heads, seq, head_dim]`` → ``[batch, seq, hidden]``.""" + batch, heads, seq_len, head_dim = tensor.shape + return tensor.transpose(1, 2).contiguous().view(batch, seq_len, heads * head_dim) + + def _resolve_video_shape(self, seq_len: int) -> tuple[int, int, int] | None: + """Resolve video_shape for the current forward pass. + + Resolution order (mirrors FastVideo's metadata flow): + 1. ``root_model._vsa_video_shape`` -- set by the forward pre-hook from + ``Modality.positions`` + 2. ``method.video_shape`` -- explicitly set via the sparsify config + """ + root_ref = getattr(self, "_vsa_root_model_ref", None) + root = root_ref() if root_ref is not None else None + if root is not None: + shape = getattr(root, "_vsa_video_shape", None) + if shape is not None: + t, h, w = shape + if t * h * w == seq_len: + return shape + + method = getattr(self, "_sparse_method_instance", None) + if method is not None and method.video_shape is not None: + t, h, w = method.video_shape + if t * h * w == seq_len: + return method.video_shape + + return None + + def forward(self, *args, **kwargs): + """Run the LTX-2 attention forward through VSA. + + Consumes LTX-2's native call signature (``x``, ``context``, ``pe``, + ``k_pe``) and dispatches to ``VSA.forward_attention``; falls through + to the original module for cross-attention or incompatible inputs. + """ + if not self.is_enabled: + return self._call_original_forward(*args, **kwargs) + + x = kwargs.get("x") + if x is None and len(args) > 0: + x = args[0] + + if x is None: + return self._call_original_forward(*args, **kwargs) + + context = kwargs.get("context") + pe = kwargs.get("pe") + k_pe = kwargs.get("k_pe") + + # Cross-attention: fall through to the original module + if context is not None and x.shape[1] != context.shape[1]: + return self._call_original_forward(*args, **kwargs) + + method = getattr(self, "_sparse_method_instance", None) + if method is None: + return self._call_original_forward(*args, **kwargs) + + query, key, value = self._compute_qkv(x, context, pe, k_pe) + + # Incompatible seq_len (e.g., audio attention with seq=32) + seq_len = query.shape[1] + block_size_3d = method.block_size_3d + block_elements = block_size_3d[0] * block_size_3d[1] * block_size_3d[2] + if seq_len < block_elements: + logger.debug(f"VSA skipped: seq_len={seq_len} < block_elements={block_elements}") + return self._call_original_forward(*args, **kwargs) + + video_shape = self._resolve_video_shape(seq_len) + if video_shape is None: + logger.debug(f"VSA skipped: no matching video_shape for seq_len={seq_len}") + return self._call_original_forward(*args, **kwargs) + + gate_compress = None + if hasattr(self, "to_gate_compress"): + gate_compress = self.to_gate_compress(x) + + # Reshape to [batch, heads, seq, head_dim] + query = self._reshape_for_vsa(query, self.heads) + key = self._reshape_for_vsa(key, self.heads) + value = self._reshape_for_vsa(value, self.heads) + if gate_compress is not None: + gate_compress = self._reshape_for_vsa(gate_compress, self.heads) + + output, stats = method.forward_attention( + query=query, + key=key, + value=value, + gate_compress=gate_compress, + video_shape=video_shape, + ) + + # Bubble stats up through SparseAttentionModule's stats path + self._last_stats = stats + if self._stats_manager is not None: + self._stats_manager.collect(stats) + self._last_stats = None + + output = self._reshape_from_vsa(output) + + if hasattr(self, "to_out"): + output = self.to_out(output) + + return output + + def _call_original_forward(self, *args, **kwargs): + """Invoke the original module's forward, bypassing VSA. + + ``SparseAttentionModule.forward`` passes through to the original + module when ``is_enabled`` is False — exploit that to avoid + reimplementing the fallback path. + """ + was_enabled = getattr(self, "_enabled", True) + self._enabled = False + try: + result = SparseAttentionModule.forward(self, *args, **kwargs) + finally: + self._enabled = was_enabled + return result + + def get_gate_compress_parameters(self): + """Return trainable ``gate_compress`` parameters for later fine-tuning.""" + if hasattr(self, "to_gate_compress"): + return self.to_gate_compress.parameters() + return iter([]) + + +def register_ltx2_attention(model: nn.Module) -> int: + """Register LTX-2 Attention modules for VSA wrapping. + + Replaces any existing generic wrapper in ``SparseAttentionRegistry`` + with ``_LTX2SparseAttention`` for each LTX-2 attention type found, wires + a weakref back to the root model on every attention instance, and + installs the ``Modality.positions`` extraction pre-hook. + """ + if not _is_ltx2_model(model): + return 0 + + # Third-party-license notice: emit once per LTX-2 model detection, + # matching the pattern used by modelopt's quantization and kernel LTX-2 + # plugins. The wrapper touches ``ltx_core`` (RoPE) at forward time, so + # users must comply with the LTX Community License Agreement. + warn_rank_0(_LTX2_LICENSE_WARNING, UserWarning, stacklevel=2) + + registered_types = set() + num_modules = 0 + + for name, module in model.named_modules(): + if not _is_ltx2_attention_module(module, name): + continue + + num_modules += 1 + module_type = type(module) + + if module_type in registered_types: + continue + + if module_type in SparseAttentionRegistry: + logger.debug(f"Unregistering generic wrapper for {module_type.__name__}") + SparseAttentionRegistry.unregister(module_type) + + SparseAttentionRegistry.register({module_type: module_type.__name__})(_LTX2SparseAttention) + registered_types.add(module_type) + logger.info(f"Registered LTX-2 attention: {module_type.__name__}") + + if num_modules > 0: + logger.info(f"Found {num_modules} LTX-2 Attention modules in model") + + # Weakref avoids the circular-submodule problem (nn.Module.__setattr__ + # would otherwise register the root model as a submodule of every + # attention, causing infinite recursion in named_children()). + root_ref = weakref.ref(model) + for _, module in model.named_modules(): + if _is_ltx2_attention_module(module): + object.__setattr__(module, "_vsa_root_model_ref", root_ref) + + model.register_forward_pre_hook(_extract_video_shape_hook) + logger.debug("Registered VSA video_shape extraction hook on model") + + return len(registered_types) + + +def register_ltx2_on_the_fly(model: nn.Module) -> bool: + """Plugin entry point: wire up LTX-2 VSA if this is an LTX-2 model.""" + num_registered = register_ltx2_attention(model) + if num_registered > 0: + logger.info(f"Registered {num_registered} LTX-2 attention types for VSA") + return True + return False + + +# Idempotent: plugins/__init__.py stores plugins in a set so re-imports are safe. +CUSTOM_MODEL_PLUGINS.add(register_ltx2_on_the_fly) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py b/modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py new file mode 100644 index 0000000000..f343f58e63 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugin for Wan 2.2 video diffusion models with VSA support. + +Wan 2.2 (``WanTransformer3DModel`` from diffusers) uses standard diffusers +``Attention`` modules whose ``AttnProcessor2_0`` calls +``F.scaled_dot_product_attention``. VSA's default SDPA patch in +``SparseAttentionModule._forward_with_vsa_sdpa_patch`` therefore intercepts +the right call — we only need to tell VSA the post-patchify ``(T, H, W)``. + +This plugin installs a forward pre-hook on every ``WanTransformer3DModel`` +that: + +1. Reads ``hidden_states`` shape ``(B, C, T, H, W)`` from the transformer + input. +2. Divides by ``model.config.patch_size = (p_t, p_h, p_w)`` — same + computation diffusers does internally (see + ``WanTransformer3DModel.forward``: ``post_patch_num_frames = num_frames // p_t`` + etc.). +3. Propagates the resulting shape to every ``SparseAttentionModule`` in + the transformer whose method is VSA, via ``method.set_video_shape()``. + +Self-attention layers (``attn1``) then see a valid ``video_shape`` when the +SDPA patch fires. Cross-attention (``attn2``) is skipped by VSA's +``can_apply_vsa`` guard since Q/K lengths differ. +""" + +import logging + +import torch.nn as nn + +from ..sparse_attention import SparseAttentionModule +from . import CUSTOM_MODEL_PLUGINS + +logger = logging.getLogger(__name__) + + +def _is_wan22_model(model: nn.Module) -> bool: + """Detect a Wan 2.2 transformer by class name. + + Wan 2.1 / 2.2 both use ``WanTransformer3DModel`` in diffusers — matching + by name keeps the plugin decoupled from the diffusers import. + """ + if type(model).__name__ == "WanTransformer3DModel": + return True + return any(type(m).__name__ == "WanTransformer3DModel" for m in model.modules()) + + +def _find_wan22_transformers(model: nn.Module) -> list[nn.Module]: + """Return every ``WanTransformer3DModel`` reachable from ``model``. + + The 14B model is a ``WanPipeline`` with ``transformer`` and + ``transformer_2``, so we return every match. + """ + if type(model).__name__ == "WanTransformer3DModel": + return [model] + return [m for m in model.modules() if type(m).__name__ == "WanTransformer3DModel"] + + +def _get_patch_size(transformer: nn.Module) -> tuple[int, int, int] | None: + """Read ``patch_size`` from the transformer's config.""" + config = getattr(transformer, "config", None) + if config is None: + return None + patch_size = getattr(config, "patch_size", None) + if patch_size is None: + return None + try: + p_t, p_h, p_w = patch_size + except (TypeError, ValueError): + return None + return (int(p_t), int(p_h), int(p_w)) + + +def _extract_hidden_states(args: tuple, kwargs: dict): + """Pick out the ``hidden_states`` argument regardless of call style.""" + if "hidden_states" in kwargs: + return kwargs["hidden_states"] + return args[0] if len(args) > 0 else None + + +def _make_wan22_video_shape_hook(transformer: nn.Module): + """Create the per-transformer forward pre-hook. + + Closes over the specific ``transformer`` so it can walk its own + submodules, independent of other Wan 2.2 transformers in the same + pipeline. + """ + patch_size = _get_patch_size(transformer) + if patch_size is None: + logger.debug("Wan 2.2 transformer has no config.patch_size; hook inert") + + def _noop(module, args, kwargs): + return None + + return _noop + + p_t, p_h, p_w = patch_size + + def _hook(module: nn.Module, args: tuple, kwargs: dict) -> None: + hidden_states = _extract_hidden_states(args, kwargs) + if hidden_states is None or hidden_states.ndim != 5: + return + + _, _, num_frames, height, width = hidden_states.shape + video_shape = (num_frames // p_t, height // p_h, width // p_w) + if any(d <= 0 for d in video_shape): + logger.debug( + f"Wan 2.2 VSA hook: invalid video_shape {video_shape} for " + f"input {(num_frames, height, width)} / patch {patch_size}; skipping" + ) + return + + # Also expose on the transformer for debugging / external inspection. + module._vsa_video_shape = video_shape + + # Propagate to every VSA method instance in this transformer. + for sub in module.modules(): + if not isinstance(sub, SparseAttentionModule): + continue + method = getattr(sub, "_sparse_method_instance", None) + if method is None: + continue + if getattr(method, "name", None) != "vsa": + continue + method.set_video_shape(video_shape) + + return _hook + + +def register_wan22_vsa(model: nn.Module) -> int: + """Install a VSA ``video_shape`` pre-hook on every Wan 2.2 transformer. + + Idempotent: the hook is re-registered on each call because + ``plugins/__init__.py`` stores callbacks in a set — re-invoking after + ``mtsa.sparsify`` is safe, but we guard against double-registration by + tagging the transformer with ``_vsa_hook_registered``. + """ + transformers = _find_wan22_transformers(model) + if not transformers: + return 0 + + registered = 0 + for transformer in transformers: + if getattr(transformer, "_vsa_hook_registered", False): + continue + hook = _make_wan22_video_shape_hook(transformer) + transformer.register_forward_pre_hook(hook, with_kwargs=True) + transformer._vsa_hook_registered = True + registered += 1 + logger.info(f"Registered Wan 2.2 VSA video_shape hook on {type(transformer).__name__}") + + return registered + + +def register_wan22_on_the_fly(model: nn.Module) -> bool: + """Plugin entry point: install the Wan 2.2 VSA hook if applicable.""" + if not _is_wan22_model(model): + return False + num = register_wan22_vsa(model) + if num > 0: + logger.info(f"Installed VSA video_shape hook on {num} Wan 2.2 transformer(s)") + return True + return False + + +CUSTOM_MODEL_PLUGINS.add(register_wan22_on_the_fly)