Skip to content

[PyTorch] Make modules.GroupedLinear graph-safe#3038

Open
yaox12 wants to merge 4 commits into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt
Open

[PyTorch] Make modules.GroupedLinear graph-safe#3038
yaox12 wants to merge 4 commits into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented May 22, 2026

Description

  • Enable grouped quantization and cuBLASLt grouped gemm for modules.GroupedLinear to benefit cases where cuteDSL fused grouped gemm is not available.

    1. Reduce CPU overhead by reducing number of kernels.
    2. Be CUDA-Graph-safe.
    3. Improve kernel performance.
  • Move grouped gemm and grouped linear related tests to a standalone file.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Xin Yao <xiny@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR introduces a new cublasLt grouped GEMM path for GroupedLinear that operates through GroupedTensor metadata, enabling CUDA graph safety on SM100+ by allowing m_splits to be a static CUDA tensor rather than a Python list. The implementation adds a routing function that selects the new path when the device is SM100+, quantizers are MXFP8 or absent, and no unsupported options (debug, cpu-offload, output quantizers) are active. Tests for grouped linear and grouped gemm operations are also moved into a dedicated test_grouped_linear.py file.

  • _forward_grouped_tensor / _backward_grouped_tensor in grouped_linear.py implement the new path end-to-end: grouped quantisation of the input, per-weight quantisation, a single cublasLt grouped GEMM for fprop/dgrad/wgrad, with a delay_wgrad_compute variant routed through wgrad_store.
  • backward_dw gains a has_grad_biases guard that prevents torch.stack(None) when single_grouped_bias=True and the wgrad store returns [None]*N bias slots (as the grouped-tensor path always does).
  • The benchmark converts m_splits to a CUDA tensor when the fused path is active, and the CI script adds the new test file to the run.

Confidence Score: 3/5

The grouped-tensor forward/backward logic is functionally correct for the intended usage, but the standard path is broken for any caller that passes a CUDA tensor when the grouped-tensor route is bypassed.

When _is_grouped_tensor_path_supported returns False at runtime on SM100+ (output_quantizers non-None, backward_override set, or env var disabled) while the caller has passed m_splits as a CUDA tensor, the standard path calls torch.split with a CUDA tensor, raising a TypeError. This breaks valid training configurations on SM100+ hardware when fallback conditions are hit.

transformer_engine/pytorch/module/grouped_linear.py — specifically the transition from the grouped-tensor early-return into the standard forward and backward paths when m_splits is a CUDA tensor.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds _forward_grouped_tensor / _backward_grouped_tensor paired methods and a _is_grouped_tensor_path_supported router enabling cublasLt grouped GEMM via GroupedTensor metadata on SM100+. P1 found: if the grouped-tensor path is bypassed at runtime while the caller has passed a CUDA tensor for m_splits, torch.split in the standard path raises a TypeError.
tests/pytorch/test_grouped_linear.py New standalone test file extracting grouped-linear tests from test_numerics.py; adds coverage for the grouped-tensor path, CUDA graph capture, and the single_grouped_bias + delay_wgrad combination. Env-var cleanup now correctly uses monkeypatch.setenv throughout.
tests/pytorch/test_numerics.py Grouped-gemm and grouped-linear tests removed and moved to the new test_grouped_linear.py; no functional changes to remaining tests.
benchmarks/linear/benchmark_grouped_linear.py Converts the benchmark's m_splits list to a CUDA tensor when the fused grouped-GEMM path is enabled, aligning the benchmark with the recommended graph-safe calling convention.
qa/L0_pytorch_unittest/test.sh Adds test_grouped_linear.py to the CI test suite.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["GroupedLinear.forward(inp, m_splits)"] --> B["_GroupedLinear.forward()"]
    B --> C{"_is_grouped_tensor_path_supported?"}
    C -- "SM>=10.0 + MXFP8/BF16 + no output quantizers + env var enabled" --> D["_forward_grouped_tensor()"]
    C -- "otherwise" --> E["Standard path (existing logic)"]
    D --> D1["torch.as_tensor(m_splits) / group_quantize / cublasLt grouped GEMM"]
    D1 --> D2["save: grouped_x, weights, split_sizes, base_offsets / ctx.use_grouped_tensor_path=True"]
    E --> E1["split_quantize / torch.split / N separate GEMMs"]
    E1 --> E2["save: inputmats, weights_fp8, biases, ctx.m_splits"]
    D2 --> F["loss.backward()"]
    E2 --> F
    F --> G{"ctx.use_grouped_tensor_path?"}
    G -- True --> H["_backward_grouped_tensor() / dgrad+wgrad grouped GEMMs / bias via compute_grouped_dbias"]
    G -- False --> I["Standard backward (existing logic)"]
    H --> J{"delay_wgrad_compute?"}
    J -- True --> K["wgrad_store.put() / bias grads returned via autograd"]
    J -- False --> L["compute wgrads immediately"]
    K --> M["backward_dw() / wgrad_store.pop() / fill weight.grad / bias skipped"]
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +399 to +413
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")

