feat: fuse add and rmsnorm#1368
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a fused add-RMSNorm Triton kernel and integrates it into the transformer layer inference template and the Qwen3Next model to optimize token forward passes. The review feedback highlights two critical runtime issues in the autotuning setup for the new kernel: both the static key function and the run key lambda do not accept all arguments passed by the autotuner decorator, which will result in TypeError exceptions. Code suggestions are provided to handle extra arguments robustly using *args, **kwargs.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _get_add_rmsnorm_static_key( | ||
| x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor | ||
| ): |
There was a problem hiding this comment.
The _get_add_rmsnorm_static_key function is used as the static_key_func for the @autotune decorator on _add_rmsnorm_forward. However, _add_rmsnorm_forward accepts 6 arguments (x_arg, residual_arg, y_arg, weight, eps, run_config), while _get_add_rmsnorm_static_key only accepts 4. When the autotuner invokes this function with all arguments, it will raise a TypeError at runtime. Adding *args, **kwargs to the signature will make it robust against extra arguments.
| def _get_add_rmsnorm_static_key( | |
| x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor | |
| ): | |
| def _get_add_rmsnorm_static_key( | |
| x_arg: torch.Tensor, residual_arg: torch.Tensor, y_arg: torch.Tensor, weight: torch.Tensor, *args, **kwargs | |
| ): |
| kernel_name="add_rmsnorm_forward:v1", | ||
| configs_gen_func=_get_add_rmsnorm_configs, | ||
| static_key_func=_get_add_rmsnorm_static_key, | ||
| run_key_func=lambda x_arg: x_arg.shape[0], |
There was a problem hiding this comment.
The run_key_func lambda lambda x_arg: x_arg.shape[0] only accepts 1 argument, but the autotuner will pass all arguments of the decorated _add_rmsnorm_forward function to it. This will cause a TypeError at runtime. Updating the lambda to accept *args, **kwargs will prevent this crash.
| run_key_func=lambda x_arg: x_arg.shape[0], | |
| run_key_func=lambda x_arg, *args, **kwargs: x_arg.shape[0], |
No description provided.