Skip to content

feat: add KV caching support for Wan models#400

Open
Perseus14 wants to merge 1 commit intomainfrom
wan_kv_cache
Open

feat: add KV caching support for Wan models#400
Perseus14 wants to merge 1 commit intomainfrom
wan_kv_cache

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 commented May 6, 2026

This Pull Request implements the KV Cache optimization for all WAN models (WAN 2.1 & 2.2, both Text-to-Video and Image-to-Video). This optimization pre-computes the Key and Value projections for text and image embeddings before the denoising loop (since they remain constant throughout)

Additionally, this PR introduces a hardware-aware Dynamic Image Alignment Padding optimization for next-generation TPUs.


Key Changes

1. KV Cache Optimization

  • Attention Level (FlaxWanAttention): Modified attention_flax.py to accept cached_kv. If present, key/value projections are bypassed. Added a robust compute_kv method to pre-project text (and image) states.
  • Transformer Level (WanModel & WanTransformerBlock): Updated transformer_wan.py to support KV cache propagation. Added compute_kv_cache to precompute block-level cached keys/values and integrated skip_embeddings inside WanTimeTextImageEmbedding to bypass redundant embedding layers when using cached states.
  • Pipeline Level (T2V & I2V Pipelines):
    • Updated forward pass helper signatures in wan_pipeline.py to accept and propagate KV caches.
    • Updated all denoising loops in wan_pipeline_2_1.py, wan_pipeline_2_2.py, wan_pipeline_i2v_2p1.py, and wan_pipeline_i2v_2p2.py to pre-compute the KV cache before starting the loop and reuse it at every step when use_kv_cache=True.
  • Config Defaults: Added use_kv_cache: False to all default .yml configuration files to ensure backward compatibility.

2. Dynamic Image Alignment Padding (Trillium & Ironwood Optimization)

  • Problem: Image embeddings were previously hardcoded to pad to multiples of 128. While optimal for older MXU tile sizes (TPU v4, v5p/v5e), next-generation hardware like Trillium (v6e) and Ironwood (v7x) utilize larger $256 \times 256$ MXU tile structures.
  • Solution: Replaced hardcoded values with dynamic TPU hardware detection (get_tpu_type()). Both attention_flax.py and embeddings_flax.py (NNXWanImageEmbedding) now dynamically adjust image alignment padding:
    • 256-alignment on Trillium (v6e) and Ironwood (v7x) to perfectly match larger hardware tiles.
    • 128-alignment fallback on older TPU architectures (v5p and below).

Detailed File Changes

Models

  • attention_flax.py:
    • Imported get_tpu_type and TpuType for dynamic hardware-aware image alignment padding.
    • Integrated cached_kv routing inside FlaxWanAttention.__call__.
    • Implemented compute_kv support for both T2V and I2V cross-attentions.
  • transformer_wan.py:
    • Added skip_embeddings parameter inside WanTimeTextImageEmbedding to bypass redundant text/image projections.
    • Updated WanTransformerBlock and WanModel to handle cached_kv / kv_cache passing.
    • Implemented WanModel.compute_kv_cache to precompute block-level cached keys/values across scan and non-scan layers.
  • embeddings_flax.py:
    • Updated NNXWanImageEmbedding to dynamically align to 256 for v6e/v7x and 128 otherwise, avoiding shape mismatches during cross-attention.

Pipelines

  • wan_pipeline.py:
    • Updated transformer_forward_pass, transformer_forward_pass_full_cfg, and transformer_forward_pass_cfg_cache to accept and pass kv_cache.
  • wan_pipeline_2_1.py & wan_pipeline_2_2.py:
    • Added use_kv_cache parameter to pipeline calls and pre-computed kv_cache and rotary_emb prior to the denoising loop.
  • wan_pipeline_i2v_2p1.py & wan_pipeline_i2v_2p2.py:
    • Fixed RoPE dummy shape bug. Integrated dynamic pre-computed kv_cache support for I2V workflows.

Configs

  • Configs (base_wan_1_3b.yml, base_wan_14b.yml, base_wan_27b.yml, base_wan_i2v_14b.yml, base_wan_i2v_27b.yml):
    • Added use_kv_cache: False for all default configs.

