Skip to content

Commit 2df8772

Browse files
committed
Make _geenrate_dummy_strategy more generic
Now should handle all ops properly, with correct shapes
1 parent ab4ab37 commit 2df8772

File tree

1 file changed

+45
-36
lines changed

1 file changed

+45
-36
lines changed

autoparallel/utils.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,30 @@
1313
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
1414

1515

16-
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
16+
def _get_meta_tensors_for_op(op, user_args, user_kwargs):
1717
out_t = op(*user_args, **user_kwargs)
1818

1919
if isinstance(out_t, torch.Tensor):
20-
new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype)
20+
out_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype)
2121
else:
22-
new_tensor_meta = tree_map_only(
22+
out_tensor_meta = tree_map_only(
2323
torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), out_t
2424
)
2525

26-
tensor_metas = tree_flatten(user_args)[0]
27-
tensor_metas = tree_map_only(
28-
torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), tensor_metas
26+
input_tensor_metas = tree_flatten(user_args)[0]
27+
input_tensor_metas = tree_map_only(
28+
torch.Tensor,
29+
lambda x: TensorMeta(x.shape, x.stride(), x.dtype),
30+
input_tensor_metas,
31+
)
32+
input_tensor_metas = tuple(
33+
x for x in input_tensor_metas if isinstance(x, TensorMeta)
2934
)
30-
tensor_metas = tuple(x for x in tensor_metas if isinstance(x, TensorMeta))
35+
return out_tensor_meta, input_tensor_metas
36+
37+
38+
def propagate_tensor_meta(op, user_args, user_kwargs, out_strat):
39+
new_tensor_meta, tensor_metas = _get_meta_tensors_for_op(op, user_args, user_kwargs)
3140

3241
for strat in out_strat.strategies:
3342
if isinstance(new_tensor_meta, TensorMeta):
@@ -85,22 +94,44 @@ def fill_missing_redistribute_cost(op, specs, out_strat):
8594
strat.redistribute_cost = redistribute_costs
8695

8796

88-
def _generate_dummy_strategy(mesh, tensor_meta, num_input_args, num_input_strategies):
97+
def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies):
8998
from torch.distributed.tensor._dtensor_spec import DTensorSpec
9099
from torch.distributed.tensor._op_schema import OpSpec
91100
from torch.distributed.tensor.placement_types import Replicate
92101

93102
placements = (Replicate(),) * mesh.ndim
103+
104+
out_tensor_meta, input_tensor_metas = _get_meta_tensors_for_op(
105+
op, user_args, user_kwargs
106+
)
107+
94108
input_specs = [
95-
DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta)
96-
for _ in range(num_input_args)
109+
DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm)
110+
for tm in input_tensor_metas
97111
]
98-
output_spec = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta)
112+
if isinstance(out_tensor_meta, TensorMeta):
113+
output_spec = DTensorSpec(
114+
mesh=mesh, placements=placements, tensor_meta=out_tensor_meta
115+
)
116+
else:
117+
output_spec = tuple(
118+
DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm)
119+
for tm in out_tensor_meta
120+
)
99121

100122
out_strat = OpSpec(output_specs=output_spec, input_specs=input_specs)
123+
num_input_args = len(input_tensor_metas)
124+
input_strategies_flat = [
125+
x for x in tree_flatten(input_strategies)[0] if isinstance(x, OpStrategy)
126+
]
127+
assert num_input_args == len(
128+
input_strategies_flat
129+
), f"{op}, {num_input_args}, {len(input_strategies_flat)}"
130+
# TODO: fix redistribute cost
101131
out_strat.redistribute_cost = [
102-
[0.0] * num_input_strategies,
103-
] * num_input_args
132+
[0.0] * len(x.strategies) for x in input_strategies_flat
133+
]
134+
assert len(out_strat.redistribute_cost) == num_input_args
104135
out_strat = OpStrategy([out_strat])
105136
return out_strat
106137

@@ -147,29 +178,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs):
147178
)
148179
else:
149180
print(f"Ops that need to be implemented {op}")
150-
from .propagation_rules import _create_all_options
151-
152-
tensor_meta = strat[0].strategies[0].output_spec.tensor_meta
153-
num_strats = len(strat[0].strategies)
154-
if op == torch.ops.aten.sort.stable:
155-
out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[
156-
torch.ops.aten.topk.default
157-
](
158-
op_schema
159-
)
160-
elif op in {
161-
torch.ops.autoparallel.fill_indices_wrapper.default,
162-
torch.ops.aten.scatter_add.default,
163-
torch.ops.prims.fma.default,
164-
}:
165-
out_strat = _generate_dummy_strategy(mesh, tensor_meta, 3, num_strats)
166-
elif op in {
167-
torch.ops.aten.slice_scatter.default,
168-
torch.ops.aten._softmax_backward_data.default,
169-
}:
170-
out_strat = _generate_dummy_strategy(mesh, tensor_meta, 2, num_strats)
171-
else:
172-
out_strat = _create_all_options(mesh, tensor_meta.shape, tensor_meta)
181+
out_strat = _generate_dummy_strategy(mesh, op, user_args, user_kwargs, strat)
173182

174183
propagate_tensor_meta(op, user_args, user_kwargs, out_strat)
175184
fill_missing_redistribute_cost(op, specs, out_strat)

0 commit comments

Comments
 (0)