Skip to content

Commit ef7226e

Browse files
committed
wip
1 parent 8326eae commit ef7226e

File tree

8 files changed

+500
-82
lines changed

8 files changed

+500
-82
lines changed

helion/_compiler/ast_extension.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import TYPE_CHECKING
1212
from typing import TypeVar
1313

14+
import torch
15+
1416
from .. import exc
1517
from .output_lines import OutputLines
1618
from .source_location import SourceLocation
@@ -87,10 +89,29 @@ def __repr__(self) -> str:
8789

8890
def update_type_info(self, type_info: TypeInfo) -> TypeInfo:
8991
if self._type_info is not None and type_info != self._type_info:
92+
prev_rank = self._tensor_rank(self._type_info)
93+
new_rank = self._tensor_rank(type_info)
94+
if (
95+
prev_rank is not None
96+
and new_rank is not None
97+
and prev_rank != new_rank
98+
):
99+
self._type_info = type_info
100+
return self._type_info
90101
type_info = self._type_info.merge(type_info)
91102
self._type_info = type_info
92103
return self._type_info
93104

105+
@staticmethod
106+
def _tensor_rank(type_info: "TypeInfo") -> int | None:
107+
for attr in ["fake_value", "tensor"]:
108+
obj = getattr(type_info, attr, None)
109+
if attr == "tensor" and obj is not None:
110+
obj = getattr(obj, "fake_value", None)
111+
if isinstance(obj, torch.Tensor):
112+
return obj.dim()
113+
return None
114+
94115
def debug_annotations(self) -> list[str]:
95116
result = []
96117
if self._type_info:

helion/_compiler/compile_environment.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
206206
if rdim.reduction and rdim.size == size:
207207
return rdim
208208

209+
# Check if size matches any tile dimension for symbolic equality.
210+
# When building expressions that mix sizes derived from tiles (e.g. via
211+
# slicing) with sizes coming directly from tile block vars, we want them
212+
# to share the same SymInt variable whenever they are equal by
213+
# construction. This preserves equality in the shape environment and
214+
# avoids spurious "size mismatch" issues during fake-tensor broadcasting
215+
# and arithmetic in type propagation.
216+
if isinstance(size, torch.SymInt):
217+
block_idx = self.get_block_id(size)
218+
if block_idx is not None and not self.block_sizes[block_idx].reduction:
219+
return self._clone_block_size_as_reduction(block_idx, size)
220+
221+
sym = size._sympy_()
222+
for block_idx, block_info in enumerate(self.block_sizes):
223+
if not block_info.reduction and sym == block_info.symbol():
224+
return self._clone_block_size_as_reduction(block_idx, size)
225+
209226
# Allocate a new reduction dimension
227+
return self._allocate_new_reduction(size)
228+
229+
def _clone_block_size_as_reduction(
230+
self, block_idx: int, size: torch.SymInt | int
231+
) -> BlockSizeInfo:
232+
rdim = self._allocate_new_reduction(size)
233+
rdim.var = self.block_sizes[block_idx].var
234+
return rdim
235+
236+
def _allocate_new_reduction(self, size: torch.SymInt | int) -> BlockSizeInfo:
210237
rdim_idx = self.allocate_block_size(
211238
size,
212239
reduction=True,
213240
source=ReductionLoopBlockSizeSource(
214-
sum([int(bs.reduction) for bs in self.block_sizes])
241+
self._next_reduction_loop_index()
215242
),
216243
hint=next_power_of_2(self.size_hint(size)),
217244
)
218245
return self.block_sizes[rdim_idx]
219246

