|
13 | 13 | from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs |
14 | 14 |
|
15 | 15 |
|
16 | | -def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): |
| 16 | +def _get_meta_tensors_for_op(op, user_args, user_kwargs): |
17 | 17 | out_t = op(*user_args, **user_kwargs) |
18 | 18 |
|
19 | 19 | if isinstance(out_t, torch.Tensor): |
20 | | - new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) |
| 20 | + out_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) |
21 | 21 | else: |
22 | | - new_tensor_meta = tree_map_only( |
| 22 | + out_tensor_meta = tree_map_only( |
23 | 23 | torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), out_t |
24 | 24 | ) |
25 | 25 |
|
26 | | - tensor_metas = tree_flatten(user_args)[0] |
27 | | - tensor_metas = tree_map_only( |
28 | | - torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), tensor_metas |
| 26 | + input_tensor_metas = tree_flatten(user_args)[0] |
| 27 | + input_tensor_metas = tree_map_only( |
| 28 | + torch.Tensor, |
| 29 | + lambda x: TensorMeta(x.shape, x.stride(), x.dtype), |
| 30 | + input_tensor_metas, |
| 31 | + ) |
| 32 | + input_tensor_metas = tuple( |
| 33 | + x for x in input_tensor_metas if isinstance(x, TensorMeta) |
29 | 34 | ) |
30 | | - tensor_metas = tuple(x for x in tensor_metas if isinstance(x, TensorMeta)) |
| 35 | + return out_tensor_meta, input_tensor_metas |
| 36 | + |
| 37 | + |
| 38 | +def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): |
| 39 | + new_tensor_meta, tensor_metas = _get_meta_tensors_for_op(op, user_args, user_kwargs) |
31 | 40 |
|
32 | 41 | for strat in out_strat.strategies: |
33 | 42 | if isinstance(new_tensor_meta, TensorMeta): |
@@ -85,22 +94,44 @@ def fill_missing_redistribute_cost(op, specs, out_strat): |
85 | 94 | strat.redistribute_cost = redistribute_costs |
86 | 95 |
|
87 | 96 |
|
88 | | -def _generate_dummy_strategy(mesh, tensor_meta, num_input_args, num_input_strategies): |
| 97 | +def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies): |
89 | 98 | from torch.distributed.tensor._dtensor_spec import DTensorSpec |
90 | 99 | from torch.distributed.tensor._op_schema import OpSpec |
91 | 100 | from torch.distributed.tensor.placement_types import Replicate |
92 | 101 |
|
93 | 102 | placements = (Replicate(),) * mesh.ndim |
| 103 | + |
| 104 | + out_tensor_meta, input_tensor_metas = _get_meta_tensors_for_op( |
| 105 | + op, user_args, user_kwargs |
| 106 | + ) |
| 107 | + |
94 | 108 | input_specs = [ |
95 | | - DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) |
96 | | - for _ in range(num_input_args) |
| 109 | + DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm) |
| 110 | + for tm in input_tensor_metas |
97 | 111 | ] |
98 | | - output_spec = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) |
| 112 | + if isinstance(out_tensor_meta, TensorMeta): |
| 113 | + output_spec = DTensorSpec( |
| 114 | + mesh=mesh, placements=placements, tensor_meta=out_tensor_meta |
| 115 | + ) |
| 116 | + else: |
| 117 | + output_spec = tuple( |
| 118 | + DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm) |
| 119 | + for tm in out_tensor_meta |
| 120 | + ) |
99 | 121 |
|
100 | 122 | out_strat = OpSpec(output_specs=output_spec, input_specs=input_specs) |
| 123 | + num_input_args = len(input_tensor_metas) |
| 124 | + input_strategies_flat = [ |
| 125 | + x for x in tree_flatten(input_strategies)[0] if isinstance(x, OpStrategy) |
| 126 | + ] |
| 127 | + assert num_input_args == len( |
| 128 | + input_strategies_flat |
| 129 | + ), f"{op}, {num_input_args}, {len(input_strategies_flat)}" |
| 130 | + # TODO: fix redistribute cost |
101 | 131 | out_strat.redistribute_cost = [ |
102 | | - [0.0] * num_input_strategies, |
103 | | - ] * num_input_args |
| 132 | + [0.0] * len(x.strategies) for x in input_strategies_flat |
| 133 | + ] |
| 134 | + assert len(out_strat.redistribute_cost) == num_input_args |
104 | 135 | out_strat = OpStrategy([out_strat]) |
105 | 136 | return out_strat |
106 | 137 |
|
@@ -147,29 +178,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): |
147 | 178 | ) |
148 | 179 | else: |
149 | 180 | print(f"Ops that need to be implemented {op}") |
150 | | - from .propagation_rules import _create_all_options |
151 | | - |
152 | | - tensor_meta = strat[0].strategies[0].output_spec.tensor_meta |
153 | | - num_strats = len(strat[0].strategies) |
154 | | - if op == torch.ops.aten.sort.stable: |
155 | | - out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ |
156 | | - torch.ops.aten.topk.default |
157 | | - ]( |
158 | | - op_schema |
159 | | - ) |
160 | | - elif op in { |
161 | | - torch.ops.autoparallel.fill_indices_wrapper.default, |
162 | | - torch.ops.aten.scatter_add.default, |
163 | | - torch.ops.prims.fma.default, |
164 | | - }: |
165 | | - out_strat = _generate_dummy_strategy(mesh, tensor_meta, 3, num_strats) |
166 | | - elif op in { |
167 | | - torch.ops.aten.slice_scatter.default, |
168 | | - torch.ops.aten._softmax_backward_data.default, |
169 | | - }: |
170 | | - out_strat = _generate_dummy_strategy(mesh, tensor_meta, 2, num_strats) |
171 | | - else: |
172 | | - out_strat = _create_all_options(mesh, tensor_meta.shape, tensor_meta) |
| 181 | + out_strat = _generate_dummy_strategy(mesh, op, user_args, user_kwargs, strat) |
173 | 182 |
|
174 | 183 | propagate_tensor_meta(op, user_args, user_kwargs, out_strat) |
175 | 184 | fill_missing_redistribute_cost(op, specs, out_strat) |
|
0 commit comments