Skip to content

Commit 4391c38

Browse files
authored
Add scatter propagation rule (#169)
PyTorch only supports Replicate sharding. Add sharded replication rules as well, taken from scatter_add from PyTorch
1 parent 313a597 commit 4391c38

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

autoparallel/propagation_rules.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,42 @@ def einsum_rule(mesh, op_schema):
783783
kwargs_schema={},
784784
)
785785
return _mm_like_strategy(mm_equation, mesh, new_op_schema)
786+
787+
788+
@register_opschema_rule(torch.ops.aten.scatter.src)
789+
def scatter_strategy(mesh, op_schema: OpSchema):
790+
# taken from scatter_add strategy from PyTorch
791+
from torch.distributed.tensor._ops._tensor_ops import (
792+
PlacementList,
793+
expand_to_full_mesh_op_strategy,
794+
normalize_dim,
795+
)
796+
797+
input_strategy = op_schema.args_schema[0]
798+
dim = op_schema.args_schema[1]
799+
index_strategy = op_schema.args_schema[2]
800+
801+
assert isinstance(input_strategy, OpStrategy)
802+
assert isinstance(index_strategy, OpStrategy)
803+
assert isinstance(dim, int)
804+
dim = normalize_dim(dim, input_strategy.ndim)
805+
mesh = input_strategy.mesh
806+
input_shape = input_strategy.shape
807+
index_shape = index_strategy.shape
808+
809+
single_mesh_dim_strategies = []
810+
811+
# placement list stores placements of [output, input, index, src]
812+
# first we always have replicate all for inputs and output
813+
all_replicate: PlacementList = [Replicate()] * 4
814+
single_mesh_dim_strategies.append(all_replicate)
815+
816+
if len(input_shape) == len(index_shape):
817+
for d in range(len(input_shape)):
818+
if d != dim and input_shape[d] == index_shape[d]:
819+
sharding: PlacementList = [Shard(d), Shard(d), Shard(d), Shard(d)]
820+
single_mesh_dim_strategies.append(sharding)
821+
822+
return expand_to_full_mesh_op_strategy(
823+
mesh, op_schema, single_mesh_dim_strategies, input_index=1
824+
)

0 commit comments

Comments
 (0)