Skip to content

GGEMM+srelu kernels for MxFP8 Nemotron#2981

Open
sraman-rgb wants to merge 12 commits into
NVIDIA:mainfrom
sraman-rgb:fc1-srelu-main
Open

GGEMM+srelu kernels for MxFP8 Nemotron#2981
sraman-rgb wants to merge 12 commits into
NVIDIA:mainfrom
sraman-rgb:fc1-srelu-main

Conversation

@sraman-rgb
Copy link
Copy Markdown

@sraman-rgb sraman-rgb commented May 12, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

@ksivaman
Copy link
Copy Markdown
Member

Please sign-off your commits @sraman-rgb

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 12, 2026

Greptile Summary

This PR refactors the MXFP8 fused grouped-MLP kernel infrastructure to support both GLU-style activations (SwiGLU, QGeGLU) and unary activations (SReLU), then adds the new ScaledSReLU op and corresponding ForwardGroupedMLP_CuTeGEMMUnary_MXFP8 / BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8 fused-kernel classes backed by cuDNN FE ≥ 1.24.0.

  • The concrete ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 class is renamed to ForwardGroupedMLP_CuTeGEMMGLU_MXFP8 and extracted into a base class that both the GLU and new Unary subclasses share; the same pattern applies to the backward.
  • ScaledSReLU is a new BasicOperation for squared-ReLU with per-row post-scaling supporting an activation_recompute_in_mlp knob.
  • fuse_grouped_mlp_ops is generalised to accept an activation_op_types tuple, and validate_grouped_mlp_dims is extended to handle both GLU (2× out-features) and unary (1× out-features) cases.

Confidence Score: 4/5

The new SReLU fused kernel path is logically sound but carries unresolved concerns from prior review rounds in the fused forward and backward files.

The base-class refactoring is clean and ScaledSReLU is correctly wired through both unfused and fused paths. The activation_recompute_in_mlp mechanism is self-consistent. The unreachable grad_scales guard is the main new finding — it cannot cause wrong results today but would silently discard the scale gradient if the kernel interface changed. Three prior-thread concerns remain open.

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py and forward_grouped_mlp.py deserve a closer look before merging.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Introduces base class + GLU/Unary subclasses for fused forward; adds SReLU recompute path and fuse_forward_srelu_ops; has a dead _grouped_gemm_dsrelu_backward_supported helper (flagged in prior review).
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Mirrors forward refactoring for backward; adds BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8; fuser_backward reconstructs grouped_fc2_x from the DSReLU kernel when recompute is enabled; contains an unreachable if grad_scales is not None guard.
transformer_engine/pytorch/ops/basic/activation.py Adds ScaledSReLU op with correct fuser_forward/fuser_backward overrides, activation-recompute guard, and proper ctx.save_for_backward layout consistent with the fused path.
transformer_engine/pytorch/ops/_common.py Generalises validate_grouped_mlp_dims, fuse_grouped_mlp_ops, and version-check helpers; extracts is_glu_activation.
transformer_engine/pytorch/ops/basic/swiglu.py Adds activation_recompute_in_mlp parameter to _ScaledGLU and ScaledClampedQGeGLU; raises informatively in non-fused paths when the flag is set.
tests/pytorch/test_fusible_ops.py Adds test_scaled_srelu, test_scaled_activation_recompute_in_mlp_config, and extends test_grouped_mlp for scaled_srelu; correctly skips when the unary fusion is unavailable.
transformer_engine/pytorch/ops/fused/init.py Re-exports renamed GLU and new Unary fused-op classes; straightforward change.
transformer_engine/pytorch/ops/basic/init.py Exports the new ScaledSReLU symbol; one-line change, no issues.

Reviews (12): Last reviewed commit: "Merge branch 'main' into fc1-srelu-main" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/basic/activation.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Signed-off-by: sraman-rgb <sraman@nvidia.com>
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but we've gotten to the point where we need to start thinking about how to gracefully handle adding new activations. It seems that every model has a different activation function.

Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread tests/pytorch/test_fusible_ops.py Outdated
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 18, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Siddhartha Raman S and others added 5 commits May 18, 2026 14:46
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We might want to wait on the cudnn release and apt cudnn guards are added.

Comment thread transformer_engine/pytorch/ops/_common.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Siddhartha Raman S added 3 commits May 21, 2026 08:40
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Comment thread transformer_engine/pytorch/ops/basic/activation.py Outdated
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 21, 2026
Siddhartha Raman S and others added 2 commits May 21, 2026 12:48
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants