Skip to content

feat: fuse add and rmsnorm#1368

Open
blueswhen wants to merge 1 commit into
mainfrom
add_rmsnorm
Open

feat: fuse add and rmsnorm#1368
blueswhen wants to merge 1 commit into
mainfrom
add_rmsnorm

Conversation

@blueswhen

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused add-RMSNorm Triton kernel and integrates it into the transformer layer inference template and the Qwen3Next model to optimize token forward passes. The review feedback highlights two critical runtime issues in the autotuning setup for the new kernel: both the static key function and the run key lambda do not accept all arguments passed by the autotuner decorator, which will result in TypeError exceptions. Code suggestions are provided to handle extra arguments robustly using *args, **kwargs.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +157 to +159
def _get_add_rmsnorm_static_key(
x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _get_add_rmsnorm_static_key function is used as the static_key_func for the @autotune decorator on _add_rmsnorm_forward. However, _add_rmsnorm_forward accepts 6 arguments (x_arg, residual_arg, y_arg, weight, eps, run_config), while _get_add_rmsnorm_static_key only accepts 4. When the autotuner invokes this function with all arguments, it will raise a TypeError at runtime. Adding *args, **kwargs to the signature will make it robust against extra arguments.

Suggested change
def _get_add_rmsnorm_static_key(
x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor
):
def _get_add_rmsnorm_static_key(
x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor, *args, **kwargs
):

kernel_name="add_rmsnorm_forward:v1",
configs_gen_func=_get_add_rmsnorm_configs,
static_key_func=_get_add_rmsnorm_static_key,
run_key_func=lambda x_arg: x_arg.shape[0],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The run_key_func lambda lambda x_arg: x_arg.shape[0] only accepts 1 argument, but the autotuner will pass all arguments of the decorated _add_rmsnorm_forward function to it. This will cause a TypeError at runtime. Updating the lambda to accept *args, **kwargs will prevent this crash.

Suggested change
run_key_func=lambda x_arg: x_arg.shape[0],
run_key_func=lambda x_arg, *args, **kwargs: x_arg.shape[0],

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant