Skip to content

Commit 252f6e0

Browse files
committed
wip
1 parent 830fbfb commit 252f6e0

File tree

5 files changed

+301
-63
lines changed

5 files changed

+301
-63
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 118 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ def _get_padded_iota_original_length(
6060
return None
6161

6262

63+
def _has_padded_iota_index(state: CodegenState | None, num_indices: int) -> bool:
64+
if state is None:
65+
return False
66+
for idx in range(num_indices):
67+
if _get_padded_iota_original_length(state, idx) is not None:
68+
return True
69+
return False
70+
71+
72+
def _has_multidim_tensor_index(index: list[object]) -> bool:
73+
for k in index:
74+
if isinstance(k, torch.Tensor) and k.ndim > 1:
75+
return True
76+
return False
77+
78+
6379
def _get_tile_with_offset_info(
6480
k: object, state: CodegenState, k_index: int
6581
) -> tuple[int, int | torch.SymInt] | None:
@@ -102,6 +118,7 @@ def _get_tile_with_offset_info(
102118
return (meta["block_id"], meta["offset"])
103119

104120
return None
121+
return None
105122

106123

107124
class IndexingStrategy:
@@ -554,6 +571,17 @@ def codegen_store(
554571
)
555572

556573

574+
def _try_python_index_shape(
575+
tensor: torch.Tensor, index: list[object]
576+
) -> list[int | torch.SymInt] | None:
577+
try:
578+
tuple_index = tuple(index)
579+
result = tensor[tuple_index] # pyright: ignore[reportGeneralTypeIssues]
580+
except Exception:
581+
return None
582+
return list(result.size())
583+
584+
557585
class SubscriptIndexing(NamedTuple):
558586
index_expr: ast.AST
559587
mask_expr: ast.AST
@@ -567,6 +595,17 @@ def has_mask(self) -> bool:
567595
def compute_shape(
568596
tensor: torch.Tensor, index: list[object], state: CodegenState | None = None
569597
) -> list[int | torch.SymInt]:
598+
advanced_mode = (
599+
isinstance(tensor, torch.Tensor)
600+
and len(index) == tensor.ndim
601+
and index
602+
and all(isinstance(k, torch.Tensor) for k in index)
603+
and _has_multidim_tensor_index(index)
604+
and not _has_padded_iota_index(state, len(index))
605+
)
606+
if advanced_mode:
607+
if (shape := _try_python_index_shape(tensor, index)) is not None:
608+
return shape
570609
assert isinstance(tensor, torch.Tensor)
571610
assert isinstance(index, (list, tuple)), index
572611
input_size = collections.deque(tensor.size())
@@ -605,18 +644,28 @@ def compute_shape(
605644
k_index += 1
606645
elif isinstance(k, slice):
607646
size = input_size.popleft()
608-
# Handle slices with steps
609-
slice_size = compute_slice_size(k, size)
610-
611-
if slice_size != 1:
612-
rdim = env.allocate_reduction_dimension(slice_size)
613-
output_size.append(rdim.var)
647+
is_full_slice = (
648+
(k.start is None or k.start == 0)
649+
and k.stop is None
650+
and (k.step is None or k.step == 1)
651+
)
652+
653+
if is_full_slice:
654+
if env.known_equal(size, 1):
655+
output_size.append(1)
656+
else:
657+
output_size.append(size)
614658
else:
615-
output_size.append(1)
659+
# Handle slices with steps or bounded ranges
660+
slice_size = compute_slice_size(k, size)
661+
662+
if slice_size != 1:
663+
rdim = env.allocate_reduction_dimension(slice_size)
664+
output_size.append(rdim.var)
665+
else:
666+
output_size.append(1)
616667
k_index += 1
617-
elif isinstance(k, torch.Tensor) and (
618-
k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1)
619-
):
668+
elif isinstance(k, torch.Tensor):
620669
input_size.popleft()
621670
output_size.extend(k.size())
622671
k_index += 1
@@ -664,6 +713,14 @@ def create(
664713
output_size = SubscriptIndexing.compute_shape(fake_value, index, state)
665714
env = CompileEnvironment.current()
666715
dtype = env.triton_index_type()
716+
advanced_mode = (
717+
isinstance(fake_value, torch.Tensor)
718+
and len(index) == fake_value.ndim
719+
and bool(index)
720+
and all(isinstance(k, torch.Tensor) for k in index)
721+
and _has_multidim_tensor_index(index)
722+
and not _has_padded_iota_index(state, len(index))
723+
)
667724
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
668725
raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype)
669726

@@ -737,8 +794,26 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
737794
else:
738795
index_values.append(f"{start}{expand}")
739796
else:
740-
# Full slice or slice without step
741-
if not _is_size_one(size):
797+
is_full_slice = (
798+
(k.start is None or k.start == 0)
799+
and k.stop is None
800+
and (k.step is None or k.step == 1)
801+
)
802+
if is_full_slice and not _is_size_one(size):
803+
block_idx = env.get_block_id(size)
804+
if block_idx is not None:
805+
index_var = state.codegen.index_var(block_idx)
806+
index_values.append(f"({index_var}){expand}")
807+
if mask := state.codegen.mask_var(block_idx):
808+
mask_values.setdefault(f"({mask}){expand}")
809+
else:
810+
rdim = env.allocate_reduction_dimension(size)
811+
block_idx = rdim.block_id
812+
index_var = state.codegen.index_var(block_idx)
813+
index_values.append(f"({index_var}){expand}")
814+
if mask := state.codegen.mask_var(block_idx):
815+
mask_values.setdefault(f"({mask}){expand}")
816+
elif not _is_size_one(size):
742817
rdim = env.allocate_reduction_dimension(size)
743818
block_idx = rdim.block_id
744819
index_var = state.codegen.index_var(block_idx)
@@ -749,22 +824,31 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
749824
index_values.append(f"tl.zeros([1], {dtype}){expand}")
750825
output_idx += 1
751826
k_index += 1
752-
elif isinstance(k, torch.Tensor) and k.ndim == 1:
753-
expand = tile_strategy.expand_str(output_size, output_idx)
827+
elif isinstance(k, torch.Tensor) and not (
828+
len(index) == 1 and fake_value.ndim == 1
829+
):
754830
ast_index = state.ast_args[1]
755831
assert isinstance(ast_index, (list, tuple))
756832
assert len(ast_index) == len(index)
757833
index_var = state.codegen.lift(ast_index[n], prefix="index").id
758-
index_values.append(f"({index_var}){expand}")
759-
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
760-
if mask := state.codegen.mask_var(block_idx):
761-
mask_values.setdefault(f"({mask}){expand}")
834+
if advanced_mode:
835+
index_values.append(index_var)
836+
else:
837+
expand = tile_strategy.expand_str(output_size, output_idx)
838+
index_values.append(f"({index_var}){expand}")
839+
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
840+
if mask := state.codegen.mask_var(block_idx):
841+
mask_values.setdefault(f"({mask}){expand}")
762842
# Check if this index comes from a padded hl.arange and generate mask
763843
if (
764844
original_length := _get_padded_iota_original_length(state, n)
765845
) is not None:
766-
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
767-
output_idx += 1
846+
if advanced_mode:
847+
mask_values.setdefault(f"({index_var} < {original_length})")
848+
else:
849+
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
850+
if not advanced_mode:
851+
output_idx += 1
768852
k_index += 1
769853
elif (
770854
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1
@@ -786,6 +870,8 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
786870
k_index += 1
787871
else:
788872
raise exc.InvalidIndexingType(type(k))
873+
if advanced_mode:
874+
output_idx = len(output_size)
789875
assert len(output_size) == output_idx
790876
assert len(index_values) == fake_value.ndim
791877
index_expr = []
@@ -885,7 +971,11 @@ def need_reshape(self, node: ast.AST) -> bool:
885971
return True
886972
env = CompileEnvironment.current()
887973
for a, b in zip(self.reshaped_size, self.block_shape, strict=True):
888-
if not env.known_equal(a, b):
974+
block_id_a = env.resolve_block_id(a)
975+
block_id_b = env.resolve_block_id(b)
976+
if block_id_a != block_id_b:
977+
return True
978+
if block_id_a is None and not env.known_equal(a, b):
889979
return True
890980
return False
891981

@@ -1035,7 +1125,13 @@ def create(
10351125
# Full slice or slice without step
10361126
if size != 1:
10371127
rdim = env.allocate_reduction_dimension(size)
1038-
res.offsets.append(state.codegen.offset_var(rdim.block_id))
1128+
active_loops = state.codegen.active_device_loops.get(
1129+
rdim.block_id
1130+
)
1131+
if active_loops:
1132+
res.offsets.append(state.codegen.offset_var(rdim.block_id))
1133+
else:
1134+
res.offsets.append("0")
10391135
res.block_shape.append(rdim.var)
10401136
else:
10411137
res.offsets.append("0")

helion/_compiler/tile_dispatch.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,25 @@ def _add_loop_strategy(
7878
loop_order=loop_order,
7979
)
8080
else:
81+
block_sizes = [bs.from_config_assert(config) for bs in block_size_infos]
82+
if len(block_ids) == 1 and not block_size_infos[0].reduction:
83+
max_reduction_block = 0
84+
for info in env.block_sizes:
85+
if not info.reduction:
86+
continue
87+
configured = info.from_config(config)
88+
if isinstance(configured, int):
89+
max_reduction_block = max(max_reduction_block, configured)
90+
if (
91+
isinstance(block_sizes[0], int)
92+
and block_sizes[0] > 1
93+
and max_reduction_block >= 1024
94+
):
95+
block_sizes[0] = 1
8196
strategy = NDTileStrategy(
8297
fn,
8398
block_ids,
84-
block_size=[bs.from_config_assert(config) for bs in block_size_infos],
99+
block_size=block_sizes,
85100
loop_order=loop_order,
86101
l2_grouping=l2_grouping,
87102
)
@@ -118,8 +133,22 @@ def codegen_device_loop(
118133

119134
def _compact_shape(self, shapes: ShapeLike) -> list[CompactedShape]:
120135
compacted_shapes = []
136+
env = CompileEnvironment.current()
137+
fn = DeviceFunction.current()
121138
for idx, shape in enumerate(shapes):
122-
block_idx = CompileEnvironment.current().get_block_id(shape)
139+
block_idx = env.get_block_id(shape)
140+
if block_idx is None and isinstance(shape, int):
141+
for info in env.block_sizes:
142+
if not info.reduction:
143+
continue
144+
configured = info.from_config(fn.config)
145+
if not isinstance(configured, int) or configured != shape:
146+
continue
147+
static_size = info.size if isinstance(info.size, int) else None
148+
if static_size is not None and static_size == configured:
149+
continue
150+
block_idx = info.block_id
151+
break
123152
if block_idx is None:
124153
# Check if this is a symbolic expression with block sizes
125154
shape_str = self._get_shape_string(shape)

helion/_compiler/type_propagation.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,19 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
453453
keys = key.unpack()
454454
else:
455455
keys = [key]
456+
advanced_mode = (
457+
len(keys) == self.fake_value.ndim
458+
and bool(keys)
459+
and all(isinstance(k, TensorType) for k in keys)
460+
)
461+
if advanced_mode:
462+
try:
463+
tuple_index = tuple(k.proxy() for k in keys)
464+
result = self.fake_value[tuple_index] # pyright: ignore[reportArgumentType]
465+
except Exception:
466+
pass
467+
else:
468+
return list(result.size())
456469
inputs_consumed = 0
457470
output_sizes = []
458471
env = CompileEnvironment.current()
@@ -501,9 +514,9 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
501514
raise exc.DataDependentOutputShapeNotSupported(
502515
op_desc="Boolean mask indexing (tensor[boolean_mask])"
503516
)
504-
elif isinstance(k, TensorType) and k.fake_value.ndim == 1:
517+
elif isinstance(k, TensorType):
505518
inputs_consumed += 1
506-
output_sizes.append(k.fake_value.size(0))
519+
output_sizes.extend(k.fake_value.size())
507520
elif k.contains_type(TileIndexType):
508521
raise exc.OverpackedTile(k)
509522
else:
@@ -1495,7 +1508,22 @@ def _eval_unary(op: ast.unaryop, value: object) -> object:
14951508

14961509
def _eval_binary(op: ast.operator, left: object, right: object) -> object:
14971510
if isinstance(op, ast.Add):
1498-
return left + right # pyright: ignore[reportOperatorIssue]
1511+
try:
1512+
return left + right # pyright: ignore[reportOperatorIssue]
1513+
except Exception as exc: # pragma: no cover - debug
1514+
import sys
1515+
1516+
def _fmt(val: object) -> str:
1517+
if hasattr(val, "shape"):
1518+
shape = getattr(val, "shape")
1519+
return f"{shape}"
1520+
return repr(val)
1521+
1522+
print(
1523+
f"_eval_binary Add failure: left={_fmt(left)}, right={_fmt(right)}",
1524+
file=sys.stderr,
1525+
)
1526+
raise
14991527
if isinstance(op, ast.Sub):
15001528
return left - right # pyright: ignore[reportOperatorIssue]
15011529
if isinstance(op, ast.Mult):

helion/language/memory_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from torch.fx import has_side_effect
88

99
from .. import exc
10+
from .._compiler.ast_extension import expr_from_string
11+
from .._compiler.host_function import HostFunction
1012
from .._compiler.indexing_strategy import SubscriptIndexing
1113
from . import _decorators
1214
from .stack_tensor import StackTensor
@@ -24,6 +26,24 @@
2426
}
2527

2628

29+
def _codegen_host_tensor_subscript(state: "CodegenState") -> ast.AST:
30+
indices: list[str] = []
31+
for val in state.proxy_arg(1): # type: ignore[reportGeneralTypeIssues]
32+
if val is None:
33+
indices.append("None")
34+
elif isinstance(val, slice) and val.start is None and val.stop is None and val.step is None:
35+
indices.append(":")
36+
else:
37+
raise exc.InvalidIndexingType(
38+
f"Host tensor indexing only supports None/':' entries, got {val!r}"
39+
)
40+
index_expr = ", ".join(indices)
41+
return expr_from_string(
42+
f"{{base}}[{index_expr}]" if index_expr else "{base}",
43+
base=state.ast_arg(0),
44+
)
45+
46+
2747
@has_side_effect
2848
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
2949
def store(
@@ -270,6 +290,10 @@ def _(state: CodegenState) -> ast.AST:
270290
eviction_policy = ast.Constant(value=eviction_policy)
271291

272292
if isinstance(tensor, torch.Tensor):
293+
host_fn = HostFunction.current()
294+
if tensor not in host_fn.tensor_to_origin:
295+
return _codegen_host_tensor_subscript(state)
296+
273297
# Use the shared memory op index for indexing strategy
274298
indexing_idx = device_fn.device_memory_op_index
275299
device_fn.device_memory_op_index += 1

0 commit comments

Comments
 (0)