[PyTorch] Make modules.GroupedLinear graph-safe#3038
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
Greptile SummaryThis PR introduces a new cublasLt grouped GEMM path for
Confidence Score: 3/5The grouped-tensor forward/backward logic is functionally correct for the intended usage, but the standard path is broken for any caller that passes a CUDA tensor when the grouped-tensor route is bypassed. When _is_grouped_tensor_path_supported returns False at runtime on SM100+ (output_quantizers non-None, backward_override set, or env var disabled) while the caller has passed m_splits as a CUDA tensor, the standard path calls torch.split with a CUDA tensor, raising a TypeError. This breaks valid training configurations on SM100+ hardware when fallback conditions are hit. transformer_engine/pytorch/module/grouped_linear.py — specifically the transition from the grouped-tensor early-return into the standard forward and backward paths when m_splits is a CUDA tensor. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["GroupedLinear.forward(inp, m_splits)"] --> B["_GroupedLinear.forward()"]
B --> C{"_is_grouped_tensor_path_supported?"}
C -- "SM>=10.0 + MXFP8/BF16 + no output quantizers + env var enabled" --> D["_forward_grouped_tensor()"]
C -- "otherwise" --> E["Standard path (existing logic)"]
D --> D1["torch.as_tensor(m_splits) / group_quantize / cublasLt grouped GEMM"]
D1 --> D2["save: grouped_x, weights, split_sizes, base_offsets / ctx.use_grouped_tensor_path=True"]
E --> E1["split_quantize / torch.split / N separate GEMMs"]
E1 --> E2["save: inputmats, weights_fp8, biases, ctx.m_splits"]
D2 --> F["loss.backward()"]
E2 --> F
F --> G{"ctx.use_grouped_tensor_path?"}
G -- True --> H["_backward_grouped_tensor() / dgrad+wgrad grouped GEMMs / bias via compute_grouped_dbias"]
G -- False --> I["Standard backward (existing logic)"]
H --> J{"delay_wgrad_compute?"}
J -- True --> K["wgrad_store.put() / bias grads returned via autograd"]
J -- False --> L["compute wgrads immediately"]
K --> M["backward_dw() / wgrad_store.pop() / fill weight.grad / bias skipped"]
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: | ||
| pytest.skip("FP8 parameters are not supported in debug mode.") | ||
| if fp8 and recipe.delayed(): | ||
| pytest.skip("DelayedScaling recipe is not supported with save_original_input") | ||
| if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: | ||
| pytest.skip("Delayed wgrad compute is not supported in debug mode.") | ||
|
|
||
| config = model_configs[model] | ||
| if config.max_seqlen_q % 16 != 0 and fp8: | ||
| pytest.skip("FP8 requires sequence length to be divisible by 16.") | ||
|
|
||
| if recipe is not None and recipe.nvfp4(): | ||
| if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): | ||
| pytest.skip( | ||
| f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" |
There was a problem hiding this comment.
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] is set via direct mutation and cleaned up with os.environ.pop at the very end of the function. If test_grouped_linear_accuracy(...) raises an assertion error (or any exception), the pop is never reached, leaving NVTE_USE_CUTLASS_GROUPED_GEMM=1 for all subsequent tests in the session. Tests that assume the cuBLAS path would then silently run with the CUTLASS path, which has looser tolerances (rtol=1e-3), and may either fail spuriously or — worse — pass when they shouldn't. The monkeypatch fixture (injected via the _reset_fp8_state autouse fixture) should be used for this env var as well, or a try/finally block should protect the cleanup.
|
|
||
| def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: | ||
| data = grouped_tensor.rowwise_data | ||
| if data is None: | ||
| data = grouped_tensor.columnwise_data | ||
| if data is None: | ||
| raise ValueError("GroupedTensor has no data buffers to pack.") | ||
| offset = 0 | ||
| for tensor in tensors: | ||
| numel = tensor.numel() | ||
| data[offset : offset + numel].copy_(tensor.reshape(-1)) | ||
| offset += numel | ||
|
|
||
|
|
||
| def _make_grouped_tensor_from_splits( | ||
| m_sizes: List[int], | ||
| last_dim: int, | ||
| device: torch.device, | ||
| dtype: torch.dtype, | ||
| ) -> GroupedTensor: | ||
| first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) | ||
| return GroupedTensor.make_grouped_tensor( | ||
| num_tensors=len(m_sizes), | ||
| first_dims=first_dims, | ||
| last_dims=None, | ||
| logical_first_dim=sum(m_sizes), | ||
| logical_last_dim=last_dim, | ||
| quantizer=None, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
|
|
||
| def _make_grouped_tensor_uniform( | ||
| num_tensors: int, | ||
| first_dim: int, | ||
| last_dim: int, | ||
| device: torch.device, |
There was a problem hiding this comment.
Env var leak on test failure (same pattern)
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" at line 832 is cleaned up by os.environ.pop(...) at the final line of the function. Any assertion failure or exception in the body of the test (e.g., torch.testing.assert_close on line 863) will bypass that cleanup, leaving the env var set for the rest of the test session. The same concern applies here as in test_grouped_linear_accuracy_cutlass: subsequent tests that expect cuBLAS semantics will run against the CUTLASS kernel instead. Use monkeypatch or try/finally to ensure cleanup.
| def grouped_gemm_wgrad(inputmats, grad_output_mats, grad_weights): | ||
| general_grouped_gemm_for_grouped_tensor( | ||
| inputmats, | ||
| grad_output_mats, | ||
| grad_weights, | ||
| layout="NT", | ||
| use_split_accumulator=wgrad_gemm_use_split_accumulator, | ||
| accumulate=accumulate, | ||
| ) | ||
| return None, [None] * N, None |
There was a problem hiding this comment.
backward_dw will crash when single_grouped_bias=True and delay_wgrad_compute=True on SM100+
grouped_gemm_wgrad explicitly returns (None, [None] * N, None) — the middle element (bias-grad slot) is a list of Nones. When delay_wgrad_compute=True, this return value is stored in wgrad_store. Later, backward_dw calls wgrad_store.pop(), unpacks grad_biases_ as [None]*N, and then, when single_grouped_bias=True, executes torch.stack(grad_biases_, dim=0), which raises TypeError: expected Tensor, got NoneType.
For the default (single_grouped_bias=False) case this is safe because the per-GEMM if bias_params[i].grad is None guard prevents the None.to() call. However the single_grouped_bias=True code path has no such guard. The fix is to either include actual bias grads in the return tuple of grouped_gemm_wgrad, or check for None before stacking in backward_dw.
| x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True) | ||
| ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)] | ||
| m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided | ||
| if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "0"))): |
There was a problem hiding this comment.
The benchmark guards the
m_splits tensor conversion with default="0", but _is_grouped_tensor_path_supported uses default="1". On SM100+ hardware without the env var set the module will silently take the new grouped-tensor path while the benchmark still passes a Python list. This is functionally correct (the code handles it via torch.as_tensor), but it means the benchmark does not reflect the recommended graph-safe usage (passing a CUDA tensor). Aligning both defaults avoids confusion.
| if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "0"))): | |
| if bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "1"))): |
| delay_wgrad_compute, | ||
| ) | ||
|
|
||
| # Shoule be bit-wise match |
Signed-off-by: Xin Yao <xiny@nvidia.com>
for more information, see https://pre-commit.ci
Description
Enable grouped quantization and cuBLASLt grouped gemm for
modules.GroupedLinearto benefit cases where cuteDSL fused grouped gemm is not available.Move grouped gemm and grouped linear related tests to a standalone file.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: