@@ -60,6 +60,22 @@ def _get_padded_iota_original_length(
6060 return None
6161
6262
63+ def _has_padded_iota_index (state : CodegenState | None , num_indices : int ) -> bool :
64+ if state is None :
65+ return False
66+ for idx in range (num_indices ):
67+ if _get_padded_iota_original_length (state , idx ) is not None :
68+ return True
69+ return False
70+
71+
72+ def _has_multidim_tensor_index (index : list [object ]) -> bool :
73+ for k in index :
74+ if isinstance (k , torch .Tensor ) and k .ndim > 1 :
75+ return True
76+ return False
77+
78+
6379def _get_tile_with_offset_info (
6480 k : object , state : CodegenState , k_index : int
6581) -> tuple [int , int | torch .SymInt ] | None :
@@ -102,6 +118,7 @@ def _get_tile_with_offset_info(
102118 return (meta ["block_id" ], meta ["offset" ])
103119
104120 return None
121+ return None
105122
106123
107124class IndexingStrategy :
@@ -554,6 +571,17 @@ def codegen_store(
554571 )
555572
556573
574+ def _try_python_index_shape (
575+ tensor : torch .Tensor , index : list [object ]
576+ ) -> list [int | torch .SymInt ] | None :
577+ try :
578+ tuple_index = tuple (index )
579+ result = tensor [tuple_index ] # pyright: ignore[reportGeneralTypeIssues]
580+ except Exception :
581+ return None
582+ return list (result .size ())
583+
584+
557585class SubscriptIndexing (NamedTuple ):
558586 index_expr : ast .AST
559587 mask_expr : ast .AST
@@ -567,6 +595,17 @@ def has_mask(self) -> bool:
567595 def compute_shape (
568596 tensor : torch .Tensor , index : list [object ], state : CodegenState | None = None
569597 ) -> list [int | torch .SymInt ]:
598+ advanced_mode = (
599+ isinstance (tensor , torch .Tensor )
600+ and len (index ) == tensor .ndim
601+ and index
602+ and all (isinstance (k , torch .Tensor ) for k in index )
603+ and _has_multidim_tensor_index (index )
604+ and not _has_padded_iota_index (state , len (index ))
605+ )
606+ if advanced_mode :
607+ if (shape := _try_python_index_shape (tensor , index )) is not None :
608+ return shape
570609 assert isinstance (tensor , torch .Tensor )
571610 assert isinstance (index , (list , tuple )), index
572611 input_size = collections .deque (tensor .size ())
@@ -605,18 +644,28 @@ def compute_shape(
605644 k_index += 1
606645 elif isinstance (k , slice ):
607646 size = input_size .popleft ()
608- # Handle slices with steps
609- slice_size = compute_slice_size (k , size )
610-
611- if slice_size != 1 :
612- rdim = env .allocate_reduction_dimension (slice_size )
613- output_size .append (rdim .var )
647+ is_full_slice = (
648+ (k .start is None or k .start == 0 )
649+ and k .stop is None
650+ and (k .step is None or k .step == 1 )
651+ )
652+
653+ if is_full_slice :
654+ if env .known_equal (size , 1 ):
655+ output_size .append (1 )
656+ else :
657+ output_size .append (size )
614658 else :
615- output_size .append (1 )
659+ # Handle slices with steps or bounded ranges
660+ slice_size = compute_slice_size (k , size )
661+
662+ if slice_size != 1 :
663+ rdim = env .allocate_reduction_dimension (slice_size )
664+ output_size .append (rdim .var )
665+ else :
666+ output_size .append (1 )
616667 k_index += 1
617- elif isinstance (k , torch .Tensor ) and (
618- k .ndim == 1 or (len (index ) == 1 and tensor .ndim == 1 )
619- ):
668+ elif isinstance (k , torch .Tensor ):
620669 input_size .popleft ()
621670 output_size .extend (k .size ())
622671 k_index += 1
@@ -664,6 +713,14 @@ def create(
664713 output_size = SubscriptIndexing .compute_shape (fake_value , index , state )
665714 env = CompileEnvironment .current ()
666715 dtype = env .triton_index_type ()
716+ advanced_mode = (
717+ isinstance (fake_value , torch .Tensor )
718+ and len (index ) == fake_value .ndim
719+ and bool (index )
720+ and all (isinstance (k , torch .Tensor ) for k in index )
721+ and _has_multidim_tensor_index (index )
722+ and not _has_padded_iota_index (state , len (index ))
723+ )
667724 if dtype == "tl.int32" and SubscriptIndexing ._needs_int64 (fake_value ):
668725 raise exc .IndexOffsetOutOfRangeForInt32 (env .settings .index_dtype )
669726
@@ -737,8 +794,26 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
737794 else :
738795 index_values .append (f"{ start } { expand } " )
739796 else :
740- # Full slice or slice without step
741- if not _is_size_one (size ):
797+ is_full_slice = (
798+ (k .start is None or k .start == 0 )
799+ and k .stop is None
800+ and (k .step is None or k .step == 1 )
801+ )
802+ if is_full_slice and not _is_size_one (size ):
803+ block_idx = env .get_block_id (size )
804+ if block_idx is not None :
805+ index_var = state .codegen .index_var (block_idx )
806+ index_values .append (f"({ index_var } ){ expand } " )
807+ if mask := state .codegen .mask_var (block_idx ):
808+ mask_values .setdefault (f"({ mask } ){ expand } " )
809+ else :
810+ rdim = env .allocate_reduction_dimension (size )
811+ block_idx = rdim .block_id
812+ index_var = state .codegen .index_var (block_idx )
813+ index_values .append (f"({ index_var } ){ expand } " )
814+ if mask := state .codegen .mask_var (block_idx ):
815+ mask_values .setdefault (f"({ mask } ){ expand } " )
816+ elif not _is_size_one (size ):
742817 rdim = env .allocate_reduction_dimension (size )
743818 block_idx = rdim .block_id
744819 index_var = state .codegen .index_var (block_idx )
@@ -749,22 +824,31 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
749824 index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
750825 output_idx += 1
751826 k_index += 1
752- elif isinstance (k , torch .Tensor ) and k .ndim == 1 :
753- expand = tile_strategy .expand_str (output_size , output_idx )
827+ elif isinstance (k , torch .Tensor ) and not (
828+ len (index ) == 1 and fake_value .ndim == 1
829+ ):
754830 ast_index = state .ast_args [1 ]
755831 assert isinstance (ast_index , (list , tuple ))
756832 assert len (ast_index ) == len (index )
757833 index_var = state .codegen .lift (ast_index [n ], prefix = "index" ).id
758- index_values .append (f"({ index_var } ){ expand } " )
759- if (block_idx := env .get_block_id (output_size [output_idx ])) is not None :
760- if mask := state .codegen .mask_var (block_idx ):
761- mask_values .setdefault (f"({ mask } ){ expand } " )
834+ if advanced_mode :
835+ index_values .append (index_var )
836+ else :
837+ expand = tile_strategy .expand_str (output_size , output_idx )
838+ index_values .append (f"({ index_var } ){ expand } " )
839+ if (block_idx := env .get_block_id (output_size [output_idx ])) is not None :
840+ if mask := state .codegen .mask_var (block_idx ):
841+ mask_values .setdefault (f"({ mask } ){ expand } " )
762842 # Check if this index comes from a padded hl.arange and generate mask
763843 if (
764844 original_length := _get_padded_iota_original_length (state , n )
765845 ) is not None :
766- mask_values .setdefault (f"({ index_var } < { original_length } ){ expand } " )
767- output_idx += 1
846+ if advanced_mode :
847+ mask_values .setdefault (f"({ index_var } < { original_length } )" )
848+ else :
849+ mask_values .setdefault (f"({ index_var } < { original_length } ){ expand } " )
850+ if not advanced_mode :
851+ output_idx += 1
768852 k_index += 1
769853 elif (
770854 isinstance (k , torch .Tensor ) and len (index ) == 1 and fake_value .ndim == 1
@@ -786,6 +870,8 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
786870 k_index += 1
787871 else :
788872 raise exc .InvalidIndexingType (type (k ))
873+ if advanced_mode :
874+ output_idx = len (output_size )
789875 assert len (output_size ) == output_idx
790876 assert len (index_values ) == fake_value .ndim
791877 index_expr = []
@@ -885,7 +971,11 @@ def need_reshape(self, node: ast.AST) -> bool:
885971 return True
886972 env = CompileEnvironment .current ()
887973 for a , b in zip (self .reshaped_size , self .block_shape , strict = True ):
888- if not env .known_equal (a , b ):
974+ block_id_a = env .resolve_block_id (a )
975+ block_id_b = env .resolve_block_id (b )
976+ if block_id_a != block_id_b :
977+ return True
978+ if block_id_a is None and not env .known_equal (a , b ):
889979 return True
890980 return False
891981
@@ -1035,7 +1125,13 @@ def create(
10351125 # Full slice or slice without step
10361126 if size != 1 :
10371127 rdim = env .allocate_reduction_dimension (size )
1038- res .offsets .append (state .codegen .offset_var (rdim .block_id ))
1128+ active_loops = state .codegen .active_device_loops .get (
1129+ rdim .block_id
1130+ )
1131+ if active_loops :
1132+ res .offsets .append (state .codegen .offset_var (rdim .block_id ))
1133+ else :
1134+ res .offsets .append ("0" )
10391135 res .block_shape .append (rdim .var )
10401136 else :
10411137 res .offsets .append ("0" )
0 commit comments