@@ -206,17 +206,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
206206 if rdim .reduction and rdim .size == size :
207207 return rdim
208208
209+ # Check if size matches any tile dimension for symbolic equality.
210+ # When building expressions that mix sizes derived from tiles (e.g. via
211+ # slicing) with sizes coming directly from tile block vars, we want them
212+ # to share the same SymInt variable whenever they are equal by
213+ # construction. This preserves equality in the shape environment and
214+ # avoids spurious "size mismatch" issues during fake-tensor broadcasting
215+ # and arithmetic in type propagation.
216+ if isinstance (size , torch .SymInt ):
217+ block_idx = self .get_block_id (size )
218+ if block_idx is not None and not self .block_sizes [block_idx ].reduction :
219+ return self ._clone_block_size_as_reduction (block_idx , size )
220+
221+ sym = size ._sympy_ ()
222+ for block_idx , block_info in enumerate (self .block_sizes ):
223+ if not block_info .reduction and sym == block_info .symbol ():
224+ return self ._clone_block_size_as_reduction (block_idx , size )
225+
209226 # Allocate a new reduction dimension
227+ return self ._allocate_new_reduction (size )
228+
229+ def _clone_block_size_as_reduction (
230+ self , block_idx : int , size : torch .SymInt | int
231+ ) -> BlockSizeInfo :
232+ rdim = self ._allocate_new_reduction (size )
233+ rdim .var = self .block_sizes [block_idx ].var
234+ return rdim
235+
236+ def _allocate_new_reduction (self , size : torch .SymInt | int ) -> BlockSizeInfo :
210237 rdim_idx = self .allocate_block_size (
211238 size ,
212239 reduction = True ,
213240 source = ReductionLoopBlockSizeSource (
214- sum ([ int ( bs . reduction ) for bs in self .block_sizes ] )
241+ self ._next_reduction_loop_index ( )
215242 ),
216243 hint = next_power_of_2 (self .size_hint (size )),
217244 )
218245 return self .block_sizes [rdim_idx ]
219246
247+ def _next_reduction_loop_index (self ) -> int :
248+ return sum (int (info .reduction ) for info in self .block_sizes )
249+
220250 def create_block_var (self , debug_name : str , hint : int = 64 ) -> torch .SymInt :
221251 source = _current_symbol_source ()
222252 with self .shape_env .ignore_fresh_unbacked_symbols ():
@@ -269,6 +299,90 @@ def cached_create_unbacked_symint(
269299 self ._symint_cache [key ] = result
270300 return result
271301
302+
303+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
304+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
305+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
306+
307+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
308+ """Return the originating ``tile.index`` block id if present."""
309+ return getattr (tensor , "_tile_index_block_id" , None )
310+
311+ def get_indexer_output_dims (
312+ self ,
313+ indexer_tensor : torch .Tensor ,
314+ base_dim_size : int | torch .SymInt | None ,
315+ ) -> list [int | torch .SymInt ]:
316+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
317+
318+ dims = list (indexer_tensor .size ())
319+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
320+
321+ # Multi-dimensional indexer - return full shape
322+ if len (non_broadcast_dims ) > 1 :
323+ return dims
324+
325+ # Try to find block_id from various sources
326+ block_id = (
327+ self .get_tile_index_tensor_block_id (indexer_tensor )
328+ or (self .get_block_id (base_dim_size ) if base_dim_size is not None else None )
329+ or (self .get_block_id (non_broadcast_dims [0 ]) if non_broadcast_dims else None )
330+ )
331+
332+ if block_id is not None :
333+ return [self .block_sizes [block_id ].var ]
334+ return [non_broadcast_dims [0 ]] if non_broadcast_dims else [1 ]
335+
336+ def tensor_indexer_broadcast_shape (
337+ self , tensors : typing .Sequence [torch .Tensor ]
338+ ) -> list [int | torch .SymInt ] | None :
339+ """Compute a shared broadcast shape for tensor indexers when needed."""
340+
341+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
342+ if not tensor_list :
343+ return None
344+
345+ if all (self .get_tile_index_tensor_block_id (t ) is not None for t in tensor_list ):
346+ return None
347+
348+ shapes = [list (t .size ()) for t in tensor_list ]
349+ return compute_broadcast_shape_for_tensor_indexers (shapes , self )
350+
351+ def resolve_tile_index_shape (
352+ self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
353+ ) -> tuple [list [int | torch .SymInt ], int | None ]:
354+ """Resolve the symbolic shape for tensors derived from ``tile.index``.
355+
356+ Returns a copy of ``output_shape`` where the single non-broadcast
357+ dimension is replaced with the canonical block-symbol and the associated
358+ block_id to register on the new tensor. If the tensor is not a tile
359+ indexer or it introduces more than one non-broadcast dimension, the
360+ original shape and ``None`` are returned.
361+ """
362+
363+ block_id = self .get_tile_index_tensor_block_id (input_tensor )
364+ if block_id is None :
365+ return list (output_shape ), None
366+
367+ resolved = list (output_shape )
368+ non_broadcast = [i for i , s in enumerate (resolved ) if self .size_hint (s ) != 1 ]
369+ if len (non_broadcast ) <= 1 :
370+ if non_broadcast :
371+ resolved [non_broadcast [0 ]] = self .block_sizes [block_id ].var
372+ return resolved , block_id
373+ return resolved , None
374+
375+ def new_index_result (
376+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
377+ ) -> torch .Tensor :
378+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
379+
380+ resolved_shape , block_id = self .resolve_tile_index_shape (tensor , output_shape )
381+ result = tensor .new_empty (resolved_shape )
382+ if block_id is not None :
383+ self .register_tile_index_tensor_block_id (result , block_id )
384+ return result
385+
272386 def to_fake (self , obj : object , origin : Origin ) -> object :
273387 if obj is None :
274388 return None
@@ -351,6 +465,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
351465 self .fake_mode , tensor , shape_env = self .shape_env , source = source
352466 )
353467 self .input_sources [result ] = source
468+ if hasattr (tensor , "_tile_index_block_id" ):
469+ self .register_tile_index_tensor_block_id (
470+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
471+ )
354472 if isinstance (source , LocalSource ):
355473 for i , s in enumerate (result .size ()):
356474 if isinstance (s , torch .SymInt ) and isinstance (
@@ -643,6 +761,34 @@ def _has_unbacked(expr: sympy.Expr) -> bool:
643761 return any (n .name .startswith ("u" ) for n in expr .free_symbols ) # pyright: ignore[reportAttributeAccessIssue]
644762
645763
764+ def compute_broadcast_shape_for_tensor_indexers (
765+ shapes : list [list [int | torch .SymInt ]],
766+ env : "CompileEnvironment"
767+ ) -> list [int | torch .SymInt ]:
768+ """Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
769+
770+ For multiple 1D tensors, this should return a shape that represents their Cartesian product.
771+ For example, two tensors of shape [8] and [8] should broadcast to shape [8, 8].
772+ """
773+ if not shapes :
774+ return []
775+
776+ # Special case: multiple 1D tensors form a Cartesian product
777+ all_1d = all (len (shape ) == 1 for shape in shapes )
778+ if all_1d and len (shapes ) > 1 :
779+ # Return the Cartesian product shape
780+ return [shape [0 ] for shape in shapes ]
781+
782+ # General broadcasting case
783+ max_ndim = max (len (s ) for s in shapes )
784+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
785+
786+ return [
787+ next ((d for d in dims if env .size_hint (d ) != 1 ), 1 )
788+ for dims in zip (* padded , strict = True )
789+ ]
790+
791+
646792def format_shape (shape : tuple [object , ...]) -> str :
647793 def _format_dim (dim : object ) -> str :
648794 if isinstance (dim , torch .SymInt ):
0 commit comments