config = model_configs[model]
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

if recipe is not None and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Env var leak on test failure

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] is set via direct mutation and cleaned up with os.environ.pop at the very end of the function. If test_grouped_linear_accuracy(...) raises an assertion error (or any exception), the pop is never reached, leaving NVTE_USE_CUTLASS_GROUPED_GEMM=1 for all subsequent tests in the session. Tests that assume the cuBLAS path would then silently run with the CUTLASS path, which has looser tolerances (rtol=1e-3), and may either fail spuriously or — worse — pass when they shouldn't. The monkeypatch fixture (injected via the _reset_fp8_state autouse fixture) should be used for this env var as well, or a try/finally block should protect the cleanup.

Comment on lines +831 to +868

def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
if data is None:
data = grouped_tensor.columnwise_data
if data is None:
raise ValueError("GroupedTensor has no data buffers to pack.")
offset = 0
for tensor in tensors:
numel = tensor.numel()
data[offset : offset + numel].copy_(tensor.reshape(-1))
offset += numel


def _make_grouped_tensor_from_splits(
m_sizes: List[int],
last_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> GroupedTensor:
first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64)
return GroupedTensor.make_grouped_tensor(
num_tensors=len(m_sizes),
first_dims=first_dims,
last_dims=None,
logical_first_dim=sum(m_sizes),
logical_last_dim=last_dim,
quantizer=None,
device=device,
dtype=dtype,
)


def _make_grouped_tensor_uniform(
num_tensors: int,
first_dim: int,
last_dim: int,
device: torch.device,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Env var leak on test failure (same pattern)

os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" at line 832 is cleaned up by os.environ.pop(...) at the final line of the function. Any assertion failure or exception in the body of the test (e.g., torch.testing.assert_close on line 863) will bypass that cleanup, leaving the env var set for the rest of the test session. The same concern applies here as in test_grouped_linear_accuracy_cutlass: subsequent tests that expect cuBLAS semantics will run against the CUTLASS kernel instead. Use monkeypatch or try/finally to ensure cleanup.

Comment on lines +829 to +838
def grouped_gemm_wgrad(inputmats, grad_output_mats, grad_weights):
general_grouped_gemm_for_grouped_tensor(
inputmats,
grad_output_mats,
grad_weights,
layout="NT",
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate,
)
return None, [None] * N, None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 backward_dw will crash when single_grouped_bias=True and delay_wgrad_compute=True on SM100+

grouped_gemm_wgrad explicitly returns (None, [None] * N, None) — the middle element (bias-grad slot) is a list of Nones. When delay_wgrad_compute=True, this return value is stored in wgrad_store. Later, backward_dw calls wgrad_store.pop(), unpacks grad_biases_ as [None]*N, and then, when single_grouped_bias=True, executes torch.stack(grad_biases_, dim=0), which raises TypeError: expected Tensor, got NoneType.

For the default (single_grouped_bias=False) case this is safe because the per-GEMM if bias_params[i].grad is None guard prevents the None.to() call. However the single_grouped_bias=True code path has no such guard. The fix is to either include actual bias grads in the return tuple of grouped_gemm_wgrad, or check for None before stacking in backward_dw.

x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "0"))):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The benchmark guards the m_splits tensor conversion with default="0", but _is_grouped_tensor_path_supported uses default="1". On SM100+ hardware without the env var set the module will silently take the new grouped-tensor path while the benchmark still passes a Python list. This is functionally correct (the code handles it via torch.as_tensor), but it means the benchmark does not reflect the recommended graph-safe usage (passing a CUDA tensor). Aligning both defaults avoids confusion.

Suggested change
if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "0"))):
if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "1"))):

Comment thread tests/pytorch/test_grouped_linear.py Outdated
delay_wgrad_compute,
)

# Shoule be bit-wise match
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Typo in comment

"Shoule be bit-wise match""Should be bit-wise match". The same typo appears on lines 662 and 739.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant