@@ -783,3 +783,42 @@ def einsum_rule(mesh, op_schema):
783783 kwargs_schema = {},
784784 )
785785 return _mm_like_strategy (mm_equation , mesh , new_op_schema )
786+
787+
788+ @register_opschema_rule (torch .ops .aten .scatter .src )
789+ def scatter_strategy (mesh , op_schema : OpSchema ):
790+ # taken from scatter_add strategy from PyTorch
791+ from torch .distributed .tensor ._ops ._tensor_ops import (
792+ PlacementList ,
793+ expand_to_full_mesh_op_strategy ,
794+ normalize_dim ,
795+ )
796+
797+ input_strategy = op_schema .args_schema [0 ]
798+ dim = op_schema .args_schema [1 ]
799+ index_strategy = op_schema .args_schema [2 ]
800+
801+ assert isinstance (input_strategy , OpStrategy )
802+ assert isinstance (index_strategy , OpStrategy )
803+ assert isinstance (dim , int )
804+ dim = normalize_dim (dim , input_strategy .ndim )
805+ mesh = input_strategy .mesh
806+ input_shape = input_strategy .shape
807+ index_shape = index_strategy .shape
808+
809+ single_mesh_dim_strategies = []
810+
811+ # placement list stores placements of [output, input, index, src]
812+ # first we always have replicate all for inputs and output
813+ all_replicate : PlacementList = [Replicate ()] * 4
814+ single_mesh_dim_strategies .append (all_replicate )
815+
816+ if len (input_shape ) == len (index_shape ):
817+ for d in range (len (input_shape )):
818+ if d != dim and input_shape [d ] == index_shape [d ]:
819+ sharding : PlacementList = [Shard (d ), Shard (d ), Shard (d ), Shard (d )]
820+ single_mesh_dim_strategies .append (sharding )
821+
822+ return expand_to_full_mesh_op_strategy (
823+ mesh , op_schema , single_mesh_dim_strategies , input_index = 1
824+ )
0 commit comments