Performance Note

  • Observed Latency Savings:
    • ~0.5s on TPU v7x-8 (Ironwood)
    • ~0.7s on TPU v6e-8 (Trillium)
  • Analysis: The latency savings during a full denoising run are minimal. This is mathematically expected because the cross-attention Key/Value projections operate on a very small text prompt sequence (typically 512 tokens). The computational FLOPs saved by caching these projections represent a negligible fraction ($< 0.01%$) of the total workload compared to the massive latent sequence length processed by the self-attention and FFN layers at every step of the denoising loop.

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 6, 2026 10:21
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 6, 2026

@Perseus14 Perseus14 requested review from mbohlool and prishajain1 May 6, 2026 10:30
@Perseus14 Perseus14 self-assigned this May 7, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces significant performance optimizations for WAN models, specifically KV caching and hardware-aware dynamic image alignment padding for next-gen TPUs. The implementation of pre-computed RoPE and KV caches outside the denoising loop is an excellent architectural improvement.

🔍 General Feedback

  • KV Cache Coverage: While the KV cache and RoPE pre-computation are implemented for all models, they are currently missing from the scan_diffusion_loop paths in the T2V 2.1, T2V 2.2, and I2V 2.1 pipelines. This prevents users of the performance-oriented scan mode from benefiting from these optimizations. I have provided suggestions to propagate these caches into the scan bodies.
  • Hardware-Aware Padding: The dynamic adjustment of image alignment padding (256 vs 128) based on TPU type is correctly implemented across the attention and embedding layers.
  • Code Organization: Moving the concatenation of embeds and RoPE computation outside the denoising loop significantly reduces redundant computation at every step.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request implements important performance optimizations for WAN models, including KV caching and hardware-aware dynamic image alignment padding for TPU v6e and v7x. The pre-computation of RoPE and KV caches is a significant improvement for inference efficiency.

🔍 General Feedback

  • Missing Scan Loop Updates: While the optimizations are correctly implemented for the standard and CFG-cache denoising loops, the scan_diffusion_loop paths in the T2V 2.1, T2V 2.2, and I2V 2.1 pipelines have not been updated to use the new pre-computed KV caches and RoPE. This means that users utilizing the scan mode will not see the expected performance gains and will still incur the cost of redundant RoPE computations at each step.
  • Hardware-Awareness: The dynamic padding logic (256 for Trillium/Ironwood, 128 otherwise) is well-integrated into the attention and embedding modules.
  • Consistency: I2V 2.2 was correctly updated to handle these changes in its scan loop, but the other pipelines were missed. Synchronizing these would ensure a consistent performance profile across all WAN model variants.

Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully implements KV Cache optimization for all WAN models (2.1 & 2.2, T2V & I2V), providing significant inference speedups on TPU. It also introduces dynamic image alignment padding, optimizing performance for Trillium (v6e) and Ironwood (v7x) accelerators. The overall implementation is solid and well-integrated into the existing pipeline and model structures.

🔍 General Feedback

  • KV Cache Integration: The propagation of the KV cache from the pipeline through the transformer and attention layers is handled correctly, including support for complex CFG and caching strategies.
  • Trillium Optimizations: Dynamically adjusting padding to 256 for newer TPU generations is a great performance win.
  • Critical Path Bug: Identified a NameError in the WanModel's TI2V path (per_token_t=True) that needs immediate attention.
  • Consistency: Minor inconsistencies in RoPE dummy shape calculations across pipelines were noted but do not affect correctness.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This second part of the review adds specific inline comments and suggestions for the core model and transformer changes.

🔍 General Feedback

  • TI2V Path Fix: Corrected a NameError in WanModel.__call__ when per_token_t is used.
  • Attention Optimizations: Suggested improvements for TPU alignment detection and sequence length handling in FlaxWanAttention.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

Detailed inline comments and suggestions for model improvements and bug fixes.

🔍 General Feedback

  • TI2V Path Fix: Corrected a NameError in WanModel.__call__ when per_token_t is used.
  • Attention Optimizations: Suggested improvements for TPU alignment detection and sequence length handling in FlaxWanAttention.

Comment thread src/maxdiffusion/models/attention_flax.py Outdated
Comment thread src/maxdiffusion/models/attention_flax.py
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

Final part of the review with the critical fix for the TI2V path.

🔍 General Feedback

  • TI2V Path Fix: Corrected a NameError in WanModel.__call__ when per_token_t is used.

