From 7e002481ce4c3147415ba749f42ad61a62d7f545 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 20 Apr 2026 21:15:16 +0000 Subject: [PATCH 1/9] Reorg the kernel dir Signed-off-by: Jingyu Xin --- examples/deepseek/ptq.py | 2 +- examples/deepseek/quantize_to_nvfp4.py | 2 +- modelopt/torch/kernels/__init__.py | 38 +------------- modelopt/torch/kernels/common/__init__.py | 50 +++++++++++++++++++ .../{ => common}/hf_triton_attention.py | 2 +- .../torch/kernels/{ => common}/triton_fa.py | 0 .../torch/kernels/quantization/__init__.py | 16 ++++++ .../quantization}/conv/README.md | 6 +-- .../kernels/quantization/conv/__init__.py | 16 ++++++ .../quantization}/conv/bench_implicit_gemm.py | 4 +- .../conv/implicit_gemm_binding.cpp | 0 .../quantization}/conv/implicit_gemm_cuda.py | 0 .../conv/implicit_gemm_kernel.cu | 0 .../quantization/gemm}/__init__.py | 0 .../quantization/gemm}/fp4_kernel.py | 0 .../quantization/gemm}/fp4_kernel_hopper.py | 0 .../quantization/gemm}/fp8_kernel.py | 0 .../quantization/gemm}/tensor_quant.cpp | 0 .../quantization/gemm}/tensor_quant.h | 0 .../quantization/gemm}/tensor_quant_gpu.cu | 0 .../gemm}/tensor_quant_gpu_fp8.cu | 0 .../quantization/gemm}/tensor_quant_mx.cu | 0 .../quantization/gemm}/tensor_quant_mx.h | 0 modelopt/torch/kernels/sparsity/__init__.py | 16 ++++++ .../sparsity/attention}/__init__.py | 2 +- .../attention}/diffusers_triton_attention.py | 2 +- .../attention}/ltx_triton_attention.py | 2 +- .../torch/kernels/sparsity/gemm/__init__.py | 16 ++++++ modelopt/torch/quantization/extensions.py | 7 +-- .../quantization/nn/modules/quant_conv.py | 2 +- .../torch/quantization/plugins/huggingface.py | 4 +- .../quantization/qtensor/nvfp4_tensor.py | 2 +- modelopt/torch/quantization/tensor_quant.py | 2 +- .../sparsity/attention_sparsity/conversion.py | 8 +-- .../methods/flash_skip_softmax.py | 2 +- .../methods/triton_skip_softmax.py | 28 ++++++++--- pyproject.toml | 4 +- .../kernels/test_implicit_gemm.py | 16 +++--- .../quantization/test_tensor_quant_cuda.py | 2 +- .../test_diffusers_triton_attention.py | 8 ++- .../attention_sparsity/test_triton_fa.py | 4 +- .../test_triton_fa_calibrate.py | 4 +- .../test_triton_fa_skip_softmax.py | 4 +- .../test_triton_fa_sparse_nm.py | 6 +-- .../test_wan22_skip_softmax.py | 2 +- tests/unit/torch/kernels/test_triton_fa.py | 2 +- .../test_kernel_backends.py | 18 +++---- .../test_ltx_triton_attention.py | 4 +- .../test_sparse_attention_conversion.py | 6 +-- 49 files changed, 201 insertions(+), 108 deletions(-) create mode 100644 modelopt/torch/kernels/common/__init__.py rename modelopt/torch/kernels/{ => common}/hf_triton_attention.py (99%) rename modelopt/torch/kernels/{ => common}/triton_fa.py (100%) create mode 100644 modelopt/torch/kernels/quantization/__init__.py rename modelopt/torch/{quantization/src => kernels/quantization}/conv/README.md (95%) create mode 100644 modelopt/torch/kernels/quantization/conv/__init__.py rename modelopt/torch/{quantization/src => kernels/quantization}/conv/bench_implicit_gemm.py (98%) rename modelopt/torch/{quantization/src => kernels/quantization}/conv/implicit_gemm_binding.cpp (100%) rename modelopt/torch/{quantization/src => kernels/quantization}/conv/implicit_gemm_cuda.py (100%) rename modelopt/torch/{quantization/src => kernels/quantization}/conv/implicit_gemm_kernel.cu (100%) rename modelopt/torch/{quantization/triton => kernels/quantization/gemm}/__init__.py (100%) rename modelopt/torch/{quantization/triton => kernels/quantization/gemm}/fp4_kernel.py (100%) rename modelopt/torch/{quantization/triton => kernels/quantization/gemm}/fp4_kernel_hopper.py (100%) rename modelopt/torch/{quantization/triton => kernels/quantization/gemm}/fp8_kernel.py (100%) rename modelopt/torch/{quantization/src => kernels/quantization/gemm}/tensor_quant.cpp (100%) rename modelopt/torch/{quantization/src => kernels/quantization/gemm}/tensor_quant.h (100%) rename modelopt/torch/{quantization/src => kernels/quantization/gemm}/tensor_quant_gpu.cu (100%) rename modelopt/torch/{quantization/src => kernels/quantization/gemm}/tensor_quant_gpu_fp8.cu (100%) rename modelopt/torch/{quantization/src => kernels/quantization/gemm}/tensor_quant_mx.cu (100%) rename modelopt/torch/{quantization/src => kernels/quantization/gemm}/tensor_quant_mx.h (100%) create mode 100644 modelopt/torch/kernels/sparsity/__init__.py rename modelopt/torch/{sparsity/attention_sparsity/kernels => kernels/sparsity/attention}/__init__.py (96%) rename modelopt/torch/{sparsity/attention_sparsity/kernels => kernels/sparsity/attention}/diffusers_triton_attention.py (99%) rename modelopt/torch/{sparsity/attention_sparsity/kernels => kernels/sparsity/attention}/ltx_triton_attention.py (99%) create mode 100644 modelopt/torch/kernels/sparsity/gemm/__init__.py diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index 30574b3eee..d60d011ed0 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -55,8 +55,8 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.model_config import KV_CACHE_FP8 from modelopt.torch.export.quant_utils import get_quant_config +from modelopt.torch.kernels.quantization.gemm import weight_dequant from modelopt.torch.quantization.nn import TensorQuantizer -from modelopt.torch.quantization.triton import weight_dequant from modelopt.torch.quantization.utils import ( is_quantized_column_parallel_linear, is_quantized_parallel_linear, diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index db6d5f6a24..e54fdbebf4 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -47,8 +47,8 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm +from modelopt.torch.kernels.quantization.gemm import weight_dequant from modelopt.torch.quantization.qtensor import NVFP4QTensor -from modelopt.torch.quantization.triton import weight_dequant def _remap_key(key_dict: dict[str, Any]): diff --git a/modelopt/torch/kernels/__init__.py b/modelopt/torch/kernels/__init__.py index fa07b06e20..151b6e21d9 100644 --- a/modelopt/torch/kernels/__init__.py +++ b/modelopt/torch/kernels/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,38 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shared Triton kernels for modelopt (attention, quantization, etc.).""" - -import torch - -from modelopt.torch.utils import import_plugin - -IS_AVAILABLE = False -attention = None -attention_calibrate = None -register_triton_attention = None - -if torch.cuda.is_available(): - with import_plugin( - "triton", - msg_if_missing=( - "Your device is potentially capable of using the triton attention " - "kernel. Try to install triton with `pip install triton`." - ), - ): - from .triton_fa import attention as _attention - from .triton_fa import attention_calibrate as _attention_calibrate - - attention = _attention - attention_calibrate = _attention_calibrate - IS_AVAILABLE = True - from .hf_triton_attention import register_triton_attention as _register_triton_attention - - register_triton_attention = _register_triton_attention - -__all__ = [ - "IS_AVAILABLE", - "attention", - "attention_calibrate", - "register_triton_attention", -] +"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" diff --git a/modelopt/torch/kernels/common/__init__.py b/modelopt/torch/kernels/common/__init__.py new file mode 100644 index 0000000000..fa07b06e20 --- /dev/null +++ b/modelopt/torch/kernels/common/__init__.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared Triton kernels for modelopt (attention, quantization, etc.).""" + +import torch + +from modelopt.torch.utils import import_plugin + +IS_AVAILABLE = False +attention = None +attention_calibrate = None +register_triton_attention = None + +if torch.cuda.is_available(): + with import_plugin( + "triton", + msg_if_missing=( + "Your device is potentially capable of using the triton attention " + "kernel. Try to install triton with `pip install triton`." + ), + ): + from .triton_fa import attention as _attention + from .triton_fa import attention_calibrate as _attention_calibrate + + attention = _attention + attention_calibrate = _attention_calibrate + IS_AVAILABLE = True + from .hf_triton_attention import register_triton_attention as _register_triton_attention + + register_triton_attention = _register_triton_attention + +__all__ = [ + "IS_AVAILABLE", + "attention", + "attention_calibrate", + "register_triton_attention", +] diff --git a/modelopt/torch/kernels/hf_triton_attention.py b/modelopt/torch/kernels/common/hf_triton_attention.py similarity index 99% rename from modelopt/torch/kernels/hf_triton_attention.py rename to modelopt/torch/kernels/common/hf_triton_attention.py index 5021d34e37..d73f281129 100644 --- a/modelopt/torch/kernels/hf_triton_attention.py +++ b/modelopt/torch/kernels/common/hf_triton_attention.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn -from modelopt.torch.kernels.triton_fa import attention +from modelopt.torch.kernels.common.triton_fa import attention def _seq_lens_from_mask( diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/common/triton_fa.py similarity index 100% rename from modelopt/torch/kernels/triton_fa.py rename to modelopt/torch/kernels/common/triton_fa.py diff --git a/modelopt/torch/kernels/quantization/__init__.py b/modelopt/torch/kernels/quantization/__init__.py new file mode 100644 index 0000000000..1ae6845c90 --- /dev/null +++ b/modelopt/torch/kernels/quantization/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantization kernels: conv (implicit GEMM) and gemm (tensor_quant + Triton FP4/FP8).""" diff --git a/modelopt/torch/quantization/src/conv/README.md b/modelopt/torch/kernels/quantization/conv/README.md similarity index 95% rename from modelopt/torch/quantization/src/conv/README.md rename to modelopt/torch/kernels/quantization/conv/README.md index 6b14fd5953..ae61235514 100644 --- a/modelopt/torch/quantization/src/conv/README.md +++ b/modelopt/torch/kernels/quantization/conv/README.md @@ -32,7 +32,7 @@ When NVFP4 quantization is configured on a `Conv3d` layer via ModelOpt PTQ, the ```python import torch -from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda +from modelopt.torch.kernels.quantization.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda from modelopt.torch.quantization.tensor_quant import dynamic_block_quantize_op x = torch.randn(1, 128, 21, 60, 106, device="cuda") @@ -75,7 +75,7 @@ out_q = conv3d_implicit_gemm_cuda( ### `conv3d_implicit_gemm_cuda` -`from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda` +`from modelopt.torch.kernels.quantization.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda` | Parameter | Description | |-----------|-------------| @@ -91,7 +91,7 @@ out_q = conv3d_implicit_gemm_cuda( ### `fp4_fake_quant` -`from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import fp4_fake_quant` +`from modelopt.torch.kernels.quantization.conv.implicit_gemm_cuda import fp4_fake_quant` Standalone FP4 (E2M1) blockwise fake quantization with FP8 E4M3 scale quantization. Uses the same CUDA device functions as the fused path inside the GEMM kernel. diff --git a/modelopt/torch/kernels/quantization/conv/__init__.py b/modelopt/torch/kernels/quantization/conv/__init__.py new file mode 100644 index 0000000000..ac5091fa9d --- /dev/null +++ b/modelopt/torch/kernels/quantization/conv/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implicit-GEMM CUDA kernel for quantized 3D convolution.""" diff --git a/modelopt/torch/quantization/src/conv/bench_implicit_gemm.py b/modelopt/torch/kernels/quantization/conv/bench_implicit_gemm.py similarity index 98% rename from modelopt/torch/quantization/src/conv/bench_implicit_gemm.py rename to modelopt/torch/kernels/quantization/conv/bench_implicit_gemm.py index 807ce17838..66e3f96851 100644 --- a/modelopt/torch/quantization/src/conv/bench_implicit_gemm.py +++ b/modelopt/torch/kernels/quantization/conv/bench_implicit_gemm.py @@ -94,7 +94,9 @@ def bench_fn(fn, warmup: int, iters: int) -> float: def run_benchmark(shapes_name: str, warmup: int, iters: int, fp4_block_size: int): """Run latency benchmark for the given shapes.""" - from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + from modelopt.torch.kernels.quantization.conv.implicit_gemm_cuda import ( + conv3d_implicit_gemm_cuda, + ) shapes = get_shapes(shapes_name) diff --git a/modelopt/torch/quantization/src/conv/implicit_gemm_binding.cpp b/modelopt/torch/kernels/quantization/conv/implicit_gemm_binding.cpp similarity index 100% rename from modelopt/torch/quantization/src/conv/implicit_gemm_binding.cpp rename to modelopt/torch/kernels/quantization/conv/implicit_gemm_binding.cpp diff --git a/modelopt/torch/quantization/src/conv/implicit_gemm_cuda.py b/modelopt/torch/kernels/quantization/conv/implicit_gemm_cuda.py similarity index 100% rename from modelopt/torch/quantization/src/conv/implicit_gemm_cuda.py rename to modelopt/torch/kernels/quantization/conv/implicit_gemm_cuda.py diff --git a/modelopt/torch/quantization/src/conv/implicit_gemm_kernel.cu b/modelopt/torch/kernels/quantization/conv/implicit_gemm_kernel.cu similarity index 100% rename from modelopt/torch/quantization/src/conv/implicit_gemm_kernel.cu rename to modelopt/torch/kernels/quantization/conv/implicit_gemm_kernel.cu diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py similarity index 100% rename from modelopt/torch/quantization/triton/__init__.py rename to modelopt/torch/kernels/quantization/gemm/__init__.py diff --git a/modelopt/torch/quantization/triton/fp4_kernel.py b/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py similarity index 100% rename from modelopt/torch/quantization/triton/fp4_kernel.py rename to modelopt/torch/kernels/quantization/gemm/fp4_kernel.py diff --git a/modelopt/torch/quantization/triton/fp4_kernel_hopper.py b/modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py similarity index 100% rename from modelopt/torch/quantization/triton/fp4_kernel_hopper.py rename to modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py diff --git a/modelopt/torch/quantization/triton/fp8_kernel.py b/modelopt/torch/kernels/quantization/gemm/fp8_kernel.py similarity index 100% rename from modelopt/torch/quantization/triton/fp8_kernel.py rename to modelopt/torch/kernels/quantization/gemm/fp8_kernel.py diff --git a/modelopt/torch/quantization/src/tensor_quant.cpp b/modelopt/torch/kernels/quantization/gemm/tensor_quant.cpp similarity index 100% rename from modelopt/torch/quantization/src/tensor_quant.cpp rename to modelopt/torch/kernels/quantization/gemm/tensor_quant.cpp diff --git a/modelopt/torch/quantization/src/tensor_quant.h b/modelopt/torch/kernels/quantization/gemm/tensor_quant.h similarity index 100% rename from modelopt/torch/quantization/src/tensor_quant.h rename to modelopt/torch/kernels/quantization/gemm/tensor_quant.h diff --git a/modelopt/torch/quantization/src/tensor_quant_gpu.cu b/modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu.cu similarity index 100% rename from modelopt/torch/quantization/src/tensor_quant_gpu.cu rename to modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu.cu diff --git a/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu b/modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu_fp8.cu similarity index 100% rename from modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu rename to modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu_fp8.cu diff --git a/modelopt/torch/quantization/src/tensor_quant_mx.cu b/modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.cu similarity index 100% rename from modelopt/torch/quantization/src/tensor_quant_mx.cu rename to modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.cu diff --git a/modelopt/torch/quantization/src/tensor_quant_mx.h b/modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.h similarity index 100% rename from modelopt/torch/quantization/src/tensor_quant_mx.h rename to modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.h diff --git a/modelopt/torch/kernels/sparsity/__init__.py b/modelopt/torch/kernels/sparsity/__init__.py new file mode 100644 index 0000000000..ca2bfdb128 --- /dev/null +++ b/modelopt/torch/kernels/sparsity/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparsity kernels: attention (Triton skip-softmax backends) and gemm (placeholder).""" diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py b/modelopt/torch/kernels/sparsity/attention/__init__.py similarity index 96% rename from modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py rename to modelopt/torch/kernels/sparsity/attention/__init__.py index 0cc4a202f5..561d8470e3 100644 --- a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py +++ b/modelopt/torch/kernels/sparsity/attention/__init__.py @@ -18,7 +18,7 @@ import contextlib import threading -from modelopt.torch.kernels import IS_AVAILABLE, attention, register_triton_attention +from modelopt.torch.kernels.common import IS_AVAILABLE, attention, register_triton_attention # --------------------------------------------------------------------------- # Optional backend registrations (depend on diffusers / ltx_core) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py similarity index 99% rename from modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py rename to modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py index 2923447cf0..bfc2f73b2a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py @@ -36,7 +36,7 @@ attention_backend, ) -from modelopt.torch.kernels import attention, attention_calibrate +from modelopt.torch.kernels.common import attention, attention_calibrate _BACKEND_NAME = "modelopt_triton" _BACKEND_REGISTERED = False diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py similarity index 99% rename from modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py rename to modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py index fd53e7f9f4..6eba2004d8 100644 --- a/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py @@ -25,7 +25,7 @@ import torch -from modelopt.torch.kernels import attention, attention_calibrate +from modelopt.torch.kernels.common import attention, attention_calibrate from modelopt.torch.utils.logging import warn_rank_0 # Thread-local storage for skip-softmax configuration diff --git a/modelopt/torch/kernels/sparsity/gemm/__init__.py b/modelopt/torch/kernels/sparsity/gemm/__init__.py new file mode 100644 index 0000000000..5a366019db --- /dev/null +++ b/modelopt/torch/kernels/sparsity/gemm/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparsity GEMM kernels (placeholder for future implementations).""" diff --git a/modelopt/torch/quantization/extensions.py b/modelopt/torch/quantization/extensions.py index 003703567a..a65396d64f 100644 --- a/modelopt/torch/quantization/extensions.py +++ b/modelopt/torch/quantization/extensions.py @@ -22,6 +22,7 @@ __all__ = ["get_cuda_ext", "get_cuda_ext_fp8", "get_cuda_ext_mx", "precompile"] path = Path(__file__).parent +kernels_gemm = path.parent / "kernels" / "quantization" / "gemm" def get_cuda_ext(raise_if_failed: bool = False): @@ -29,7 +30,7 @@ def get_cuda_ext(raise_if_failed: bool = False): if not hasattr(get_cuda_ext, "extension"): get_cuda_ext.extension = load_cpp_extension( # type:ignore[attr-defined] name="modelopt_cuda_ext", - sources=[path / "src/tensor_quant.cpp", path / "src/tensor_quant_gpu.cu"], + sources=[kernels_gemm / "tensor_quant.cpp", kernels_gemm / "tensor_quant_gpu.cu"], cuda_version_specifiers=">=11", raise_if_failed=raise_if_failed, ) @@ -41,7 +42,7 @@ def get_cuda_ext_fp8(raise_if_failed: bool = False): if not hasattr(get_cuda_ext_fp8, "extension"): get_cuda_ext_fp8.extension = load_cpp_extension( # type:ignore[attr-defined] name="modelopt_cuda_ext_fp8", - sources=[path / "src/tensor_quant_gpu_fp8.cu"], + sources=[kernels_gemm / "tensor_quant_gpu_fp8.cu"], cuda_version_specifiers=">=11.8", fail_msg=( "CUDA extension for FP8 quantization could not be built and loaded, FP8 simulated" @@ -58,7 +59,7 @@ def get_cuda_ext_mx(raise_if_failed: bool = False): get_cuda_ext_mx.extension = load_cpp_extension( # type:ignore[attr-defined] name="modelopt_cuda_ext_mx", sources=[ - path / "src/tensor_quant_mx.cu", + kernels_gemm / "tensor_quant_mx.cu", ], cuda_version_specifiers=">=11.8", fail_msg=( diff --git a/modelopt/torch/quantization/nn/modules/quant_conv.py b/modelopt/torch/quantization/nn/modules/quant_conv.py index ed16555624..375fdc96e0 100644 --- a/modelopt/torch/quantization/nn/modules/quant_conv.py +++ b/modelopt/torch/quantization/nn/modules/quant_conv.py @@ -19,7 +19,7 @@ import torch.nn as nn -from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda +from modelopt.torch.kernels.quantization.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda from ... import tensor_quant from .quant_module import QuantLinearConvBase, QuantModuleRegistry, _LegacyQuantLinearConvBaseMixin diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 59bcd215bb..92eaf12ece 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -30,6 +30,7 @@ from torch.nn.functional import linear from transformers.models.t5.modeling_t5 import T5Attention +from modelopt.torch.kernels.quantization.gemm import IS_AVAILABLE as IS_TRITON_AVAILABLE from modelopt.torch.opt.dynamic import DynamicModule from modelopt.torch.utils.distributed import ParallelState @@ -37,7 +38,6 @@ from ..conversion import register from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import _QuantLinear -from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE from ..utils import replace_function, sync_moe_expert_amax from ..utils.layerwise_calib import LayerActivationCollector from .attention import register_attention_for_kv_quant @@ -58,7 +58,7 @@ kitchen = None if IS_TRITON_AVAILABLE: - from ..triton import weight_dequant + from modelopt.torch.kernels.quantization.gemm import weight_dequant else: weight_dequant = None diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index 6ff31424c7..fe30e283c2 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -346,7 +346,7 @@ def _unpack_tensor(input: torch.Tensor): ) from e if fast: - from ..triton.fp4_kernel import fp4_dequantize + from modelopt.torch.kernels.quantization.gemm.fp4_kernel import fp4_dequantize return fp4_dequantize( self._quantized_data, diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 16b9d32997..15d782c4a7 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -21,7 +21,7 @@ from torch.autograd import Function from torch.onnx import symbolic_helper -import modelopt.torch.quantization.triton as triton_kernel +import modelopt.torch.kernels.quantization.gemm as triton_kernel from .config import QuantizerAttributeConfig from .extensions import get_cuda_ext, get_cuda_ext_fp8, get_cuda_ext_mx diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index cc92819850..f0c33520c4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -79,7 +79,7 @@ def _set_attn_implementation(model: nn.Module, config: SparseAttentionConfig) -> ) if "triton" in backends: - from .kernels import register_triton_attention + from modelopt.torch.kernels.sparsity.attention import register_triton_attention if register_triton_attention is None: raise ImportError( @@ -128,7 +128,9 @@ def _register_diffusers_backends_if_needed(model: nn.Module) -> None: from diffusers.models.modeling_utils import ModelMixin if isinstance(model, ModelMixin): - from .kernels import register_diffusers_triton_attention + from modelopt.torch.kernels.sparsity.attention import ( + register_diffusers_triton_attention, + ) if register_diffusers_triton_attention is not None: register_diffusers_triton_attention() @@ -137,7 +139,7 @@ def _register_diffusers_backends_if_needed(model: nn.Module) -> None: # Patch ltx_core Attention modules if present (independent of diffusers) try: - from .kernels import register_ltx_triton_attention + from modelopt.torch.kernels.sparsity.attention import register_ltx_triton_attention except (ImportError, RuntimeError): return diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 117e337809..c1d6465ba6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -384,7 +384,7 @@ def sparse_softmax(input, dim=-1, *args, **kwargs): input = self.apply_sparsity(input, sparse_mask) return original_softmax(input, dim, *args, **kwargs) - from ..kernels import set_skip_softmax_context + from modelopt.torch.kernels.sparsity.attention import set_skip_softmax_context stack = ExitStack() set_skip_softmax_context(True) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index 1e2f3905e7..ff74d13fae 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -168,7 +168,9 @@ def _get_scale_factor(self) -> float | None: def _get_diffusers_backend_context(): """Activate the modelopt_triton diffusers backend if registered.""" try: - from ..kernels.diffusers_triton_attention import get_triton_attention_backend + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( + get_triton_attention_backend, + ) with get_triton_attention_backend(): yield @@ -178,13 +180,17 @@ def _get_diffusers_backend_context(): def _set_triton_backends(self, **kwargs): """Set config on both diffusers and LTX Triton backends.""" try: - from ..kernels.diffusers_triton_attention import set_triton_skip_softmax_config + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( + set_triton_skip_softmax_config, + ) set_triton_skip_softmax_config(**kwargs) except ImportError: pass try: - from ..kernels.ltx_triton_attention import set_ltx_triton_context + from modelopt.torch.kernels.sparsity.attention.ltx_triton_attention import ( + set_ltx_triton_context, + ) set_ltx_triton_context(active=True, **kwargs) except ImportError: @@ -193,13 +199,17 @@ def _set_triton_backends(self, **kwargs): def _clear_triton_backends(self): """Clear config on both Triton backends.""" try: - from ..kernels.diffusers_triton_attention import clear_triton_skip_softmax_config + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( + clear_triton_skip_softmax_config, + ) clear_triton_skip_softmax_config() except ImportError: pass try: - from ..kernels.ltx_triton_attention import clear_ltx_triton_context + from modelopt.torch.kernels.sparsity.attention.ltx_triton_attention import ( + clear_ltx_triton_context, + ) clear_ltx_triton_context() except ImportError: @@ -211,7 +221,7 @@ def _collect_calibration_stats(self, module): seq_k = None try: - from ..kernels.diffusers_triton_attention import ( + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( get_calibration_counters, get_calibration_seq_k, ) @@ -223,7 +233,7 @@ def _collect_calibration_stats(self, module): if counters is None: try: - from ..kernels.ltx_triton_attention import ( + from modelopt.torch.kernels.sparsity.attention.ltx_triton_attention import ( get_calibration_counters, get_calibration_seq_k, ) @@ -288,7 +298,9 @@ def get_sparsity_counters(self) -> tuple[int, int]: def _collect_sparsity_counters(self) -> None: """Read runtime sparsity counters from the backend and accumulate.""" try: - from ..kernels.diffusers_triton_attention import get_sparsity_counters + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( + get_sparsity_counters, + ) total, skipped = get_sparsity_counters() self._sparsity_total += total diff --git a/pyproject.toml b/pyproject.toml index fdd60b5193..bace52dff9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -227,8 +227,8 @@ extend-ignore = [ "SIM", "UP", ] # TODO: Disabled for now, will enable later, once all puzzletron code is migrated -"modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style -"modelopt/torch/sparsity/attention_sparsity/kernels/*" = [ +"modelopt/torch/kernels/quantization/gemm/*" = ["N803", "N806", "E731"] # triton style +"modelopt/torch/kernels/sparsity/attention/*" = [ "N803", "N806", ] # triton kernel style diff --git a/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py b/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py index 56ceaacc01..96cc24c2b9 100644 --- a/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py +++ b/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py @@ -28,7 +28,9 @@ @pytest.fixture(scope="module") def cuda_conv3d(): """Import and return the CUDA implicit GEMM conv3d function.""" - from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import conv3d_implicit_gemm_cuda + from modelopt.torch.kernels.quantization.conv.implicit_gemm_cuda import ( + conv3d_implicit_gemm_cuda, + ) return conv3d_implicit_gemm_cuda @@ -36,7 +38,7 @@ def cuda_conv3d(): def _triton_fp4_available(): """Check if the Triton FP4 fake quant kernel is available (requires compute >= 8.9).""" try: - import modelopt.torch.quantization.triton as triton_kernel + import modelopt.torch.kernels.quantization.gemm as triton_kernel return hasattr(triton_kernel, "fp4_fake_quant_block") except ImportError: @@ -305,7 +307,7 @@ def test_deterministic(self, cuda_conv3d): @pytest.fixture(scope="module") def cuda_fp4(): """Import and return the CUDA FP4 fake quant function.""" - from modelopt.torch.quantization.src.conv.implicit_gemm_cuda import fp4_fake_quant + from modelopt.torch.kernels.quantization.conv.implicit_gemm_cuda import fp4_fake_quant return fp4_fake_quant @@ -780,7 +782,7 @@ class TestFP4FakeQuantVsTriton: @pytest.mark.parametrize("num_blocks", [4, 16, 64]) def test_vs_triton(self, cuda_fp4, block_size, num_blocks): """CUDA kernel should match the Triton fp4_fake_quant_block.""" - from modelopt.torch.quantization.triton import fp4_fake_quant_block + from modelopt.torch.kernels.quantization.gemm import fp4_fake_quant_block torch.manual_seed(42) x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 10 @@ -865,7 +867,7 @@ class TestFP4FakeQuantVsModelopt: @pytest.mark.parametrize("seed", [42, 123, 999]) def test_vs_triton_fp4_fake_quant_block(self, cuda_fp4, block_size, seed): """Compare against modelopt Triton fp4_fake_quant_block.""" - from modelopt.torch.quantization.triton import fp4_fake_quant_block + from modelopt.torch.kernels.quantization.gemm import fp4_fake_quant_block torch.manual_seed(seed) num_blocks = 16 @@ -957,7 +959,7 @@ def test_vs_triton_realistic_shape(self, cuda_fp4): x = torch.randn(num_blocks, block_size, device="cuda", dtype=torch.float32) * 5 global_amax = x.abs().max() - from modelopt.torch.quantization.triton import fp4_fake_quant_block + from modelopt.torch.kernels.quantization.gemm import fp4_fake_quant_block ours = cuda_fp4(x, global_amax.unsqueeze(0), block_size) theirs = fp4_fake_quant_block( @@ -983,7 +985,7 @@ def test_vs_triton_input_dtypes(self, cuda_fp4, dtype): Our kernel casts to float32 internally, so the result should match Triton's output when both receive the same dtype input. """ - from modelopt.torch.quantization.triton import fp4_fake_quant_block + from modelopt.torch.kernels.quantization.gemm import fp4_fake_quant_block torch.manual_seed(42) block_size = 16 diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index 1a28d229f4..d2503669ac 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -20,7 +20,7 @@ from _test_utils.torch.quantization.quant_utils import quant from _test_utils.torch.quantization.tensor_quant_common import FakeTensorQuantTester -import modelopt.torch.quantization.triton as triton_kernel +import modelopt.torch.kernels.quantization.gemm as triton_kernel import modelopt.torch.quantization.utils as quant_utils from modelopt.torch.quantization import tensor_quant from modelopt.torch.quantization.extensions import get_cuda_ext, get_cuda_ext_mx diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py b/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py index f479b8883f..dd8a265bd4 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py @@ -26,11 +26,9 @@ diffusers = pytest.importorskip("diffusers") -from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE -from modelopt.torch.sparsity.attention_sparsity.kernels import ( - diffusers_triton_attention as diffusers_mod, -) -from modelopt.torch.sparsity.attention_sparsity.kernels import ltx_triton_attention as ltx_mod +from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.sparsity.attention import diffusers_triton_attention as diffusers_mod +from modelopt.torch.kernels.sparsity.attention import ltx_triton_attention as ltx_mod @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py index a5174496cf..2d0798ade7 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py @@ -26,10 +26,10 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - from modelopt.torch.kernels import attention, register_triton_attention + from modelopt.torch.kernels.common import attention, register_triton_attention if register_triton_attention is not None: register_triton_attention() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py index 37c4da9969..babf7283bd 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py @@ -29,10 +29,10 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - from modelopt.torch.kernels import attention, attention_calibrate + from modelopt.torch.kernels.common import attention, attention_calibrate @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py index 21b2a12ca7..7f0d1edf72 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py @@ -25,10 +25,10 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - from modelopt.torch.kernels import attention, register_triton_attention + from modelopt.torch.kernels.common import attention, register_triton_attention if register_triton_attention is not None: register_triton_attention() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py index 4eec5799a5..dab67f28f8 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py @@ -27,14 +27,14 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: import triton import triton.language as tl - from modelopt.torch.kernels import attention - from modelopt.torch.kernels.triton_fa import _apply_sparse_nm_to_qk_tile + from modelopt.torch.kernels.common import attention + from modelopt.torch.kernels.common.triton_fa import _apply_sparse_nm_to_qk_tile @triton.jit def _test_apply_sparse_nm( diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py index 72b20df932..cbc38210fb 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py @@ -33,7 +33,7 @@ diffusers = pytest.importorskip("diffusers") -from modelopt.torch.kernels import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: import modelopt.torch.sparsity.attention_sparsity as mtsa diff --git a/tests/unit/torch/kernels/test_triton_fa.py b/tests/unit/torch/kernels/test_triton_fa.py index ac054e10e6..1770b103e7 100644 --- a/tests/unit/torch/kernels/test_triton_fa.py +++ b/tests/unit/torch/kernels/test_triton_fa.py @@ -33,7 +33,7 @@ def test_triton_fa_importable_on_cpu(): except ImportError: pytest.skip("triton is not installed") - from modelopt.torch.kernels import triton_fa + from modelopt.torch.kernels.common import triton_fa assert "attention" in triton_fa.__all__ assert "attention_calibrate" in triton_fa.__all__ diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py index 775723e66c..f997c98be6 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py @@ -34,12 +34,12 @@ class TestSkipSoftmaxContext: def test_default_is_false(self): - from modelopt.torch.sparsity.attention_sparsity.kernels import get_skip_softmax_context + from modelopt.torch.kernels.sparsity.attention import get_skip_softmax_context assert get_skip_softmax_context() is False def test_set_and_get(self): - from modelopt.torch.sparsity.attention_sparsity.kernels import ( + from modelopt.torch.kernels.sparsity.attention import ( get_skip_softmax_context, set_skip_softmax_context, ) @@ -60,9 +60,7 @@ class TestDiffusersTritonBackend: @pytest.fixture(autouse=True) def _reset(self): - from modelopt.torch.sparsity.attention_sparsity.kernels import ( - diffusers_triton_attention as mod, - ) + from modelopt.torch.kernels.sparsity.attention import diffusers_triton_attention as mod mod._BACKEND_REGISTERED = False mod.clear_triton_skip_softmax_config() @@ -70,7 +68,7 @@ def _reset(self): mod.clear_triton_skip_softmax_config() def test_set_clear_config(self): - from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( clear_triton_skip_softmax_config, set_triton_skip_softmax_config, ) @@ -79,7 +77,7 @@ def test_set_clear_config(self): clear_triton_skip_softmax_config() def test_register_idempotent(self): - from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( register_diffusers_triton_attention, ) @@ -87,7 +85,7 @@ def test_register_idempotent(self): register_diffusers_triton_attention() # Should be a no-op def test_get_backend_before_register_raises(self): - from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import ( get_triton_attention_backend, ) @@ -113,12 +111,10 @@ def test_with_diffusers_model(self): """A ModelMixin subclass triggers diffusers backend registration.""" from diffusers.models.modeling_utils import ModelMixin + from modelopt.torch.kernels.sparsity.attention import diffusers_triton_attention as mod from modelopt.torch.sparsity.attention_sparsity.conversion import ( _register_diffusers_backends_if_needed, ) - from modelopt.torch.sparsity.attention_sparsity.kernels import ( - diffusers_triton_attention as mod, - ) mod._BACKEND_REGISTERED = False diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_ltx_triton_attention.py b/tests/unit/torch/sparsity/attention_sparsity/test_ltx_triton_attention.py index 4751fbae35..6ed2b3f20b 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_ltx_triton_attention.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_ltx_triton_attention.py @@ -32,7 +32,7 @@ @pytest.fixture def ltx_mod(): """Import ltx_triton_attention and ensure thread-local state is reset.""" - from modelopt.torch.sparsity.attention_sparsity.kernels import ltx_triton_attention as mod + from modelopt.torch.kernels.sparsity.attention import ltx_triton_attention as mod mod.clear_ltx_triton_context() try: @@ -119,7 +119,7 @@ def __init__(self): parent = Parent() ltx_mod.register_ltx_triton_attention(parent) - from modelopt.torch.sparsity.attention_sparsity.kernels.ltx_triton_attention import ( + from modelopt.torch.kernels.sparsity.attention.ltx_triton_attention import ( _TritonLTXAttentionWrapper, ) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py index 93389a4610..8e68c28e19 100644 --- a/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py +++ b/tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py @@ -270,7 +270,7 @@ def test_triton_backend_sets_attn_impl(self): "*": {"method": "triton_skip_softmax", "backend": "triton"}, } with patch( - "modelopt.torch.sparsity.attention_sparsity.kernels.register_triton_attention", + "modelopt.torch.kernels.sparsity.attention.register_triton_attention", MagicMock(return_value=True), ): _set_attn_implementation(model, config) @@ -287,7 +287,7 @@ def test_triton_backend_register_failure_raises(self): config.sparse_cfg = {"*": {"method": "triton_skip_softmax", "backend": "triton"}} with ( patch( - "modelopt.torch.sparsity.attention_sparsity.kernels.register_triton_attention", + "modelopt.torch.kernels.sparsity.attention.register_triton_attention", MagicMock(return_value=False), ), pytest.raises(RuntimeError, match="Failed to register"), @@ -305,7 +305,7 @@ def test_triton_backend_no_triton_raises(self): config.sparse_cfg = {"*": {"method": "triton_skip_softmax", "backend": "triton"}} with ( patch( - "modelopt.torch.sparsity.attention_sparsity.kernels.register_triton_attention", + "modelopt.torch.kernels.sparsity.attention.register_triton_attention", None, ), pytest.raises(ImportError, match="Triton backend requires"), From bab50225a8d344aeadab265702dac7c262534f3d Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Tue, 21 Apr 2026 18:24:55 +0000 Subject: [PATCH 2/9] Reorg the kernel Signed-off-by: Jingyu Xin --- modelopt/torch/kernels/common/__init__.py | 38 +- .../kernels/common/attention/__init__.py | 57 ++ .../{ => attention}/hf_triton_attention.py | 2 +- .../common/{ => attention}/triton_fa.py | 501 ++---------------- .../quantization/attention/__init__.py | 16 + .../kernels/sparsity/attention/__init__.py | 6 +- .../kernels/sparsity/attention/calibrate.py | 293 ++++++++++ .../attention/diffusers_triton_attention.py | 9 +- .../attention/ltx_triton_attention.py | 9 +- .../attention/skip_softmax_helpers.py | 208 ++++++++ .../test_diffusers_triton_attention.py | 2 +- .../attention_sparsity/test_triton_fa.py | 4 +- .../test_triton_fa_calibrate.py | 4 +- .../test_triton_fa_skip_softmax.py | 4 +- .../test_triton_fa_sparse_nm.py | 8 +- .../test_wan22_skip_softmax.py | 2 +- tests/unit/torch/kernels/test_triton_fa.py | 5 +- 17 files changed, 665 insertions(+), 503 deletions(-) create mode 100644 modelopt/torch/kernels/common/attention/__init__.py rename modelopt/torch/kernels/common/{ => attention}/hf_triton_attention.py (98%) rename modelopt/torch/kernels/common/{ => attention}/triton_fa.py (66%) create mode 100644 modelopt/torch/kernels/quantization/attention/__init__.py create mode 100644 modelopt/torch/kernels/sparsity/attention/calibrate.py create mode 100644 modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py diff --git a/modelopt/torch/kernels/common/__init__.py b/modelopt/torch/kernels/common/__init__.py index fa07b06e20..f5c9e562d3 100644 --- a/modelopt/torch/kernels/common/__init__.py +++ b/modelopt/torch/kernels/common/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,38 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shared Triton kernels for modelopt (attention, quantization, etc.).""" - -import torch - -from modelopt.torch.utils import import_plugin - -IS_AVAILABLE = False -attention = None -attention_calibrate = None -register_triton_attention = None - -if torch.cuda.is_available(): - with import_plugin( - "triton", - msg_if_missing=( - "Your device is potentially capable of using the triton attention " - "kernel. Try to install triton with `pip install triton`." - ), - ): - from .triton_fa import attention as _attention - from .triton_fa import attention_calibrate as _attention_calibrate - - attention = _attention - attention_calibrate = _attention_calibrate - IS_AVAILABLE = True - from .hf_triton_attention import register_triton_attention as _register_triton_attention - - register_triton_attention = _register_triton_attention - -__all__ = [ - "IS_AVAILABLE", - "attention", - "attention_calibrate", - "register_triton_attention", -] +"""Common (non-domain-specific) kernels. Base FA lives in ``common/attention``.""" diff --git a/modelopt/torch/kernels/common/attention/__init__.py b/modelopt/torch/kernels/common/attention/__init__.py new file mode 100644 index 0000000000..caf319a765 --- /dev/null +++ b/modelopt/torch/kernels/common/attention/__init__.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared Triton kernels for modelopt (attention, quantization, etc.).""" + +import torch + +from modelopt.torch.utils import import_plugin + +IS_AVAILABLE = False +attention = None +attention_calibrate = None +register_triton_attention = None + +if torch.cuda.is_available(): + with import_plugin( + "triton", + msg_if_missing=( + "Your device is potentially capable of using the triton attention " + "kernel. Try to install triton with `pip install triton`." + ), + ): + from .triton_fa import attention as _attention + + attention = _attention + IS_AVAILABLE = True + from .hf_triton_attention import register_triton_attention as _register_triton_attention + + register_triton_attention = _register_triton_attention + + # Calibration lives in the sparsity subpackage (skip-softmax specific). + # Imported here so ``from modelopt.torch.kernels.common.attention import + # attention_calibrate`` keeps working. + from modelopt.torch.kernels.sparsity.attention.calibrate import ( + attention_calibrate as _attention_calibrate, + ) + + attention_calibrate = _attention_calibrate + +__all__ = [ + "IS_AVAILABLE", + "attention", + "attention_calibrate", + "register_triton_attention", +] diff --git a/modelopt/torch/kernels/common/hf_triton_attention.py b/modelopt/torch/kernels/common/attention/hf_triton_attention.py similarity index 98% rename from modelopt/torch/kernels/common/hf_triton_attention.py rename to modelopt/torch/kernels/common/attention/hf_triton_attention.py index d73f281129..235487462f 100644 --- a/modelopt/torch/kernels/common/hf_triton_attention.py +++ b/modelopt/torch/kernels/common/attention/hf_triton_attention.py @@ -25,7 +25,7 @@ import torch import torch.nn as nn -from modelopt.torch.kernels.common.triton_fa import attention +from modelopt.torch.kernels.common.attention.triton_fa import attention def _seq_lens_from_mask( diff --git a/modelopt/torch/kernels/common/triton_fa.py b/modelopt/torch/kernels/common/attention/triton_fa.py similarity index 66% rename from modelopt/torch/kernels/common/triton_fa.py rename to modelopt/torch/kernels/common/attention/triton_fa.py index 8044383889..a4b3cc90e3 100644 --- a/modelopt/torch/kernels/common/triton_fa.py +++ b/modelopt/torch/kernels/common/attention/triton_fa.py @@ -24,11 +24,44 @@ """ import math +from typing import Any import torch import triton import triton.language as tl +# Helpers for optional N:M sparsity and sink/window-aware dense regions live +# in the sparsity package. The baseline forward kernel below calls them +# conditionally under constexpr guards, so the unified single-kernel design +# stays intact while keeping feature-specific logic in its own subpackage. +# +# Lazy import: Triton resolves @triton.jit names at kernel compile time (first +# call), not at definition time, so populating the module globals before the +# first ``attention()`` call is sufficient. Deferring avoids a circular import +# (common.attention/__init__.py ↔ sparsity.attention/__init__.py via this file). +_apply_sparse_nm_to_qk_tile: Any = None +_is_dense_region: Any = None +_skip_softmax_decision: Any = None + + +def _load_sparsity_helpers() -> None: + global _apply_sparse_nm_to_qk_tile, _is_dense_region, _skip_softmax_decision + if _apply_sparse_nm_to_qk_tile is None: + from modelopt.torch.kernels.sparsity.attention.skip_softmax_helpers import ( + _apply_sparse_nm_to_qk_tile as _nm, + ) + from modelopt.torch.kernels.sparsity.attention.skip_softmax_helpers import ( + _is_dense_region as _dense, + ) + from modelopt.torch.kernels.sparsity.attention.skip_softmax_helpers import ( + _skip_softmax_decision as _skip, + ) + + _apply_sparse_nm_to_qk_tile = _nm + _is_dense_region = _dense + _skip_softmax_decision = _skip + + LOG2E: float = 1.44269504088896 # --------------------------------------------------------------------------- @@ -47,145 +80,6 @@ _FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)] -# --------------------------------------------------------------------------- -# N:M sparse softmax helpers -# --------------------------------------------------------------------------- -@triton.jit -def _sparse_nm_masks_m4(x0, x1, x2, x3, N: tl.constexpr): - """Top-N of 4 selection via pure boolean logic (6 comparisons, no int casts). - - Uses ``>=`` so that ties are broken by index (lower index wins). - Guarantees exactly N masks are True for any input including all-equal. - - Boolean formulas for "at least K of 3 wins": - K=3 (N=1): AND of all — must beat all 3 others - K=2 (N=2): majority — must beat at least 2 (sorting network) - K=1 (N=3): OR of all — must beat at least 1 - """ - c01 = x0 >= x1 - c02 = x0 >= x2 - c03 = x0 >= x3 - c12 = x1 >= x2 - c13 = x1 >= x3 - c23 = x2 >= x3 - - nc01 = ~c01 - nc02 = ~c02 - nc03 = ~c03 - nc12 = ~c12 - nc13 = ~c13 - nc23 = ~c23 - - if N == 1: - # Keep max only: must beat all 3 - m0 = c01 & c02 & c03 - m1 = nc01 & c12 & c13 - m2 = nc02 & nc12 & c23 - m3 = nc03 & nc13 & nc23 - elif N == 2: - # Majority vote: must beat at least 2 of 3 - m0 = (c01 & c02) | (c01 & c03) | (c02 & c03) - m1 = (nc01 & c12) | (nc01 & c13) | (c12 & c13) - m2 = (nc02 & nc12) | (nc02 & c23) | (nc12 & c23) - m3 = (nc03 & nc13) | (nc03 & nc23) | (nc13 & nc23) - elif N == 3: - # Keep all but min: must beat at least 1 - m0 = c01 | c02 | c03 - m1 = nc01 | c12 | c13 - m2 = nc02 | nc12 | c23 - m3 = nc03 | nc13 | nc23 - else: - tl.static_assert(False, "N must be 1, 2, or 3 for M=4") - - return m0, m1, m2, m3 - - -@triton.jit -def _apply_sparse_nm_to_qk_tile( - qk, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - SPARSITY_N: tl.constexpr, - SPARSITY_M: tl.constexpr, -): - """Apply N:M sparse softmax to a QK score tile. - - For every ``SPARSITY_M`` consecutive elements along the N (key) dimension, - keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``. - ``BLOCK_N`` must be divisible by ``SPARSITY_M``. - - For M=4, exactly N values are retained (ties broken by position). - For M=8, a threshold-based approach (``tl.sort``) may retain more - than N values when ties straddle the threshold boundary. - """ - tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714 - MASK_VAL: tl.constexpr = float("-inf") - - if SPARSITY_M == 4: - tl.static_assert(BLOCK_N % 4 == 0, "BLOCK_N must be divisible by 4") - reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 4, 4)) - cols = tl.arange(0, 4)[None, None, :] - x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2) - x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2) - x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2) - x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2) - - m0, m1, m2, m3 = _sparse_nm_masks_m4(x0, x1, x2, x3, SPARSITY_N) - - out = tl.full((BLOCK_M, BLOCK_N // 4, 4), 0.0, dtype=qk.dtype) - out = tl.where(cols == 0, tl.expand_dims(tl.where(m0, x0, MASK_VAL), 2), out) - out = tl.where(cols == 1, tl.expand_dims(tl.where(m1, x1, MASK_VAL), 2), out) - out = tl.where(cols == 2, tl.expand_dims(tl.where(m2, x2, MASK_VAL), 2), out) - out = tl.where(cols == 3, tl.expand_dims(tl.where(m3, x3, MASK_VAL), 2), out) - return tl.reshape(out, (BLOCK_M, BLOCK_N)) - - else: # SPARSITY_M == 8 - tl.static_assert(BLOCK_N % 8 == 0, "BLOCK_N must be divisible by 8") - reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 8, 8)) - - # Sort each group of 8 ascending; N-th largest is at index (8 - N) - sorted_vals = tl.sort(reshaped, dim=2) - KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order - - # Extract the threshold value at KTH_IDX via masked sum - # Use 0.0 as fill (not -inf) so sum equals just the KTH element - cols = tl.arange(0, 8)[None, None, :] - threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2) - - # Mask: keep elements >= threshold (may keep >N on ties — acceptable) - mask = reshaped >= tl.expand_dims(threshold, 2) - return tl.reshape(tl.where(mask, reshaped, MASK_VAL), (BLOCK_M, BLOCK_N)) - - -# --------------------------------------------------------------------------- -# Sink/window dense-region check -# --------------------------------------------------------------------------- -@triton.jit -def _is_dense_region( - kv_start, - tile_q, - seq_len_q, - seq_len_kv, - BLOCK_M: tl.constexpr, - NUM_SINK_TOKENS: tl.constexpr, - DENSE_WINDOW_SIZE: tl.constexpr, -): - """Check if a KV tile falls in a dense region (sink tokens or local window). - - Uses absolute token positions so the result is BLOCK_N-independent, - ensuring forward and backward (which may use different BLOCK_N) agree. - - Returns: - True if the tile should be kept dense (skip N:M sparsification). - """ - is_sink = kv_start < NUM_SINK_TOKENS - causal_offset = seq_len_kv - seq_len_q - q_abs_pos = tile_q * BLOCK_M + causal_offset - token_distance = q_abs_pos - kv_start - is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) - return is_sink or is_local - - # --------------------------------------------------------------------------- # Masking helper # --------------------------------------------------------------------------- @@ -327,55 +221,22 @@ def _attn_fwd( scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M ) + # Optional skip-softmax decision — the decision logic (and optional + # atomic counter updates) lives in sparsity/attention; this kernel + # just consults it under its constexpr guard. + skip_tile = False if APPLY_SKIP_SOFTMAX: - # --- Skip-softmax (BLASST, https://arxiv.org/pdf/2512.12087) --- - # - # Algorithm: During FlashAttention's block-wise computation, we - # maintain a running maximum m_i^(j) across blocks. If a block's - # local maximum ~m_i^(j) is significantly smaller than the running - # maximum m_i^(j): - # - # ~m_i^(j) - m_i^(j) < ln(lambda) - # - # then exp(~m_i^(j) - m_i^(j)) < lambda ≈ 0, meaning the block's - # contribution to the final output is negligible. We skip the - # softmax computation, V load, and BMM2 computation entirely. - # - # The threshold is pre-scaled by qk_scale in the Python wrapper so - # it can be compared directly against scaled scores (matching the - # BLASST reference semantics on unscaled scores). - tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) - # Per-row: True if row's tile max is negligible vs running max - can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) - # Per-tile: skip entire tile only if ALL rows are negligible - skip_tile = tl.min(can_skip.to(tl.int32)) == 1 - - # Optional runtime sparsity measurement via atomic counters - if MEASURE_SPARSITY: - tl.atomic_add(Sparsity_total, 1) # count every tile - if skip_tile: - tl.atomic_add(Sparsity_skipped, 1) # count skipped tiles - - if not skip_tile: - m_new = tl.maximum(row_max, tile_row_max) - p = tl.math.exp2(scores - m_new[:, None]) - l_new = tl.sum(p, 1) - correction = tl.math.exp2(row_max - m_new) - row_sum = row_sum * correction + l_new - acc = acc * correction[:, None] - - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) - acc = tl.dot(p.to(v.dtype), v, acc) - row_max = m_new - # else: tile skipped: no softmax computation, V load, and BMM2 computation - else: - # --- Standard path: no skip check --- - # Online softmax update + skip_tile = _skip_softmax_decision( + scores, + row_max, + SKIP_THRESHOLD_LOG2, + Sparsity_total, + Sparsity_skipped, + MEASURE_SPARSITY, + ) + + if not skip_tile: + # --- Online softmax update --- m_new = tl.maximum(row_max, tl.max(scores, 1)) p = tl.math.exp2(scores - m_new[:, None]) l_new = tl.sum(p, 1) @@ -392,6 +253,7 @@ def _attn_fwd( ) acc = tl.dot(p.to(v.dtype), v, acc) row_max = m_new + # else: tile skipped — no softmax, no V load, no BMM2 for this tile # --- Final normalization: output = acc / row_sum --- # Clamp denominator to avoid 0/0 NaN when skip-softmax skips all KV tiles. @@ -1092,6 +954,7 @@ def attention( Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. """ + _load_sparsity_helpers() sm_scale = 1.0 / (q.shape[2] ** 0.5) if softmax_scale is None else softmax_scale return _Attention.apply( q, @@ -1115,266 +978,4 @@ def attention( ) -# --------------------------------------------------------------------------- -# Calibration kernel: collect multi-threshold skip-softmax sparsity stats -# --------------------------------------------------------------------------- -@triton.jit -def _attn_fwd_calibrate( - Q, - K, - V, - qk_scale, - b_start_loc, - b_seq_len, - b_start_loc_k, - b_seq_len_k, - Out, - stride_qbs, - stride_qh, - stride_kbs, - stride_kh, - stride_vbs, - stride_vh, - stride_obs, - stride_oh, - Threshold_trials, # [NUM_THRESHOLDS] float32 — pre-scaled to log2 space - Per_program_totals, # [num_programs * NUM_THRESHOLDS] int32 — per-program tile counts - Per_program_skipped, # [num_programs * NUM_THRESHOLDS] int32 — per-program skip counts - kv_group_num: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_D: tl.constexpr, - BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, - HEAD_DIM: tl.constexpr, - NUM_THRESHOLDS: tl.constexpr, - PADDED_THRESHOLDS: tl.constexpr, # next_power_of_2(NUM_THRESHOLDS) for tl.arange -): - """Forward kernel with multi-threshold sparsity measurement. - - Computes full attention (no skipping) while counting how many KV tiles - would be skipped at each threshold. Each program writes its local counts - to ``Per_program_totals`` and ``Per_program_skipped``; the Python wrapper - sums across programs afterward. This avoids global atomic contention. - """ - batch_idx = tl.program_id(0) - head_idx = tl.program_id(1) - tile_q = tl.program_id(2) - kv_head_idx = head_idx // kv_group_num - - seq_len_q = tl.load(b_seq_len + batch_idx) - seq_len_kv = tl.load(b_seq_len_k + batch_idx) - q_offset = tl.load(b_start_loc + batch_idx) - kv_offset = tl.load(b_start_loc_k + batch_idx) - - if tile_q * BLOCK_M >= seq_len_q: - return - - q_pos = tile_q * BLOCK_M + tl.arange(0, BLOCK_M) - kv_pos = tl.arange(0, BLOCK_N) - dim_pos = tl.arange(0, BLOCK_D) - d_mask = dim_pos < HEAD_DIM - - q_ptrs = (q_offset + q_pos[:, None]) * stride_qbs + head_idx * stride_qh + dim_pos[None, :] - q = tl.load(Q + q_ptrs, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :], other=0.0) - - k_base = K + kv_head_idx * stride_kh - v_base = V + kv_head_idx * stride_vh - - row_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - row_sum = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) - - # Pre-load all thresholds once (vectorized, stays in registers). - # tl.arange requires power-of-2 size, so use PADDED_THRESHOLDS with masking. - thresh_offs = tl.arange(0, PADDED_THRESHOLDS) - thresh_mask = thresh_offs < NUM_THRESHOLDS - thresholds = tl.load(Threshold_trials + thresh_offs, mask=thresh_mask, other=float("inf")) - - # Per-program local counters: avoid global atomic contention in inner loop. - # Each program accumulates locally, then writes once to Per_program buffers. - local_skipped = tl.zeros([PADDED_THRESHOLDS], dtype=tl.int32) - num_tiles = 0 - - kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) - - for kv_start in range(0, kv_bound, BLOCK_N): - kv_start = tl.multiple_of(kv_start, BLOCK_N) - - k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] - k = tl.load( - k_base + k_offs, - mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], - other=0.0, - ) - - scores = tl.dot(q, k) * qk_scale - scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) - - tile_row_max = tl.max(scores, 1) - - # --- Vectorized multi-threshold sparsity measurement --- - # A tile is skipped iff ALL Q rows satisfy: tile_row_max < row_max + thresh. - # Equivalently: max(tile_row_max - row_max) < thresh (worst-case row - # must still be below threshold for the tile to be skippable). - max_gap = tl.max(tile_row_max - row_max) # scalar - skip_mask = (max_gap < thresholds).to(tl.int32) # [PADDED_THRESHOLDS] - local_skipped += skip_mask - num_tiles += 1 - - # --- Always compute full attention (no skipping) --- - m_new = tl.maximum(row_max, tile_row_max) - p = tl.math.exp2(scores - m_new[:, None]) - l_new = tl.sum(p, 1) - correction = tl.math.exp2(row_max - m_new) - row_sum = row_sum * correction + l_new - acc = acc * correction[:, None] - - v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] - v = tl.load( - v_base + v_offs, - mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], - other=0.0, - ) - acc = tl.dot(p.to(v.dtype), v, acc) - row_max = m_new - - # --- Write per-program counters (no atomics, just stores) --- - # Compute unique flat program index for this (batch, head, q_tile) - num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound - num_heads = tl.num_programs(1) - prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q - base = prog_idx * NUM_THRESHOLDS - tl.store( - Per_program_totals + base + thresh_offs, - tl.full([PADDED_THRESHOLDS], num_tiles, dtype=tl.int32), - mask=thresh_mask, - ) - tl.store( - Per_program_skipped + base + thresh_offs, - local_skipped, - mask=thresh_mask, - ) - - acc = acc / tl.maximum(row_sum[:, None], 1e-6) - o_ptrs = (q_offset + q_pos[:, None]) * stride_obs + head_idx * stride_oh + dim_pos[None, :] - tl.store(Out + o_ptrs, acc, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :]) - - -def attention_calibrate( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - b_start_loc: torch.Tensor, - b_seq_len: torch.Tensor, - max_input_len: int, - is_causal: bool = True, - softmax_scale: float | None = None, - b_start_loc_k: torch.Tensor | None = None, - b_seq_len_k: torch.Tensor | None = None, - max_input_len_k: int | None = None, - *, - threshold_trials: list[float] | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """Flash attention with multi-threshold skip-softmax sparsity measurement. - - Computes full attention (identical output to dense attention) while - measuring how many KV tiles would be skipped at each threshold in - ``threshold_trials``. No autograd — forward only. - - Args: - q, k, v, b_start_loc, b_seq_len, max_input_len, is_causal, - softmax_scale, b_start_loc_k, b_seq_len_k, max_input_len_k: - Same as :func:`attention`. - threshold_trials: List of threshold values to measure sparsity for. - Each value is converted to log2-scaled space for the kernel. - - Returns: - Tuple of (output, sparsity_counters): - - output: ``[total_q_tokens, num_q_heads, head_dim]`` - - sparsity_counters: ``[num_thresholds, 2]`` int64 tensor where - ``[:, 0]`` = total tile evaluations, ``[:, 1]`` = skipped tiles. - Sparsity per threshold = ``counters[:, 1] / counters[:, 0]``. - """ - if threshold_trials is None or len(threshold_trials) == 0: - raise ValueError("threshold_trials must be a non-empty list") - - HEAD_DIM = q.shape[2] - num_q_heads = q.shape[1] - num_kv_heads = k.shape[1] - kv_group_num = num_q_heads // num_kv_heads - batch = b_seq_len.shape[0] - sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale - qk_scale = sm_scale * LOG2E - BLOCK_D = triton.next_power_of_2(HEAD_DIM) - BLOCK_M = 128 - BLOCK_N = 64 - - if b_seq_len_k is None: - b_seq_len_k = b_seq_len - b_start_loc_k = b_start_loc - - num_thresholds = len(threshold_trials) - - # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale - threshold_tensor = torch.tensor( - [math.log2(t) * sm_scale for t in threshold_trials], - dtype=torch.float32, - device=q.device, - ) - - o = torch.empty_like(q) - - num_q_tiles = triton.cdiv(max_input_len, BLOCK_M) - grid = (batch, num_q_heads, num_q_tiles) - num_programs = batch * num_q_heads * num_q_tiles - - # Per-program output buffers (no atomics needed — each program writes its own row) - per_program_totals = torch.zeros( - num_programs * num_thresholds, dtype=torch.int32, device=q.device - ) - per_program_skipped = torch.zeros( - num_programs * num_thresholds, dtype=torch.int32, device=q.device - ) - - _attn_fwd_calibrate[grid]( - q, - k, - v, - qk_scale, - b_start_loc, - b_seq_len, - b_start_loc_k, - b_seq_len_k, - o, - q.stride(0), - q.stride(1), - k.stride(0), - k.stride(1), - v.stride(0), - v.stride(1), - o.stride(0), - o.stride(1), - threshold_tensor, - per_program_totals, - per_program_skipped, - kv_group_num=kv_group_num, - BLOCK_M=BLOCK_M, - BLOCK_D=BLOCK_D, - BLOCK_N=BLOCK_N, - IS_CAUSAL=is_causal, - HEAD_DIM=HEAD_DIM, - NUM_THRESHOLDS=num_thresholds, - PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), - num_warps=4, - num_stages=1, - ) - - # Reduce across programs: sum per-program counts → [num_thresholds] - totals = per_program_totals.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) - skipped = per_program_skipped.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) - sparsity_counters = torch.stack([totals, skipped], dim=1) # [num_thresholds, 2] - - return o, sparsity_counters - - -__all__ = ["attention", "attention_calibrate"] +__all__ = ["LOG2E", "_apply_mask", "attention"] diff --git a/modelopt/torch/kernels/quantization/attention/__init__.py b/modelopt/torch/kernels/quantization/attention/__init__.py new file mode 100644 index 0000000000..ee64a4dd67 --- /dev/null +++ b/modelopt/torch/kernels/quantization/attention/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantization-specific attention kernel pieces (placeholder for combined sparse+quant path).""" diff --git a/modelopt/torch/kernels/sparsity/attention/__init__.py b/modelopt/torch/kernels/sparsity/attention/__init__.py index 561d8470e3..b45f4f27ae 100644 --- a/modelopt/torch/kernels/sparsity/attention/__init__.py +++ b/modelopt/torch/kernels/sparsity/attention/__init__.py @@ -18,7 +18,11 @@ import contextlib import threading -from modelopt.torch.kernels.common import IS_AVAILABLE, attention, register_triton_attention +from modelopt.torch.kernels.common.attention import ( + IS_AVAILABLE, + attention, + register_triton_attention, +) # --------------------------------------------------------------------------- # Optional backend registrations (depend on diffusers / ltx_core) diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py new file mode 100644 index 0000000000..8e7ef144c7 --- /dev/null +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -0,0 +1,293 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Skip-softmax multi-threshold calibration kernel and Python API. + +Runs a full attention forward (identical to dense attention) while measuring +how many KV tiles would be skipped at each candidate threshold. Used by the +sparse-attention calibration workflow in +``modelopt.torch.sparsity.attention_sparsity`` to fit a skip threshold. +""" + +import math + +import torch +import triton +import triton.language as tl + +from modelopt.torch.kernels.common.attention.triton_fa import LOG2E, _apply_mask + + +# --------------------------------------------------------------------------- +# Calibration kernel: collect multi-threshold skip-softmax sparsity stats +# --------------------------------------------------------------------------- +@triton.jit +def _attn_fwd_calibrate( + Q, + K, + V, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + Threshold_trials, # [NUM_THRESHOLDS] float32 — pre-scaled to log2 space + Per_program_totals, # [num_programs * NUM_THRESHOLDS] int32 — per-program tile counts + Per_program_skipped, # [num_programs * NUM_THRESHOLDS] int32 — per-program skip counts + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + NUM_THRESHOLDS: tl.constexpr, + PADDED_THRESHOLDS: tl.constexpr, # next_power_of_2(NUM_THRESHOLDS) for tl.arange +): + """Forward kernel with multi-threshold sparsity measurement. + + Computes full attention (no skipping) while counting how many KV tiles + would be skipped at each threshold. Each program writes its local counts + to ``Per_program_totals`` and ``Per_program_skipped``; the Python wrapper + sums across programs afterward. This avoids global atomic contention. + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + tile_q = tl.program_id(2) + kv_head_idx = head_idx // kv_group_num + + seq_len_q = tl.load(b_seq_len + batch_idx) + seq_len_kv = tl.load(b_seq_len_k + batch_idx) + q_offset = tl.load(b_start_loc + batch_idx) + kv_offset = tl.load(b_start_loc_k + batch_idx) + + if tile_q * BLOCK_M >= seq_len_q: + return + + q_pos = tile_q * BLOCK_M + tl.arange(0, BLOCK_M) + kv_pos = tl.arange(0, BLOCK_N) + dim_pos = tl.arange(0, BLOCK_D) + d_mask = dim_pos < HEAD_DIM + + q_ptrs = (q_offset + q_pos[:, None]) * stride_qbs + head_idx * stride_qh + dim_pos[None, :] + q = tl.load(Q + q_ptrs, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :], other=0.0) + + k_base = K + kv_head_idx * stride_kh + v_base = V + kv_head_idx * stride_vh + + row_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + row_sum = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) + + # Pre-load all thresholds once (vectorized, stays in registers). + # tl.arange requires power-of-2 size, so use PADDED_THRESHOLDS with masking. + thresh_offs = tl.arange(0, PADDED_THRESHOLDS) + thresh_mask = thresh_offs < NUM_THRESHOLDS + thresholds = tl.load(Threshold_trials + thresh_offs, mask=thresh_mask, other=float("inf")) + + # Per-program local counters: avoid global atomic contention in inner loop. + # Each program accumulates locally, then writes once to Per_program buffers. + local_skipped = tl.zeros([PADDED_THRESHOLDS], dtype=tl.int32) + num_tiles = 0 + + kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) + + for kv_start in range(0, kv_bound, BLOCK_N): + kv_start = tl.multiple_of(kv_start, BLOCK_N) + + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) + + scores = tl.dot(q, k) * qk_scale + scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + tile_row_max = tl.max(scores, 1) + + # --- Vectorized multi-threshold sparsity measurement --- + # A tile is skipped iff ALL Q rows satisfy: tile_row_max < row_max + thresh. + # Equivalently: max(tile_row_max - row_max) < thresh (worst-case row + # must still be below threshold for the tile to be skippable). + max_gap = tl.max(tile_row_max - row_max) # scalar + skip_mask = (max_gap < thresholds).to(tl.int32) # [PADDED_THRESHOLDS] + local_skipped += skip_mask + num_tiles += 1 + + # --- Always compute full attention (no skipping) --- + m_new = tl.maximum(row_max, tile_row_max) + p = tl.math.exp2(scores - m_new[:, None]) + l_new = tl.sum(p, 1) + correction = tl.math.exp2(row_max - m_new) + row_sum = row_sum * correction + l_new + acc = acc * correction[:, None] + + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + acc = tl.dot(p.to(v.dtype), v, acc) + row_max = m_new + + # --- Write per-program counters (no atomics, just stores) --- + # Compute unique flat program index for this (batch, head, q_tile) + num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound + num_heads = tl.num_programs(1) + prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q + base = prog_idx * NUM_THRESHOLDS + tl.store( + Per_program_totals + base + thresh_offs, + tl.full([PADDED_THRESHOLDS], num_tiles, dtype=tl.int32), + mask=thresh_mask, + ) + tl.store( + Per_program_skipped + base + thresh_offs, + local_skipped, + mask=thresh_mask, + ) + + acc = acc / tl.maximum(row_sum[:, None], 1e-6) + o_ptrs = (q_offset + q_pos[:, None]) * stride_obs + head_idx * stride_oh + dim_pos[None, :] + tl.store(Out + o_ptrs, acc, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :]) + + +def attention_calibrate( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + b_start_loc_k: torch.Tensor | None = None, + b_seq_len_k: torch.Tensor | None = None, + max_input_len_k: int | None = None, + *, + threshold_trials: list[float] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Flash attention with multi-threshold skip-softmax sparsity measurement. + + Computes full attention (identical output to dense attention) while + measuring how many KV tiles would be skipped at each threshold in + ``threshold_trials``. No autograd — forward only. + + Args: + q, k, v, b_start_loc, b_seq_len, max_input_len, is_causal, + softmax_scale, b_start_loc_k, b_seq_len_k, max_input_len_k: + Same as :func:`modelopt.torch.kernels.common.attention.attention`. + threshold_trials: List of threshold values to measure sparsity for. + Each value is converted to log2-scaled space for the kernel. + + Returns: + Tuple of (output, sparsity_counters): + - output: ``[total_q_tokens, num_q_heads, head_dim]`` + - sparsity_counters: ``[num_thresholds, 2]`` int64 tensor where + ``[:, 0]`` = total tile evaluations, ``[:, 1]`` = skipped tiles. + Sparsity per threshold = ``counters[:, 1] / counters[:, 0]``. + """ + if threshold_trials is None or len(threshold_trials) == 0: + raise ValueError("threshold_trials must be a non-empty list") + + HEAD_DIM = q.shape[2] + num_q_heads = q.shape[1] + num_kv_heads = k.shape[1] + kv_group_num = num_q_heads // num_kv_heads + batch = b_seq_len.shape[0] + sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale + qk_scale = sm_scale * LOG2E + BLOCK_D = triton.next_power_of_2(HEAD_DIM) + BLOCK_M = 128 + BLOCK_N = 64 + + if b_seq_len_k is None: + b_seq_len_k = b_seq_len + b_start_loc_k = b_start_loc + + num_thresholds = len(threshold_trials) + + # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale + threshold_tensor = torch.tensor( + [math.log2(t) * sm_scale for t in threshold_trials], + dtype=torch.float32, + device=q.device, + ) + + o = torch.empty_like(q) + + num_q_tiles = triton.cdiv(max_input_len, BLOCK_M) + grid = (batch, num_q_heads, num_q_tiles) + num_programs = batch * num_q_heads * num_q_tiles + + # Per-program output buffers (no atomics needed — each program writes its own row) + per_program_totals = torch.zeros( + num_programs * num_thresholds, dtype=torch.int32, device=q.device + ) + per_program_skipped = torch.zeros( + num_programs * num_thresholds, dtype=torch.int32, device=q.device + ) + + _attn_fwd_calibrate[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + threshold_tensor, + per_program_totals, + per_program_skipped, + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_CAUSAL=is_causal, + HEAD_DIM=HEAD_DIM, + NUM_THRESHOLDS=num_thresholds, + PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), + num_warps=4, + num_stages=1, + ) + + # Reduce across programs: sum per-program counts → [num_thresholds] + totals = per_program_totals.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) + skipped = per_program_skipped.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) + sparsity_counters = torch.stack([totals, skipped], dim=1) # [num_thresholds, 2] + + return o, sparsity_counters diff --git a/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py index bfc2f73b2a..434c4824f8 100644 --- a/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py @@ -36,7 +36,10 @@ attention_backend, ) -from modelopt.torch.kernels.common import attention, attention_calibrate +# ``attention`` and ``attention_calibrate`` are resolved lazily inside the +# call-site functions below. Capturing them at module top-level would fetch +# ``None`` from the partially-loaded ``common.attention`` package during the +# sparsity↔common circular import chain. _BACKEND_NAME = "modelopt_triton" _BACKEND_REGISTERED = False @@ -166,6 +169,8 @@ def _diffusers_triton_attention( calib_mode = getattr(_thread_local, "calibration_mode", False) if calib_mode: trials = getattr(_thread_local, "threshold_trials", None) + from modelopt.torch.kernels.common.attention import attention_calibrate + if trials and attention_calibrate is not None: o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) @@ -196,6 +201,8 @@ def _diffusers_triton_attention( if threshold is not None and threshold > 0.0: kw["skip_softmax_threshold"] = threshold + from modelopt.torch.kernels.common.attention import attention + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" do_measure = getattr(_thread_local, "measure_sparsity", False) if do_measure: diff --git a/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py index 6eba2004d8..90601dc2ca 100644 --- a/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py +++ b/modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py @@ -25,7 +25,10 @@ import torch -from modelopt.torch.kernels.common import attention, attention_calibrate +# ``attention`` and ``attention_calibrate`` are resolved lazily inside the +# call-site functions below. Capturing them at module top-level would fetch +# ``None`` from the partially-loaded ``common.attention`` package during the +# sparsity↔common circular import chain. from modelopt.torch.utils.logging import warn_rank_0 # Thread-local storage for skip-softmax configuration @@ -126,6 +129,8 @@ def _ltx_triton_attention( calib_mode = getattr(_thread_local, "calibration_mode", False) if calib_mode: trials = getattr(_thread_local, "threshold_trials", None) + from modelopt.torch.kernels.common.attention import attention_calibrate + if trials and attention_calibrate is not None: o, counters = attention_calibrate(q_flat, k_flat, v_flat, **kw, threshold_trials=trials) @@ -150,6 +155,8 @@ def _ltx_triton_attention( elif threshold is not None and threshold > 0.0: kw["skip_softmax_threshold"] = threshold + from modelopt.torch.kernels.common.attention import attention + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" o = attention(q_flat, k_flat, v_flat, **kw) return o.view(b, seq_q, heads * dim_head) diff --git a/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py new file mode 100644 index 0000000000..f066f9c4b7 --- /dev/null +++ b/modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Skip-softmax / N:M sparse attention helpers. + +These ``@triton.jit`` helpers are called conditionally from the baseline +flash-attention forward kernel in ``common/attention/triton_fa.py`` when the +user requests N:M sparsity or sink/window-aware dense regions. +""" + +import triton +import triton.language as tl + + +# --------------------------------------------------------------------------- +# N:M sparse softmax helpers +# --------------------------------------------------------------------------- +@triton.jit +def _sparse_nm_masks_m4(x0, x1, x2, x3, N: tl.constexpr): + """Top-N of 4 selection via pure boolean logic (6 comparisons, no int casts). + + Uses ``>=`` so that ties are broken by index (lower index wins). + Guarantees exactly N masks are True for any input including all-equal. + + Boolean formulas for "at least K of 3 wins": + K=3 (N=1): AND of all — must beat all 3 others + K=2 (N=2): majority — must beat at least 2 (sorting network) + K=1 (N=3): OR of all — must beat at least 1 + """ + c01 = x0 >= x1 + c02 = x0 >= x2 + c03 = x0 >= x3 + c12 = x1 >= x2 + c13 = x1 >= x3 + c23 = x2 >= x3 + + nc01 = ~c01 + nc02 = ~c02 + nc03 = ~c03 + nc12 = ~c12 + nc13 = ~c13 + nc23 = ~c23 + + if N == 1: + # Keep max only: must beat all 3 + m0 = c01 & c02 & c03 + m1 = nc01 & c12 & c13 + m2 = nc02 & nc12 & c23 + m3 = nc03 & nc13 & nc23 + elif N == 2: + # Majority vote: must beat at least 2 of 3 + m0 = (c01 & c02) | (c01 & c03) | (c02 & c03) + m1 = (nc01 & c12) | (nc01 & c13) | (c12 & c13) + m2 = (nc02 & nc12) | (nc02 & c23) | (nc12 & c23) + m3 = (nc03 & nc13) | (nc03 & nc23) | (nc13 & nc23) + elif N == 3: + # Keep all but min: must beat at least 1 + m0 = c01 | c02 | c03 + m1 = nc01 | c12 | c13 + m2 = nc02 | nc12 | c23 + m3 = nc03 | nc13 | nc23 + else: + tl.static_assert(False, "N must be 1, 2, or 3 for M=4") + + return m0, m1, m2, m3 + + +@triton.jit +def _apply_sparse_nm_to_qk_tile( + qk, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SPARSITY_N: tl.constexpr, + SPARSITY_M: tl.constexpr, +): + """Apply N:M sparse softmax to a QK score tile. + + For every ``SPARSITY_M`` consecutive elements along the N (key) dimension, + keeps the top ``SPARSITY_N`` values and sets the rest to ``-inf``. + ``BLOCK_N`` must be divisible by ``SPARSITY_M``. + + For M=4, exactly N values are retained (ties broken by position). + For M=8, a threshold-based approach (``tl.sort``) may retain more + than N values when ties straddle the threshold boundary. + """ + tl.static_assert(SPARSITY_M == 4 or SPARSITY_M == 8, "SPARSITY_M must be 4 or 8") # noqa: PLR1714 + MASK_VAL: tl.constexpr = float("-inf") + + if SPARSITY_M == 4: + tl.static_assert(BLOCK_N % 4 == 0, "BLOCK_N must be divisible by 4") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 4, 4)) + cols = tl.arange(0, 4)[None, None, :] + x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2) + x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2) + x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2) + x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2) + + m0, m1, m2, m3 = _sparse_nm_masks_m4(x0, x1, x2, x3, SPARSITY_N) + + out = tl.full((BLOCK_M, BLOCK_N // 4, 4), 0.0, dtype=qk.dtype) + out = tl.where(cols == 0, tl.expand_dims(tl.where(m0, x0, MASK_VAL), 2), out) + out = tl.where(cols == 1, tl.expand_dims(tl.where(m1, x1, MASK_VAL), 2), out) + out = tl.where(cols == 2, tl.expand_dims(tl.where(m2, x2, MASK_VAL), 2), out) + out = tl.where(cols == 3, tl.expand_dims(tl.where(m3, x3, MASK_VAL), 2), out) + return tl.reshape(out, (BLOCK_M, BLOCK_N)) + + else: # SPARSITY_M == 8 + tl.static_assert(BLOCK_N % 8 == 0, "BLOCK_N must be divisible by 8") + reshaped = tl.reshape(qk, (BLOCK_M, BLOCK_N // 8, 8)) + + # Sort each group of 8 ascending; N-th largest is at index (8 - N) + sorted_vals = tl.sort(reshaped, dim=2) + KTH_IDX: tl.constexpr = SPARSITY_M - SPARSITY_N # index of N-th largest in ascending order + + # Extract the threshold value at KTH_IDX via masked sum + # Use 0.0 as fill (not -inf) so sum equals just the KTH element + cols = tl.arange(0, 8)[None, None, :] + threshold = tl.sum(tl.where(cols == KTH_IDX, sorted_vals, 0.0), axis=2) + + # Mask: keep elements >= threshold (may keep >N on ties — acceptable) + mask = reshaped >= tl.expand_dims(threshold, 2) + return tl.reshape(tl.where(mask, reshaped, MASK_VAL), (BLOCK_M, BLOCK_N)) + + +# --------------------------------------------------------------------------- +# BLASST skip-softmax per-tile decision +# --------------------------------------------------------------------------- +@triton.jit +def _skip_softmax_decision( + scores, + row_max, + SKIP_THRESHOLD_LOG2: tl.constexpr, + Sparsity_total, + Sparsity_skipped, + MEASURE_SPARSITY: tl.constexpr, +): + """BLASST skip-softmax per-tile decision (https://arxiv.org/pdf/2512.12087). + + During FlashAttention's block-wise computation we maintain a running + maximum ``m_i^(j)`` across blocks. If a block's local maximum + ``~m_i^(j)`` is significantly smaller than the running maximum + (``~m_i^(j) - m_i^(j) < ln(lambda)``), then ``exp(~m_i^(j) - m_i^(j)) + < lambda ~= 0`` and the block's contribution to the output is negligible. + The caller may then skip the softmax computation, V load, and BMM2. + + The threshold is pre-scaled to log2 space by the Python wrapper so it can + be compared directly against the already-scaled scores. + + Returns: + True when *all* Q rows in the tile satisfy the skip criterion. + + When ``MEASURE_SPARSITY`` is set, also records total/skipped tile counts + via atomic adds on ``Sparsity_total`` / ``Sparsity_skipped``. + """ + tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled) + # Per-row: True if row's tile max is negligible vs running max + can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2) + # Per-tile: skip entire tile only if ALL rows are negligible + skip_tile = tl.min(can_skip.to(tl.int32)) == 1 + + if MEASURE_SPARSITY: + tl.atomic_add(Sparsity_total, 1) # count every tile + if skip_tile: + tl.atomic_add(Sparsity_skipped, 1) # count skipped tiles + + return skip_tile + + +# --------------------------------------------------------------------------- +# Sink/window dense-region check +# --------------------------------------------------------------------------- +@triton.jit +def _is_dense_region( + kv_start, + tile_q, + seq_len_q, + seq_len_kv, + BLOCK_M: tl.constexpr, + NUM_SINK_TOKENS: tl.constexpr, + DENSE_WINDOW_SIZE: tl.constexpr, +): + """Check if a KV tile falls in a dense region (sink tokens or local window). + + Uses absolute token positions so the result is BLOCK_N-independent, + ensuring forward and backward (which may use different BLOCK_N) agree. + + Returns: + True if the tile should be kept dense (skip N:M sparsification). + """ + is_sink = kv_start < NUM_SINK_TOKENS + causal_offset = seq_len_kv - seq_len_q + q_abs_pos = tile_q * BLOCK_M + causal_offset + token_distance = q_abs_pos - kv_start + is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) + return is_sink or is_local diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py b/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py index dd8a265bd4..54f66a279e 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py @@ -26,7 +26,7 @@ diffusers = pytest.importorskip("diffusers") -from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE from modelopt.torch.kernels.sparsity.attention import diffusers_triton_attention as diffusers_mod from modelopt.torch.kernels.sparsity.attention import ltx_triton_attention as ltx_mod diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py index 2d0798ade7..7fc3a554c7 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py @@ -26,10 +26,10 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - from modelopt.torch.kernels.common import attention, register_triton_attention + from modelopt.torch.kernels.common.attention import attention, register_triton_attention if register_triton_attention is not None: register_triton_attention() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py index babf7283bd..eaa1f5e325 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py @@ -29,10 +29,10 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - from modelopt.torch.kernels.common import attention, attention_calibrate + from modelopt.torch.kernels.common.attention import attention, attention_calibrate @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py index 7f0d1edf72..56f0a9e9d8 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py @@ -25,10 +25,10 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: - from modelopt.torch.kernels.common import attention, register_triton_attention + from modelopt.torch.kernels.common.attention import attention, register_triton_attention if register_triton_attention is not None: register_triton_attention() diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py index dab67f28f8..ff215a6ff8 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py @@ -27,14 +27,16 @@ pytest.mark.filterwarnings("ignore::DeprecationWarning"), ] -from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: import triton import triton.language as tl - from modelopt.torch.kernels.common import attention - from modelopt.torch.kernels.common.triton_fa import _apply_sparse_nm_to_qk_tile + from modelopt.torch.kernels.common.attention import attention + from modelopt.torch.kernels.sparsity.attention.skip_softmax_helpers import ( + _apply_sparse_nm_to_qk_tile, + ) @triton.jit def _test_apply_sparse_nm( diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py index cbc38210fb..0c267ee212 100644 --- a/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py @@ -33,7 +33,7 @@ diffusers = pytest.importorskip("diffusers") -from modelopt.torch.kernels.common import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE +from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE if TRITON_KERNEL_AVAILABLE: import modelopt.torch.sparsity.attention_sparsity as mtsa diff --git a/tests/unit/torch/kernels/test_triton_fa.py b/tests/unit/torch/kernels/test_triton_fa.py index 1770b103e7..6969ae0a0e 100644 --- a/tests/unit/torch/kernels/test_triton_fa.py +++ b/tests/unit/torch/kernels/test_triton_fa.py @@ -33,7 +33,8 @@ def test_triton_fa_importable_on_cpu(): except ImportError: pytest.skip("triton is not installed") - from modelopt.torch.kernels.common import triton_fa + from modelopt.torch.kernels.common.attention import triton_fa + from modelopt.torch.kernels.sparsity.attention import calibrate assert "attention" in triton_fa.__all__ - assert "attention_calibrate" in triton_fa.__all__ + assert callable(calibrate.attention_calibrate) From bc81615ed3c1d892ac4b1d98539f1fde51d8717e Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Tue, 21 Apr 2026 21:31:52 +0000 Subject: [PATCH 3/9] Update test case path Signed-off-by: Jingyu Xin --- .../common/attention}/test_triton_fa.py | 0 .../torch/{sparsity/attention_sparsity => kernels}/conftest.py | 0 .../kernels => kernels/quantization/conv}/test_implicit_gemm.py | 0 .../sparsity/attention}/test_diffusers_triton_attention.py | 0 .../sparsity/attention}/test_triton_fa_calibrate.py | 0 .../sparsity/attention}/test_triton_fa_skip_softmax.py | 0 .../sparsity/attention}/test_triton_fa_sparse_nm.py | 0 tests/unit/torch/kernels/{ => common/attention}/test_triton_fa.py | 0 .../sparsity/attention}/test_kernel_backends.py | 0 .../sparsity/attention}/test_ltx_triton_attention.py | 0 10 files changed, 0 insertions(+), 0 deletions(-) rename tests/gpu/torch/{sparsity/attention_sparsity => kernels/common/attention}/test_triton_fa.py (100%) rename tests/gpu/torch/{sparsity/attention_sparsity => kernels}/conftest.py (100%) rename tests/gpu/torch/{quantization/kernels => kernels/quantization/conv}/test_implicit_gemm.py (100%) rename tests/gpu/torch/{sparsity/attention_sparsity => kernels/sparsity/attention}/test_diffusers_triton_attention.py (100%) rename tests/gpu/torch/{sparsity/attention_sparsity => kernels/sparsity/attention}/test_triton_fa_calibrate.py (100%) rename tests/gpu/torch/{sparsity/attention_sparsity => kernels/sparsity/attention}/test_triton_fa_skip_softmax.py (100%) rename tests/gpu/torch/{sparsity/attention_sparsity => kernels/sparsity/attention}/test_triton_fa_sparse_nm.py (100%) rename tests/unit/torch/kernels/{ => common/attention}/test_triton_fa.py (100%) rename tests/unit/torch/{sparsity/attention_sparsity => kernels/sparsity/attention}/test_kernel_backends.py (100%) rename tests/unit/torch/{sparsity/attention_sparsity => kernels/sparsity/attention}/test_ltx_triton_attention.py (100%) diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py b/tests/gpu/torch/kernels/common/attention/test_triton_fa.py similarity index 100% rename from tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py rename to tests/gpu/torch/kernels/common/attention/test_triton_fa.py diff --git a/tests/gpu/torch/sparsity/attention_sparsity/conftest.py b/tests/gpu/torch/kernels/conftest.py similarity index 100% rename from tests/gpu/torch/sparsity/attention_sparsity/conftest.py rename to tests/gpu/torch/kernels/conftest.py diff --git a/tests/gpu/torch/quantization/kernels/test_implicit_gemm.py b/tests/gpu/torch/kernels/quantization/conv/test_implicit_gemm.py similarity index 100% rename from tests/gpu/torch/quantization/kernels/test_implicit_gemm.py rename to tests/gpu/torch/kernels/quantization/conv/test_implicit_gemm.py diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py b/tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py similarity index 100% rename from tests/gpu/torch/sparsity/attention_sparsity/test_diffusers_triton_attention.py rename to tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py similarity index 100% rename from tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_calibrate.py rename to tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py similarity index 100% rename from tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_skip_softmax.py rename to tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py b/tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_sparse_nm.py similarity index 100% rename from tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py rename to tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_sparse_nm.py diff --git a/tests/unit/torch/kernels/test_triton_fa.py b/tests/unit/torch/kernels/common/attention/test_triton_fa.py similarity index 100% rename from tests/unit/torch/kernels/test_triton_fa.py rename to tests/unit/torch/kernels/common/attention/test_triton_fa.py diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py b/tests/unit/torch/kernels/sparsity/attention/test_kernel_backends.py similarity index 100% rename from tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py rename to tests/unit/torch/kernels/sparsity/attention/test_kernel_backends.py diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_ltx_triton_attention.py b/tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py similarity index 100% rename from tests/unit/torch/sparsity/attention_sparsity/test_ltx_triton_attention.py rename to tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py From a7125fdea1705da0ae8a02812f9e403e4a874892 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Tue, 21 Apr 2026 22:56:33 +0000 Subject: [PATCH 4/9] Fix the CI/CD Signed-off-by: Jingyu Xin --- tests/gpu/torch/quantization/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu/torch/quantization/conftest.py b/tests/gpu/torch/quantization/conftest.py index 9e34e5ef68..c6f9d23b6c 100644 --- a/tests/gpu/torch/quantization/conftest.py +++ b/tests/gpu/torch/quantization/conftest.py @@ -16,7 +16,7 @@ import pytest -from modelopt.torch.quantization import triton as triton_kernel +from modelopt.torch.kernels.quantization import gemm as triton_kernel @pytest.fixture(autouse=True) From 004f966738ce2233f1e2497216b968d0db3ba219 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Tue, 21 Apr 2026 23:19:46 +0000 Subject: [PATCH 5/9] minor update on the doc Signed-off-by: Jingyu Xin --- CHANGELOG.rst | 6 +++--- CLAUDE.md | 1 + examples/diffusers/README.md | 2 +- modelopt/torch/kernels/quantization/gemm/__init__.py | 1 - 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 20a677d0a0..e1a5aacca1 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,14 +9,14 @@ Changelog - Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics. - Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md `_ for more details. - Added iterator interface using CalibrationDataReader in ONNX quantization workflow. -- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. -- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. +- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. +- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. - Add Video Sparse Attention (VSA) method for video diffusion models (``modelopt.torch.sparsity.attention_sparsity``). VSA uses 3D block tiling with a two-branch architecture for attention speedup. - Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml `_ for more details. - Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. - [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution. - Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml `_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml `_ for usage. -- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning. +- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.kernels.quantization.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning. **Backward Breaking Changes** diff --git a/CLAUDE.md b/CLAUDE.md index 4af3858678..d0b47148c3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -65,6 +65,7 @@ ModelOpt code base is organized into four top-level namespaces: | `nas` | `modelopt/torch/nas/` | Neural architecture search | | `export` | `modelopt/torch/export/` | Checkpoint export for TRT-LLM / Megatron | | `peft` | `modelopt/torch/peft/` | QLoRA and PEFT integration | +| `kernels` | `modelopt/torch/kernels/` | Custom CUDA/Triton kernels grouped by role: `common/attention` (baseline Triton FA), `quantization/{conv,gemm}` (implicit-GEMM CUDA + tensor-quant C++/CUDA + fp4/fp8 Triton), `sparsity/attention` (skip-softmax / N:M / diffusers+LTX backends) | | `_deploy` | `modelopt/torch/_deploy/` | Internal deployment utilities | | `utils` | `modelopt/torch/utils/` | Shared utilities and plugin infrastructure | diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index ac14d98227..8e0f7cdef9 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -119,7 +119,7 @@ python quantize.py \ #### Wan 2.2 VAE NVFP4 (Conv3D Implicit GEMM) -The Wan 2.2 VAE (`AutoencoderKLWan`, shared between the 5B and 14B pipelines) is built from 3D convolutions. When quantizing the VAE with NVFP4, the `Conv3d` layers are automatically dispatched through a custom BF16 WMMA implicit-GEMM kernel with fused FP4 activation quantization. Requires SM80+ (Ampere or newer). See [`modelopt/torch/quantization/src/conv/README.md`](../../modelopt/torch/quantization/src/conv/README.md) for kernel details. +The Wan 2.2 VAE (`AutoencoderKLWan`, shared between the 5B and 14B pipelines) is built from 3D convolutions. When quantizing the VAE with NVFP4, the `Conv3d` layers are automatically dispatched through a custom BF16 WMMA implicit-GEMM kernel with fused FP4 activation quantization. Requires SM80+ (Ampere or newer). See [`modelopt/torch/kernels/quantization/conv/README.md`](../../modelopt/torch/kernels/quantization/conv/README.md) for kernel details. ```sh python quantize.py \ diff --git a/modelopt/torch/kernels/quantization/gemm/__init__.py b/modelopt/torch/kernels/quantization/gemm/__init__.py index def70e5914..39b07b4faa 100644 --- a/modelopt/torch/kernels/quantization/gemm/__init__.py +++ b/modelopt/torch/kernels/quantization/gemm/__init__.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # From a849d88ab66bf446c05e4e28f3b91202f90f41b2 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 22 Apr 2026 07:27:03 +0000 Subject: [PATCH 6/9] Update the VSA support for WAN and LTX Signed-off-by: Jingyu Xin --- examples/diffusers/sparsity/README.md | 216 ++++++++- examples/diffusers/sparsity/ltx2_vsa.py | 276 ++++++++++++ ...2_skip_softmax.py => wan22_sparse_attn.py} | 326 ++++++++++---- .../attention_sparsity/methods/vsa.py | 18 +- .../attention_sparsity/plugins/__init__.py | 21 +- .../attention_sparsity/plugins/huggingface.py | 2 +- .../attention_sparsity/plugins/ltx2.py | 413 ++++++++++++++++++ .../attention_sparsity/plugins/wan22.py | 180 ++++++++ 8 files changed, 1328 insertions(+), 124 deletions(-) create mode 100644 examples/diffusers/sparsity/ltx2_vsa.py rename examples/diffusers/sparsity/{wan22_skip_softmax.py => wan22_sparse_attn.py} (57%) create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py create mode 100644 modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md index dc44fbcd17..31620ec7d7 100644 --- a/examples/diffusers/sparsity/README.md +++ b/examples/diffusers/sparsity/README.md @@ -1,4 +1,4 @@ -# Skip-Softmax Sparse Attention for Diffusion Models +# Sparse Attention for Diffusion Models > [!WARNING] > **Third-Party License Notice — LTX-2** @@ -13,11 +13,32 @@ > fine-tuned weights produced from LTX-2 (including quantized, distilled, or sparsified > checkpoints) remain subject to the LTX Community License Agreement, not Apache 2.0. -Skip-softmax sparse attention (BLASST, ) skips KV -tiles whose attention scores are negligible during the FlashAttention computation, -reducing FLOPs without retraining. +Two sparse-attention methods are supported under +`modelopt.torch.sparsity.attention_sparsity` (`mtsa`): + +| Method | When to use | Calibration | +|--------|-------------|-------------| +| **Skip-Softmax** (BLASST) | Drop low-impact KV tiles inside FlashAttention. Works on any transformer with bidirectional attention. | Optional (exponential model) | +| **VSA** (Video Sparse Attention) | Block-level two-branch attention tuned for video models with long 3D token sequences. | None (fixed `top_k_ratio`) | + +Switching between methods is a CLI/config change — the pipelines, APIs, +and plugins are shared. + +| Model | Script | Methods | +|-------|--------|---------| +| Wan 2.2 5B / 14B | `wan22_sparse_attn.py` | `--method skip_softmax` (default), `--method vsa` | +| LTX-2 | `ltx2_vsa.py` | VSA only (LTX-2 uses a custom attention module; skip-softmax backend in progress) | + +--- + +## Skip-Softmax Sparse Attention + +Skip-softmax (BLASST, ) skips KV tiles whose attention +scores are negligible during the FlashAttention computation, reducing FLOPs without +retraining. + +Two threshold modes are supported: -Two modes are supported: - **Fixed raw threshold** — pass a log2-space threshold directly to the Triton kernel. No calibration needed. Good for quick testing and sweeps. - **Calibrated threshold** — an exponential model @@ -26,43 +47,38 @@ Two modes are supported: without recalibration. Log-space fitting (`fit_logspace=True`) is recommended for diffusion models where scale_factors span many orders of magnitude. -## Supported Models - -| Model | Script | Notes | -|-------|--------|-------| -| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only | -| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) | -| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend | - -## Quick Start +### Quick Start (Wan 2.2) ```bash # Fixed raw threshold (no calibration, fast) -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --raw-threshold -0.7 \ --prompt "A cat playing piano" --output out.mp4 # With calibration -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --calibrate --target-sparsity 0.5 \ --prompt "A cat playing piano" --output out.mp4 # Dense baseline (no sparsity, for comparison) -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --baseline \ --prompt "A cat playing piano" --output baseline.mp4 # Report runtime sparsity (per-layer tile skip ratios) -python wan22_skip_softmax.py \ +python wan22_sparse_attn.py \ --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ --raw-threshold -0.7 --report-avg-sparsity \ --prompt "A cat playing piano" --output out.mp4 ``` -## Threshold Modes +`--method skip_softmax` is the default, so it doesn't need to be passed +explicitly when using skip-softmax flags. + +### Threshold Modes | Mode | How threshold reaches the kernel | Use case | |------|----------------------------------|----------| @@ -70,7 +86,163 @@ python wan22_skip_softmax.py \ | **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation | | **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated | -## Known Issues +### Known Issues + +- **14B dual transformer calibration**: Transformers are calibrated sequentially — + transformer_2's calibration runs while transformer_1 is already sparsified, + introducing asymmetric calibration conditions. +- **Minimum achievable sparsity**: Even the strictest threshold may yield 30–40% + sparsity on diffusion models (many tiles are inherently negligible). Targets + below this floor cause extrapolation; an inference-time warning is emitted. + +--- + +## Video Sparse Attention (VSA) + +VSA is a two-branch sparse attention architecture tailored for video diffusion +models: + +1. **Compression branch** — averages tokens within a 3D block (default `4,4,4` = + 64 tokens) and computes coarse-grained block-level attention for global context. +2. **Sparse branch** — selects the top-K most important blocks by the compression + branch's attention scores and computes fine-grained attention only on those. +3. **Gate blend** — `output = compression * gate_compress + sparse`. On models + without a learned `gate_compress` (Wan 2.2, and LTX-2 until fine-tuned), VSA + passes a zero tensor so `output = 0 * compression + sparse = sparse`. This + makes VSA at `top_k_ratio=1.0` (keep all blocks) mathematically equivalent to + dense attention, modulo `bfloat16` kernel rounding (~10⁻⁵ per call on a 75k + token sequence). + +VSA is **calibration-free** — sparsity is controlled by a fixed `top_k_ratio` +(`0.5` keeps 50% of blocks, `0.3` keeps 30%). + +### Quick Start + +```bash +# Wan 2.2 — VSA with default 50% top-K ratio (video_shape auto-derived) +python wan22_sparse_attn.py --method vsa \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --top-k-ratio 0.5 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# Wan 2.2 — aggressive 30% top-K (70% sparsity), keep first/last 2 layers dense +python wan22_sparse_attn.py --method vsa \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --top-k-ratio 0.3 --skip-first-last 2 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# Wan 2.2 — 720p+ / 81+ frames can OOM during VAE decode since VSA reserves +# ~15 GB of GPU memory for its tile buffers. Enable VAE tiling to recover. +python wan22_sparse_attn.py --method vsa \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --top-k-ratio 0.5 --enable-vae-tiling \ + --num-frames 81 --height 720 --width 1280 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# LTX-2 — VSA +python ltx2_vsa.py \ + --checkpoint /path/to/ltx2.safetensors \ + --text-encoder-path /path/to/gemma \ + --top-k-ratio 0.5 \ + --prompt "A cat playing piano" --output vsa.mp4 + +# LTX-2 — baseline (no VSA) +python ltx2_vsa.py \ + --checkpoint /path/to/ltx2.safetensors \ + --text-encoder-path /path/to/gemma \ + --no-vsa --output baseline.mp4 +``` + +### Requirements + +- `fastvideo_kernel` at runtime (the Triton VSA kernel). Install with + `pip install fastvideo_kernel`. VSA imports this lazily, so the modelopt + sparsity API loads without it, but a VSA forward will raise a clear + `ImportError` if missing. +- For LTX-2 only: `ltx_core`, `ltx_trainer`, `ltx_pipelines` (see LICENSE + notice above). + +### Programmatic API + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT + +# Apply with the pre-built default (50% top-K, self-attention only) +transformer = mtsa.sparsify(transformer, VSA_DEFAULT) + +# Or with a custom config +config = { + "sparse_cfg": { + "*.attn1*": { + "method": "vsa", + "block_size_3d": (4, 4, 4), # 3D tile (T, H, W); 64 tokens per block + "top_k_ratio": 0.3, # 70% sparsity + "enable": True, + # "video_shape": (T, H, W), # optional; auto-derived by the plugin + }, + "*.attn2*": {"enable": False}, # skip cross-attention + "default": {"enable": False}, + }, +} +transformer = mtsa.sparsify(transformer, config) +``` + +### Configuration Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `block_size_3d` | `(4, 4, 4)` | Video tile dims (T, H, W) — default creates 64-token blocks | +| `top_k_ratio` | `0.5` | Fraction of blocks kept in the sparse branch (0 < ratio ≤ 1). `1.0` keeps all blocks = degenerate dense mode | +| `video_shape` | `None` | Post-patchify video shape (T, H, W). Plugins auto-derive it — set explicitly only to override. | +| `enable` | `True` | Per-layer toggle | + +### How VSA Routes Through the Sparse-Attention API + +- **Wan 2.2** uses diffusers' `WanAttention` whose processor calls + `F.scaled_dot_product_attention` — VSA's SDPA patch in + `SparseAttentionModule._forward_with_vsa_sdpa_patch` intercepts that call and + replaces the computation with the Triton VSA kernel. The Wan 2.2 plugin + registers a forward pre-hook that reads `hidden_states.shape = (B, C, T, H, W)` + and sets `video_shape = (T // p_t, H // p_h, W // p_w)` on each VSA method + instance before the transformer runs. +- **LTX-2** uses its native `LTXSelfAttention` whose forward signature is + `(x, context, pe, k_pe)` and does **not** call `F.scaled_dot_product_attention`. + The LTX-2 plugin installs a `_LTX2SparseAttention` wrapper that computes + Q/K/V (with LTX-2's RMSNorm and `ltx_core` RoPE), an optional trainable + `gate_compress` (zero-init), and then calls `VSA.forward_attention` directly. + A forward pre-hook on the root `LTXModel` extracts `video_shape` from + `Modality.positions`. +- Cross-attention is detected via Q/K sequence-length mismatch and falls + through to the original attention path (no behaviour change). + +### Verifying the Setup on Wan 2.2 + +A good sanity check is to compare `top_k_ratio=1.0` to the dense baseline — +since VSA without a learned gate becomes pure sparse attention and a full +mask is mathematically equivalent to dense, the two outputs should be close. +On a Wan 2.2 14B run at 720×1280 / 81 frames / 40 steps we measured: + +| Comparison | First-frame PSNR | +|---|---| +| baseline vs baseline w/ VAE tiling | 40.5 dB | +| baseline vs VSA `top_k_ratio=1.0` | 23.9 dB | +| baseline vs VSA `top_k_ratio=0.5` | 13.1 dB | + +The ~24 dB degrade at `top_k=1.0` is error accumulation (6400 attention +calls × bf16 rounding through the denoising loop) — a single-call PSNR vs +dense SDPA is 50 dB on random inputs. + +### Known Limits -- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions. -- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted. +- **Peak memory on 720p+**: VSA's tile buffers reserve ~15 GB of GPU memory + on top of the transformer, which can OOM the one-shot VAE decode at 720p / + 81 frames. Pass `--enable-vae-tiling` (or set + `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`) to recover. +- **Token count ≥ 16 tiles (≈1024 tokens)**: VSA's setup overhead dominates for + tiny sequences. For LTX-2, use ≥121 frames at ≥512×768 for meaningful speedups. +- **Mixing with skip-softmax**: VSA patches SDPA globally per module, while + skip-softmax needs `attn_implementation="eager"`. `conversion.py` rejects + configs that mix the two — run them separately. +- **Training**: `to_gate_compress` is zero-initialised and trainable, but no + training loop is wired up yet. This example covers inference only. diff --git a/examples/diffusers/sparsity/ltx2_vsa.py b/examples/diffusers/sparsity/ltx2_vsa.py new file mode 100644 index 0000000000..87db1faf01 --- /dev/null +++ b/examples/diffusers/sparsity/ltx2_vsa.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LTX-2 inference with Video Sparse Attention (VSA). + +Applies VSA to LTX-2's self-attention modules. VSA is calibration-free — +sparsity is controlled via ``top_k_ratio`` (fraction of 3D blocks kept in +the sparse branch). + +The LTX-2 plugin under ``modelopt.torch.sparsity.attention_sparsity.plugins.ltx2`` +handles the specifics: + +- Detects ``LTXSelfAttention`` modules by class name. +- Computes ``(T, H, W)`` from ``Modality.positions`` at each forward. +- Wraps each attention module in ``_LTX2SparseAttention``, which computes + Q/K/V, RoPE, and the optional (zero-initialised, trainable) + ``gate_compress`` before calling ``VSA.forward_attention``. + +Requirements: +- ``fastvideo_kernel`` (Triton VSA kernel). +- ``ltx_core``, ``ltx_trainer``, ``ltx_pipelines`` (third-party LTX-2 packages + from Lightricks — see the LICENSE notice in the top-level sparsity README). + +Example:: + + # VSA at 50% top-K ratio + python ltx2_vsa.py --checkpoint path/to/model.safetensors \\ + --text-encoder-path path/to/gemma --top-k-ratio 0.5 \\ + --prompt "A cat playing piano" --output vsa.mp4 + + # Baseline (no VSA) + python ltx2_vsa.py --checkpoint path/to/model.safetensors \\ + --text-encoder-path path/to/gemma --no-vsa --output baseline.mp4 +""" + +import argparse +import copy +import time +from pathlib import Path + +import torch + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT + +# LTX-2 is optional; import lazily so --help works even without it. +try: + from ltx_trainer.model_loader import load_model + from ltx_trainer.progress import StandaloneSamplingProgress + from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler + from ltx_trainer.video_utils import save_video + + _LTX_AVAILABLE = True +except ImportError as _exc: + _LTX_IMPORT_ERROR = _exc + _LTX_AVAILABLE = False + + +# LTX-2 uses a 1:8192 pixels-to-tokens compression ratio +LTX2_PIXEL_TO_TOKEN_RATIO = 8192 + +# VSA 3D block size: 4x4x4 = 64 tokens per block +VSA_BLOCK_ELEMENTS = 64 + + +def calculate_expected_tokens(num_frames: int, height: int, width: int) -> int: + return num_frames * height * width // LTX2_PIXEL_TO_TOKEN_RATIO + + +def is_vsa_compatible(num_frames: int, height: int, width: int) -> tuple[bool, str]: + """Check whether the requested input size is large enough for VSA to help.""" + tokens = calculate_expected_tokens(num_frames, height, width) + tiles = tokens // VSA_BLOCK_ELEMENTS + if tiles >= 90: + return True, f"Excellent: {tokens} tokens ({tiles} tiles)" + if tiles >= 16: + return True, f"Marginal: {tokens} tokens ({tiles} tiles)" + return False, f"Too small: {tokens} tokens ({tiles} tiles, need 16+ for VSA)" + + +def apply_vsa( + transformer: torch.nn.Module, + num_frames: int, + height: int, + width: int, + top_k_ratio: float, +) -> torch.nn.Module: + """Apply VSA to the LTX-2 transformer.""" + compatible, reason = is_vsa_compatible(num_frames, height, width) + print(f" VSA compatibility: {reason}") + if not compatible: + print(" [WARNING] Input size may be too small for VSA to help.") + + config = copy.deepcopy(VSA_DEFAULT) + # Override top_k_ratio on the attention pattern + for cfg in config["sparse_cfg"].values(): + if isinstance(cfg, dict) and cfg.get("method") == "vsa": + cfg["top_k_ratio"] = top_k_ratio + + print(" Applying VSA to attention modules...") + return mtsa.sparsify(transformer, config) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="LTX-2 video generation with Video Sparse Attention (VSA)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--checkpoint", type=str, required=True, help="LTX-2 model checkpoint") + parser.add_argument( + "--text-encoder-path", type=str, required=True, help="Gemma text encoder directory" + ) + parser.add_argument( + "--prompt", + type=str, + default="A serene mountain landscape with a flowing river, golden hour lighting", + ) + parser.add_argument("--negative-prompt", type=str, default="") + parser.add_argument("--height", type=int, default=512, help="Video height (multiple of 32)") + parser.add_argument("--width", type=int, default=768, help="Video width (multiple of 32)") + parser.add_argument("--num-frames", type=int, default=121, help="Must be k*8 + 1") + parser.add_argument("--frame-rate", type=float, default=25.0) + parser.add_argument("--num-inference-steps", type=int, default=30) + parser.add_argument("--guidance-scale", type=float, default=4.0, help="CFG scale") + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument( + "--no-vsa", + action="store_true", + help="Disable VSA (baseline run, for timing comparison)", + ) + parser.add_argument( + "--top-k-ratio", + type=float, + default=0.5, + help="VSA sparsity ratio (0.5 ⇒ 50%% sparsity, 0.3 ⇒ 70%%)", + ) + + parser.add_argument("--skip-audio", action="store_true", help="Skip audio generation") + parser.add_argument("--output", type=str, default="output_vsa.mp4") + parser.add_argument("--device", type=str, default="cuda") + return parser.parse_args() + + +def run_generation( + sampler, + config, + device: str, + num_inference_steps: int, + label: str = "", +) -> tuple[torch.Tensor, torch.Tensor | None, float]: + if label: + print(f"\n{label}") + print(f"Generating video ({num_inference_steps} steps)...") + t0 = time.time() + with StandaloneSamplingProgress(num_steps=num_inference_steps) as progress: + sampler.sampling_context = progress + video, audio = sampler.generate(config=config, device=device) + elapsed = time.time() - t0 + print(f"Generation completed in {elapsed:.2f}s") + return video, audio, elapsed + + +def main() -> None: + if not _LTX_AVAILABLE: + raise ImportError( + "LTX-2 packages are required for this example. Install with: " + "pip install ltx-core ltx-trainer ltx-pipelines. " + f"(original error: {_LTX_IMPORT_ERROR})" + ) + + args = parse_args() + generate_audio = not args.skip_audio + + print("=" * 72) + print("LTX-2 + VSA") + print("=" * 72) + + tokens = calculate_expected_tokens(args.num_frames, args.height, args.width) + tiles = tokens // VSA_BLOCK_ELEMENTS + _, reason = is_vsa_compatible(args.num_frames, args.height, args.width) + print("\nInput Configuration:") + print(f" Resolution: {args.width}x{args.height}") + print(f" Frames: {args.num_frames} @ {args.frame_rate} fps") + print(f" Tokens: {tokens} ({tiles} tiles)") + print(f" VSA: {reason}") + + print("\nLoading LTX-2 model components...") + components = load_model( + checkpoint_path=args.checkpoint, + device="cpu", + dtype=torch.bfloat16, + with_video_vae_encoder=False, + with_video_vae_decoder=True, + with_audio_vae_decoder=generate_audio, + with_vocoder=generate_audio, + with_text_encoder=True, + text_encoder_path=args.text_encoder_path, + ) + print("Model loaded") + + transformer = components.transformer + + if not args.no_vsa: + transformer = apply_vsa( + transformer, + args.num_frames, + args.height, + args.width, + top_k_ratio=args.top_k_ratio, + ) + components.transformer = transformer + + gen_config = GenerationConfig( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + seed=args.seed, + condition_image=None, + reference_video=None, + generate_audio=generate_audio, + include_reference_in_output=False, + ) + + sampler = ValidationSampler( + transformer=components.transformer, + vae_decoder=components.video_vae_decoder, + vae_encoder=components.video_vae_encoder, + text_encoder=components.text_encoder, + audio_decoder=components.audio_vae_decoder if generate_audio else None, + vocoder=components.vocoder if generate_audio else None, + ) + + label = "BASELINE (no VSA)" if args.no_vsa else f"WITH VSA (top_k_ratio={args.top_k_ratio})" + video, audio, elapsed = run_generation( + sampler, gen_config, args.device, args.num_inference_steps, label=label + ) + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + audio_sample_rate = None + if audio is not None and components.vocoder is not None: + audio_sample_rate = components.vocoder.output_sample_rate + save_video( + video_tensor=video, + output_path=out_path, + fps=args.frame_rate, + audio=audio, + audio_sample_rate=audio_sample_rate, + ) + print(f"Saved: {args.output}") + + print("\n" + "=" * 72) + print(f"Done in {elapsed:.2f}s") + print("=" * 72) + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_sparse_attn.py similarity index 57% rename from examples/diffusers/sparsity/wan22_skip_softmax.py rename to examples/diffusers/sparsity/wan22_sparse_attn.py index e335451e2b..65a5618fb0 100644 --- a/examples/diffusers/sparsity/wan22_skip_softmax.py +++ b/examples/diffusers/sparsity/wan22_sparse_attn.py @@ -13,40 +13,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Wan 2.2 inference with skip-softmax sparse attention. +"""Wan 2.2 inference with sparse attention (skip-softmax or VSA). -This example applies skip-softmax sparse attention to the Wan 2.2 video -generation model (text-to-video). Four modes are supported: +Two sparse-attention methods are supported via ``--method``: -1. **Baseline** — pass ``--baseline`` for dense inference (default diffusers backend). -2. **Triton baseline** — pass ``--triton-baseline`` for dense Triton FA kernel - (no skip-softmax, same kernel as sparse runs for apples-to-apples comparison). -3. **Fixed raw threshold** — pass ``--raw-threshold`` to supply a log2-space - threshold directly to the Triton kernel. No calibration data is needed. -4. **Calibrated threshold** — pass ``--calibrate`` to run exponential-model - calibration (``scale_factor = a * exp(b * target_sparsity)``). +- **skip_softmax** (default, BLASST) — drops KV tiles whose attention + scores are negligible during the FlashAttention computation. Reduces FLOPs + without retraining. Supports ``--raw-threshold`` (log2-space, no + calibration) and ``--calibrate`` (exponential model fitted once, target + sparsity tunable at runtime). +- **vsa** (Video Sparse Attention) — two-branch (compression + sparse) + block-level attention tuned for video models. Calibration-free — sparsity + is controlled by a fixed ``top_k_ratio``. The Wan 2.2 plugin auto-derives + ``video_shape`` from each forward's ``hidden_states``. -During calibration, ``triton_skip_softmax`` with the Triton calibration kernel -collects sparsity statistics across multiple threshold trials. The fitted -exponential model then allows runtime control of the target sparsity ratio -without recalibration. +Run modes (method-agnostic): + +- ``--baseline`` — dense inference, no sparsity (default diffusers backend). +- ``--triton-baseline`` — dense Triton FA kernel, no skip-softmax + (apples-to-apples comparison with skip-softmax runs; skip_softmax only). The Wan 2.2 5B model has 40 transformer blocks with self-attention (attn1) -and cross-attention (attn2). Only self-attention is sparsified. +and cross-attention (attn2); the 14B model has two transformers. Only +self-attention is sparsified — cross-attention is always left dense. Usage:: - # Baseline (dense, no sparsity) - python wan22_skip_softmax.py --baseline --prompt "A cat playing piano" \\ - --output baseline.mp4 - - # Fixed raw threshold (no calibration needed) - python wan22_skip_softmax.py --raw-threshold -5.0 --report-avg-sparsity \\ + # Skip-softmax with fixed raw threshold (default method, no calibration) + python wan22_sparse_attn.py --raw-threshold -5.0 --report-avg-sparsity \\ --prompt "A cat playing piano" --output out.mp4 - # With calibration - python wan22_skip_softmax.py --calibrate --target-sparsity 0.25 \\ + # Skip-softmax with calibration + python wan22_sparse_attn.py --calibrate --target-sparsity 0.25 \\ --report-avg-sparsity --prompt "A cat playing piano" --output out.mp4 + + # VSA with 50% top-K (50% sparsity) + python wan22_sparse_attn.py --method vsa --top-k-ratio 0.5 \\ + --prompt "A cat playing piano" --output vsa.mp4 + + # VSA with aggressive 30% top-K (70% sparsity), keep first/last 2 layers dense + python wan22_sparse_attn.py --method vsa --top-k-ratio 0.3 \\ + --skip-first-last 2 --report-avg-sparsity \\ + --prompt "A cat playing piano" --output vsa.mp4 + + # Dense baseline (any method) + python wan22_sparse_attn.py --baseline --prompt "A cat playing piano" \\ + --output baseline.mp4 """ import argparse @@ -73,7 +85,7 @@ ) # fmt: on -# Default threshold trials for calibration +# Default threshold trials for calibration (skip_softmax only) DEFAULT_THRESHOLD_TRIALS = [ 1e-12, 1e-10, @@ -102,7 +114,7 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Wan 2.2 video generation with skip-softmax sparse attention" + description=("Wan 2.2 video generation with sparse attention (skip-softmax or VSA)") ) parser.add_argument( "--prompt", @@ -137,7 +149,15 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--seed", type=int, default=42, help="Random seed") - # Sparse attention options + # ---- Method selection ---- + parser.add_argument( + "--method", + choices=["skip_softmax", "vsa"], + default="skip_softmax", + help="Sparse attention method (default: skip_softmax)", + ) + + # ---- Run-mode flags (method-agnostic) ---- parser.add_argument( "--baseline", action="store_true", @@ -147,15 +167,7 @@ def parse_args() -> argparse.Namespace: "--triton-baseline", action="store_true", help="Run dense inference with Triton FA kernel (no skip-softmax, " - "apples-to-apples comparison with sparse runs)", - ) - parser.add_argument( - "--raw-threshold", - type=float, - default=None, - help="Raw skip_threshold_log2 value passed directly to the Triton kernel. " - "Negative values (e.g., -5.0 means tile must be within 5 units of running max). " - "Bypasses calibration and lambda conversion. Typical range: -1 to -30.", + "apples-to-apples comparison with sparse runs). skip_softmax only.", ) parser.add_argument( "--skip-first-last", @@ -166,40 +178,94 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--report-avg-sparsity", action="store_true", - help="Report per-layer and overall average tile sparsity after generation", + help="[skip_softmax] Report per-layer and overall average tile sparsity " + "measured via the Triton kernel's atomic counters. " + "No-op for VSA (sparsity is deterministic from --top-k-ratio).", + ) + parser.add_argument( + "--enable-vae-tiling", + action="store_true", + help="Enable VAE tiling in the Wan pipeline to reduce peak memory during " + "decode. Recommended at 720p+ when VSA is active, since VSA leaves less " + "GPU memory free than the dense baseline.", ) - # Calibration options + # ---- Skip-softmax options ---- + parser.add_argument( + "--raw-threshold", + type=float, + default=None, + help="[skip_softmax] Raw skip_threshold_log2 value passed directly to the Triton kernel. " + "Negative values (e.g., -5.0 means tile must be within 5 units of running max). " + "Bypasses calibration and lambda conversion. Typical range: -1 to -30.", + ) parser.add_argument( "--calibrate", action="store_true", - help="Calibrate threshold via exponential model (recommended)", + help="[skip_softmax] Calibrate threshold via exponential model (recommended)", ) parser.add_argument( "--target-sparsity", type=float, default=0.5, - help="Target sparsity ratio for calibration (0.0-1.0)", + help="[skip_softmax] Target sparsity ratio for calibration (0.0-1.0)", ) parser.add_argument( "--calib-steps", type=int, default=40, - help="Inference steps for calibration", + help="[skip_softmax] Inference steps for calibration", ) parser.add_argument( "--calib-frames", type=int, default=151, - help="Number of frames for calibration", + help="[skip_softmax] Number of frames for calibration", ) parser.add_argument( "--calib-size", type=int, default=4, - help="Number of calibration prompts from OpenVid-1M dataset", + help="[skip_softmax] Number of calibration prompts from OpenVid-1M dataset", + ) + + # ---- VSA options ---- + parser.add_argument( + "--top-k-ratio", + type=float, + default=0.5, + help="[vsa] Ratio of blocks kept in the sparse branch (0 < ratio ≤ 1). " + "Lower = more sparsity. 0.5 → 50%% sparsity, 0.3 → 70%%.", + ) + parser.add_argument( + "--block-size", + type=str, + default="4,4,4", + help="[vsa] VSA 3D block size as 'T,H,W' (default 4,4,4 → 64-token blocks)", + ) + parser.add_argument( + "--video-shape", + type=str, + default=None, + help="[vsa] Override post-patchify video shape as 'T,H,W'. " + "If unset, the Wan 2.2 plugin derives it automatically from hidden_states.", ) - return parser.parse_args() + + args = parser.parse_args() + + # Cross-method validation + if args.triton_baseline and args.method != "skip_softmax": + parser.error("--triton-baseline is only valid with --method skip_softmax") + + return args + + +def _parse_int_triple(spec: str) -> tuple[int, int, int]: + """Parse 'T,H,W' into a triple of positive ints.""" + parts = [int(p.strip()) for p in spec.split(",")] + if len(parts) != 3 or any(p <= 0 for p in parts): + raise ValueError(f"expected 3 positive integers T,H,W — got {spec!r}") + return (parts[0], parts[1], parts[2]) def build_pipeline(model_path: str) -> WanPipeline: @@ -210,8 +276,8 @@ def build_pipeline(model_path: str) -> WanPipeline: return pipe -def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: - """Build sparse attention config from CLI args. +def build_skip_softmax_config(args: argparse.Namespace, num_blocks: int) -> dict: + """Build a skip-softmax config from CLI args. Two modes: - **Raw threshold**: ``--raw-threshold`` sets ``skip_softmax_raw_threshold`` @@ -257,6 +323,37 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: return config +def build_vsa_config(args: argparse.Namespace, num_blocks: int) -> dict: + """Build a VSA sparse-attention config from CLI args. + + Applies VSA to self-attention (``attn1``) only — cross-attention + (``attn2``) is disabled because VSA's 3D-tile structure does not apply + to text KV. Optionally keeps the first/last N transformer layers dense. + """ + block_size = _parse_int_triple(args.block_size) + + attn_cfg: dict = { + "method": "vsa", + "block_size_3d": block_size, + "top_k_ratio": args.top_k_ratio, + "enable": True, + } + if args.video_shape is not None: + attn_cfg["video_shape"] = _parse_int_triple(args.video_shape) + + sparse_cfg: dict = { + "*.attn1*": attn_cfg, # Self-attention only + "*.attn2*": {"enable": False}, # Text cross-attention + "default": {"enable": False}, + } + + for i in range(args.skip_first_last): + sparse_cfg[f"*blocks.{i}.attn*"] = {"enable": False} + sparse_cfg[f"*blocks.{num_blocks - 1 - i}.attn*"] = {"enable": False} + + return {"sparse_cfg": sparse_cfg} + + def load_calib_prompts(calib_size: int) -> list[str]: """Load calibration prompts from OpenVid-1M dataset.""" dataset = load_dataset("nkp37/OpenVid-1M", split="train") @@ -277,7 +374,7 @@ def build_calibration_forward_loop( guidance_scale_2: float | None = 3.0, negative_prompt: str = "", ): - """Build a forward loop for exponential model calibration. + """Build a forward loop for exponential model calibration (skip_softmax). Uses prompts from OpenVid-1M dataset (same as quantization examples). Each prompt is run individually (batch_size=1). @@ -305,7 +402,12 @@ def forward_loop(model): def enable_sparsity_measurement(model: torch.nn.Module) -> None: - """Enable runtime sparsity measurement on all sparse attention modules.""" + """Enable runtime sparsity measurement on skip-softmax modules. + + Only applies to methods that expose ``enable_measure_sparsity`` (i.e. + the Triton skip-softmax kernel). VSA reports stats via its stats manager + instead — see ``print_vsa_runtime_stats``. + """ for _name, module in model.named_modules(): if isinstance(module, SparseAttentionModule) and module.is_enabled: method = module._sparse_method_instance @@ -315,7 +417,7 @@ def enable_sparsity_measurement(model: torch.nn.Module) -> None: def print_sparsity_summary(model: torch.nn.Module) -> None: - """Print per-module sparsity statistics including runtime kernel counters.""" + """Print per-module sparsity configuration (method-agnostic).""" enabled, disabled = [], [] for name, module in model.named_modules(): if isinstance(module, SparseAttentionModule): @@ -330,8 +432,8 @@ def print_sparsity_summary(model: torch.nn.Module) -> None: print(f" {name}: {info}") -def print_runtime_sparsity(model: torch.nn.Module) -> None: - """Print runtime tile sparsity measured via kernel atomic counters.""" +def print_skip_softmax_runtime_sparsity(model: torch.nn.Module) -> None: + """Print per-layer tile sparsity measured via the Triton kernel's atomic counters.""" total_all = 0 skipped_all = 0 per_module: list[tuple[str, int, int]] = [] @@ -378,6 +480,65 @@ def _get_num_blocks(transformer: torch.nn.Module) -> int: return max_idx + 1 +def _apply_skip_softmax( + pipe: WanPipeline, + transformers: list[tuple[str, torch.nn.Module]], + args: argparse.Namespace, + is_14b: bool, +): + """Sparsify every transformer with skip-softmax. + + Returns the calibration ``forward_loop`` (or None) so the caller can + free memory after calibration completes. + """ + forward_loop = None + if args.triton_baseline: + print("Triton baseline: dense Triton FA kernel (no skip-softmax)") + elif args.raw_threshold is not None: + print(f"Skip-softmax: fixed raw threshold {args.raw_threshold} (no calibration)") + if args.calibrate: + print("Warning: --calibrate is ignored when --raw-threshold is set") + elif args.calibrate: + forward_loop = build_calibration_forward_loop( + pipe, + calib_size=args.calib_size, + num_steps=args.calib_steps, + num_frames=args.calib_frames, + height=args.height, + width=args.width, + seed=args.seed, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_2 if is_14b else None, + negative_prompt=args.negative_prompt, + ) + else: + print( + "Warning: skip_softmax requested without --raw-threshold or --calibrate; " + "falling back to static skip_softmax_threshold=0.1" + ) + + for name, transformer in transformers: + num_blocks = _get_num_blocks(transformer) + label = "Triton backend" if args.triton_baseline else "skip-softmax" + print(f"Applying {label} to {name} ({num_blocks} blocks)...") + config = build_skip_softmax_config(args, num_blocks=num_blocks) + mtsa.sparsify(transformer, config, forward_loop=forward_loop) + + return forward_loop + + +def _apply_vsa( + transformers: list[tuple[str, torch.nn.Module]], + args: argparse.Namespace, +): + """Sparsify every transformer with VSA. No calibration needed.""" + for name, transformer in transformers: + num_blocks = _get_num_blocks(transformer) + print(f"Applying VSA to {name} ({num_blocks} blocks)...") + config = build_vsa_config(args, num_blocks=num_blocks) + mtsa.sparsify(transformer, config) + + def main() -> None: args = parse_args() @@ -385,9 +546,16 @@ def main() -> None: print(f"Loading Wan 2.2 from {args.model_path}...") pipe = build_pipeline(args.model_path) + if args.enable_vae_tiling: + # VAE tiling decodes latents in tiles instead of one shot — essential at + # 720p+ when VSA is active (VSA's tile buffers leave ~15 GB less free GPU + # memory vs. dense, which can OOM the one-shot VAE decode). + pipe.vae.enable_tiling() + print("Enabled VAE tiling (reduces peak memory during decode)") + # ---- Collect transformers ---- # Wan 2.2 5B has one transformer; 14B has two (transformer + transformer_2) - transformers = [] + transformers: list[tuple[str, torch.nn.Module]] = [] if pipe.transformer is not None: transformers.append(("transformer", pipe.transformer)) if getattr(pipe, "transformer_2", None) is not None: @@ -395,57 +563,24 @@ def main() -> None: is_14b = len(transformers) > 1 # ---- Sparsify (unless baseline) ---- + forward_loop = None if args.baseline: print("Baseline mode: running dense inference (default diffusers backend)") - elif args.triton_baseline: - print("Triton baseline: dense Triton FA kernel (no skip-softmax)") - for name, transformer in transformers: - num_blocks = _get_num_blocks(transformer) - print(f"Applying Triton backend to {name} ({num_blocks} blocks)...") - config = build_sparse_config(args, num_blocks=num_blocks) - mtsa.sparsify(transformer, config, forward_loop=None) - else: - # Build calibration forward loop if needed - forward_loop = None - if args.raw_threshold is not None: - print(f"Using fixed raw threshold: {args.raw_threshold} (skipping calibration)") - if args.calibrate: - print("Warning: --calibrate is ignored when --raw-threshold is set") - elif args.calibrate: - forward_loop = build_calibration_forward_loop( - pipe, - calib_size=args.calib_size, - num_steps=args.calib_steps, - num_frames=args.calib_frames, - height=args.height, - width=args.width, - seed=args.seed, - guidance_scale=args.guidance_scale, - guidance_scale_2=args.guidance_scale_2 if is_14b else None, - negative_prompt=args.negative_prompt, - ) - else: - print( - "Warning: neither --baseline, --raw-threshold, nor --calibrate specified; " - "using default static threshold" - ) - - for name, transformer in transformers: - num_blocks = _get_num_blocks(transformer) - print(f"Applying skip-softmax to {name} ({num_blocks} blocks)...") - config = build_sparse_config(args, num_blocks=num_blocks) - mtsa.sparsify(transformer, config, forward_loop=forward_loop) + elif args.method == "skip_softmax": + forward_loop = _apply_skip_softmax(pipe, transformers, args, is_14b) + elif args.method == "vsa": + _apply_vsa(transformers, args) # ---- Free calibration memory before inference ---- - if not args.baseline and not args.triton_baseline and forward_loop is not None: + if forward_loop is not None: gc.collect() torch.cuda.empty_cache() print("Cleared CUDA cache after calibration") # ---- Generate (optional) ---- if args.prompt: - # Enable runtime sparsity measurement before generation - if args.report_avg_sparsity and not args.baseline: + # Enable runtime sparsity measurement before generation (skip_softmax only) + if args.report_avg_sparsity and not args.baseline and args.method == "skip_softmax": for _name, transformer in transformers: enable_sparsity_measurement(transformer) @@ -477,8 +612,11 @@ def main() -> None: for name, transformer in transformers: print(f"\n{name}:") print_sparsity_summary(transformer) - if args.report_avg_sparsity: - print_runtime_sparsity(transformer) + # Runtime sparsity is meaningful only for skip-softmax (data-dependent). + # VSA sparsity is deterministic from top_k_ratio — the per-module + # summary above already reports it via get_threshold_info(). + if args.report_avg_sparsity and args.method == "skip_softmax": + print_skip_softmax_runtime_sparsity(transformer) if __name__ == "__main__": diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py index 66acfb510c..c6bf33b63b 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/vsa.py @@ -286,9 +286,21 @@ def forward_attention( query_tiled = self._tile_tensor(query, metadata) key_tiled = self._tile_tensor(key, metadata) value_tiled = self._tile_tensor(value, metadata) - gate_tiled = ( - self._tile_tensor(gate_compress, metadata) if gate_compress is not None else None - ) + if gate_compress is not None: + gate_tiled = self._tile_tensor(gate_compress, metadata) + else: + # The fastvideo kernel's default behaviour when + # ``compress_attn_weight is None`` is ``out_c + out_s`` — i.e. it + # *adds* the compression branch at full strength on top of the + # sparse branch. For models without a learned ``gate_compress`` + # (e.g. Wan 2.2), this doubles the attention signal and corrupts + # the output. The intended "no gate" semantics is + # ``gate_compress = 0`` → ``out = 0 * out_c + out_s = out_s``, + # which (a) matches an untrained LTX-2 whose ``to_gate_compress`` + # is zero-initialised, and (b) makes VSA at ``top_k_ratio=1.0`` + # reduce to dense attention (since ``out_s`` with all blocks + # selected is mathematically equivalent to dense SDPA). + gate_tiled = torch.zeros((), dtype=query_tiled.dtype, device=query_tiled.device) # ========== TRITON VSA KERNEL ========== # Kernel operates on tiled tensors in [batch, heads, padded_seq, dim] format diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py index 434fc18214..9c99e42c07 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py @@ -15,19 +15,32 @@ """Plugins for sparse attention integration with various frameworks.""" -# List of model plugins that are called during conversion -# Each plugin is a callable that takes (model) and performs validation/setup -CUSTOM_MODEL_PLUGINS: list = [] +from modelopt.torch.utils import import_plugin + +# Set of model plugins called during conversion. A set (rather than a list) +# keeps re-imports idempotent — the same callback inserted multiple times +# stays registered once. Matches the convention used by quantization and peft. +CUSTOM_MODEL_PLUGINS: set = set() def register_custom_model_plugins_on_the_fly(model): - """Applies all registered custom model plugins.""" + """Apply every registered custom model plugin to ``model``.""" for callback in CUSTOM_MODEL_PLUGINS: callback(model) +# Built-in plugins from . import huggingface # noqa: E402 +# Model-specific plugins for VSA. Guarded by ``import_plugin`` so the +# module-level imports stay soft — a missing dependency in one plugin must +# not break the core sparse-attention API. +with import_plugin("ltx2"): + from . import ltx2 + +with import_plugin("wan22"): + from . import wan22 + __all__ = [ "CUSTOM_MODEL_PLUGINS", "register_custom_model_plugins_on_the_fly", diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index d26b73f0b4..f4e43a40de 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -132,4 +132,4 @@ def _is_supported_model(model: nn.Module) -> bool: # Register plugins -CUSTOM_MODEL_PLUGINS.append(register_sparse_attention_on_the_fly) +CUSTOM_MODEL_PLUGINS.add(register_sparse_attention_on_the_fly) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py new file mode 100644 index 0000000000..27bea131ae --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py @@ -0,0 +1,413 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugin for LTX-2 video diffusion models with VSA support. + +LTX-2 uses a native ``LTXSelfAttention`` module whose forward signature is +``(x, context, pe, k_pe)`` and which does not call +``F.scaled_dot_product_attention``. VSA's default SDPA patching in +``SparseAttentionModule`` therefore has no effect on it, so this plugin +installs a model-specific wrapper that: + +1. Projects Q/K/V from ``x`` (and ``context`` for self-attention: ``context = x``) +2. Applies LTX-2's ``q_norm`` / ``k_norm`` RMSNorms and RoPE via ``ltx_core`` +3. Computes an optional ``gate_compress`` from a trainable zero-initialised + projection (used by VSA's compression branch, trained later) +4. Calls ``VSA.forward_attention()`` directly, bypassing SDPA +5. Applies the original module's ``to_out`` projection + +A forward pre-hook on the root ``LTXModel`` extracts the ``(T, H, W)`` +shape from ``Modality.positions`` (same source FastVideo uses) and stores it +on the model, so the wrapper can read it per-step without module-level global +state. +""" + +import logging +import weakref + +import torch +import torch.nn as nn + +from modelopt.torch.utils.logging import warn_rank_0 + +from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry +from . import CUSTOM_MODEL_PLUGINS + +logger = logging.getLogger(__name__) + +_LTX2_LICENSE_WARNING = ( + "LTX-2 packages (ltx-core, ltx-pipelines, ltx-trainer) are provided by " + "Lightricks and are NOT covered by the Apache 2.0 license governing NVIDIA " + "Model Optimizer. You MUST comply with the LTX Community License Agreement " + "when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative " + "models or fine-tuned weights from LTX-2 (including quantized or distilled " + "checkpoints) remain subject to the LTX Community License Agreement, not " + "Apache 2.0. See: https://github.com/Lightricks/LTX-2/blob/main/LICENSE" +) + + +def _extract_video_shape_hook(module: nn.Module, args: tuple) -> None: + """Forward pre-hook on LTXModel to extract ``dit_seq_shape`` from Modality.positions. + + Mirrors FastVideo's ``VideoSparseAttentionMetadataBuilder.build()`` which + computes ``dit_seq_shape = raw_latent_shape // patch_size``. Here we + derive the same shape by counting unique position values per dimension in + ``Modality.positions``, which is available at the LTXModel entry point + (before ``TransformerArgsPreprocessor`` converts it to RoPE embeddings). + + The result is stored on the model instance as ``module._vsa_video_shape`` + so ``_LTX2SparseAttention._resolve_video_shape()`` can read it via its + weak reference to the root model. Using an instance attribute (not a + global) makes this safe for concurrent models. + """ + # LTXModel.forward(self, video: Modality | None, audio, perturbations) + video = args[0] if len(args) > 0 else None + if video is None or not hasattr(video, "positions") or video.positions is None: + return + + positions = video.positions # (B, 3, T) or (B, 3, T, 2) + + try: + if positions.ndim == 4: + # (B, 3, T, 2) -- take start coordinates + pos_per_dim = positions[0, :, :, 0] # (3, T) + elif positions.ndim == 3: + # (B, 3, T) + pos_per_dim = positions[0] # (3, T) + else: + return + + t_dim = pos_per_dim[0].unique().numel() + h_dim = pos_per_dim[1].unique().numel() + w_dim = pos_per_dim[2].unique().numel() + seq_len = positions.shape[2] + + if t_dim * h_dim * w_dim == seq_len: + module._vsa_video_shape = (t_dim, h_dim, w_dim) + logger.debug( + f"Extracted dit_seq_shape={module._vsa_video_shape} from " + f"Modality.positions (seq_len={seq_len})" + ) + else: + logger.debug( + f"Position-derived shape {(t_dim, h_dim, w_dim)} product " + f"({t_dim * h_dim * w_dim}) != seq_len ({seq_len}), skipping" + ) + except Exception: + logger.debug("Failed to extract video_shape from Modality.positions", exc_info=True) + + +def _is_ltx2_model(model: nn.Module) -> bool: + """Check if model is an LTX-2 model. + + Uses ``LTXModel`` / ``LTXSelfAttention`` class names to avoid false + positives from other DiTs (e.g., LongCat) that share similar attribute + patterns. + """ + if type(model).__name__ == "LTXModel": + return True + return any(type(m).__name__ == "LTXSelfAttention" for m in model.modules()) + + +def _is_ltx2_attention_module(module: nn.Module, name: str = "") -> bool: + """Check if a module is an LTX-2 Attention module by class name or structure. + + Primary: class name is ``LTXSelfAttention``. Fallback: has ``to_q/k/v``, + ``q_norm``, ``k_norm``, and ``rope_type`` (unique to LTX-2 among DiTs we + support). + """ + class_name = type(module).__name__ + if class_name == "LTXSelfAttention": + return True + return ( + hasattr(module, "to_q") + and hasattr(module, "to_k") + and hasattr(module, "to_v") + and hasattr(module, "q_norm") + and hasattr(module, "k_norm") + and hasattr(module, "rope_type") + ) + + +class _LTX2SparseAttention(SparseAttentionModule): + """Sparse-attention wrapper for LTX-2 ``LTXSelfAttention`` modules. + + Handles LTX-2 specifics (native forward args, RMSNorm, RoPE, trainable + ``gate_compress``) and delegates the actual attention computation to + ``VSA.forward_attention``. Falls back to the original module forward + for cross-attention / incompatible sequence lengths / missing video + shape, matching how the core SDPA patch falls through to original SDPA. + """ + + def _setup(self): + super()._setup() + + # Add trainable gate_compress projection if not already present. + # Zero-init so its initial contribution is 0 — matches VSA's behaviour + # when gate_compress is None but leaves room for fine-tuning. + if not hasattr(self, "to_gate_compress"): + to_q = self.to_q + in_features = to_q.in_features + out_features = to_q.out_features + + self.to_gate_compress = nn.Linear(in_features, out_features, bias=True) + nn.init.zeros_(self.to_gate_compress.weight) + nn.init.zeros_(self.to_gate_compress.bias) + + self.to_gate_compress = self.to_gate_compress.to( + device=to_q.weight.device, + dtype=to_q.weight.dtype, + ) + + def _compute_qkv( + self, + x: torch.Tensor, + context: torch.Tensor | None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute Q/K/V with LTX-2 norms and RoPE. + + Inputs are ``[batch, seq, hidden_dim]``; output tensors share the same + layout and are reshaped to heads later in ``forward``. + """ + context = context if context is not None else x + + query = self.to_q(x) + key = self.to_k(context) + value = self.to_v(context) + + if hasattr(self, "q_norm"): + query = self.q_norm(query) + if hasattr(self, "k_norm"): + key = self.k_norm(key) + + if pe is not None and hasattr(self, "rope_type"): + try: + from ltx_core.model.transformer.rope import apply_rotary_emb + except ModuleNotFoundError: + raise ModuleNotFoundError( + "LTX-2 VSA plugin requires the 'ltx_core' package for RoPE " + "support. The plugin registered successfully, but 'ltx_core' " + "is needed at runtime. Install with: pip install ltx-core" + ) from None + + query = apply_rotary_emb(query, pe, self.rope_type) + key = apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type) + + return query, key, value + + @staticmethod + def _reshape_for_vsa(tensor: torch.Tensor, num_heads: int) -> torch.Tensor: + """``[batch, seq, hidden]`` → ``[batch, heads, seq, head_dim]``.""" + batch, seq_len, hidden_dim = tensor.shape + head_dim = hidden_dim // num_heads + return tensor.view(batch, seq_len, num_heads, head_dim).transpose(1, 2) + + @staticmethod + def _reshape_from_vsa(tensor: torch.Tensor) -> torch.Tensor: + """``[batch, heads, seq, head_dim]`` → ``[batch, seq, hidden]``.""" + batch, heads, seq_len, head_dim = tensor.shape + return tensor.transpose(1, 2).contiguous().view(batch, seq_len, heads * head_dim) + + def _resolve_video_shape(self, seq_len: int) -> tuple[int, int, int] | None: + """Resolve video_shape for the current forward pass. + + Resolution order (mirrors FastVideo's metadata flow): + 1. ``root_model._vsa_video_shape`` -- set by the forward pre-hook from + ``Modality.positions`` + 2. ``method.video_shape`` -- explicitly set via the sparsify config + """ + root_ref = getattr(self, "_vsa_root_model_ref", None) + root = root_ref() if root_ref is not None else None + if root is not None: + shape = getattr(root, "_vsa_video_shape", None) + if shape is not None: + t, h, w = shape + if t * h * w == seq_len: + return shape + + method = getattr(self, "_sparse_method_instance", None) + if method is not None and method.video_shape is not None: + t, h, w = method.video_shape + if t * h * w == seq_len: + return method.video_shape + + return None + + def forward(self, *args, **kwargs): + """Run the LTX-2 attention forward through VSA. + + Consumes LTX-2's native call signature (``x``, ``context``, ``pe``, + ``k_pe``) and dispatches to ``VSA.forward_attention``; falls through + to the original module for cross-attention or incompatible inputs. + """ + if not self.is_enabled: + return self._call_original_forward(*args, **kwargs) + + x = kwargs.get("x") + if x is None and len(args) > 0: + x = args[0] + + if x is None: + return self._call_original_forward(*args, **kwargs) + + context = kwargs.get("context") + pe = kwargs.get("pe") + k_pe = kwargs.get("k_pe") + + # Cross-attention: fall through to the original module + if context is not None and x.shape[1] != context.shape[1]: + return self._call_original_forward(*args, **kwargs) + + method = getattr(self, "_sparse_method_instance", None) + if method is None: + return self._call_original_forward(*args, **kwargs) + + query, key, value = self._compute_qkv(x, context, pe, k_pe) + + # Incompatible seq_len (e.g., audio attention with seq=32) + seq_len = query.shape[1] + block_size_3d = method.block_size_3d + block_elements = block_size_3d[0] * block_size_3d[1] * block_size_3d[2] + if seq_len < block_elements: + logger.debug(f"VSA skipped: seq_len={seq_len} < block_elements={block_elements}") + return self._call_original_forward(*args, **kwargs) + + video_shape = self._resolve_video_shape(seq_len) + if video_shape is None: + logger.debug(f"VSA skipped: no matching video_shape for seq_len={seq_len}") + return self._call_original_forward(*args, **kwargs) + + gate_compress = None + if hasattr(self, "to_gate_compress"): + gate_compress = self.to_gate_compress(x) + + # Reshape to [batch, heads, seq, head_dim] + query = self._reshape_for_vsa(query, self.heads) + key = self._reshape_for_vsa(key, self.heads) + value = self._reshape_for_vsa(value, self.heads) + if gate_compress is not None: + gate_compress = self._reshape_for_vsa(gate_compress, self.heads) + + output, stats = method.forward_attention( + query=query, + key=key, + value=value, + gate_compress=gate_compress, + video_shape=video_shape, + ) + + # Bubble stats up through SparseAttentionModule's stats path + self._last_stats = stats + if self._stats_manager is not None: + self._stats_manager.collect(stats) + self._last_stats = None + + output = self._reshape_from_vsa(output) + + if hasattr(self, "to_out"): + output = self.to_out(output) + + return output + + def _call_original_forward(self, *args, **kwargs): + """Invoke the original module's forward, bypassing VSA. + + ``SparseAttentionModule.forward`` passes through to the original + module when ``is_enabled`` is False — exploit that to avoid + reimplementing the fallback path. + """ + was_enabled = getattr(self, "_enabled", True) + self._enabled = False + try: + result = SparseAttentionModule.forward(self, *args, **kwargs) + finally: + self._enabled = was_enabled + return result + + def get_gate_compress_parameters(self): + """Return trainable ``gate_compress`` parameters for later fine-tuning.""" + if hasattr(self, "to_gate_compress"): + return self.to_gate_compress.parameters() + return iter([]) + + +def register_ltx2_attention(model: nn.Module) -> int: + """Register LTX-2 Attention modules for VSA wrapping. + + Replaces any existing generic wrapper in ``SparseAttentionRegistry`` + with ``_LTX2SparseAttention`` for each LTX-2 attention type found, wires + a weakref back to the root model on every attention instance, and + installs the ``Modality.positions`` extraction pre-hook. + """ + if not _is_ltx2_model(model): + return 0 + + # Third-party-license notice: emit once per LTX-2 model detection, + # matching the pattern used by modelopt's quantization and kernel LTX-2 + # plugins. The wrapper touches ``ltx_core`` (RoPE) at forward time, so + # users must comply with the LTX Community License Agreement. + warn_rank_0(_LTX2_LICENSE_WARNING, UserWarning, stacklevel=2) + + registered_types = set() + num_modules = 0 + + for name, module in model.named_modules(): + if not _is_ltx2_attention_module(module, name): + continue + + num_modules += 1 + module_type = type(module) + + if module_type in registered_types: + continue + + if module_type in SparseAttentionRegistry: + logger.debug(f"Unregistering generic wrapper for {module_type.__name__}") + SparseAttentionRegistry.unregister(module_type) + + SparseAttentionRegistry.register({module_type: module_type.__name__})(_LTX2SparseAttention) + registered_types.add(module_type) + logger.info(f"Registered LTX-2 attention: {module_type.__name__}") + + if num_modules > 0: + logger.info(f"Found {num_modules} LTX-2 Attention modules in model") + + # Weakref avoids the circular-submodule problem (nn.Module.__setattr__ + # would otherwise register the root model as a submodule of every + # attention, causing infinite recursion in named_children()). + root_ref = weakref.ref(model) + for _, module in model.named_modules(): + if _is_ltx2_attention_module(module): + object.__setattr__(module, "_vsa_root_model_ref", root_ref) + + model.register_forward_pre_hook(_extract_video_shape_hook) + logger.debug("Registered VSA video_shape extraction hook on model") + + return len(registered_types) + + +def register_ltx2_on_the_fly(model: nn.Module) -> bool: + """Plugin entry point: wire up LTX-2 VSA if this is an LTX-2 model.""" + num_registered = register_ltx2_attention(model) + if num_registered > 0: + logger.info(f"Registered {num_registered} LTX-2 attention types for VSA") + return True + return False + + +# Idempotent: plugins/__init__.py stores plugins in a set so re-imports are safe. +CUSTOM_MODEL_PLUGINS.add(register_ltx2_on_the_fly) diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py b/modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py new file mode 100644 index 0000000000..f343f58e63 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plugin for Wan 2.2 video diffusion models with VSA support. + +Wan 2.2 (``WanTransformer3DModel`` from diffusers) uses standard diffusers +``Attention`` modules whose ``AttnProcessor2_0`` calls +``F.scaled_dot_product_attention``. VSA's default SDPA patch in +``SparseAttentionModule._forward_with_vsa_sdpa_patch`` therefore intercepts +the right call — we only need to tell VSA the post-patchify ``(T, H, W)``. + +This plugin installs a forward pre-hook on every ``WanTransformer3DModel`` +that: + +1. Reads ``hidden_states`` shape ``(B, C, T, H, W)`` from the transformer + input. +2. Divides by ``model.config.patch_size = (p_t, p_h, p_w)`` — same + computation diffusers does internally (see + ``WanTransformer3DModel.forward``: ``post_patch_num_frames = num_frames // p_t`` + etc.). +3. Propagates the resulting shape to every ``SparseAttentionModule`` in + the transformer whose method is VSA, via ``method.set_video_shape()``. + +Self-attention layers (``attn1``) then see a valid ``video_shape`` when the +SDPA patch fires. Cross-attention (``attn2``) is skipped by VSA's +``can_apply_vsa`` guard since Q/K lengths differ. +""" + +import logging + +import torch.nn as nn + +from ..sparse_attention import SparseAttentionModule +from . import CUSTOM_MODEL_PLUGINS + +logger = logging.getLogger(__name__) + + +def _is_wan22_model(model: nn.Module) -> bool: + """Detect a Wan 2.2 transformer by class name. + + Wan 2.1 / 2.2 both use ``WanTransformer3DModel`` in diffusers — matching + by name keeps the plugin decoupled from the diffusers import. + """ + if type(model).__name__ == "WanTransformer3DModel": + return True + return any(type(m).__name__ == "WanTransformer3DModel" for m in model.modules()) + + +def _find_wan22_transformers(model: nn.Module) -> list[nn.Module]: + """Return every ``WanTransformer3DModel`` reachable from ``model``. + + The 14B model is a ``WanPipeline`` with ``transformer`` and + ``transformer_2``, so we return every match. + """ + if type(model).__name__ == "WanTransformer3DModel": + return [model] + return [m for m in model.modules() if type(m).__name__ == "WanTransformer3DModel"] + + +def _get_patch_size(transformer: nn.Module) -> tuple[int, int, int] | None: + """Read ``patch_size`` from the transformer's config.""" + config = getattr(transformer, "config", None) + if config is None: + return None + patch_size = getattr(config, "patch_size", None) + if patch_size is None: + return None + try: + p_t, p_h, p_w = patch_size + except (TypeError, ValueError): + return None + return (int(p_t), int(p_h), int(p_w)) + + +def _extract_hidden_states(args: tuple, kwargs: dict): + """Pick out the ``hidden_states`` argument regardless of call style.""" + if "hidden_states" in kwargs: + return kwargs["hidden_states"] + return args[0] if len(args) > 0 else None + + +def _make_wan22_video_shape_hook(transformer: nn.Module): + """Create the per-transformer forward pre-hook. + + Closes over the specific ``transformer`` so it can walk its own + submodules, independent of other Wan 2.2 transformers in the same + pipeline. + """ + patch_size = _get_patch_size(transformer) + if patch_size is None: + logger.debug("Wan 2.2 transformer has no config.patch_size; hook inert") + + def _noop(module, args, kwargs): + return None + + return _noop + + p_t, p_h, p_w = patch_size + + def _hook(module: nn.Module, args: tuple, kwargs: dict) -> None: + hidden_states = _extract_hidden_states(args, kwargs) + if hidden_states is None or hidden_states.ndim != 5: + return + + _, _, num_frames, height, width = hidden_states.shape + video_shape = (num_frames // p_t, height // p_h, width // p_w) + if any(d <= 0 for d in video_shape): + logger.debug( + f"Wan 2.2 VSA hook: invalid video_shape {video_shape} for " + f"input {(num_frames, height, width)} / patch {patch_size}; skipping" + ) + return + + # Also expose on the transformer for debugging / external inspection. + module._vsa_video_shape = video_shape + + # Propagate to every VSA method instance in this transformer. + for sub in module.modules(): + if not isinstance(sub, SparseAttentionModule): + continue + method = getattr(sub, "_sparse_method_instance", None) + if method is None: + continue + if getattr(method, "name", None) != "vsa": + continue + method.set_video_shape(video_shape) + + return _hook + + +def register_wan22_vsa(model: nn.Module) -> int: + """Install a VSA ``video_shape`` pre-hook on every Wan 2.2 transformer. + + Idempotent: the hook is re-registered on each call because + ``plugins/__init__.py`` stores callbacks in a set — re-invoking after + ``mtsa.sparsify`` is safe, but we guard against double-registration by + tagging the transformer with ``_vsa_hook_registered``. + """ + transformers = _find_wan22_transformers(model) + if not transformers: + return 0 + + registered = 0 + for transformer in transformers: + if getattr(transformer, "_vsa_hook_registered", False): + continue + hook = _make_wan22_video_shape_hook(transformer) + transformer.register_forward_pre_hook(hook, with_kwargs=True) + transformer._vsa_hook_registered = True + registered += 1 + logger.info(f"Registered Wan 2.2 VSA video_shape hook on {type(transformer).__name__}") + + return registered + + +def register_wan22_on_the_fly(model: nn.Module) -> bool: + """Plugin entry point: install the Wan 2.2 VSA hook if applicable.""" + if not _is_wan22_model(model): + return False + num = register_wan22_vsa(model) + if num > 0: + logger.info(f"Installed VSA video_shape hook on {num} Wan 2.2 transformer(s)") + return True + return False + + +CUSTOM_MODEL_PLUGINS.add(register_wan22_on_the_fly) From 8b9aed1c8ce718385be8f1ab7e90f03985e05fdd Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 22 Apr 2026 20:34:58 +0000 Subject: [PATCH 7/9] Update the changelog Signed-off-by: Jingyu Xin --- CHANGELOG.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e1a5aacca1..4e14f50325 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,13 @@ Changelog ========= +0.45 (2026-06-xx) +^^^^^^^^^^^^^^^^^ + +**Backward Breaking Changes** + +- Reorganize custom CUDA / Triton kernels under ``modelopt.torch.kernels`` into ``common/attention``, ``quantization/{conv,gemm}``, and ``sparsity/attention``. Direct imports from the old paths (``quantization.conv_gemm``, ``quantization.src``, ``sparsity.attention_sparsity.kernels``, flat ``kernels.triton_fa`` / ``kernels.hf_triton_attention``) must be updated; high-level APIs are unchanged. + 0.44 (2026-05-xx) ^^^^^^^^^^^^^^^^^ From 852572f0e41b0660659d40765d94e66e8f7b4372 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 22 Apr 2026 20:43:30 +0000 Subject: [PATCH 8/9] Fixed the bug Signed-off-by: Jingyu Xin --- .../kernels/sparsity/attention/calibrate.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/kernels/sparsity/attention/calibrate.py b/modelopt/torch/kernels/sparsity/attention/calibrate.py index 8e7ef144c7..37c5fccd6b 100644 --- a/modelopt/torch/kernels/sparsity/attention/calibrate.py +++ b/modelopt/torch/kernels/sparsity/attention/calibrate.py @@ -155,8 +155,11 @@ def _attn_fwd_calibrate( row_max = m_new # --- Write per-program counters (no atomics, just stores) --- - # Compute unique flat program index for this (batch, head, q_tile) - num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound + # Compute unique flat program index for this (batch, head, q_tile). + # Use tl.num_programs(2) (grid z dim = cdiv(max_input_len, BLOCK_M)) so the + # stride matches the wrapper's buffer layout for any batch order. Loading + # b_seq_len[0] would collide with later batches when batch 0 is shorter. + num_q_tiles = tl.num_programs(2) num_heads = tl.num_programs(1) prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q base = prog_idx * NUM_THRESHOLDS @@ -214,6 +217,33 @@ def attention_calibrate( if threshold_trials is None or len(threshold_trials) == 0: raise ValueError("threshold_trials must be a non-empty list") + # Calibration has only been validated with uniform-length batches (current + # diffusion + RULER paths). Varlen inputs would exercise code paths in the + # kernel that have not been tested — fail loudly rather than silently + # produce wrong sparsity counts. + if b_seq_len.numel() > 1 and not torch.all(b_seq_len == b_seq_len[0]).item(): + raise NotImplementedError( + "attention_calibrate currently supports only uniform-length batches. " + f"Got b_seq_len={b_seq_len.tolist()}. Varlen calibration is untested — " + "validate the kernel against a reference before removing this guard." + ) + if int(b_seq_len[0].item()) != max_input_len: + raise ValueError( + "attention_calibrate expects max_input_len to equal b_seq_len[0] " + f"(uniform batching). Got max_input_len={max_input_len}, " + f"b_seq_len[0]={int(b_seq_len[0].item())}." + ) + if ( + b_seq_len_k is not None + and b_seq_len_k.data_ptr() != b_seq_len.data_ptr() + and b_seq_len_k.numel() > 1 + and not torch.all(b_seq_len_k == b_seq_len_k[0]).item() + ): + raise NotImplementedError( + "attention_calibrate currently supports only uniform-length batches. " + f"Got b_seq_len_k={b_seq_len_k.tolist()}. Varlen calibration is untested." + ) + HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] num_kv_heads = k.shape[1] From b6ad34005595cbc42a48c38f0a9d96856bb6903c Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 23 Apr 2026 00:10:31 +0000 Subject: [PATCH 9/9] Removed the ltx2 example Signed-off-by: Jingyu Xin --- examples/diffusers/sparsity/ltx2_vsa.py | 276 ------------------------ 1 file changed, 276 deletions(-) delete mode 100644 examples/diffusers/sparsity/ltx2_vsa.py diff --git a/examples/diffusers/sparsity/ltx2_vsa.py b/examples/diffusers/sparsity/ltx2_vsa.py deleted file mode 100644 index 87db1faf01..0000000000 --- a/examples/diffusers/sparsity/ltx2_vsa.py +++ /dev/null @@ -1,276 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""LTX-2 inference with Video Sparse Attention (VSA). - -Applies VSA to LTX-2's self-attention modules. VSA is calibration-free — -sparsity is controlled via ``top_k_ratio`` (fraction of 3D blocks kept in -the sparse branch). - -The LTX-2 plugin under ``modelopt.torch.sparsity.attention_sparsity.plugins.ltx2`` -handles the specifics: - -- Detects ``LTXSelfAttention`` modules by class name. -- Computes ``(T, H, W)`` from ``Modality.positions`` at each forward. -- Wraps each attention module in ``_LTX2SparseAttention``, which computes - Q/K/V, RoPE, and the optional (zero-initialised, trainable) - ``gate_compress`` before calling ``VSA.forward_attention``. - -Requirements: -- ``fastvideo_kernel`` (Triton VSA kernel). -- ``ltx_core``, ``ltx_trainer``, ``ltx_pipelines`` (third-party LTX-2 packages - from Lightricks — see the LICENSE notice in the top-level sparsity README). - -Example:: - - # VSA at 50% top-K ratio - python ltx2_vsa.py --checkpoint path/to/model.safetensors \\ - --text-encoder-path path/to/gemma --top-k-ratio 0.5 \\ - --prompt "A cat playing piano" --output vsa.mp4 - - # Baseline (no VSA) - python ltx2_vsa.py --checkpoint path/to/model.safetensors \\ - --text-encoder-path path/to/gemma --no-vsa --output baseline.mp4 -""" - -import argparse -import copy -import time -from pathlib import Path - -import torch - -import modelopt.torch.sparsity.attention_sparsity as mtsa -from modelopt.torch.sparsity.attention_sparsity.config import VSA_DEFAULT - -# LTX-2 is optional; import lazily so --help works even without it. -try: - from ltx_trainer.model_loader import load_model - from ltx_trainer.progress import StandaloneSamplingProgress - from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler - from ltx_trainer.video_utils import save_video - - _LTX_AVAILABLE = True -except ImportError as _exc: - _LTX_IMPORT_ERROR = _exc - _LTX_AVAILABLE = False - - -# LTX-2 uses a 1:8192 pixels-to-tokens compression ratio -LTX2_PIXEL_TO_TOKEN_RATIO = 8192 - -# VSA 3D block size: 4x4x4 = 64 tokens per block -VSA_BLOCK_ELEMENTS = 64 - - -def calculate_expected_tokens(num_frames: int, height: int, width: int) -> int: - return num_frames * height * width // LTX2_PIXEL_TO_TOKEN_RATIO - - -def is_vsa_compatible(num_frames: int, height: int, width: int) -> tuple[bool, str]: - """Check whether the requested input size is large enough for VSA to help.""" - tokens = calculate_expected_tokens(num_frames, height, width) - tiles = tokens // VSA_BLOCK_ELEMENTS - if tiles >= 90: - return True, f"Excellent: {tokens} tokens ({tiles} tiles)" - if tiles >= 16: - return True, f"Marginal: {tokens} tokens ({tiles} tiles)" - return False, f"Too small: {tokens} tokens ({tiles} tiles, need 16+ for VSA)" - - -def apply_vsa( - transformer: torch.nn.Module, - num_frames: int, - height: int, - width: int, - top_k_ratio: float, -) -> torch.nn.Module: - """Apply VSA to the LTX-2 transformer.""" - compatible, reason = is_vsa_compatible(num_frames, height, width) - print(f" VSA compatibility: {reason}") - if not compatible: - print(" [WARNING] Input size may be too small for VSA to help.") - - config = copy.deepcopy(VSA_DEFAULT) - # Override top_k_ratio on the attention pattern - for cfg in config["sparse_cfg"].values(): - if isinstance(cfg, dict) and cfg.get("method") == "vsa": - cfg["top_k_ratio"] = top_k_ratio - - print(" Applying VSA to attention modules...") - return mtsa.sparsify(transformer, config) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="LTX-2 video generation with Video Sparse Attention (VSA)", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("--checkpoint", type=str, required=True, help="LTX-2 model checkpoint") - parser.add_argument( - "--text-encoder-path", type=str, required=True, help="Gemma text encoder directory" - ) - parser.add_argument( - "--prompt", - type=str, - default="A serene mountain landscape with a flowing river, golden hour lighting", - ) - parser.add_argument("--negative-prompt", type=str, default="") - parser.add_argument("--height", type=int, default=512, help="Video height (multiple of 32)") - parser.add_argument("--width", type=int, default=768, help="Video width (multiple of 32)") - parser.add_argument("--num-frames", type=int, default=121, help="Must be k*8 + 1") - parser.add_argument("--frame-rate", type=float, default=25.0) - parser.add_argument("--num-inference-steps", type=int, default=30) - parser.add_argument("--guidance-scale", type=float, default=4.0, help="CFG scale") - parser.add_argument("--seed", type=int, default=42) - - parser.add_argument( - "--no-vsa", - action="store_true", - help="Disable VSA (baseline run, for timing comparison)", - ) - parser.add_argument( - "--top-k-ratio", - type=float, - default=0.5, - help="VSA sparsity ratio (0.5 ⇒ 50%% sparsity, 0.3 ⇒ 70%%)", - ) - - parser.add_argument("--skip-audio", action="store_true", help="Skip audio generation") - parser.add_argument("--output", type=str, default="output_vsa.mp4") - parser.add_argument("--device", type=str, default="cuda") - return parser.parse_args() - - -def run_generation( - sampler, - config, - device: str, - num_inference_steps: int, - label: str = "", -) -> tuple[torch.Tensor, torch.Tensor | None, float]: - if label: - print(f"\n{label}") - print(f"Generating video ({num_inference_steps} steps)...") - t0 = time.time() - with StandaloneSamplingProgress(num_steps=num_inference_steps) as progress: - sampler.sampling_context = progress - video, audio = sampler.generate(config=config, device=device) - elapsed = time.time() - t0 - print(f"Generation completed in {elapsed:.2f}s") - return video, audio, elapsed - - -def main() -> None: - if not _LTX_AVAILABLE: - raise ImportError( - "LTX-2 packages are required for this example. Install with: " - "pip install ltx-core ltx-trainer ltx-pipelines. " - f"(original error: {_LTX_IMPORT_ERROR})" - ) - - args = parse_args() - generate_audio = not args.skip_audio - - print("=" * 72) - print("LTX-2 + VSA") - print("=" * 72) - - tokens = calculate_expected_tokens(args.num_frames, args.height, args.width) - tiles = tokens // VSA_BLOCK_ELEMENTS - _, reason = is_vsa_compatible(args.num_frames, args.height, args.width) - print("\nInput Configuration:") - print(f" Resolution: {args.width}x{args.height}") - print(f" Frames: {args.num_frames} @ {args.frame_rate} fps") - print(f" Tokens: {tokens} ({tiles} tiles)") - print(f" VSA: {reason}") - - print("\nLoading LTX-2 model components...") - components = load_model( - checkpoint_path=args.checkpoint, - device="cpu", - dtype=torch.bfloat16, - with_video_vae_encoder=False, - with_video_vae_decoder=True, - with_audio_vae_decoder=generate_audio, - with_vocoder=generate_audio, - with_text_encoder=True, - text_encoder_path=args.text_encoder_path, - ) - print("Model loaded") - - transformer = components.transformer - - if not args.no_vsa: - transformer = apply_vsa( - transformer, - args.num_frames, - args.height, - args.width, - top_k_ratio=args.top_k_ratio, - ) - components.transformer = transformer - - gen_config = GenerationConfig( - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - frame_rate=args.frame_rate, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - seed=args.seed, - condition_image=None, - reference_video=None, - generate_audio=generate_audio, - include_reference_in_output=False, - ) - - sampler = ValidationSampler( - transformer=components.transformer, - vae_decoder=components.video_vae_decoder, - vae_encoder=components.video_vae_encoder, - text_encoder=components.text_encoder, - audio_decoder=components.audio_vae_decoder if generate_audio else None, - vocoder=components.vocoder if generate_audio else None, - ) - - label = "BASELINE (no VSA)" if args.no_vsa else f"WITH VSA (top_k_ratio={args.top_k_ratio})" - video, audio, elapsed = run_generation( - sampler, gen_config, args.device, args.num_inference_steps, label=label - ) - - out_path = Path(args.output) - out_path.parent.mkdir(parents=True, exist_ok=True) - audio_sample_rate = None - if audio is not None and components.vocoder is not None: - audio_sample_rate = components.vocoder.output_sample_rate - save_video( - video_tensor=video, - output_path=out_path, - fps=args.frame_rate, - audio=audio, - audio_sample_rate=audio_sample_rate, - ) - print(f"Saved: {args.output}") - - print("\n" + "=" * 72) - print(f"Done in {elapsed:.2f}s") - print("=" * 72) - - -if __name__ == "__main__": - main()