Skip to content

Commit 5726d7c

Browse files
authored
Remove more invalid / uneven shardings (#23)
Before we were only removing shardings which were invalid for the inputs of the ops. Now we are also removing those which are invalid for the output. With that, we can now remove the solver constraint to remove invalid views, as those don't appear anymore
1 parent 2bc148e commit 5726d7c

File tree

3 files changed

+37
-81
lines changed

3 files changed

+37
-81
lines changed

autoparallel/optimize_sharding.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -206,67 +206,8 @@ def add_default_constraints(self):
206206
self.add_output_input_consistent_constraint()
207207
self.add_inf_cost_constraint()
208208

209-
self.remove_invalid_configurations()
210209
self.penalize_inefficient_collectives()
211210

212-
def remove_invalid_configurations(self):
213-
"""
214-
Remove shardings that could yield invalid configurations,
215-
for example, when sharding a view on a dimension that would yield
216-
an empty size. Maybe this should be fixed in the returned specs from PyTorch
217-
though, but removing those invalid cases here for now
218-
"""
219-
for s_i, node in enumerate(self.graph.nodes):
220-
if node.op != "call_function":
221-
continue
222-
# only targetting view for now
223-
if node.target != torch.ops.aten.view.default:
224-
continue
225-
orig_shape = node.args[0].meta["val"].shape
226-
shape = list(node.args[1])
227-
if len(orig_shape) > len(shape):
228-
# TODO: FIXME as I think we should also handle this case
229-
continue
230-
# print("in heeeererereer", orig_shape, shape)
231-
tgt_op_strat = self.strats[node]
232-
for counter, parent in enumerate(node.all_input_nodes):
233-
curr_op_strat = self.strats[parent]
234-
235-
for oi, tgt_strat in enumerate(tgt_op_strat.strategies):
236-
spec = tgt_strat.input_specs[counter]
237-
if not isinstance(spec, DTensorSpec):
238-
# TODO: check if this is correct
239-
continue
240-
241-
for ii, curr_strat in enumerate(curr_op_strat.strategies):
242-
curr_spec = curr_strat.output_specs
243-
if not isinstance(curr_spec, DTensorSpec):
244-
continue
245-
shape = list(node.args[1])
246-
if -1 in shape:
247-
# handle cases where we need to infer the size
248-
numel = math.prod(orig_shape)
249-
index_loc = shape.index(-1)
250-
# this works because the shape we infer is -1
251-
# and there is a single one
252-
visible_numel = -math.prod(shape)
253-
shape[index_loc] = numel // visible_numel
254-
for mesh_shape, tgt_plc, curr_plc in zip(
255-
spec.mesh.shape, spec.placements, curr_spec.placements
256-
):
257-
# only keep view shardings that don't yield empty shapes
258-
# which could happen with S(0)S(0) on a dimension whose shape
259-
# is smaller than world_size
260-
if tgt_plc.is_shard():
261-
dim = tgt_plc.dim
262-
if shape[dim] % mesh_shape == 0:
263-
shape[dim] /= mesh_shape
264-
else:
265-
self.prob += (
266-
self.ds[(s_i, counter, oi, ii)]["va"] == 0,
267-
_get_next_name("invalid_view"),
268-
)
269-
270211
def penalize_inefficient_collectives(self):
271212
"""
272213
When performing shard_{n} -> replicate (for n != 0), there is additional

autoparallel/propagation_rules.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,32 @@ def _build_meta_tensor(tensor_meta):
7676
)
7777

7878

79+
def remove_invalid_configs(out_strat, mesh):
80+
kept = []
81+
for strategy in out_strat.strategies:
82+
is_valid = True
83+
output_specs = strategy.output_specs
84+
if isinstance(output_specs, DTensorSpec):
85+
output_specs = [output_specs]
86+
specs = list(strategy.input_specs) + list(output_specs)
87+
for spec in specs:
88+
if spec is None:
89+
continue
90+
shape = list(spec.tensor_meta.shape)
91+
for mesh_shape, plc in zip(mesh.shape, spec.placements):
92+
if plc.is_shard():
93+
dim = plc.dim
94+
if shape[dim] % mesh_shape == 0:
95+
shape[dim] //= mesh_shape
96+
else:
97+
is_valid = False
98+
break
99+
if is_valid:
100+
kept.append(strategy)
101+
102+
return OpStrategy(kept)
103+
104+
79105
def _create_all_options_no_nested_sharding(mesh, shape, tensor_meta=None):
80106
if tensor_meta is None:
81107
tensor_meta = _gen_tensor_meta(shape)
@@ -94,7 +120,9 @@ def _create_all_options_no_nested_sharding(mesh, shape, tensor_meta=None):
94120
continue
95121
spec = DTensorSpec.from_dim_map(mesh, op, [], tensor_meta)
96122
strats.append(OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]]))
97-
return OpStrategy(strats)
123+
out_strats = OpStrategy(strats)
124+
out_strats = remove_invalid_configs(out_strats, mesh)
125+
return out_strats
98126

99127

100128
def _create_all_options(mesh, shape, tensor_meta=None, tensor=None):
@@ -112,7 +140,9 @@ def _create_all_options(mesh, shape, tensor_meta=None, tensor=None):
112140
for placement in all_options:
113141
spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
114142
strats.append(OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]]))
115-
return OpStrategy(strats)
143+
out_strats = OpStrategy(strats)
144+
out_strats = remove_invalid_configs(out_strats, mesh)
145+
return out_strats
116146

117147

118148
@register_rule(operator.getitem)

autoparallel/utils.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.distributed.tensor._ops.utils import generate_redistribute_costs
1111
from torch.utils._pytree import tree_flatten, tree_map_only
1212

13-
from .propagation_rules import _op_partial_rules, _op_rules
13+
from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs
1414

1515

1616
def propagate_tensor_meta(op, user_args, out_strat):
@@ -90,7 +90,9 @@ def get_placement_options(mesh, op, specs, user_args):
9090
# print(op)
9191

9292
if op in _op_rules:
93-
return _op_rules[op](mesh, specs)
93+
out_strat = _op_rules[op](mesh, specs)
94+
out_strat = remove_invalid_configs(out_strat, mesh)
95+
return out_strat
9496

9597
strat = []
9698
for spec in specs:
@@ -119,24 +121,7 @@ def get_placement_options(mesh, op, specs, user_args):
119121

120122
propagate_tensor_meta(op, user_args, out_strat)
121123
fill_missing_redistribute_cost(op, specs, out_strat)
122-
123-
kept = []
124-
for strategy in out_strat.strategies:
125-
is_valid = True
126-
for input_spec in strategy.input_specs:
127-
shape = list(input_spec.tensor_meta.shape)
128-
for mesh_shape, plc in zip(mesh.shape, input_spec.placements):
129-
if plc.is_shard():
130-
dim = plc.dim
131-
if shape[dim] % mesh_shape == 0:
132-
shape[dim] /= mesh_shape
133-
else:
134-
is_valid = False
135-
break
136-
if is_valid:
137-
kept.append(strategy)
138-
139-
out_strat = OpStrategy(kept)
124+
out_strat = remove_invalid_configs(out_strat, mesh)
140125

141126
return out_strat
142127

0 commit comments

Comments
 (0)