247+
def _next_reduction_loop_index(self) -> int:
248+
return sum(int(info.reduction) for info in self.block_sizes)
249+
220250
def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
221251
source = _current_symbol_source()
222252
with self.shape_env.ignore_fresh_unbacked_symbols():
@@ -269,6 +299,90 @@ def cached_create_unbacked_symint(
269299
self._symint_cache[key] = result
270300
return result
271301

302+
303+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
304+
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
305+
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]
306+
307+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
308+
"""Return the originating ``tile.index`` block id if present."""
309+
return getattr(tensor, "_tile_index_block_id", None)
310+
311+
def get_indexer_output_dims(
312+
self,
313+
indexer_tensor: torch.Tensor,
314+
base_dim_size: int | torch.SymInt | None,
315+
) -> list[int | torch.SymInt]:
316+
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
317+
318+
dims = list(indexer_tensor.size())
319+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
320+
321+
# Multi-dimensional indexer - return full shape
322+
if len(non_broadcast_dims) > 1:
323+
return dims
324+
325+
# Try to find block_id from various sources
326+
block_id = (
327+
self.get_tile_index_tensor_block_id(indexer_tensor)
328+
or (self.get_block_id(base_dim_size) if base_dim_size is not None else None)
329+
or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None)
330+
)
331+
332+
if block_id is not None:
333+
return [self.block_sizes[block_id].var]
334+
return [non_broadcast_dims[0]] if non_broadcast_dims else [1]
335+
336+
def tensor_indexer_broadcast_shape(
337+
self, tensors: typing.Sequence[torch.Tensor]
338+
) -> list[int | torch.SymInt] | None:
339+
"""Compute a shared broadcast shape for tensor indexers when needed."""
340+
341+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
342+
if not tensor_list:
343+
return None
344+
345+
if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list):
346+
return None
347+
348+
shapes = [list(t.size()) for t in tensor_list]
349+
return compute_broadcast_shape_for_tensor_indexers(shapes, self)
350+
351+
def resolve_tile_index_shape(
352+
self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
353+
) -> tuple[list[int | torch.SymInt], int | None]:
354+
"""Resolve the symbolic shape for tensors derived from ``tile.index``.
355+
356+
Returns a copy of ``output_shape`` where the single non-broadcast
357+
dimension is replaced with the canonical block-symbol and the associated
358+
block_id to register on the new tensor. If the tensor is not a tile
359+
indexer or it introduces more than one non-broadcast dimension, the
360+
original shape and ``None`` are returned.
361+
"""
362+
363+
block_id = self.get_tile_index_tensor_block_id(input_tensor)
364+
if block_id is None:
365+
return list(output_shape), None
366+
367+
resolved = list(output_shape)
368+
non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1]
369+
if len(non_broadcast) <= 1:
370+
if non_broadcast:
371+
resolved[non_broadcast[0]] = self.block_sizes[block_id].var
372+
return resolved, block_id
373+
return resolved, None
374+
375+
def new_index_result(
376+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
377+
) -> torch.Tensor:
378+
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
379+
380+
resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape)
381+
result = tensor.new_empty(resolved_shape)
382+
if block_id is not None:
383+
self.register_tile_index_tensor_block_id(result, block_id)
384+
return result
385+
272386
def to_fake(self, obj: object, origin: Origin) -> object:
273387
if obj is None:
274388
return None
@@ -351,6 +465,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
351465
self.fake_mode, tensor, shape_env=self.shape_env, source=source
352466
)
353467
self.input_sources[result] = source
468+
if hasattr(tensor, "_tile_index_block_id"):
469+
self.register_tile_index_tensor_block_id(
470+
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
471+
)
354472
if isinstance(source, LocalSource):
355473
for i, s in enumerate(result.size()):
356474
if isinstance(s, torch.SymInt) and isinstance(
@@ -643,6 +761,34 @@ def _has_unbacked(expr: sympy.Expr) -> bool:
643761
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]
644762

645763

764+
def compute_broadcast_shape_for_tensor_indexers(
765+
shapes: list[list[int | torch.SymInt]],
766+
env: "CompileEnvironment"
767+
) -> list[int | torch.SymInt]:
768+
"""Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
769+
770+
For multiple 1D tensors, this should return a shape that represents their Cartesian product.
771+
For example, two tensors of shape [8] and [8] should broadcast to shape [8, 8].
772+
"""
773+
if not shapes:
774+
return []
775+
776+
# Special case: multiple 1D tensors form a Cartesian product
777+
all_1d = all(len(shape) == 1 for shape in shapes)
778+
if all_1d and len(shapes) > 1:
779+
# Return the Cartesian product shape
780+
return [shape[0] for shape in shapes]
781+
782+
# General broadcasting case
783+
max_ndim = max(len(s) for s in shapes)
784+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
785+
786+
return [
787+
next((d for d in dims if env.size_hint(d) != 1), 1)
788+
for dims in zip(*padded, strict=True)
789+
]
790+
791+
646792
def format_shape(shape: tuple[object, ...]) -> str:
647793
def _format_dim(dim: object) -> str:
648794
if isinstance(dim, torch.SymInt):

0 commit comments

Comments
 (0)