Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a51afd4
Major: Used a zero-nonzero specialization in which 0 will be dealt wi…
Itssshikhar Oct 30, 2025
a81fb12
Merge branch 'pytorch:main' into feature/disable_0_1_specialization
Itssshikhar Oct 31, 2025
eb13f0c
Fix the type 'str' is not assignable lint error by using a helper fun…
Itssshikhar Nov 1, 2025
7723054
Merge branch 'pytorch:main' into feature/disable_0_1_specialization
Itssshikhar Nov 5, 2025
ad805e8
Add `settings.autotune_baseline_fn` to allow passing in custom baseli…
yf225 Oct 30, 2025
20234fd
Fix the shape-bucketing error where FakeTensors coming from ShapeEnv …
Itssshikhar Nov 5, 2025
734e1a4
Merge branch 'main' into feature/disable_0_1_specialization
yf225 Dec 5, 2025
c41cb66
up
yf225 Dec 5, 2025
8c7bb16
up
yf225 Dec 5, 2025
99d2339
up
yf225 Dec 6, 2025
111dcac
up
yf225 Dec 6, 2025
78c9198
update tests
yf225 Dec 6, 2025
56e2359
up
yf225 Dec 6, 2025
a6442f5
up
yf225 Dec 6, 2025
c9cb2c8
up
yf225 Dec 7, 2025
33e558a
up
yf225 Dec 7, 2025
17462e3
Merge branch 'main' into feature/disable_0_1_specialization
yf225 Dec 7, 2025
bb7af3e
Merge branch 'main' into feature/disable_0_1_specialization
yf225 Dec 7, 2025
3a56f8f
fix expected
yf225 Dec 7, 2025
42895cc
up
yf225 Dec 7, 2025
e6734c8
Merge branch 'main' into feature/disable_0_1_specialization
yf225 Dec 8, 2025
c8be6c1
update tests
yf225 Dec 8, 2025
cd9eed8
update tests
yf225 Dec 8, 2025
3426efa
updated BC test
yf225 Dec 8, 2025
031e352
wip fix bugs
yf225 Dec 9, 2025
a5a49c7
wip test
yf225 Dec 9, 2025
587550e
try simplify
yf225 Dec 10, 2025
1072fcf
update expect
yf225 Dec 10, 2025
5f22ae4
Merge branch 'main' into feature/disable_0_1_specialization
yf225 Dec 12, 2025
11bc150
lint
yf225 Dec 12, 2025
b430dff
Merge branch 'main' into feature/disable_0_1_specialization
yf225 Dec 12, 2025
de7c4d1
fix expected
yf225 Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ def configs(self) -> list[Config]:

def format_kernel_decorator(self, config: Config, settings: Settings) -> str:
"""Return the @helion.kernel decorator snippet capturing configs and settings that influence Triton code generation."""
# Include shape_bucketing only when non-default to keep logs compact
if getattr(settings, "shape_bucketing", "min2") != "min2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why getattr?

return (
f"@helion.kernel(config={config.__repr__()}, "
f"static_shapes={settings.static_shapes}, "
f"shape_bucketing='{settings.shape_bucketing}')"
)
return f"@helion.kernel(config={config.__repr__()}, static_shapes={settings.static_shapes})"

def to_triton_code(
Expand Down Expand Up @@ -817,11 +824,15 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
(*obj.size(),),
(*obj.stride(),),
)
# Non-static path: bucket sizes for specialization. Default is 0/1/>=2 (as 2).
vals = tuple([min(s, 2) for s in obj.size()])
if getattr(fn.settings, "shape_bucketing", "min2") == "zero_nonzero":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

# Keep zero distinct; unify 1 with >=2 to reduce variant churn
vals = tuple(0 if v == 0 else 2 for v in vals)
return (
obj.dtype,
obj.device.type,
# 0, 1, or >=2 specialization
tuple([min(s, 2) for s in obj.size()]),
vals,
)


Expand Down
27 changes: 27 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,21 @@ def _get_autotune_random_seed() -> int:
return int(time.time() * 1000) % 2**32


