Conversation
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces significant performance optimizations for WAN models, specifically KV caching and hardware-aware dynamic image alignment padding for next-gen TPUs. The implementation of pre-computed RoPE and KV caches outside the denoising loop is an excellent architectural improvement.
🔍 General Feedback
- KV Cache Coverage: While the KV cache and RoPE pre-computation are implemented for all models, they are currently missing from the
scan_diffusion_looppaths in the T2V 2.1, T2V 2.2, and I2V 2.1 pipelines. This prevents users of the performance-oriented scan mode from benefiting from these optimizations. I have provided suggestions to propagate these caches into the scan bodies. - Hardware-Aware Padding: The dynamic adjustment of image alignment padding (256 vs 128) based on TPU type is correctly implemented across the attention and embedding layers.
- Code Organization: Moving the concatenation of embeds and RoPE computation outside the denoising loop significantly reduces redundant computation at every step.
There was a problem hiding this comment.
This Pull Request implements important performance optimizations for WAN models, including KV caching and hardware-aware dynamic image alignment padding for TPU v6e and v7x. The pre-computation of RoPE and KV caches is a significant improvement for inference efficiency.
🔍 General Feedback
- Missing Scan Loop Updates: While the optimizations are correctly implemented for the standard and CFG-cache denoising loops, the
scan_diffusion_looppaths in the T2V 2.1, T2V 2.2, and I2V 2.1 pipelines have not been updated to use the new pre-computed KV caches and RoPE. This means that users utilizing the scan mode will not see the expected performance gains and will still incur the cost of redundant RoPE computations at each step. - Hardware-Awareness: The dynamic padding logic (256 for Trillium/Ironwood, 128 otherwise) is well-integrated into the attention and embedding modules.
- Consistency: I2V 2.2 was correctly updated to handle these changes in its scan loop, but the other pipelines were missed. Synchronizing these would ensure a consistent performance profile across all WAN model variants.
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully implements KV Cache optimization for all WAN models (2.1 & 2.2, T2V & I2V), providing significant inference speedups on TPU. It also introduces dynamic image alignment padding, optimizing performance for Trillium (v6e) and Ironwood (v7x) accelerators. The overall implementation is solid and well-integrated into the existing pipeline and model structures.
🔍 General Feedback
- KV Cache Integration: The propagation of the KV cache from the pipeline through the transformer and attention layers is handled correctly, including support for complex CFG and caching strategies.
- Trillium Optimizations: Dynamically adjusting padding to 256 for newer TPU generations is a great performance win.
- Critical Path Bug: Identified a
NameErrorin theWanModel's TI2V path (per_token_t=True) that needs immediate attention. - Consistency: Minor inconsistencies in RoPE dummy shape calculations across pipelines were noted but do not affect correctness.
There was a problem hiding this comment.
This second part of the review adds specific inline comments and suggestions for the core model and transformer changes.
🔍 General Feedback
- TI2V Path Fix: Corrected a
NameErrorinWanModel.__call__whenper_token_tis used. - Attention Optimizations: Suggested improvements for TPU alignment detection and sequence length handling in
FlaxWanAttention.
There was a problem hiding this comment.
Detailed inline comments and suggestions for model improvements and bug fixes.
🔍 General Feedback
- TI2V Path Fix: Corrected a
NameErrorinWanModel.__call__whenper_token_tis used. - Attention Optimizations: Suggested improvements for TPU alignment detection and sequence length handling in
FlaxWanAttention.
ec2f6bb to
8467dbb
Compare
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR implements KV caching for WAN models, providing a significant optimization for inference performance. It also introduces dynamic, hardware-aware alignment padding to ensure optimal execution on next-generation TPU architectures. The changes are largely consistent across the main model variants and pipelines, but there are some omissions in the Animate model and inconsistencies in parameter propagation that should be addressed.
🔍 General Feedback
- Optimization: The pre-computation of RoPE and KV cache is a welcome improvement that aligns with JAX/Flax best practices for efficient inference.
- Hardware-Awareness: The use of dynamic alignment (128 vs 256) based on TPU type is a great addition for performance portability across Trillium and earlier generations.
- Omission:
WanAnimateTransformer3DModelappears to have been missed in this update, despite the PR's intent to cover all WAN models. - Consistency: Ensure all
WanTransformerBlockinstantiations (especially in non-scan paths) propagate the new parameters (use_base2_exp, etc.) to avoid behavioral discrepancies.
There was a problem hiding this comment.
This PR implements KV caching for WAN models, providing a significant optimization for inference performance. It also introduces dynamic, hardware-aware alignment padding to ensure optimal execution on next-generation TPU architectures (Trillium/Ironwood). While the core implementation is solid, there are omissions in the Animate model and inconsistencies in parameter propagation across non-scan code paths.
🔍 General Feedback
- Optimization: The pre-computation of RoPE and KV cache is a significant performance improvement for long sequence generation.
- Hardware-Awareness: The dynamic alignment logic based on TPU type is excellent for ensuring optimal MXU utilization across different hardware generations.
- Omission:
WanAnimateTransformer3DModelneeds to be updated with KV cache support to fulfill the goal of supporting "all WAN models". - Consistency: Ensure all
WanTransformerBlockinstantiations propagateuse_base2_expanduse_experimental_schedulerto maintain behavioral parity between scan and non-scan paths.
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully implements KV Cache optimization for the WAN 2.1 and 2.2 model families, significantly improving inference efficiency for both Text-to-Video and Image-to-Video tasks. The implementation includes dynamic TPU alignment for newer hardware (Trillium/Ironwood) and correctly integrates with existing optimizations like CFG cache and SenCache.
🔍 General Feedback
- Wait! Critical Logic Bug in
transformer_wan.py: I noticed that inWanModel.__call__, the concatenation of image and text embeddings is skipped whenkv_cacheis provided. However, theFlaxWanAttentionlayer still expects a concatenated sequence (or at least the correct sequence length) for its internal slicing logic in I2V models. This will lead to incorrect slicing where text tokens are treated as image tokens, and potentially out-of-bounds errors for shorter prompts. Please ensure thatencoder_hidden_statesis always correctly concatenated in I2V models, regardless of whetherkv_cacheis used. - Consistency: In
generate_wan.py, consider usingconfig.use_kv_cachedirectly instead ofgetattrfor better consistency with neighboring lines. - TPU Optimization: The dynamic
alignmentlogic (128 vs 256) is well-implemented and will ensure optimal memory access patterns across different TPU generations. - RoPE Handling: Pre-computing
rotary_embonce per batch is a good efficiency gain, and the implementation correctly handles the different latent shapes between T2V and I2V pipelines.
| guidance_scale_high=config.guidance_scale_high, | ||
| use_cfg_cache=config.use_cfg_cache, | ||
| use_sen_cache=config.use_sen_cache, | ||
| use_kv_cache=getattr(config, "use_kv_cache", False), |
There was a problem hiding this comment.
| use_kv_cache=getattr(config, "use_kv_cache", False), | |
| use_kv_cache=config.use_kv_cache, |
| @@ -330,9 +351,25 @@ def run_inference_2_1_i2v( | |||
| image_embeds_combined = image_embeds | |||
| condition_combined = condition | |||
There was a problem hiding this comment.
| attention_kernel = "tokamax_flash" # do not use ring attention for cross attention | ||
| self.added_kv_proj_dim = added_kv_proj_dim # New for I2V | ||
| self.image_seq_len = image_seq_len # New for I2V | ||
| tpu_type = get_tpu_type() |
There was a problem hiding this comment.
|
Has this branch been tested with WAN I2V pipelines to ensure no breakage or performance regressions? (asking as we are making changes to the padding alignment) |
|
Can we add some unittests with use_kv_cache=True for these pipelines? |
This Pull Request implements the KV Cache optimization for all WAN models (WAN 2.1 & 2.2, both Text-to-Video and Image-to-Video). This optimization pre-computes the Key and Value projections for text and image embeddings before the denoising loop (since they remain constant throughout)
Additionally, this PR introduces a hardware-aware Dynamic Image Alignment Padding optimization for next-generation TPUs.
Key Changes
1. KV Cache Optimization
FlaxWanAttention): Modifiedattention_flax.pyto acceptcached_kv. If present, key/value projections are bypassed. Added a robustcompute_kvmethod to pre-project text (and image) states.WanModel&WanTransformerBlock): Updatedtransformer_wan.pyto support KV cache propagation. Addedcompute_kv_cacheto precompute block-level cached keys/values and integratedskip_embeddingsinsideWanTimeTextImageEmbeddingto bypass redundant embedding layers when using cached states.wan_pipeline.pyto accept and propagate KV caches.wan_pipeline_2_1.py,wan_pipeline_2_2.py,wan_pipeline_i2v_2p1.py, andwan_pipeline_i2v_2p2.pyto pre-compute the KV cache before starting the loop and reuse it at every step whenuse_kv_cache=True.use_kv_cache: Falseto all default.ymlconfiguration files to ensure backward compatibility.2. Dynamic Image Alignment Padding (Trillium & Ironwood Optimization)
128. While optimal for older MXU tile sizes (TPU v4, v5p/v5e), next-generation hardware like Trillium (v6e) and Ironwood (v7x) utilize largerget_tpu_type()). Bothattention_flax.pyandembeddings_flax.py(NNXWanImageEmbedding) now dynamically adjust image alignment padding:v6e) and Ironwood (v7x) to perfectly match larger hardware tiles.v5pand below).Detailed File Changes
Models
attention_flax.py:get_tpu_typeandTpuTypefor dynamic hardware-aware image alignment padding.cached_kvrouting insideFlaxWanAttention.__call__.compute_kvsupport for both T2V and I2V cross-attentions.transformer_wan.py:skip_embeddingsparameter insideWanTimeTextImageEmbeddingto bypass redundant text/image projections.WanTransformerBlockandWanModelto handlecached_kv/kv_cachepassing.WanModel.compute_kv_cacheto precompute block-level cached keys/values across scan and non-scan layers.embeddings_flax.py:NNXWanImageEmbeddingto dynamically align to256forv6e/v7xand128otherwise, avoiding shape mismatches during cross-attention.Pipelines
wan_pipeline.py:transformer_forward_pass,transformer_forward_pass_full_cfg, andtransformer_forward_pass_cfg_cacheto accept and passkv_cache.wan_pipeline_2_1.py&wan_pipeline_2_2.py:use_kv_cacheparameter to pipeline calls and pre-computedkv_cacheandrotary_embprior to the denoising loop.wan_pipeline_i2v_2p1.py&wan_pipeline_i2v_2p2.py:kv_cachesupport for I2V workflows.Configs
base_wan_1_3b.yml,base_wan_14b.yml,base_wan_27b.yml,base_wan_i2v_14b.yml,base_wan_i2v_27b.yml):use_kv_cache: Falsefor all default configs.Performance Note