Skip to content

Commit 9f9f00a

Browse files
committed
Add proper redistribute_cost to dummy strategies
1 parent 2df8772 commit 9f9f00a

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

autoparallel/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,12 @@ def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies)
127127
assert num_input_args == len(
128128
input_strategies_flat
129129
), f"{op}, {num_input_args}, {len(input_strategies_flat)}"
130-
# TODO: fix redistribute cost
131-
out_strat.redistribute_cost = [
132-
[0.0] * len(x.strategies) for x in input_strategies_flat
130+
redistribute_cost = [
131+
generate_redistribute_costs(input_strategies_flat[i], input_specs[i])
132+
for i in range(num_input_args)
133133
]
134+
out_strat.redistribute_cost = redistribute_cost
135+
134136
assert len(out_strat.redistribute_cost) == num_input_args
135137
out_strat = OpStrategy([out_strat])
136138
return out_strat

0 commit comments

Comments
 (0)