def _get_shape_bucketing() -> Literal["min2", "zero_nonzero"]:
val = _env_get_literal(
"HELION_SHAPE_BUCKETING",
"min2",
mapping={
"min2": "min2",
"zero_nonzero": "zero_nonzero",
},
)
# Narrow to Literal explicitly
if val == "zero_nonzero":
return "zero_nonzero"
return "min2"


def _get_ref_mode() -> RefMode:
interpret = _env_get_bool("HELION_INTERPRET", False)
return RefMode.EAGER if interpret else RefMode.OFF
Expand Down Expand Up @@ -347,6 +362,12 @@ class _Settings:
_env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False
)
)
# Controls non-static shape specialization bucketing. When "min2" (default),
# we bucket dynamic sizes per-dimension into 0, 1, or >=2 (represented as 2).
# When "zero_nonzero", we keep 0 distinct and unify 1 with >=2 to reduce churn.
shape_bucketing: Literal["min2", "zero_nonzero"] = dataclasses.field(
default_factory=_get_shape_bucketing
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some though, perhaps instead of adding a new config we should make static_shapes an enum of "all", "ones", "none". Since if I set static_shapes=True this does nothing.

We will need backcompat for True/False, but that might result in a cleaner config.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I was thinking we can do something like this:

  • static_shape = "all" would be equivalent to setting static_shape=True
  • static_shape = "ones" would be the "min2" case, meaning specialize 0/1.
  • static_shape = "none" would be this "zero_nonzero" case, basically disabling 0/1 specialization.

To make backcompat for True/False, we can set them as True->"all" & False->"none" and then HELION_STATIC_SHAPES can go through "all", "ones", "none".

ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode)
autotuner_fn: AutotunerFunction = default_autotuner_fn
autotune_baseline_fn: Callable[..., object] | None = None
Expand Down Expand Up @@ -401,6 +422,12 @@ class Settings(_Settings):
),
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
"debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.",
"shape_bucketing": (
"Dynamic-shape specialization policy when static_shapes=False. "
"'min2' buckets each dimension into 0,1,>=2 (current behavior). "
"'zero_nonzero' keeps 0 distinct and unifies 1 with >=2 to reduce variants. "
"Override with HELION_SHAPE_BUCKETING=min2|zero_nonzero."
),
"ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.",
"autotuner_fn": (
"Function to create an autotuner. "
Expand Down
54 changes: 54 additions & 0 deletions test/test_shape_bucketing.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a reduction test (using the sum example kernel).

Have you tried manually running some of the examples with this flag set to shake out any other bugs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to find out what all bugs this change might introduce. Will do more tests, along with adding reduction kernel.

Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

import unittest

import torch

from helion.runtime.kernel import kernel
from helion.runtime.settings import Settings


def _dummy(x: torch.Tensor) -> torch.Tensor:
return x


class TestShapeBucketing(unittest.TestCase):
def test_min2_bucketing_default(self) -> None:
k = kernel(_dummy, settings=Settings(static_shapes=False))

t0 = torch.empty(0, 3)
t1 = torch.empty(1, 3)
t2 = torch.empty(2, 3)
t7 = torch.empty(7, 3)

key_0 = k.specialization_key([t0])
key_1 = k.specialization_key([t1])
key_2 = k.specialization_key([t2])
key_7 = k.specialization_key([t7])

# min2: 0,1,>=2 (as 2)
self.assertNotEqual(key_0, key_2)
self.assertNotEqual(key_1, key_2)
self.assertEqual(key_2, key_7)

def test_zero_nonzero_bucketing(self) -> None:
k = kernel(
_dummy,
settings=Settings(static_shapes=False, shape_bucketing="zero_nonzero"),
)

t0 = torch.empty(0, 3)
t1 = torch.empty(1, 3)
t2 = torch.empty(2, 3)

key_0 = k.specialization_key([t0])
key_1 = k.specialization_key([t1])
key_2 = k.specialization_key([t2])

# zero_nonzero: keep 0 distinct; unify 1 with >=2
self.assertNotEqual(key_0, key_2)
self.assertEqual(key_1, key_2)


if __name__ == "__main__":
unittest.main()
Loading