diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0d7876149..9c79fe703 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -425,6 +425,27 @@ def is_mx_format(self): and self.block_sizes.get("scale_bits", None) == (8, 0) ) + @property + def is_mxfp4(self): + """Check if is MXFP4.""" + return ( + self.is_mx_format and self.num_bits == (2, 1) and self.block_sizes.get(-1, None) == 32 + ) + + @property + def is_mxfp6(self): + """Check if is MXFP6.""" + return ( + self.is_mx_format and self.num_bits == (3, 2) and self.block_sizes.get(-1, None) == 32 + ) + + @property + def is_mxfp8(self): + """Check if is MXFP8.""" + return ( + self.is_mx_format and self.num_bits == (4, 3) and self.block_sizes.get(-1, None) == 32 + ) + @property def is_static_block_quant(self): """Check if is static block quantization.""" diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a3fa6ef1a..5150a879d 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -28,6 +28,13 @@ except ImportError: Shard = None +try: + import kitchen + from kitchen.fa import KitchenFlashAttentionModule + from kitchen.triton_module import triton_fa_params +except ImportError: + kitchen = None + import torch.nn as nn import transformers from transformers.models.t5.modeling_t5 import T5Attention @@ -56,17 +63,94 @@ def _setup(self): self.q_bmm_quantizer = TensorQuantizer() self.k_bmm_quantizer = TensorQuantizer() self.v_bmm_quantizer = TensorQuantizer() + self.softmax_quantizer = TensorQuantizer() + self.kitchen_attn_fn = None + self.use_kitchen = False + + def _init_kitchen_attn_fn(self): + if not self.softmax_quantizer.is_enabled: + self.kitchen_attn_fn = "disabled" + return + self.use_kitchen = True + if self.softmax_quantizer.is_mxfp8: + qfa_params = triton_fa_params.QTritonFAParams( + backend="triton", + qk_dot_precisions="bf16@bf16", + pv_dot_precisions="mxfp8_e4m3_emulation@bf16", + dp_v_x_do_dot_precisions="bf16@bf16", + dp_do_x_v_dot_precisions="bf16@bf16", + dq_ds_x_k_dot_precisions="bf16@bf16", + dk_ds_x_q_dot_precisions="bf16@bf16", + dv_p_x_do_dot_precisions="bf16@bf16", + use_natural_transcendental_func=False, # Different from default + ) + else: + raise NotImplementedError(f"softmax_quantizer not supported: {self.softmax_quantizer}") + + self.kitchen_attn_fn = KitchenFlashAttentionModule( + num_attention_heads=self.config.num_attention_heads, + kv_channels=self.config.head_dim, + num_gqa_groups=None, # self.config.num_key_value_heads, kitchen does not support gqa. + attention_dropout=self.config.attention_dropout, + qkv_format="sbhd", # this is not used at all, but in forward, this is the only supported format. + attn_mask_type="causal", + window_size=getattr(self.config, "sliding_window", None), + sequence_parallel=False, + get_rng_state_tracker=None, + layer_number=None, + attention_type="self", + softmax_scale=None, # This will be convert to the same default as sdpa: 1/sqrt(dim_q) + qfa_params=qfa_params, + ) @staticmethod def _quantized_attention( - original_attention_interface, self, query_states, key_states, value_states, *args, **kwargs + original_attention_interface, + self, + query_states, + key_states, + value_states, + *args, + **kwargs, ): + if kitchen is not None and self.kitchen_attn_fn is None: + self._init_kitchen_attn_fn() + query_states = self.q_bmm_quantizer(query_states) key_states = self.k_bmm_quantizer(key_states) value_states = self.v_bmm_quantizer(value_states) - return original_attention_interface( - self, query_states, key_states, value_states, *args, **kwargs - ) + if not self.use_kitchen: + return original_attention_interface( + self, query_states, key_states, value_states, *args, **kwargs + ) + + query_sequence_length = query_states.shape[2] + if query_states.shape[2] < key_states.shape[2]: # For decoding stage. + shape = list(query_states.shape) + shape[2] = key_states.shape[2] - query_states.shape[2] + query_states = torch.cat( + [ + torch.empty(shape, dtype=query_states.dtype, device=query_states.device), + query_states, + ], + dim=2, + ) + + n_repeat = self.config.num_attention_heads // self.config.num_key_value_heads + if n_repeat > 1: + key_states = key_states.repeat_interleave(n_repeat, dim=1) + value_states = value_states.repeat_interleave(n_repeat, dim=1) + # kitchen only supports sbhd. we have bhsd. + query_states = query_states.permute(2, 0, 1, 3) + key_states = key_states.permute(2, 0, 1, 3) + value_states = value_states.permute(2, 0, 1, 3) + attn_out = self.kitchen_attn_fn(query_states, key_states, value_states) + attn_out = attn_out[-query_sequence_length:, :, :] + # output is sb(h*d), we need bshd + attn_out = attn_out.reshape( + (attn_out.shape[0], attn_out.shape[1], query_states.shape[2], -1) + ).permute(1, 0, 2, 3) + return attn_out.contiguous(), None def forward(self, *args, **kwargs): """Forward method for KV cache quantization compatible with new_attention_interface in transformers >= 4.48.0. diff --git a/tests/unit/torch/quantization/plugins/test_attention_quant.py b/tests/unit/torch/quantization/plugins/test_attention_quant.py index 832bc8a2e..9526f80ac 100644 --- a/tests/unit/torch/quantization/plugins/test_attention_quant.py +++ b/tests/unit/torch/quantization/plugins/test_attention_quant.py @@ -13,11 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + import pytest import torch import torch.nn as nn import torch.nn.functional as F from _test_utils.torch.transformers_models import get_tiny_bert, get_tiny_llama, get_tiny_t5 +from transformers import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention + +try: + import kitchen +except ImportError: + kitchen = None import modelopt.torch.quantization as mtq from modelopt.torch.quantization.plugins.huggingface import _QuantAttention @@ -54,6 +63,7 @@ def forward(self, hidden_states, **kwargs): kv_cache_config = { "quant_cfg": { "*[kv]_bmm_quantizer": {"num_bits": 4, "enable": True}, + "*softmax_quantizer": {"enable": False}, }, "algorithm": "max", } @@ -147,3 +157,77 @@ def test_kv_quant_bert(): assert output is not None assert output.start_logits is not None assert output.end_logits is not None + + +@pytest.mark.skipif(kitchen is None, reason="kitchen is not installed.") +def test_kitchen_fa(): + batch_size = 2 + num_q_heads = 4 + num_kv_heads = 2 + seqlen = 8 + hidden_size = 128 + + config = LlamaConfig( + hidden_size=hidden_size, + num_attention_heads=num_q_heads, + num_key_value_heads=num_kv_heads, + ) + original_attention = LlamaAttention(config, layer_idx=0) + + q_states = torch.randn( + batch_size, num_q_heads, seqlen, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + k_states = torch.randn( + batch_size, num_kv_heads, seqlen, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + v_states = torch.randn( + batch_size, num_kv_heads, seqlen, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + + # Convert it to _QuantAttention using the convert() class method + quant_attention = _QuantAttention.convert(original_attention) + quant_attention.config._attn_implementation = "sdpa" + assert hasattr(quant_attention, "q_bmm_quantizer") + assert hasattr(quant_attention, "k_bmm_quantizer") + assert hasattr(quant_attention, "v_bmm_quantizer") + assert hasattr(quant_attention, "softmax_quantizer") + quant_attention.softmax_quantizer.disable() + module = inspect.getmodule(quant_attention.get_attn_type(quant_attention)) + orig_attn_fn = module.ALL_ATTENTION_FUNCTIONS["sdpa"] + + output = quant_attention._quantized_attention( + orig_attn_fn, + quant_attention, + q_states, + k_states, + v_states, + attention_mask=None, + ) + expected = output[0] + + config = LlamaConfig( + hidden_size=hidden_size, + num_attention_heads=num_q_heads, + num_key_value_heads=num_kv_heads, + ) + original_attention = LlamaAttention(config, layer_idx=0) + quant_attention = _QuantAttention.convert(original_attention) + quant_attention.config._attn_implementation = "sdpa" + quant_attention.softmax_quantizer.num_bits = (4, 3) + quant_attention.softmax_quantizer.block_sizes = { + -1: 32, + "type": "dynamic", + "scale_bits": (8, 0), + } + output = quant_attention._quantized_attention( + None, + quant_attention, + q_states, + k_states, + v_states, + attention_mask=None, + ) + diff = (expected - output[0]).abs() + assert torch.allclose(expected, output[0], atol=0.75, rtol=0.75), ( + f"{diff.max().item(), diff.mean().item(), diff.std().item()}" + )