@Perseus14 Perseus14 force-pushed the wan_kv_cache branch 3 times, most recently from ec2f6bb to 8467dbb Compare May 7, 2026 19:02
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR implements KV caching for WAN models, providing a significant optimization for inference performance. It also introduces dynamic, hardware-aware alignment padding to ensure optimal execution on next-generation TPU architectures. The changes are largely consistent across the main model variants and pipelines, but there are some omissions in the Animate model and inconsistencies in parameter propagation that should be addressed.

🔍 General Feedback

  • Optimization: The pre-computation of RoPE and KV cache is a welcome improvement that aligns with JAX/Flax best practices for efficient inference.
  • Hardware-Awareness: The use of dynamic alignment (128 vs 256) based on TPU type is a great addition for performance portability across Trillium and earlier generations.
  • Omission: WanAnimateTransformer3DModel appears to have been missed in this update, despite the PR's intent to cover all WAN models.
  • Consistency: Ensure all WanTransformerBlock instantiations (especially in non-scan paths) propagate the new parameters (use_base2_exp, etc.) to avoid behavioral discrepancies.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR implements KV caching for WAN models, providing a significant optimization for inference performance. It also introduces dynamic, hardware-aware alignment padding to ensure optimal execution on next-generation TPU architectures (Trillium/Ironwood). While the core implementation is solid, there are omissions in the Animate model and inconsistencies in parameter propagation across non-scan code paths.

🔍 General Feedback

  • Optimization: The pre-computation of RoPE and KV cache is a significant performance improvement for long sequence generation.
  • Hardware-Awareness: The dynamic alignment logic based on TPU type is excellent for ensuring optimal MXU utilization across different hardware generations.
  • Omission: WanAnimateTransformer3DModel needs to be updated with KV cache support to fulfill the goal of supporting "all WAN models".
  • Consistency: Ensure all WanTransformerBlock instantiations propagate use_base2_exp and use_experimental_scheduler to maintain behavioral parity between scan and non-scan paths.

Comment thread src/maxdiffusion/models/wan/transformers/transformer_wan.py
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 7, 2026

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully implements KV Cache optimization for the WAN 2.1 and 2.2 model families, significantly improving inference efficiency for both Text-to-Video and Image-to-Video tasks. The implementation includes dynamic TPU alignment for newer hardware (Trillium/Ironwood) and correctly integrates with existing optimizations like CFG cache and SenCache.

🔍 General Feedback

  • Wait! Critical Logic Bug in transformer_wan.py: I noticed that in WanModel.__call__, the concatenation of image and text embeddings is skipped when kv_cache is provided. However, the FlaxWanAttention layer still expects a concatenated sequence (or at least the correct sequence length) for its internal slicing logic in I2V models. This will lead to incorrect slicing where text tokens are treated as image tokens, and potentially out-of-bounds errors for shorter prompts. Please ensure that encoder_hidden_states is always correctly concatenated in I2V models, regardless of whether kv_cache is used.
  • Consistency: In generate_wan.py, consider using config.use_kv_cache directly instead of getattr for better consistency with neighboring lines.
  • TPU Optimization: The dynamic alignment logic (128 vs 256) is well-implemented and will ensure optimal memory access patterns across different TPU generations.
  • RoPE Handling: Pre-computing rotary_emb once per batch is a good efficiency gain, and the implementation correctly handles the different latent shapes between T2V and I2V pipelines.

guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
use_sen_cache=config.use_sen_cache,
use_kv_cache=getattr(config, "use_kv_cache", False),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Use `config.use_kv_cache` directly for consistency with `use_cfg_cache` and `use_sen_cache`, as this property has been added to the base Wan configurations.
Suggested change
use_kv_cache=getattr(config, "use_kv_cache", False),
use_kv_cache=config.use_kv_cache,

@@ -330,9 +351,25 @@ def run_inference_2_1_i2v(
image_embeds_combined = image_embeds
condition_combined = condition
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 KV cache pre-computation is correctly placed outside the denoising loop.

attention_kernel = "tokamax_flash" # do not use ring attention for cross attention
self.added_kv_proj_dim = added_kv_proj_dim # New for I2V
self.image_seq_len = image_seq_len # New for I2V
tpu_type = get_tpu_type()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Dynamic alignment based on TPU type is a great addition for performance on Trillium (v6e) and newer generations.

@prishajain1
Copy link
Copy Markdown
Collaborator

Has this branch been tested with WAN I2V pipelines to ensure no breakage or performance regressions? (asking as we are making changes to the padding alignment)

@prishajain1
Copy link
Copy Markdown
Collaborator

Can we add some unittests with use_kv_cache=True for these pipelines?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants