From 294b3db5dda69874b69a769baf6a903e67dc575f Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 11:07:34 -0700 Subject: [PATCH 1/3] add acc for blackwell --- examples/blackwell_attention.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/blackwell_attention.py b/examples/blackwell_attention.py index d96697071..62b670900 100644 --- a/examples/blackwell_attention.py +++ b/examples/blackwell_attention.py @@ -13,6 +13,7 @@ from __future__ import annotations import math +import os from typing import Callable import torch @@ -70,11 +71,20 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso ) +HAS_ACC = os.getenv("WITH_ACC", "0") == "1" +if HAS_ACC: + ACC_OPTIONS = [0, 11, 12, 13, 14, 15] + + def make_acc_option(option): + return {"advanced_compiler_configuration": option} +else: + ACC_OPTIONS = [None] + make_acc_option = lambda _: {} + # %% # Attention Kernel Implementation # ------------------------------- - # %% @helion.kernel( configs=[ @@ -89,10 +99,12 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso _triton_range_id_data_partition_factor=0, _triton_range_value_data_partition_factor=2, _triton_config_maxRegAutoWS=maxreg, + **make_acc_option(ACC_OPTION), ) for N in [64, 128] for OUTER_LOOP in [True] for maxreg in [152, 192] + for ACC_OPTION in ACC_OPTIONS ], static_shapes=True, autotune_accuracy_check=False, From 736d02b5e28c132405f9bf22d1b3446ddd02161b Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 14:28:24 -0700 Subject: [PATCH 2/3] hack: check once --- helion/autotuner/base_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 2982120d8..904841303 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -295,7 +295,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: self.log.debug(lambda: f"Running {config} at {datetime.datetime.now()}") t0 = time.perf_counter() # HACK: run checks multiple times to detect data races - for _ in range(5): + for _ in range(1): if self._kernel_mutates_args: self.args = self._clone_args(self._original_args) torch.accelerator.synchronize() From 1990fb8eaddabd4ce8d180aecc4421d1847c6046 Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Mon, 20 Oct 2025 15:19:13 -0700 Subject: [PATCH 3/3] allow stock to run blackwell_example --- examples/blackwell_attention.py | 61 +++++++++++++++++++++++------ helion/_compiler/device_function.py | 8 +++- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/examples/blackwell_attention.py b/examples/blackwell_attention.py index 62b670900..eeab4fe18 100644 --- a/examples/blackwell_attention.py +++ b/examples/blackwell_attention.py @@ -17,6 +17,7 @@ from typing import Callable import torch +import triton from triton.testing import do_bench import helion @@ -81,30 +82,66 @@ def make_acc_option(option): ACC_OPTIONS = [None] make_acc_option = lambda _: {} + +def _supports_reg_auto_ws(): + """Check if the current Triton version supports minRegAutoWS/maxRegAutoWS""" + try: + # Try to create a Config with minRegAutoWS to test support + test_config = triton.Config({}, minRegAutoWS=24, maxRegAutoWS=152) + return True + except (TypeError, AttributeError): + # Parameter not supported in this Triton version + return False + + +HAS_REG_AUTO_WS = _supports_reg_auto_ws() +print(f"!!!!!!!!!! {HAS_REG_AUTO_WS=} !!!!!!!!!!!!!") +if HAS_REG_AUTO_WS: + REG_AUTO_WS_OPTIONS = [152, 192] + M_OPTIONS = [256] + + def make_reg_auto_ws_option(maxreg): + OUTER_LOOP = True + return dict( + _triton_range_id_data_partition_factor=0, + _triton_range_value_data_partition_factor=2, + _triton_config_maxRegAutoWS=maxreg, + range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True], + range_multi_buffers=[None, False], + ) +else: + REG_AUTO_WS_OPTIONS = [None] + M_OPTIONS = [128] + + def make_reg_auto_ws_option(maxreg): + return dict( + _triton_range_id_data_partition_factor=-1, + _triton_range_value_data_partition_factor=-1, + _triton_config_maxRegAutoWS=-1, + ) + + # %% # Attention Kernel Implementation # ------------------------------- + # %% @helion.kernel( configs=[ helion.Config( - block_sizes=[256, N], - range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True], - range_multi_buffers=[None, False], + block_sizes=[M, N], pid_type="persistent_interleaved", indexing="tensor_descriptor", num_warps=4, num_stages=3, - _triton_range_id_data_partition_factor=0, - _triton_range_value_data_partition_factor=2, - _triton_config_maxRegAutoWS=maxreg, **make_acc_option(ACC_OPTION), + **make_reg_auto_ws_option(REG_AUTO_WS_OPTION), ) + for M in M_OPTIONS for N in [64, 128] - for OUTER_LOOP in [True] - for maxreg in [152, 192] for ACC_OPTION in ACC_OPTIONS + for REG_AUTO_WS_OPTION in REG_AUTO_WS_OPTIONS ], static_shapes=True, autotune_accuracy_check=False, @@ -148,12 +185,14 @@ def blackwell_attention_kernel( assert M % block_m == 0 assert N % block_n == 0 hl.register_tunable( - "_triton_range_id_data_partition_factor", EnumFragment(choices=(0,)) + "_triton_range_id_data_partition_factor", EnumFragment(choices=(-1, 0)) + ) + hl.register_tunable( + "_triton_range_value_data_partition_factor", EnumFragment(choices=(-1, 2)) ) hl.register_tunable( - "_triton_range_value_data_partition_factor", EnumFragment(choices=(2,)) + "_triton_config_maxRegAutoWS", EnumFragment(choices=(-1, 152, 192)) ) - hl.register_tunable("_triton_config_maxRegAutoWS", EnumFragment(choices=(152, 192))) SUBTILING = True VECT_MUL = 1 qk_scale = qk_scale * 1.44269504 # 1/log(2) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 4d60c339e..3ed57e86c 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -228,7 +228,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: + [ x.removeprefix("_triton_config_") for x in config - if x.startswith("_triton_config_") + if x.startswith("_triton_config_") and config[x] != -1 ] ) self._variable_renames: dict[str, list[str]] = {} @@ -614,6 +614,10 @@ def codegen_function_call(self) -> ast.AST: if any(self.config.range_warp_specializes): num_warps = max(4, num_warps) + print( + type(self.config["_triton_config_maxRegAutoWS"]), + self.config["_triton_config_maxRegAutoWS"], + ) args.extend( [ f"num_warps={num_warps}", @@ -622,7 +626,7 @@ def codegen_function_call(self) -> ast.AST: + [ f"{x.removeprefix('_triton_config_')}={self.config[x]}" for x in self.config - if x.startswith("_triton_config_") + if x.startswith("_triton_config_") and self.config[x] != -1 ] ) advanced_compiler_configuration = self.config.advanced_compiler_configuration