Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 142 additions & 19 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,30 +190,39 @@ def allocate_block_size(
return idx

def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInfo:
# Check if this size is already a registered block size
if isinstance(size, torch.SymInt):
from .host_function import HostFunction
# Quick return for existing reduction with same size
for info in self.block_sizes:
if info.reduction and info.size == size:
return info

expr = size._sympy_()
origin_info = HostFunction.current().expr_to_origin.get(expr)
if origin_info and isinstance(origin_info.origin, BlockSizeOrigin):
block_idx = origin_info.origin.block_id
# Return the existing block size if it's a reduction dimension
if self.block_sizes[block_idx].reduction:
return self.block_sizes[block_idx]

# Check for existing reduction dimensions with the same size
for rdim in self.block_sizes:
if rdim.reduction and rdim.size == size:
return rdim
# For SymInt, check if we can reuse existing block for symbolic equality
if isinstance(size, torch.SymInt):
sym = size._sympy_()
block_id = self.get_block_id(size)

# Return existing reduction by block_id
if block_id is not None and self.block_sizes[block_id].reduction:
return self.block_sizes[block_id]

# Clone non-reduction block as reduction for symbolic equality
for idx, info in enumerate(self.block_sizes):
if not info.reduction and (idx == block_id or sym == info.symbol()):
reduction_loop_index = sum(int(b.reduction) for b in self.block_sizes)
rdim_idx = self.allocate_block_size(
size,
reduction=True,
source=ReductionLoopBlockSizeSource(reduction_loop_index),
hint=next_power_of_2(self.size_hint(size)),
)
self.block_sizes[rdim_idx].var = info.var
return self.block_sizes[rdim_idx]

# Allocate a new reduction dimension
# Allocate new reduction dimension
reduction_loop_index = sum(int(info.reduction) for info in self.block_sizes)
rdim_idx = self.allocate_block_size(
size,
reduction=True,
source=ReductionLoopBlockSizeSource(
sum([int(bs.reduction) for bs in self.block_sizes])
),
source=ReductionLoopBlockSizeSource(reduction_loop_index),
hint=next_power_of_2(self.size_hint(size)),
)
return self.block_sizes[rdim_idx]
Expand Down Expand Up @@ -271,6 +280,108 @@ def cached_create_unbacked_symint(
self._symint_cache[key] = result
return result


def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]

def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
"""Return the originating ``tile.index`` block id if present."""
return getattr(tensor, "_tile_index_block_id", None)

def get_indexer_output_dims(
self,
indexer_tensor: torch.Tensor,
base_dim_size: int | torch.SymInt | None,
) -> list[int | torch.SymInt]:
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
dims = list(indexer_tensor.size())
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]

# Multi-dimensional indexer - return full shape
if len(non_broadcast_dims) > 1:
return dims

# Try to find block_id from various sources
block_id = (
self.get_tile_index_tensor_block_id(indexer_tensor)
or (self.get_block_id(base_dim_size) if base_dim_size else None)
or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None)
)

return [self.block_sizes[block_id].var] if block_id else (non_broadcast_dims or [1])

def tensor_indexer_broadcast_shape(
self, tensors: typing.Sequence[torch.Tensor]
) -> list[int | torch.SymInt] | None:
"""Compute a shared broadcast shape for tensor indexers when needed."""
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
if not tensor_list or all(self.get_tile_index_tensor_block_id(t) for t in tensor_list):
return None

shapes = [list(t.size()) for t in tensor_list]

# Inline compute_broadcast_shape_for_tensor_indexers logic
if not shapes:
return []

# Special case: multiple 1D tensors form a Cartesian product
if all(len(s) == 1 for s in shapes) and len(shapes) > 1:
return [s[0] for s in shapes]

# General broadcasting case
max_ndim = max(len(s) for s in shapes)
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
return [
next((d for d in dims if self.size_hint(d) != 1), 1)
for dims in zip(*padded, strict=True)
]

def tensor_indexer_broadcast(
self, tensor_list: "Sequence[torch.Tensor] | None"
) -> TensorIndexerBroadcast:
"""Initialize broadcast tracking for tensor indexers."""
if not tensor_list:
return TensorIndexerBroadcast(shape=None)
return TensorIndexerBroadcast(
shape=self.tensor_indexer_broadcast_shape(tensor_list)
)

def tensor_indexer_dims(
self,
indexer_tensor: torch.Tensor,
base_dim_size: int | torch.SymInt,
broadcast: TensorIndexerBroadcast,
) -> list[int | torch.SymInt]:
"""Return dims contributed by a tensor indexer, honoring shared broadcast."""
if broadcast.shape is None:
return self.get_indexer_output_dims(indexer_tensor, base_dim_size)
if broadcast.used:
return []
broadcast.used = True
return list(broadcast.shape)

def new_index_result(
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
) -> torch.Tensor:
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
# Inline resolve_tile_index_shape logic
block_id = self.get_tile_index_tensor_block_id(tensor)
if not block_id:
resolved_shape = list(output_shape)
else:
resolved_shape = list(output_shape)
non_broadcast = [i for i, s in enumerate(resolved_shape) if self.size_hint(s) != 1]
if len(non_broadcast) == 1:
resolved_shape[non_broadcast[0]] = self.block_sizes[block_id].var
elif len(non_broadcast) > 1:
block_id = None

result = tensor.new_empty(resolved_shape)
if block_id is not None:
self.register_tile_index_tensor_block_id(result, block_id)
return result

def to_fake(self, obj: object, origin: Origin) -> object:
if obj is None:
return None
Expand Down Expand Up @@ -354,6 +465,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
self.fake_mode, tensor, shape_env=self.shape_env, source=source
)
self.input_sources[result] = source
if hasattr(tensor, "_tile_index_block_id"):
self.register_tile_index_tensor_block_id(
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
)
if isinstance(source, LocalSource):
for i, s in enumerate(result.size()):
if isinstance(s, torch.SymInt) and isinstance(
Expand Down Expand Up @@ -493,6 +608,14 @@ class AutoSize:
"""A marker used to delay setting the size of a block until it is known."""


@dataclasses.dataclass
class TensorIndexerBroadcast:
"""Tracks shared broadcast state for tensor indexers."""

shape: list[int | torch.SymInt] | None
used: bool = False


@dataclasses.dataclass
class BlockSizeInfo:
"""
Expand Down
Loading
Loading