diff --git a/examples/blackwell_attention.py b/examples/blackwell_attention.py index d96697071..eeab4fe18 100644 --- a/examples/blackwell_attention.py +++ b/examples/blackwell_attention.py @@ -13,9 +13,11 @@ from __future__ import annotations import math +import os from typing import Callable import torch +import triton from triton.testing import do_bench import helion @@ -70,6 +72,55 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso ) +HAS_ACC = os.getenv("WITH_ACC", "0") == "1" +if HAS_ACC: + ACC_OPTIONS = [0, 11, 12, 13, 14, 15] + + def make_acc_option(option): + return {"advanced_compiler_configuration": option} +else: + ACC_OPTIONS = [None] + make_acc_option = lambda _: {} + + +def _supports_reg_auto_ws(): + """Check if the current Triton version supports minRegAutoWS/maxRegAutoWS""" + try: + # Try to create a Config with minRegAutoWS to test support + test_config = triton.Config({}, minRegAutoWS=24, maxRegAutoWS=152) + return True + except (TypeError, AttributeError): + # Parameter not supported in this Triton version + return False + + +HAS_REG_AUTO_WS = _supports_reg_auto_ws() +print(f"!!!!!!!!!! {HAS_REG_AUTO_WS=} !!!!!!!!!!!!!") +if HAS_REG_AUTO_WS: + REG_AUTO_WS_OPTIONS = [152, 192] + M_OPTIONS = [256] + + def make_reg_auto_ws_option(maxreg): + OUTER_LOOP = True + return dict( + _triton_range_id_data_partition_factor=0, + _triton_range_value_data_partition_factor=2, + _triton_config_maxRegAutoWS=maxreg, + range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True], + range_multi_buffers=[None, False], + ) +else: + REG_AUTO_WS_OPTIONS = [None] + M_OPTIONS = [128] + + def make_reg_auto_ws_option(maxreg): + return dict( + _triton_range_id_data_partition_factor=-1, + _triton_range_value_data_partition_factor=-1, + _triton_config_maxRegAutoWS=-1, + ) + + # %% # Attention Kernel Implementation # ------------------------------- @@ -79,20 +130,18 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso @helion.kernel( configs=[ helion.Config( - block_sizes=[256, N], - range_warp_specializes=[OUTER_LOOP or None, None if OUTER_LOOP else True], - range_multi_buffers=[None, False], + block_sizes=[M, N], pid_type="persistent_interleaved", indexing="tensor_descriptor", num_warps=4, num_stages=3, - _triton_range_id_data_partition_factor=0, - _triton_range_value_data_partition_factor=2, - _triton_config_maxRegAutoWS=maxreg, + **make_acc_option(ACC_OPTION), + **make_reg_auto_ws_option(REG_AUTO_WS_OPTION), ) + for M in M_OPTIONS for N in [64, 128] - for OUTER_LOOP in [True] - for maxreg in [152, 192] + for ACC_OPTION in ACC_OPTIONS + for REG_AUTO_WS_OPTION in REG_AUTO_WS_OPTIONS ], static_shapes=True, autotune_accuracy_check=False, @@ -136,12 +185,14 @@ def blackwell_attention_kernel( assert M % block_m == 0 assert N % block_n == 0 hl.register_tunable( - "_triton_range_id_data_partition_factor", EnumFragment(choices=(0,)) + "_triton_range_id_data_partition_factor", EnumFragment(choices=(-1, 0)) + ) + hl.register_tunable( + "_triton_range_value_data_partition_factor", EnumFragment(choices=(-1, 2)) ) hl.register_tunable( - "_triton_range_value_data_partition_factor", EnumFragment(choices=(2,)) + "_triton_config_maxRegAutoWS", EnumFragment(choices=(-1, 152, 192)) ) - hl.register_tunable("_triton_config_maxRegAutoWS", EnumFragment(choices=(152, 192))) SUBTILING = True VECT_MUL = 1 qk_scale = qk_scale * 1.44269504 # 1/log(2) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 4d60c339e..3ed57e86c 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -228,7 +228,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: + [ x.removeprefix("_triton_config_") for x in config - if x.startswith("_triton_config_") + if x.startswith("_triton_config_") and config[x] != -1 ] ) self._variable_renames: dict[str, list[str]] = {} @@ -614,6 +614,10 @@ def codegen_function_call(self) -> ast.AST: if any(self.config.range_warp_specializes): num_warps = max(4, num_warps) + print( + type(self.config["_triton_config_maxRegAutoWS"]), + self.config["_triton_config_maxRegAutoWS"], + ) args.extend( [ f"num_warps={num_warps}", @@ -622,7 +626,7 @@ def codegen_function_call(self) -> ast.AST: + [ f"{x.removeprefix('_triton_config_')}={self.config[x]}" for x in self.config - if x.startswith("_triton_config_") + if x.startswith("_triton_config_") and self.config[x] != -1 ] ) advanced_compiler_configuration = self.config.advanced_compiler_configuration diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 2982120d8..904841303 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -295,7 +295,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: self.log.debug(lambda: f"Running {config} at {datetime.datetime.now()}") t0 = time.perf_counter() # HACK: run checks multiple times to detect data races - for _ in range(5): + for _ in range(1): if self._kernel_mutates_args: self.args = self._clone_args(self._original_args) torch.accelerator.synchronize()