diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 815e367f71..acb4e49d3e 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. import argparse +import os import torch import torch.utils.benchmark as benchmark import pandas as pd @@ -185,6 +186,8 @@ def run_benchmark_linear( x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided + if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "1"))): + m_splits = torch.tensor(m_splits, dtype=torch.int64, device=device) # Bias is not supported for GroupedLinear benchmark bias = None diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index c35dc4c063..422d6c2f6b 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -29,6 +29,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_P python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_grouped_linear.xml $TE_PATH/tests/pytorch/test_grouped_linear.py || test_fail "test_grouped_linear.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py new file mode 100644 index 0000000000..2f9b35beeb --- /dev/null +++ b/tests/pytorch/test_grouped_linear.py @@ -0,0 +1,1699 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import random +from typing import Dict, List, Optional + +import pytest +import torch +import torch.nn as nn +from torch.nn import Parameter + +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +from transformer_engine.pytorch import ( + Float8Quantizer, + Fp8Padding, + Fp8Unpadding, + GroupedLinear, + Linear, + MXFP8Quantizer, + autocast, + is_bf16_available, + quantized_model_init, +) +from transformer_engine.pytorch.cpp_extensions import ( + general_gemm, + general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, +) +from transformer_engine.pytorch.quantization import ( + FP8GlobalStateManager, + get_align_size_for_quantization, +) +from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor +import transformer_engine_torch as tex +from utils import ModelConfig, reset_rng_states + +# Only run FP8 tests on supported devices. +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) + +seed = 1234 +reset_rng_states() + +NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) + +if NVTE_TEST_NVINSPECT_ENABLED: + import nvdlfw_inspect.api as debug_api + + debug_api.initialize( + os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"], + feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], + ) + + +model_configs = { + "126m": ModelConfig(1, 2048, 12, 64, num_layers=12), +} + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + +def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + raise ValueError(f"Unsupported dtype ({dtype})") + + +param_types = [torch.float32, torch.float16] +if is_bf16_available(): + param_types.append(torch.bfloat16) + +batch_sizes = [1, 2] +all_boolean = [True, False] + +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_block_scaling_available: + fp8_recipes.append(recipe.Float8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) +if nvfp4_available: + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + +use_cutlass_grouped_gemm = [False] +if torch.cuda.get_device_capability() == (9, 0): + use_cutlass_grouped_gemm.append(True) + + +class TorchGroupedLinearWithPadding(nn.Module): + + def __init__( + self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8 + ) -> None: + super().__init__() + + self.padding = Fp8Padding(num_gemms) + self.linear_fn = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + device="cuda", + ) + self.unpadding = Fp8Unpadding(num_gemms) + + self.fp8 = fp8 + + def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: + if self.fp8: + orig_m_splits = m_splits + inp, m_splits = self.padding(inp, m_splits) + + out = self.linear_fn(inp, m_splits) + + if self.fp8: + out = self.unpadding(out, orig_m_splits) + + return out + + +def _test_grouped_linear_accuracy( + block, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute=False, +): + reset_rng_states() + if fp8: + FP8GlobalStateManager.reset() + + inp_hidden_states = torch.randn( + (config.max_seqlen_q, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_hidden_states.retain_grad() + + if num_gemms > 1: + split_size = 1 + if fp8: + split_size = get_align_size_for_quantization(recipe) + m = config.max_seqlen_q // split_size + dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() + dist.append(dist[-1]) # Manually add a zero + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + m_splits = m_splits * split_size + assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms + else: + m_splits = torch.tensor([config.max_seqlen_q]) + + with autocast(enabled=fp8, recipe=recipe): + if isinstance(block, GroupedLinear): + m_splits = m_splits * bs + out = block(inp_hidden_states, m_splits.tolist()) + else: + out = torch.cat( + [ + block[i](inp) + for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist())) + ] + ) + loss = out.sum() + loss.backward() + if delay_wgrad_compute: + if isinstance(block, GroupedLinear): + block.backward_dw() + else: + for i in range(num_gemms): + block[i].backward_dw() + + torch.cuda.synchronize() + outputs = [out, inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + if getattr(p, "main_grad", None) is not None: + outputs.append(p.main_grad) + assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True + else: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("bias", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + bias, + delay_wgrad_compute, + parallel_mode=None, + use_cutlass=False, +): + fp8 = recipe is not None + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: + pytest.skip("Delayed wgrad compute is not supported in debug mode.") + + config = model_configs[model] + if config.max_seqlen_q % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + delay_wgrad_compute=delay_wgrad_compute, + save_original_input=False, + ).eval() + sequential_linear = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + for _ in range(num_gemms) + ] + ) + + # Share params + with torch.no_grad(): + for i in range(num_gemms): + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() + + outputs_ref = _test_grouped_linear_accuracy( + sequential_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + + for o, o_ref in zip(outputs, outputs_ref): + if use_cutlass: + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + else: + # cuBLAS implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_accuracy_cutlass( + dtype, + num_gemms, + bs, + model, + fuse_wgrad_accumulation, + delay_wgrad_compute, + monkeypatch, +): + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + None, + False, + fuse_wgrad_accumulation, + False, + delay_wgrad_compute, + None, + use_cutlass=True, + ) + + +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3]) +@pytest.mark.parametrize("bs", [1]) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("fp8_model_params", [False]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) +@pytest.mark.parametrize("bias", [False]) +@pytest.mark.parametrize("delay_wgrad_compute", [True]) +def test_grouped_linear_accuracy_save_original_input( + dtype, + num_gemms, + bs, + model, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + bias, + delay_wgrad_compute, + parallel_mode=None, +): + fp8 = recipe is not None + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") + if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: + pytest.skip("Delayed wgrad compute is not supported in debug mode.") + + config = model_configs[model] + if config.max_seqlen_q % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + delay_wgrad_compute=delay_wgrad_compute, + save_original_input=True, + ).eval() + sequential_linear = torch.nn.ModuleList( + [ + Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=bias, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ).eval() + for _ in range(num_gemms) + ] + ) + + # Share params + with torch.no_grad(): + for i in range(num_gemms): + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + if bias: + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() + + outputs_ref = _test_grouped_linear_accuracy( + sequential_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + recipe, + fp8, + fuse_wgrad_accumulation, + delay_wgrad_compute, + ) + + # Should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +def test_grouped_linear_accuracy_single_gemm(recipe): + """Split the tests to save CI time""" + test_grouped_linear_accuracy( + dtype=torch.float32, + num_gemms=1, + bs=2, + model="126m", + recipe=recipe, + fp8_model_params=True, + fuse_wgrad_accumulation=True, + bias=True, + delay_wgrad_compute=False, + ) + + +def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): + + def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): + align_size = get_align_size_for_quantization(recipe) + padded_tokens_per_expert = [ + (num_tokens + align_size - 1) // align_size * align_size + for num_tokens in tokens_per_expert + ] + hidden_states = torch.split(hidden_states, tokens_per_expert) + padded_hidden_states = [] + for hidden_state, actual_num_tokens, padded_num_tokens in zip( + hidden_states, tokens_per_expert, padded_tokens_per_expert + ): + padded_hidden_states.append(hidden_state) + if padded_num_tokens > actual_num_tokens: + pad_tensor = torch.zeros( + padded_num_tokens - actual_num_tokens, + hidden_state.shape[1], + dtype=hidden_state.dtype, + device=hidden_state.device, + ) + padded_hidden_states.append(pad_tensor) + padded_hidden_states = torch.cat(padded_hidden_states, dim=0) + return padded_hidden_states, padded_tokens_per_expert + + def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert): + inputmats = torch.split( + padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert + ) + hidden_states = torch.cat( + [ + grad_output_mat[: actual_tokens_per_expert[i]] + for i, grad_output_mat in enumerate(inputmats) + ], + dim=0, + ) + + return hidden_states + + def _generate_random_numbers(n, total_sum): + if n <= 0: + return [] + + # reset seed + random.seed(seed) + + breaks = sorted(random.sample(range(1, total_sum), n - 1)) + random_numbers = ( + [breaks[0]] + + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)] + + [total_sum - breaks[-1]] + ) + + return random_numbers + + reset_rng_states() + if fp8: + FP8GlobalStateManager.reset() + + inp_hidden_states = torch.randn( + (config.max_seqlen_q * bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_hidden_states.retain_grad() + + m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) + + with autocast(enabled=fp8, recipe=recipe): + if isinstance(block, TorchGroupedLinearWithPadding): + out = block(inp_hidden_states, m_splits) + else: + if fp8: + padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8( + inp_hidden_states, m_splits + ) + padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits) + out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits) + else: + out = block(inp_hidden_states, m_splits) + + loss = out.sum() + loss.backward() + + torch.cuda.synchronize() + outputs = [out, inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_padding_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + fp8, + recipe, + fp8_model_params, + parallel_mode=None, +): + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + + config = model_configs[model] + if config.max_seqlen_q % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = TorchGroupedLinearWithPadding( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + fp8=fp8, + ).eval() + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + ref_grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + save_original_input=False, + ).eval() + + # Share params + with torch.no_grad(): + inner_grouped_linear = grouped_linear.linear_fn + for i in range(num_gemms): + setattr( + ref_grouped_linear, + f"weight{i}", + Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), + ) + + outputs = _test_padding_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + outputs_ref = _test_padding_grouped_linear_accuracy( + ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + + # Should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("num_gemms", [3]) +@pytest.mark.parametrize("bs", [1]) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_model_params", [False]) +def test_padding_grouped_linear_accuracy_save_original_input( + dtype, + num_gemms, + bs, + model, + fp8, + recipe, + fp8_model_params, + parallel_mode=None, +): + if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.delayed(): + pytest.skip("DelayedScaling recipe is not supported with save_original_input") + + config = model_configs[model] + if config.max_seqlen_q % 16 != 0 and fp8: + pytest.skip("FP8 requires sequence length to be divisible by 16.") + + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + grouped_linear = TorchGroupedLinearWithPadding( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + fp8=fp8, + ).eval() + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + ref_grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + parallel_mode=parallel_mode, + device="cuda", + save_original_input=True, + ).eval() + + # Share params + with torch.no_grad(): + inner_grouped_linear = grouped_linear.linear_fn + for i in range(num_gemms): + setattr( + ref_grouped_linear, + f"weight{i}", + Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), + ) + + outputs = _test_padding_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + outputs_ref = _test_padding_grouped_linear_accuracy( + ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + ) + + # Should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 127, 128, 512), + (8, 15, 128, 512), + (8, 1027, 128, 512), + (16, 10027, 128, 512), + ], +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, monkeypatch): + torch.manual_seed(0) + z, m, k, n = shape + + dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + assert m_splits.sum() == m and len(m_splits) == z + m_splits = m_splits.tolist() + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + grad = False + single_output = True + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] + grad = True + single_output = True + else: # layout == "NT" + A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [o.clone() for o in out] + grad = True + single_output = False + + if use_cutlass: + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + + for i in range(z): + general_gemm( + A[i], + B[i], + dtype, + grad=grad, + accumulate=accumulate, + layout=layout, + out=out_ref[i], + ) + if single_output: + out_ref = [torch.cat(out_ref)] + + general_grouped_gemm( + A, + B, + out, + [None] * z, + dtype, + m_splits=m_splits, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=single_output, + ) + + for o, o_ref in zip(out, out_ref): + if not use_cutlass: + # cublas implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + +def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + data = grouped_tensor.rowwise_data + if data is None: + data = grouped_tensor.columnwise_data + if data is None: + raise ValueError("GroupedTensor has no data buffers to pack.") + offset = 0 + for tensor in tensors: + numel = tensor.numel() + data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + +def _make_grouped_tensor_from_splits( + m_sizes: List[int], + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) + return GroupedTensor.make_grouped_tensor( + num_tensors=len(m_sizes), + first_dims=first_dims, + last_dims=None, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +def _make_grouped_tensor_uniform( + num_tensors: int, + first_dim: int, + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, + last_dims=None, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +def _apply_grouped_bias_ref( + base_outs: List[torch.Tensor], + bias: Optional[List[torch.Tensor]], + bias_scale: Optional[torch.Tensor], + m_sizes: List[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + """Reference: add (optionally per-row scaled) bias to each group's output, cast to ``dtype``.""" + if bias is None: + return list(base_outs) + if bias_scale is None: + return [(o.float() + b.float()).to(dtype) for o, b in zip(base_outs, bias)] + out = [] + offset = 0 + for i, ms in enumerate(m_sizes): + s = bias_scale[offset : offset + ms].unsqueeze(-1) + out.append((base_outs[i].float() + bias[i].float() * s).to(dtype)) + offset += ms + return out + + +@pytest.mark.parametrize( + "z, m, n, k", + [ + (4, 256, 256, 256), + (4, 512, 256, 512), + (4, 512, 512, 256), + (8, 512, 256, 512), + ], +) +@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize("use_bias_scale", [False, True]) +def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_bias_scale) -> None: + if tex.get_cublasLt_version() < 130300: + pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + + dtype = torch.bfloat16 + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + if layout == "NT": + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] + else: + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [ + torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> input, NN --> grad_output + out = [ + torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> output, NN --> dgrad + if layout == "NN": + out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] + else: # layout == "TN" + out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] + + if accumulate: + out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] + + # Bias is applied after GEMM (broadcasted along rows) + # Match kernel behavior: GEMM output is already in output dtype when bias is added. + out_ref_no_bias = [o.to(dtype) for o in out_ref] + if layout == "TN": + bias_last_dim = n + else: # layout == "NT" or "NN" + bias_last_dim = k + bias = ( + [torch.randn(1, bias_last_dim, dtype=dtype, device="cuda") for _ in range(z)] + if case != "discrete_out" + else None + ) + bias_scale = None + if use_bias_scale and bias is not None and layout != "NT": + bias_scale = torch.randn(m, device="cuda", dtype=torch.float32) + # Bias add in grouped kernel accumulates in FP32 for BF16/FP16. + out_ref = _apply_grouped_bias_ref(out_ref_no_bias, bias, bias_scale, m_sizes, dtype) + # Create grouped tensors based on case + device = A[0].device + grouped_A = A + grouped_out = out + grouped_out_bias = [o.clone() for o in out] + grouped_out_no_bias = [o.clone() for o in out] + grouped_bias = None + if layout == "TN": + grouped_A = ( + _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # weight + grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input + if case != "discrete_out": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output + grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_A = ( + _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # weight + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + if case != "discrete_out": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_A = ( + _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + if case != "discrete_in" + else A + ) # input + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + if case != "discrete_out": + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad + grouped_out_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) + grouped_out_no_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_B, B) + if case != "discrete_out": + _pack_grouped_tensor(grouped_out, out) + _pack_grouped_tensor(grouped_out_bias, out) + _pack_grouped_tensor(grouped_out_no_bias, out) + if case != "discrete_in": + _pack_grouped_tensor(grouped_A, A) + + if bias is not None: + grouped_bias = _make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype) + _pack_grouped_tensor(grouped_bias, bias) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_no_bias, + layout=layout, + accumulate=accumulate, + bias=None, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out_bias, + layout=layout, + accumulate=accumulate, + bias=grouped_bias, + bias_scale=bias_scale, + ) + out_grouped_no_bias = ( + grouped_out_no_bias + if isinstance(grouped_out_no_bias, list) + else grouped_out_no_bias.split_into_quantized_tensors() + ) + out_grouped_bias = ( + grouped_out_bias + if isinstance(grouped_out_bias, list) + else grouped_out_bias.split_into_quantized_tensors() + ) + + out_grouped_manual_bias = _apply_grouped_bias_ref( + out_grouped_no_bias, bias, bias_scale, m_sizes, dtype + ) + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped_no_bias, out_ref_no_bias): + torch.testing.assert_close(o, o_ref, **tols) + if bias is not None: + for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): + torch.testing.assert_close(o, o_ref, **tols) + + +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize("quant_type", ["bf16", "mxfp8"]) +def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) -> None: + """Grouped GEMM with all-zero split sizes (zero total work). + + For wgrad (NT layout) the output should be zero when not accumulating, + or unchanged when accumulating with beta=1. + """ + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + if quant_type == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + z = 4 + k, n = 256, 256 + dtype = torch.bfloat16 + device = torch.device("cuda") + use_mxfp8 = quant_type == "mxfp8" + + transa = layout[0] == "T" + transb = layout[1] == "T" + zero_first_dims = torch.zeros(z, dtype=torch.int64, device=device) + + def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): + """Create a GroupedTensor with non-zero logical_shape but zero first_dims.""" + buf = torch.randn(0, logical_last_dim, dtype=dtype, device=device) + if use_mxfp8: + if is_a: + rowwise, columnwise = transa, not transa + else: + rowwise, columnwise = not transb, transb + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + quantizer.optimize_for_gemm = True + return tex.group_quantize(buf, quantizer, z, zero_first_dims) + return GroupedTensor.make_grouped_tensor( + num_tensors=z, + first_dims=zero_first_dims, + last_dims=None, + logical_first_dim=k, + logical_last_dim=logical_last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + if layout in ("TN", "NN"): + weight_tensors = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] + if use_mxfp8: + grouped_A = _make_grouped_tensor_quantized_mxfp8( + weight_tensors, + rowwise=transa, + columnwise=not transa, + device=device, + ) + else: + grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_A, weight_tensors) + else: # NT + grouped_A = _make_zero_tokens_grouped_tensor(k, is_a=True) + + b_last_dim = k if layout == "TN" else n + grouped_B = _make_zero_tokens_grouped_tensor(b_last_dim, is_a=False) + + if layout == "NT": + out = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_out, out) + else: + out = [torch.zeros(0, dtype=dtype, device=device) for _ in range(z)] + out_last_dim = n if layout == "TN" else k + grouped_out = GroupedTensor.make_grouped_tensor( + num_tensors=z, + first_dims=zero_first_dims, + last_dims=None, + logical_first_dim=k, + logical_last_dim=out_last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + out_before = [o.clone() for o in out] + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=accumulate, + ) + + out_result = ( + grouped_out if isinstance(grouped_out, list) else grouped_out.split_into_quantized_tensors() + ) + for i in range(z): + if out_result[i].numel() == 0: + continue + if accumulate: + torch.testing.assert_close(out_result[i], out_before[i]) + else: + torch.testing.assert_close(out_result[i], torch.zeros_like(out_result[i])) + + +def _make_grouped_tensor_quantized_mxfp8( + tensors: List[torch.Tensor], + *, + rowwise: bool, + columnwise: bool, + device: torch.device, + is_weight: bool = False, +) -> GroupedTensor: + """Create a quantized MXFP8 GroupedTensor from a list of per-expert tensors. + + For weights (uniform per-expert shape), we generally won't keep it swizzled since we + might need for future dequantize operations. Swizzling is done internally within + general_grouped_gemm_for_grouped_tensor call. + + For non-weight tensors (inputs / grad_outputs), we still pass + ``first_dims`` and keep ``optimize_for_gemm=True``; so the kernel must emit the + already-swizzled layout up front. + """ + if not tensors: + raise ValueError("Expected non-empty tensor list for grouped quantization.") + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + quantizer.optimize_for_gemm = not is_weight + grouped_input = torch.cat(tensors, dim=0) + if is_weight: + first_dims = None + else: + first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) + return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) + + +def _per_tensor_quantize_mxfp8( + tensors: List[torch.Tensor], + *, + rowwise: bool, + columnwise: bool, +) -> List: + """Quantize each tensor individually with MXFP8. + Used to build reference discrete inputs for grouped GEMM. + """ + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + return [quantizer(t) for t in tensors] + + +@pytest.mark.parametrize( + "shape", + [ + (1, 128, 128, 512), + (8, 1024, 128, 512), + (16, 4096, 128, 512), + (2, 256, 2880, 2880), + ], +) +@pytest.mark.parametrize("accumulate", [False, True]) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_grouped_gemm_grouped_tensor_mxfp8( + shape, accumulate, layout: str, case: str, dtype: torch.dtype +) -> None: + if tex.get_cublasLt_version() < 130300: + pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if dtype == torch.bfloat16 and not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + z, m, k, n = shape + m_sizes = [m // z] * z + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output + grad = False + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad + grad = True + else: # layout == "NT" + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + + out_ref = [o.clone() for o in out] + + transa = layout[0] == "T" + transb = layout[1] == "T" + a_is_weight = all(t.shape == A[0].shape for t in A) + a_rowwise, a_columnwise = transa, not transa + b_rowwise, b_columnwise = not transb, transb + grouped_A = _make_grouped_tensor_quantized_mxfp8( + A, + rowwise=a_rowwise, + columnwise=a_columnwise, + device="cuda", + is_weight=a_is_weight, + ) + grouped_B = _make_grouped_tensor_quantized_mxfp8( + B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda" + ) + A_fp8 = _per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise) + B_fp8 = _per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise) + + general_grouped_gemm( + A_fp8, + B_fp8, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=False, + ) + + device = A[0].device + + grouped_out = None + if case != "discrete_out": + if layout == "TN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) + _pack_grouped_tensor(grouped_out, out) + + grouped_out_input = out if case == "discrete_out" else grouped_out + grouped_A_input = A_fp8 if case == "discrete_in" else grouped_A + general_grouped_gemm_for_grouped_tensor( + grouped_A_input, + grouped_B, + grouped_out_input, + layout=layout, + accumulate=accumulate, + ) + + out_grouped = out if case == "discrete_out" else grouped_out.split_into_quantized_tensors() + tols = dict(rtol=0.125, atol=0.0675) # mxfp8 tolerance + + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 128, 128, 512), + (8, 1024, 128, 512), + (16, 4096, 128, 512), + ], +) +@pytest.mark.parametrize("accumulate", [False, True]) +def test_fp8_grouped_gemm(shape, accumulate): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + z, m, k, n = shape + m_splits = [m // z] * z + + dtype = torch.bfloat16 + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input + out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output + out_ref = [o.clone() for o in out] + + # fp8 should be robust enough to this fake scale + scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() + amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") + + a_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, + ) + for _ in range(z) + ] + b_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, + ) + for _ in range(z) + ] + + A_fp8 = [] + B_fp8 = [] + + for i in range(z): + A_fp8.append(a_quantizers[i](A[i])) + B_fp8.append(b_quantizers[i](B[i])) + + # baseline + for i in range(z): + general_gemm( + A_fp8[i], + B_fp8[i], + dtype, + out=out_ref[i], + accumulate=accumulate, + ) + general_grouped_gemm( + A_fp8, + B_fp8, + out, + [None] * z, + dtype, + m_splits=m_splits, + accumulate=accumulate, + ) + + # should be bit-wise match + for o, o_ref in zip(out, out_ref): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM" +_ALL_BOOLEAN = all_boolean +_mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8 + + +@pytest.fixture(autouse=True) +def _reset_fp8_state(monkeypatch): + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "0") + yield + FP8GlobalStateManager.reset() + monkeypatch.delenv(_FUSED_GROUPED_GEMM_ENV, raising=False) + + +def _clone_outputs(outputs): + return [None if out is None else out.detach().clone() for out in outputs] + + +def _run_grouped_linear_path( + *, + enable_grouped_tensor_path: bool, + fp8_recipe, + bias: bool, + fp8_model_params: bool, + delay_wgrad_compute: bool, + x_base: torch.Tensor, + dy: torch.Tensor, + weights, + biases, + m_splits, + monkeypatch, +): + FP8GlobalStateManager.reset() + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1" if enable_grouped_tensor_path else "0") + + dtype = x_base.dtype + num_gemms = len(m_splits) + in_features = weights[0].size(1) + out_features = weights[0].size(0) + use_fp8 = fp8_recipe is not None + + x = x_base.detach().clone().requires_grad_(True) + with quantized_model_init(enabled=fp8_model_params, recipe=fp8_recipe): + grouped_linear = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=delay_wgrad_compute, + ) + with torch.no_grad(): + for i in range(num_gemms): + getattr(grouped_linear, f"weight{i}").copy_(weights[i]) + if bias: + getattr(grouped_linear, f"bias{i}").copy_(biases[i]) + + # The fused path is the graph-safe path and accepts a CUDA tensor for split metadata. + # The legacy path still expects Python split sections in several places. + m_splits_arg = ( + torch.tensor(m_splits, dtype=torch.int64, device="cuda") + if enable_grouped_tensor_path + else m_splits + ) + with autocast(enabled=use_fp8, recipe=fp8_recipe): + y = grouped_linear(x, m_splits_arg) + y.backward(dy) + if delay_wgrad_compute: + grouped_linear.backward_dw() + + outputs = [y, x.grad] + for i in range(num_gemms): + outputs.append(getattr(grouped_linear, f"weight{i}").grad) + if bias: + outputs.append(getattr(grouped_linear, f"bias{i}").grad) + return _clone_outputs(outputs) + + +@pytest.mark.parametrize( + "fp8_recipe", + [ + None, + pytest.param( + recipe.MXFP8BlockScaling(), + marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), + ), + ], + ids=["bf16", "mxfp8"], +) +@pytest.mark.parametrize("bias", _ALL_BOOLEAN) +@pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN) +@pytest.mark.parametrize("delay_wgrad_compute", _ALL_BOOLEAN) +def test_grouped_linear_grouped_tensor_path_matches_legacy( + fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, monkeypatch +): + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + + use_fp8 = fp8_recipe is not None + if fp8_model_params and not use_fp8: + pytest.skip("fp8_model_params requires FP8") + + dtype = torch.bfloat16 + num_gemms = 3 + in_features = 64 + out_features = 64 + m_splits = [128, 256, 384] + total_tokens = sum(m_splits) + + torch.manual_seed(1234) + x_base = (0.1 * torch.randn(total_tokens, in_features, device="cuda")).to(dtype) + dy = (0.1 * torch.randn(total_tokens, out_features, device="cuda")).to(dtype) + weights = [ + (0.1 * torch.randn(out_features, in_features, device="cuda")).to(dtype) + for _ in range(num_gemms) + ] + biases = None + if bias: + biases = [ + (0.1 * torch.randn(out_features, device="cuda")).to(dtype) for _ in range(num_gemms) + ] + + outputs_legacy = _run_grouped_linear_path( + enable_grouped_tensor_path=False, + fp8_recipe=fp8_recipe, + bias=bias, + fp8_model_params=fp8_model_params, + delay_wgrad_compute=delay_wgrad_compute, + x_base=x_base, + dy=dy, + weights=weights, + biases=biases, + m_splits=m_splits, + monkeypatch=monkeypatch, + ) + outputs_grouped_tensor = _run_grouped_linear_path( + enable_grouped_tensor_path=True, + fp8_recipe=fp8_recipe, + bias=bias, + fp8_model_params=fp8_model_params, + delay_wgrad_compute=delay_wgrad_compute, + x_base=x_base, + dy=dy, + weights=weights, + biases=biases, + m_splits=m_splits, + monkeypatch=monkeypatch, + ) + + tols = dict(rtol=1e-2, atol=5e-3) + if use_fp8: + tols = dict(rtol=0.05, atol=0.05) + for grouped_tensor_out, legacy_out in zip(outputs_grouped_tensor, outputs_legacy): + assert grouped_tensor_out is not None + assert legacy_out is not None + torch.testing.assert_close(grouped_tensor_out.float(), legacy_out.float(), **tols) + + +def test_grouped_linear_grouped_tensor_path_single_grouped_bias_delay_wgrad(monkeypatch): + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") + + dtype = torch.bfloat16 + num_gemms = 3 + in_features = 64 + out_features = 64 + total_tokens = 64 + 96 + 128 + m_splits = torch.tensor([64, 96, 128], dtype=torch.int64, device="cuda") + x = torch.randn(total_tokens, in_features, dtype=dtype, device="cuda").requires_grad_() + dy = torch.randn(x.size(0), out_features, dtype=dtype, device="cuda") + + grouped_linear = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=True, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=True, + single_grouped_bias=True, + ) + + y = grouped_linear(x, m_splits) + y.backward(dy) + grouped_linear.backward_dw() + + +@pytest.mark.parametrize( + "fp8_recipe", + [ + None, + pytest.param( + recipe.MXFP8BlockScaling(), + marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8), + ), + ], + ids=["bf16", "mxfp8"], +) +@pytest.mark.parametrize("bias", _ALL_BOOLEAN) +def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch): + """Fused GroupedTensor GEMM path should be CUDA graph capturable.""" + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("GroupedTensor grouped GEMM path requires SM100+") + + monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1") + FP8GlobalStateManager.reset() + + use_fp8 = fp8_recipe is not None + dtype = torch.bfloat16 + device = "cuda" + num_gemms = 3 + in_features = 128 + out_features = 128 + split_sizes = [128, 256, 384] + total_tokens = sum(split_sizes) + static_m_splits = torch.tensor(split_sizes, dtype=torch.int64, device=device) + + grouped_linear = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device=device, + ) + + static_x = torch.randn(total_tokens, in_features, dtype=dtype, device=device) + static_x.requires_grad_(True) + static_dy = torch.randn(total_tokens, out_features, dtype=dtype, device=device) + static_out_buf = torch.empty(total_tokens, out_features, dtype=dtype, device=device) + + def _zero_grads(): + if static_x.grad is not None: + static_x.grad.zero_() + for param in grouped_linear.parameters(): + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.zero_() + + def _clone_param_grads(): + return [param.grad.detach().clone() for param in grouped_linear.parameters()] + + def _train_step(x, dy, out_buf, *, use_graphed): + with autocast(enabled=use_fp8, recipe=fp8_recipe): + out = ( + graphed_grouped_linear(x, static_m_splits) + if use_graphed + else grouped_linear(x, static_m_splits) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + graphed_grouped_linear = te.make_graphed_callables( + grouped_linear, + (static_x, static_m_splits), + num_warmup_iters=3, + enabled=use_fp8, + recipe=fp8_recipe, + ) + + fresh_x = torch.randn_like(static_x) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_dy.copy_(fresh_dy) + + _zero_grads() + graph_out = ( + _train_step( + static_x, + static_dy, + static_out_buf, + use_graphed=True, + ) + .detach() + .clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + graph_param_grads = _clone_param_grads() + + _zero_grads() + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with autocast(enabled=use_fp8, recipe=fp8_recipe): + expected_out = grouped_linear(expected_x, static_m_splits) + expected_out.backward(expected_dy) + + tols = dict(rtol=1e-2, atol=5e-3) + if use_fp8: + tols = dict(rtol=0.05, atol=0.05) + torch.testing.assert_close(graph_out.float(), expected_out.float(), **tols) + torch.testing.assert_close(graph_dx.float(), expected_x.grad.float(), **tols) + for graph_grad, param in zip(graph_param_grads, grouped_linear.parameters()): + assert param.grad is not None + torch.testing.assert_close(graph_grad.float(), param.grad.float(), **tols) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..7d0579f5ac 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -6,7 +6,6 @@ import os from typing import Dict, List, Tuple, Optional import pytest -import random import torch import torch.nn as nn @@ -14,7 +13,6 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, - get_align_size_for_quantization, ) from transformer_engine.pytorch.utils import ( init_method_normal, @@ -28,13 +26,10 @@ LayerNormLinear, LayerNormMLP, Linear, - GroupedLinear, MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, - Fp8Padding, - Fp8Unpadding, Float8Quantizer, Float8CurrentScalingQuantizer, MXFP8Quantizer, @@ -46,17 +41,11 @@ is_nvfp4_available, ) from transformer_engine.pytorch import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import ( - general_gemm, - general_grouped_gemm, - general_grouped_gemm_for_grouped_tensor, -) -from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states - # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @@ -448,40 +437,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return (input > 0) * input * input -class TorchGroupedLinearWithPadding(nn.Module): - - def __init__( - self, num_gemms, in_features, out_features, bias, params_dtype, parallel_mode, fp8 - ) -> None: - super().__init__() - - self.padding = Fp8Padding(num_gemms) - self.linear_fn = GroupedLinear( - num_gemms, - in_features, - out_features, - bias=bias, - params_dtype=params_dtype, - parallel_mode=parallel_mode, - device="cuda", - ) - self.unpadding = Fp8Unpadding(num_gemms) - - self.fp8 = fp8 - - def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: - if self.fp8: - orig_m_splits = m_splits - inp, m_splits = self.padding(inp, m_splits) - - out = self.linear_fn(inp, m_splits) - - if self.fp8: - out = self.unpadding(out, orig_m_splits) - - return out - - _supported_act = { "gelu": nn.GELU(approximate="tanh"), "geglu": nn.GELU(approximate="tanh"), @@ -1822,584 +1777,6 @@ def test_layernorm_mlp_accuracy_checkpoint( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -def _test_grouped_linear_accuracy( - block, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute=False, -): - reset_rng_states() - if fp8: - FP8GlobalStateManager.reset() - - inp_hidden_states = torch.randn( - (config.max_seqlen_q, bs, config.hidden_size), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - inp_hidden_states.retain_grad() - - if num_gemms > 1: - split_size = 1 - if fp8: - split_size = get_align_size_for_quantization(recipe) - m = config.max_seqlen_q // split_size - dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() - dist.append(dist[-1]) # Manually add a zero - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * split_size - assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms - else: - m_splits = torch.tensor([config.max_seqlen_q]) - - with autocast(enabled=fp8, recipe=recipe): - if isinstance(block, GroupedLinear): - m_splits = m_splits * bs - out = block(inp_hidden_states, m_splits.tolist()) - else: - out = torch.cat( - [ - block[i](inp) - for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist())) - ] - ) - loss = out.sum() - loss.backward() - if delay_wgrad_compute: - if isinstance(block, GroupedLinear): - block.backward_dw() - else: - for i in range(num_gemms): - block[i].backward_dw() - - torch.cuda.synchronize() - outputs = [out, inp_hidden_states.grad] - for p in block.parameters(): - if p.requires_grad: - if getattr(p, "main_grad", None) is not None: - outputs.append(p.main_grad) - assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True - else: - outputs.append(p.grad) - return outputs - - -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("num_gemms", [3, 6]) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) -@pytest.mark.parametrize("fp8_model_params", all_boolean) -@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -@pytest.mark.parametrize("bias", all_boolean) -@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) -def test_grouped_linear_accuracy( - dtype, - num_gemms, - bs, - model, - recipe, - fp8_model_params, - fuse_wgrad_accumulation, - bias, - delay_wgrad_compute, - parallel_mode=None, - use_cutlass=False, -): - fp8 = recipe is not None - if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: - pytest.skip("Delayed wgrad compute is not supported in debug mode.") - - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = GroupedLinear( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - delay_wgrad_compute=delay_wgrad_compute, - save_original_input=False, - ).eval() - sequential_linear = torch.nn.ModuleList( - [ - Linear( - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ).eval() - for _ in range(num_gemms) - ] - ) - - # Share params - with torch.no_grad(): - for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) - if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) - if fuse_wgrad_accumulation: - weight_i = getattr(grouped_linear, f"weight{i}") - weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) - sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() - - outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - outputs = _test_grouped_linear_accuracy( - grouped_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - - for o, o_ref in zip(outputs, outputs_ref): - if use_cutlass: - torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) - else: - # cuBLAS implementation should be bit-wise match - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - -@pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0), - reason="Only enable CUTLASS grouped gemm on Hopper", -) -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("num_gemms", [3, 6]) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) -@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) -def test_grouped_linear_accuracy_cutlass( - dtype, - num_gemms, - bs, - model, - fuse_wgrad_accumulation, - delay_wgrad_compute, -): - os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" - test_grouped_linear_accuracy( - dtype, - num_gemms, - bs, - model, - None, - False, - fuse_wgrad_accumulation, - False, - delay_wgrad_compute, - None, - use_cutlass=True, - ) - os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) - - -@pytest.mark.parametrize("dtype", param_types, ids=str) -@pytest.mark.parametrize("num_gemms", [3]) -@pytest.mark.parametrize("bs", [1]) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) -@pytest.mark.parametrize("fp8_model_params", [False]) -@pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) -@pytest.mark.parametrize("bias", [False]) -@pytest.mark.parametrize("delay_wgrad_compute", [True]) -def test_grouped_linear_accuracy_save_original_input( - dtype, - num_gemms, - bs, - model, - recipe, - fp8_model_params, - fuse_wgrad_accumulation, - bias, - delay_wgrad_compute, - parallel_mode=None, -): - fp8 = recipe is not None - if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.delayed(): - pytest.skip("DelayedScaling recipe is not supported with save_original_input") - if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: - pytest.skip("Delayed wgrad compute is not supported in debug mode.") - - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = GroupedLinear( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - delay_wgrad_compute=delay_wgrad_compute, - save_original_input=True, - ).eval() - sequential_linear = torch.nn.ModuleList( - [ - Linear( - config.hidden_size, - 4 * config.hidden_size, - bias=bias, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ).eval() - for _ in range(num_gemms) - ] - ) - - # Share params - with torch.no_grad(): - for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) - if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) - if fuse_wgrad_accumulation: - weight_i = getattr(grouped_linear, f"weight{i}") - weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) - sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() - - outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - outputs = _test_grouped_linear_accuracy( - grouped_linear, - num_gemms, - bs, - dtype, - config, - recipe, - fp8, - fuse_wgrad_accumulation, - delay_wgrad_compute, - ) - - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) -def test_grouped_linear_accuracy_single_gemm(recipe): - """Split the tests to save CI time""" - test_grouped_linear_accuracy( - dtype=torch.float32, - num_gemms=1, - bs=2, - model="126m", - recipe=recipe, - fp8_model_params=True, - fuse_wgrad_accumulation=True, - bias=True, - delay_wgrad_compute=False, - ) - - -def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): - - def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): - align_size = get_align_size_for_quantization(recipe) - padded_tokens_per_expert = [ - (num_tokens + align_size - 1) // align_size * align_size - for num_tokens in tokens_per_expert - ] - hidden_states = torch.split(hidden_states, tokens_per_expert) - padded_hidden_states = [] - for hidden_state, actual_num_tokens, padded_num_tokens in zip( - hidden_states, tokens_per_expert, padded_tokens_per_expert - ): - padded_hidden_states.append(hidden_state) - if padded_num_tokens > actual_num_tokens: - pad_tensor = torch.zeros( - padded_num_tokens - actual_num_tokens, - hidden_state.shape[1], - dtype=hidden_state.dtype, - device=hidden_state.device, - ) - padded_hidden_states.append(pad_tensor) - padded_hidden_states = torch.cat(padded_hidden_states, dim=0) - return padded_hidden_states, padded_tokens_per_expert - - def _unpad_tensor_for_fp8(padded_hidden_states, actual_tokens_per_expert, tokens_per_expert): - inputmats = torch.split( - padded_hidden_states.view(-1, padded_hidden_states.shape[-1]), tokens_per_expert - ) - hidden_states = torch.cat( - [ - grad_output_mat[: actual_tokens_per_expert[i]] - for i, grad_output_mat in enumerate(inputmats) - ], - dim=0, - ) - - return hidden_states - - def _generate_random_numbers(n, total_sum): - if n <= 0: - return [] - - # reset seed - random.seed(seed) - - breaks = sorted(random.sample(range(1, total_sum), n - 1)) - random_numbers = ( - [breaks[0]] - + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)] - + [total_sum - breaks[-1]] - ) - - return random_numbers - - reset_rng_states() - if fp8: - FP8GlobalStateManager.reset() - - inp_hidden_states = torch.randn( - (config.max_seqlen_q * bs, config.hidden_size), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - inp_hidden_states.retain_grad() - - m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) - - with autocast(enabled=fp8, recipe=recipe): - if isinstance(block, TorchGroupedLinearWithPadding): - out = block(inp_hidden_states, m_splits) - else: - if fp8: - padded_inp_hidden_states, padding_m_splits = _pad_tensor_for_fp8( - inp_hidden_states, m_splits - ) - padded_inp_hidden_states = block(padded_inp_hidden_states, padding_m_splits) - out = _unpad_tensor_for_fp8(padded_inp_hidden_states, m_splits, padding_m_splits) - else: - out = block(inp_hidden_states, m_splits) - - loss = out.sum() - loss.backward() - - torch.cuda.synchronize() - outputs = [out, inp_hidden_states.grad] - for p in block.parameters(): - if p.requires_grad: - outputs.append(p.grad) - return outputs - - -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("num_gemms", [3, 6]) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) -@pytest.mark.parametrize("fp8_model_params", all_boolean) -def test_padding_grouped_linear_accuracy( - dtype, - num_gemms, - bs, - model, - fp8, - recipe, - fp8_model_params, - parallel_mode=None, -): - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = TorchGroupedLinearWithPadding( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, - params_dtype=dtype, - parallel_mode=parallel_mode, - fp8=fp8, - ).eval() - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - ref_grouped_linear = GroupedLinear( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - save_original_input=False, - ).eval() - - # Share params - with torch.no_grad(): - inner_grouped_linear = grouped_linear.linear_fn - for i in range(num_gemms): - setattr( - ref_grouped_linear, - f"weight{i}", - Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), - ) - - outputs = _test_padding_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 - ) - outputs_ref = _test_padding_grouped_linear_accuracy( - ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 - ) - - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("num_gemms", [3]) -@pytest.mark.parametrize("bs", [1]) -@pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) -@pytest.mark.parametrize("fp8_model_params", [False]) -def test_padding_grouped_linear_accuracy_save_original_input( - dtype, - num_gemms, - bs, - model, - fp8, - recipe, - fp8_model_params, - parallel_mode=None, -): - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.delayed(): - pytest.skip("DelayedScaling recipe is not supported with save_original_input") - - config = model_configs[model] - if config.max_seqlen_q % 16 != 0 and fp8: - pytest.skip("FP8 requires sequence length to be divisible by 16.") - - if recipe is not None and recipe.nvfp4(): - if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): - pytest.skip( - f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" - ) - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - grouped_linear = TorchGroupedLinearWithPadding( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, - params_dtype=dtype, - parallel_mode=parallel_mode, - fp8=fp8, - ).eval() - - with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): - ref_grouped_linear = GroupedLinear( - num_gemms, - config.hidden_size, - 4 * config.hidden_size, - bias=False, - params_dtype=dtype, - parallel_mode=parallel_mode, - device="cuda", - save_original_input=True, - ).eval() - - # Share params - with torch.no_grad(): - inner_grouped_linear = grouped_linear.linear_fn - for i in range(num_gemms): - setattr( - ref_grouped_linear, - f"weight{i}", - Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), - ) - - outputs = _test_padding_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 - ) - outputs_ref = _test_padding_grouped_linear_accuracy( - ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 - ) - - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): reset_rng_states() @@ -2721,579 +2098,6 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) -def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): - torch.manual_seed(0) - z, m, k, n = shape - - dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - assert m_splits.sum() == m and len(m_splits) == z - m_splits = m_splits.tolist() - - if layout == "TN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output - out_ref = [o.clone() for o in torch.split(out[0], m_splits)] - grad = False - single_output = True - elif layout == "NN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = list( - torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) - ) # grad_output - out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad - out_ref = [o.clone() for o in torch.split(out[0], m_splits)] - grad = True - single_output = True - else: # layout == "NT" - A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input - B = list( - torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) - ) # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - out_ref = [o.clone() for o in out] - grad = True - single_output = False - - if use_cutlass: - os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" - - for i in range(z): - general_gemm( - A[i], - B[i], - dtype, - grad=grad, - accumulate=accumulate, - layout=layout, - out=out_ref[i], - ) - if single_output: - out_ref = [torch.cat(out_ref)] - - general_grouped_gemm( - A, - B, - out, - [None] * z, - dtype, - m_splits=m_splits, - grad=grad, - accumulate=accumulate, - layout=layout, - single_output=single_output, - ) - - for o, o_ref in zip(out, out_ref): - if not use_cutlass: - # cublas implementation should be bit-wise match - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - else: - torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) - - if use_cutlass: - os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) - - -def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: - data = grouped_tensor.rowwise_data - if data is None: - data = grouped_tensor.columnwise_data - if data is None: - raise ValueError("GroupedTensor has no data buffers to pack.") - offset = 0 - for tensor in tensors: - numel = tensor.numel() - data[offset : offset + numel].copy_(tensor.reshape(-1)) - offset += numel - - -def _make_grouped_tensor_from_splits( - m_sizes: List[int], - last_dim: int, - device: torch.device, - dtype: torch.dtype, -) -> GroupedTensor: - first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) - return GroupedTensor.make_grouped_tensor( - num_tensors=len(m_sizes), - first_dims=first_dims, - last_dims=None, - logical_first_dim=sum(m_sizes), - logical_last_dim=last_dim, - quantizer=None, - device=device, - dtype=dtype, - ) - - -def _make_grouped_tensor_uniform( - num_tensors: int, - first_dim: int, - last_dim: int, - device: torch.device, - dtype: torch.dtype, -) -> GroupedTensor: - return GroupedTensor.make_grouped_tensor( - num_tensors=num_tensors, - first_dims=None, - last_dims=None, - logical_first_dim=num_tensors * first_dim, - logical_last_dim=last_dim, - quantizer=None, - device=device, - dtype=dtype, - ) - - -def _apply_grouped_bias_ref( - base_outs: List[torch.Tensor], - bias: Optional[List[torch.Tensor]], - bias_scale: Optional[torch.Tensor], - m_sizes: List[int], - dtype: torch.dtype, -) -> List[torch.Tensor]: - """Reference: add (optionally per-row scaled) bias to each group's output, cast to ``dtype``.""" - if bias is None: - return list(base_outs) - if bias_scale is None: - return [(o.float() + b.float()).to(dtype) for o, b in zip(base_outs, bias)] - out = [] - offset = 0 - for i, ms in enumerate(m_sizes): - s = bias_scale[offset : offset + ms].unsqueeze(-1) - out.append((base_outs[i].float() + bias[i].float() * s).to(dtype)) - offset += ms - return out - - -@pytest.mark.parametrize( - "z, m, n, k", - [ - (4, 256, 256, 256), - (4, 512, 256, 512), - (4, 512, 512, 256), - (8, 512, 256, 512), - ], -) -@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("use_bias_scale", [False, True]) -def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_bias_scale) -> None: - if tex.get_cublasLt_version() < 130300: - pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") - if not is_bf16_available(): - pytest.skip("bfloat16 is required for grouped GEMM test.") - - torch.manual_seed(0) - - dtype = torch.bfloat16 - - split_points = torch.randperm(m - 1)[: z - 1] + 1 - split_points = torch.sort(split_points).values.tolist() - m_sizes = [split_points[0]] - m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] - m_sizes.append(m - split_points[-1]) - assert sum(m_sizes) == m and len(m_sizes) == z - - if layout == "NT": - A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input - B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] - else: - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [ - torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") - for ms in m_sizes - ] # TN --> input, NN --> grad_output - out = [ - torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") - for ms in m_sizes - ] # TN --> output, NN --> dgrad - if layout == "NN": - out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] - else: # layout == "TN" - out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] - - if accumulate: - out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] - - # Bias is applied after GEMM (broadcasted along rows) - # Match kernel behavior: GEMM output is already in output dtype when bias is added. - out_ref_no_bias = [o.to(dtype) for o in out_ref] - if layout == "TN": - bias_last_dim = n - else: # layout == "NT" or "NN" - bias_last_dim = k - bias = ( - [torch.randn(1, bias_last_dim, dtype=dtype, device="cuda") for _ in range(z)] - if case != "discrete_out" - else None - ) - bias_scale = None - if use_bias_scale and bias is not None and layout != "NT": - bias_scale = torch.randn(m, device="cuda", dtype=torch.float32) - # Bias add in grouped kernel accumulates in FP32 for BF16/FP16. - out_ref = _apply_grouped_bias_ref(out_ref_no_bias, bias, bias_scale, m_sizes, dtype) - # Create grouped tensors based on case - device = A[0].device - grouped_A = A - grouped_out = out - grouped_out_bias = [o.clone() for o in out] - grouped_out_no_bias = [o.clone() for o in out] - grouped_bias = None - if layout == "TN": - grouped_A = ( - _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A - ) # weight - grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input - if case != "discrete_out": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # output - grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) - grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) - elif layout == "NN": - grouped_A = ( - _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A - ) # weight - grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output - if case != "discrete_out": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - grouped_out_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - grouped_out_no_bias = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - else: # layout == "NT" - grouped_A = ( - _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - if case != "discrete_in" - else A - ) # input - grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output - if case != "discrete_out": - grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) # wgrad - grouped_out_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) - grouped_out_no_bias = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_B, B) - if case != "discrete_out": - _pack_grouped_tensor(grouped_out, out) - _pack_grouped_tensor(grouped_out_bias, out) - _pack_grouped_tensor(grouped_out_no_bias, out) - if case != "discrete_in": - _pack_grouped_tensor(grouped_A, A) - - if bias is not None: - grouped_bias = _make_grouped_tensor_uniform(z, 1, bias_last_dim, device, dtype) - _pack_grouped_tensor(grouped_bias, bias) - - general_grouped_gemm_for_grouped_tensor( - grouped_A, - grouped_B, - grouped_out_no_bias, - layout=layout, - accumulate=accumulate, - bias=None, - ) - general_grouped_gemm_for_grouped_tensor( - grouped_A, - grouped_B, - grouped_out_bias, - layout=layout, - accumulate=accumulate, - bias=grouped_bias, - bias_scale=bias_scale, - ) - out_grouped_no_bias = ( - grouped_out_no_bias - if isinstance(grouped_out_no_bias, list) - else grouped_out_no_bias.split_into_quantized_tensors() - ) - out_grouped_bias = ( - grouped_out_bias - if isinstance(grouped_out_bias, list) - else grouped_out_bias.split_into_quantized_tensors() - ) - - out_grouped_manual_bias = _apply_grouped_bias_ref( - out_grouped_no_bias, bias, bias_scale, m_sizes, dtype - ) - tols = dtype_tols(dtype) - for o, o_ref in zip(out_grouped_no_bias, out_ref_no_bias): - torch.testing.assert_close(o, o_ref, **tols) - if bias is not None: - for o, o_ref in zip(out_grouped_bias, out_grouped_manual_bias): - torch.testing.assert_close(o, o_ref, **tols) - - -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("quant_type", ["bf16", "mxfp8"]) -def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) -> None: - """Grouped GEMM with all-zero split sizes (zero total work). - - For wgrad (NT layout) the output should be zero when not accumulating, - or unchanged when accumulating with beta=1. - """ - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") - if not is_bf16_available(): - pytest.skip("bfloat16 is required for grouped GEMM test.") - if quant_type == "mxfp8" and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - - z = 4 - k, n = 256, 256 - dtype = torch.bfloat16 - device = torch.device("cuda") - use_mxfp8 = quant_type == "mxfp8" - - transa = layout[0] == "T" - transb = layout[1] == "T" - zero_first_dims = torch.zeros(z, dtype=torch.int64, device=device) - - def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): - """Create a GroupedTensor with non-zero logical_shape but zero first_dims.""" - buf = torch.randn(0, logical_last_dim, dtype=dtype, device=device) - if use_mxfp8: - if is_a: - rowwise, columnwise = transa, not transa - else: - rowwise, columnwise = not transb, transb - quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=rowwise, - columnwise=columnwise, - ) - quantizer.optimize_for_gemm = True - return tex.group_quantize(buf, quantizer, z, zero_first_dims) - return GroupedTensor.make_grouped_tensor( - num_tensors=z, - first_dims=zero_first_dims, - last_dims=None, - logical_first_dim=k, - logical_last_dim=logical_last_dim, - quantizer=None, - device=device, - dtype=dtype, - ) - - if layout in ("TN", "NN"): - weight_tensors = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] - if use_mxfp8: - grouped_A = _make_grouped_tensor_quantized_mxfp8( - weight_tensors, - rowwise=transa, - columnwise=not transa, - device=device, - ) - else: - grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_A, weight_tensors) - else: # NT - grouped_A = _make_zero_tokens_grouped_tensor(k, is_a=True) - - b_last_dim = k if layout == "TN" else n - grouped_B = _make_zero_tokens_grouped_tensor(b_last_dim, is_a=False) - - if layout == "NT": - out = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] - grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_out, out) - else: - out = [torch.zeros(0, dtype=dtype, device=device) for _ in range(z)] - out_last_dim = n if layout == "TN" else k - grouped_out = GroupedTensor.make_grouped_tensor( - num_tensors=z, - first_dims=zero_first_dims, - last_dims=None, - logical_first_dim=k, - logical_last_dim=out_last_dim, - quantizer=None, - device=device, - dtype=dtype, - ) - - out_before = [o.clone() for o in out] - - general_grouped_gemm_for_grouped_tensor( - grouped_A, - grouped_B, - grouped_out, - layout=layout, - accumulate=accumulate, - ) - - out_result = ( - grouped_out if isinstance(grouped_out, list) else grouped_out.split_into_quantized_tensors() - ) - for i in range(z): - if out_result[i].numel() == 0: - continue - if accumulate: - torch.testing.assert_close(out_result[i], out_before[i]) - else: - torch.testing.assert_close(out_result[i], torch.zeros_like(out_result[i])) - - -def _make_grouped_tensor_quantized_mxfp8( - tensors: List[torch.Tensor], - *, - rowwise: bool, - columnwise: bool, - device: torch.device, - is_weight: bool = False, -) -> GroupedTensor: - """Create a quantized MXFP8 GroupedTensor from a list of per-expert tensors. - - For weights (uniform per-expert shape), we generally won't keep it swizzled since we - might need for future dequantize operations. Swizzling is done internally within - general_grouped_gemm_for_grouped_tensor call. - - For non-weight tensors (inputs / grad_outputs), we still pass - ``first_dims`` and keep ``optimize_for_gemm=True``; so the kernel must emit the - already-swizzled layout up front. - """ - if not tensors: - raise ValueError("Expected non-empty tensor list for grouped quantization.") - quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=rowwise, - columnwise=columnwise, - ) - quantizer.optimize_for_gemm = not is_weight - grouped_input = torch.cat(tensors, dim=0) - if is_weight: - first_dims = None - else: - first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) - return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) - - -def _per_tensor_quantize_mxfp8( - tensors: List[torch.Tensor], - *, - rowwise: bool, - columnwise: bool, -) -> List: - """Quantize each tensor individually with MXFP8. - Used to build reference discrete inputs for grouped GEMM. - """ - quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=rowwise, - columnwise=columnwise, - ) - return [quantizer(t) for t in tensors] - - -@pytest.mark.parametrize( - "shape", - [ - (1, 128, 128, 512), - (8, 1024, 128, 512), - (16, 4096, 128, 512), - (2, 256, 2880, 2880), - ], -) -@pytest.mark.parametrize("accumulate", [False, True]) -@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) -@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_grouped_gemm_grouped_tensor_mxfp8( - shape, accumulate, layout: str, case: str, dtype: torch.dtype -) -> None: - if tex.get_cublasLt_version() < 130300: - pytest.skip("Grouped GEMM requires cuBLAS 13.3+.") - if torch.cuda.get_device_capability() < (10, 0): - pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") - if dtype == torch.bfloat16 and not is_bf16_available(): - pytest.skip("bfloat16 is required for grouped GEMM test.") - - torch.manual_seed(0) - z, m, k, n = shape - m_sizes = [m // z] * z - - if layout == "TN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input - out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output - grad = False - elif layout == "NN": - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output - out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad - grad = True - else: # layout == "NT" - A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input - B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output - out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad - grad = True - - out_ref = [o.clone() for o in out] - - transa = layout[0] == "T" - transb = layout[1] == "T" - a_is_weight = all(t.shape == A[0].shape for t in A) - a_rowwise, a_columnwise = transa, not transa - b_rowwise, b_columnwise = not transb, transb - grouped_A = _make_grouped_tensor_quantized_mxfp8( - A, - rowwise=a_rowwise, - columnwise=a_columnwise, - device="cuda", - is_weight=a_is_weight, - ) - grouped_B = _make_grouped_tensor_quantized_mxfp8( - B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda" - ) - A_fp8 = _per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise) - B_fp8 = _per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise) - - general_grouped_gemm( - A_fp8, - B_fp8, - out_ref, - [None] * z, - dtype, - m_splits=m_sizes, - grad=grad, - accumulate=accumulate, - layout=layout, - single_output=False, - ) - - device = A[0].device - - grouped_out = None - if case != "discrete_out": - if layout == "TN": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) - elif layout == "NN": - grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) - else: # layout == "NT" - grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) - _pack_grouped_tensor(grouped_out, out) - - grouped_out_input = out if case == "discrete_out" else grouped_out - grouped_A_input = A_fp8 if case == "discrete_in" else grouped_A - general_grouped_gemm_for_grouped_tensor( - grouped_A_input, - grouped_B, - grouped_out_input, - layout=layout, - accumulate=accumulate, - ) - - out_grouped = out if case == "discrete_out" else grouped_out.split_into_quantized_tensors() - tols = dict(rtol=0.125, atol=0.0675) # mxfp8 tolerance - - for o, o_ref in zip(out_grouped, out_ref): - torch.testing.assert_close(o, o_ref, **tols) - - @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( @@ -3367,72 +2171,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua (16, 4096, 128, 512), ], ) -@pytest.mark.parametrize("accumulate", [False, True]) -def test_fp8_grouped_gemm(shape, accumulate): - if not fp8_available: - pytest.skip(reason_for_no_fp8) - - z, m, k, n = shape - m_splits = [m // z] * z - - dtype = torch.bfloat16 - A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output - out_ref = [o.clone() for o in out] - - # fp8 should be robust enough to this fake scale - scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() - amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - - a_quantizers = [ - Float8Quantizer( - scale.clone(), - amax.clone(), - tex.DType.kFloat8E4M3, - ) - for _ in range(z) - ] - b_quantizers = [ - Float8Quantizer( - scale.clone(), - amax.clone(), - tex.DType.kFloat8E4M3, - ) - for _ in range(z) - ] - - A_fp8 = [] - B_fp8 = [] - - for i in range(z): - A_fp8.append(a_quantizers[i](A[i])) - B_fp8.append(b_quantizers[i](B[i])) - - # baseline - for i in range(z): - general_gemm( - A_fp8[i], - B_fp8[i], - dtype, - out=out_ref[i], - accumulate=accumulate, - ) - general_grouped_gemm( - A_fp8, - B_fp8, - out, - [None] * z, - dtype, - m_splits=m_splits, - accumulate=accumulate, - ) - - # should be bit-wise match - for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) - - def test_noncontiguous(): def _create2modules(m, params): mod1 = m(*params) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 627144345c..e76b92c60d 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,8 +3,10 @@ # See LICENSE for license information. """GroupedLinear API""" + from typing import Union, Optional, Callable, Tuple, List from itertools import chain +import os import warnings import weakref @@ -29,6 +31,7 @@ divide, cast_if_needed, clear_tensor_data, + get_device_compute_capability, init_method_constant, requires_grad, resolve_grouped_linear_single_param_flags, @@ -42,12 +45,15 @@ ) from ..cpp_extensions import ( general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload +from ..triton.grouped_dbias_dscales import compute_grouped_dbias from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..quantized_tensor import ( QuantizedTensorStorage, Quantizer, @@ -65,6 +71,307 @@ class _GroupedLinear(torch.autograd.Function): Calls custom cuda extensions. """ + @staticmethod + def _maybe_dequantize( + tensor: Union[torch.Tensor, QuantizedTensorStorage], + dtype: torch.dtype, + ) -> torch.Tensor: + """Dequantize quantized tensors or cast regular tensors to ``dtype``.""" + if isinstance(tensor, QuantizedTensorStorage): + return tensor.dequantize(dtype=dtype) + return cast_if_needed(tensor, dtype) + + @staticmethod + def _is_grouped_tensor_path_supported( + *, + fp8: bool, + fp8_calibration: bool, + debug: bool, + cpu_offloading: bool, + backward_override: Optional[str], + save_original_input: bool, + activation_dtype: torch.dtype, + input_quantizers: List[Optional[Quantizer]], + weight_quantizers: List[Optional[Quantizer]], + output_quantizers: List[Optional[Quantizer]], + grad_output_quantizers: List[Optional[Quantizer]], + ) -> bool: + """Whether to use cublasLt grouped GEMM through GroupedTensor metadata.""" + if not bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "1"))): + return False + if ( + debug + or cpu_offloading + or fp8_calibration + or backward_override is not None + or save_original_input + ): + return False + if get_device_compute_capability() < (10, 0): + return False + if any(q is not None for q in output_quantizers): + return False + # Graph-safe callers may provide m_splits as a CUDA tensor. Avoid Python + # value checks here; split validity and MXFP8 alignment are left to the + # grouped quantize/GEMM kernels and their metadata validation. + if fp8: + return ( + activation_dtype in (torch.bfloat16, torch.float16) + and all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) + and all(isinstance(q, MXFP8Quantizer) for q in weight_quantizers) + and all(q is None or isinstance(q, MXFP8Quantizer) for q in grad_output_quantizers) + ) + return activation_dtype in (torch.bfloat16, torch.float16) + + @staticmethod + def _make_grouped_tensor( + data: torch.Tensor, + *, + num_gemms: int, + split_sizes: torch.Tensor, + base_split_offsets: torch.Tensor, + last_dim: int, + dtype: torch.dtype, + ) -> GroupedTensor: + """Wrap a packed 2D buffer as a varying-first-dimension GroupedTensor.""" + return GroupedTensor( + shape=(data.size(0), last_dim), + dtype=dtype, + num_tensors=num_gemms, + quantizer=None, + data=data.reshape(-1), + first_dims=split_sizes, + tensor_offsets=base_split_offsets * last_dim, + ) + + @staticmethod + def _make_grouped_bias( + biases: Tuple[torch.Tensor, ...], + *, + num_gemms: int, + out_features: int, + dtype: torch.dtype, + ) -> GroupedTensor: + """Pack per-GEMM biases into the grouped GEMM bias format.""" + bias_data = torch.stack( + [_GroupedLinear._maybe_dequantize(bias, dtype) for bias in biases], + dim=0, + ).contiguous() + return GroupedTensor( + shape=(num_gemms, out_features), + dtype=dtype, + num_tensors=num_gemms, + shapes=[(1, out_features)] * num_gemms, + quantizer=None, + data=bias_data.reshape(-1), + ) + + @staticmethod + def _prepare_weights_for_grouped_tensor_gemm( + weights: Tuple[torch.Tensor, ...], + weight_quantizers: List[Optional[Quantizer]], + weight_workspaces: List[Optional[QuantizedTensorStorage]], + *, + with_quantized_compute: bool, + columnwise_usage: bool, + activation_dtype: torch.dtype, + is_first_microbatch: Optional[bool], + skip_fp8_weight_update: Optional[torch.Tensor], + cache_weight: bool, + ) -> Tuple[List[torch.Tensor], List[Optional[QuantizedTensorStorage]]]: + """Prepare discrete weight tensors for GroupedTensor GEMM.""" + weights_for_gemm: List[torch.Tensor] = [] + new_workspaces: List[Optional[QuantizedTensorStorage]] = [None] * len(weights) + if not with_quantized_compute: + return ( + [_GroupedLinear._maybe_dequantize(weight, activation_dtype) for weight in weights], + new_workspaces, + ) + + update_ws = is_first_microbatch is None or is_first_microbatch + for idx, weight in enumerate(weights): + weight_quantizer = weight_quantizers[idx] + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + weight_fp8, new_workspaces[idx] = quantize_weight( + tensor=weight, + quantizer=weight_quantizer, + workspace=weight_workspaces[idx] if weight_workspaces else None, + update_workspace=update_ws, + skip_update_flag=skip_fp8_weight_update, + workspace_dtype=activation_dtype, + cache=cache_weight, + ) + weights_for_gemm.append(weight_fp8) + return weights_for_gemm, new_workspaces + + @staticmethod + def _forward_grouped_tensor( + ctx, + *, + inp: torch.Tensor, + m_splits: Union[List[int], torch.Tensor], + use_bias: bool, + is_first_microbatch: Optional[bool], + fp8: bool, + wgrad_store: Optional[WeightGradStore], + input_quantizers: List[Optional[Quantizer]], + weight_quantizers: List[Optional[Quantizer]], + grad_input_quantizers: List[Optional[Quantizer]], + grad_weight_quantizers: List[Optional[Quantizer]], + grad_output_quantizers: List[Optional[Quantizer]], + fuse_wgrad_accumulation: bool, + activation_dtype: torch.dtype, + is_grad_enabled: bool, + weight_workspaces: List[Optional[QuantizedTensorStorage]], + cache_weight: bool, + skip_fp8_weight_update: Optional[torch.Tensor], + weights: Tuple[torch.Tensor, ...], + biases: Tuple[torch.Tensor, ...], + ) -> Tuple[torch.Tensor, list]: + """Forward path backed by GroupedTensor + cublasLt grouped GEMM.""" + num_gemms = len(m_splits) + device = inp.device + in_features = weights[0].size(-1) + out_features = weights[0].size(0) + weight_requires_grad = weights[0].requires_grad + + split_sizes = torch.as_tensor(m_splits, dtype=torch.int64, device=device) + base_split_offsets = tex.splits_to_offsets(split_sizes, 1) + + inp_view = inp.reshape(-1, in_features) + x = cast_if_needed(inp_view, activation_dtype) + if fp8: + input_quantizer = input_quantizers[0] + input_quantizer.set_usage( + rowwise=True, + columnwise=is_grad_enabled and weight_requires_grad, + ) + input_quantizer.optimize_for_gemm = True + grouped_x = tex.group_quantize(x, input_quantizer, num_gemms, split_sizes) + else: + grouped_x = _GroupedLinear._make_grouped_tensor( + x, + num_gemms=num_gemms, + split_sizes=split_sizes, + base_split_offsets=base_split_offsets, + last_dim=in_features, + dtype=activation_dtype, + ) + + columnwise_usage = is_grad_enabled and inp.requires_grad + weights_for_gemm, new_workspaces = _GroupedLinear._prepare_weights_for_grouped_tensor_gemm( + weights, + weight_quantizers, + weight_workspaces, + with_quantized_compute=fp8, + columnwise_usage=columnwise_usage, + activation_dtype=activation_dtype, + is_first_microbatch=is_first_microbatch, + skip_fp8_weight_update=skip_fp8_weight_update, + cache_weight=cache_weight, + ) + + out = torch.empty( + [x.size(0), out_features], + dtype=activation_dtype, + device=device, + ) + grouped_out = _GroupedLinear._make_grouped_tensor( + out, + num_gemms=num_gemms, + split_sizes=split_sizes, + base_split_offsets=base_split_offsets, + last_dim=out_features, + dtype=activation_dtype, + ) + + grouped_bias = None + if use_bias: + grouped_bias = _GroupedLinear._make_grouped_bias( + biases, + num_gemms=num_gemms, + out_features=out_features, + dtype=activation_dtype, + ) + + use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + + general_grouped_gemm_for_grouped_tensor( + weights_for_gemm, + grouped_x, + grouped_out, + layout="TN", + bias=grouped_bias, + use_split_accumulator=use_split_accumulator, + ) + + if is_grad_enabled: + if weight_requires_grad: + if fp8: + grouped_x.rowwise_data = None + grouped_x.scale_inv = None + else: + grouped_x = None + + weights_to_save = weights_for_gemm if inp.requires_grad else [None] * num_gemms + tensors_to_save, tensor_objects = prepare_for_saving( + grouped_x, + *weights_to_save, + split_sizes, + base_split_offsets, + ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.use_grouped_tensor_path = True + ctx.weight_quantizers = weight_quantizers + ctx.weights_shape_0 = out_features + ctx.weights_shape_1 = in_features + ctx.grad_input_quantizers = grad_input_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_weight_quantizers = grad_weight_quantizers + ctx.weights_requires_grad = weight_requires_grad + if fuse_wgrad_accumulation and ctx.weights_requires_grad: + ctx.origin_weight_refs = [weakref.ref(w) for w in weights] + ctx.origin_weights_overwrite_main_grad = getattr( + weights[0], "overwrite_main_grad", False + ) + if hasattr(weights[0], "__fsdp_param__"): + ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + else: + ctx.main_grad_funcs = [ + lambda j=i: weights[j].main_grad for i in range(num_gemms) + ] + ctx.device = device + ctx.m_splits = m_splits + ctx.num_gemms = num_gemms + ctx.activation_dtype = activation_dtype + ctx.fp8 = fp8 + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.backward_override = None + ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = False + ctx.is_first_microbatch = is_first_microbatch + ctx.use_bias = use_bias + ctx.inp_shape = inp.shape + ctx.requires_dgrad = inp.requires_grad + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): + ctx.reduce_and_update_bwd_fp8_tensors = ( + ctx.reduce_and_update_bwd_fp8_tensors + or FP8GlobalStateManager.is_first_fp8_module() + ) + ctx.wgrad_store = wgrad_store + ctx.debug = False + ctx.save_original_input = False + ctx.input_quantizers = input_quantizers + + return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspaces + # pylint: disable=keyword-arg-before-vararg @staticmethod def forward( @@ -168,6 +475,43 @@ def forward( f"Input tensor (shape={tuple(inp.size())}) is not compatible with " f"weight tensor (shape={tuple(weights[0].size())})" ) + + if _GroupedLinear._is_grouped_tensor_path_supported( + fp8=fp8, + fp8_calibration=fp8_calibration, + debug=debug, + cpu_offloading=cpu_offloading, + backward_override=backward_override, + save_original_input=save_original_input, + activation_dtype=activation_dtype, + input_quantizers=input_quantizers, + weight_quantizers=weight_quantizers, + output_quantizers=output_quantizers, + grad_output_quantizers=grad_output_quantizers, + ): + return _GroupedLinear._forward_grouped_tensor( + ctx, + inp=inp, + m_splits=m_splits, + use_bias=use_bias, + is_first_microbatch=is_first_microbatch, + fp8=fp8, + wgrad_store=wgrad_store, + input_quantizers=input_quantizers, + weight_quantizers=weight_quantizers, + grad_input_quantizers=grad_input_quantizers, + grad_weight_quantizers=grad_weight_quantizers, + grad_output_quantizers=grad_output_quantizers, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + activation_dtype=activation_dtype, + is_grad_enabled=is_grad_enabled, + weight_workspaces=weight_workspaces, + cache_weight=cache_weight, + skip_fp8_weight_update=skip_fp8_weight_update, + weights=weights, + biases=biases, + ) + inp_view = inp.reshape(-1, in_features) inputmats: list if fp8 and not debug: @@ -256,6 +600,7 @@ def forward( mark_not_offload(*weights_fp8, *weights) if is_grad_enabled: + ctx.use_grouped_tensor_path = False ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] @@ -276,10 +621,18 @@ def forward( else: inputmats = [None] * num_gemms + # Original weights are only needed by high_precision dgrad. The weakrefs + # used for fused wgrad accumulation serve a different purpose: restoring + # Python parameter attributes without keeping the parameter alive here. + saved_weights = ( + weights + if backward_override == "high_precision" and inp.requires_grad + else [None] * num_gemms + ) tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, - *weights, + *saved_weights, *biases, ) ctx.save_for_backward(*tensors_to_save) @@ -349,12 +702,199 @@ def forward( # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspaces + @staticmethod + def _backward_grouped_tensor( + ctx, + grad_output: torch.Tensor, + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Backward path paired with ``_forward_grouped_tensor``.""" + saved_tensors = restore_from_func_ctx(ctx) + N = ctx.num_gemms + grouped_x = saved_tensors[0] + weights = saved_tensors[1 : 1 + N] + split_sizes = saved_tensors[1 + N] + base_split_offsets = saved_tensors[2 + N] + + origin_weights = [None] * N + main_grads = [None] * N + if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: + origin_weight_refs = ctx.origin_weight_refs + ctx.origin_weight_refs = None + origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs] + assert all( + w is not None for w in origin_weights + ), "weight was removed while fuse_wgrad_accumulation=True" + main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + for origin_weight, main_grad in zip(origin_weights, main_grads): + if main_grad is not None: + origin_weight.main_grad = main_grad + + grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) + dy_2d = cast_if_needed(grad_output_view, ctx.activation_dtype) + dbias_packed = None + if ctx.fp8: + grad_output_quantizer = ctx.grad_output_quantizers[0] + grad_output_quantizer.set_usage( + rowwise=ctx.requires_dgrad, + columnwise=ctx.weights_requires_grad, + ) + grad_output_quantizer.optimize_for_gemm = True + if ctx.use_bias: + grouped_dy, dbias_packed = tex.bgrad_group_quantize( + dy_2d, + grad_output_quantizer, + N, + split_sizes, + ) + else: + grouped_dy = tex.group_quantize( + dy_2d, + grad_output_quantizer, + N, + split_sizes, + ) + else: + grouped_dy = _GroupedLinear._make_grouped_tensor( + dy_2d, + num_gemms=N, + split_sizes=split_sizes, + base_split_offsets=base_split_offsets, + last_dim=ctx.weights_shape_0, + dtype=ctx.activation_dtype, + ) + + grad_biases = [None] * N + if ctx.use_bias: + if dbias_packed is None: + dbias_packed = compute_grouped_dbias(dy_2d, base_split_offsets, N) + grad_biases = [dbias_packed[i].to(dtype=ctx.activation_dtype) for i in range(N)] + + dgrad = None + if ctx.requires_dgrad: + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + for weight in weights: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) + dgrad = torch.empty( + (dy_2d.size(0), ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) + grouped_dgrad = _GroupedLinear._make_grouped_tensor( + dgrad, + num_gemms=N, + split_sizes=split_sizes, + base_split_offsets=base_split_offsets, + last_dim=ctx.weights_shape_1, + dtype=ctx.activation_dtype, + ) + general_grouped_gemm_for_grouped_tensor( + weights, + grouped_dy, + grouped_dgrad, + layout="NN", + use_split_accumulator=dgrad_gemm_use_split_accumulator, + ) + + if ctx.is_first_microbatch is not None: + accumulate_wgrad_into_param_main_grad = ( + ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + ) + else: + accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + + if ctx.weights_requires_grad: + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + if ctx.fuse_wgrad_accumulation: + wgrad_list = main_grads + else: + weight_shape = [ctx.weights_shape_0, ctx.weights_shape_1] + wgrad_list = tex.bulk_allocate( + [weight_shape] * N, + [ctx.activation_dtype] * N, + ctx.device, + [256] * N, + ) + + accumulate = ( + accumulate_wgrad_into_param_main_grad + if not getattr(ctx, "origin_weights_overwrite_main_grad", False) + else False + ) + + def grouped_gemm_wgrad(inputmats, grad_output_mats, grad_weights): + general_grouped_gemm_for_grouped_tensor( + inputmats, + grad_output_mats, + grad_weights, + layout="NT", + use_split_accumulator=wgrad_gemm_use_split_accumulator, + accumulate=accumulate, + ) + return None, [None] * N, None + + if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + ctx.wgrad_store.put([grouped_x, grouped_dy, wgrad_list], grouped_gemm_wgrad) + else: + grouped_gemm_wgrad(grouped_x, grouped_dy, wgrad_list) + + def handle_custom_ddp_from_mcore(weight, main_grad, wgrad): + if ctx.weights_requires_grad: + if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): + weight.grad_added_to_main_grad = True + if getattr(weight, "zero_out_wgrad", False): + wgrad = get_dummy_wgrad( + list(main_grad.shape), + weight.dtype, + zero=True, + ) + else: + wgrad = get_dummy_wgrad( + list(main_grad.shape), + weight.dtype, + ) + elif ctx.fuse_wgrad_accumulation: + wgrad = None + else: + wgrad = None + return wgrad + + wgrad_list = [ + handle_custom_ddp_from_mcore(weight, main_grad, wgrad) + for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) + ] + else: + wgrad_list = [None] * N + + if not ctx.use_bias: + grad_biases = [None] * N + + if ctx.reduce_and_update_bwd_fp8_tensors: + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( + dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + None, + *wgrad_list, + *grad_biases, + ) + @staticmethod def backward( ctx, grad_output: torch.Tensor, _grad_workspaces ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): + if getattr(ctx, "use_grouped_tensor_path", False): + return _GroupedLinear._backward_grouped_tensor(ctx, grad_output) + saved_tensors = restore_from_func_ctx(ctx) N = ctx.num_gemms inputmats = saved_tensors[:N] @@ -1105,7 +1645,7 @@ def _load_from_state_dict( def forward( self, inp: torch.Tensor, - m_splits: List[int], + m_splits: Union[List[int], torch.Tensor], is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -1115,8 +1655,9 @@ def forward( ---------- inp : torch.Tensor Input tensor. - m_splits : List[int] - List of integers representing the split of the input tensor. + m_splits : Union[List[int], torch.Tensor] + Split sizes for the input tensor. The grouped tensor path accepts + a CUDA tensor here to avoid Python value checks during graph capture. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -1237,9 +1778,14 @@ def backward_dw(self): if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) - if self.use_bias: + has_grad_biases = [ + grad_bias is not None and grad_bias.numel() != 0 for grad_bias in grad_biases_ + ] + if self.use_bias and any(has_grad_biases): grouped_bias = getattr(self, "bias", None) if grouped_bias is not None: + if not all(has_grad_biases): + raise RuntimeError("Expected all grouped bias gradients to be present.") gstack = torch.stack(grad_biases_, dim=0).to(grouped_bias.dtype) if grouped_bias.grad is None: grouped_bias.grad = gstack @@ -1248,7 +1794,7 @@ def backward_dw(self): else: bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] for i in range(self.num_gemms): - if bias_params[i].grad is None: + if has_grad_biases[i] and bias_params[i].grad is None: bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) del grad_biases_ del wgrad_list