Skip to content

Commit 986c922

Browse files
authored
Remove decomp_table from ApplyShardingInterpreter (#137)
Now that we recursively apply decomps following #136, we don't need to add one more decomp in the AppyShardingInterpreter (which was there just to be able to perform two passes of decompositions)
1 parent eaac0a7 commit 986c922

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

autoparallel/apply_sharding.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,9 @@
3434

3535

3636
class ApplyShardingInterpreter(torch.fx.Interpreter):
37-
def __init__(self, module, sharding_placement, decomp_table=None):
37+
def __init__(self, module, sharding_placement):
3838
super().__init__(module, garbage_collect_values=True, graph=None)
3939
self.sharding_placement = sharding_placement
40-
if decomp_table is None:
41-
decomp_table = {}
42-
self.decomp_table = decomp_table
4340
param_placement_order = {}
4441
if _ENABLE_ORDERED_SHARDING_OPTIMIZATION:
4542
param_placement_order = compute_optimal_placement_order_for_parameters(
@@ -170,13 +167,6 @@ def call_function(self, target, args, kwargs):
170167
# TODO: see if we can remove this contiguous properly
171168
new_args[0] = new_args[0].contiguous()
172169

173-
if target in self.decomp_table:
174-
new_target = self.decomp_table[target]
175-
out = super().call_function(new_target, tuple(new_args), kwargs)
176-
# NOTE: is there a canonical way of handling this?
177-
if out is not NotImplemented:
178-
out = tree_map_only(DTensor, lambda x: x.to_local(), out)
179-
return out
180170
out = super().call_function(target, tuple(new_args), kwargs)
181171
out = tree_map_only(DTensor, lambda x: x.to_local(), out)
182172
return out
@@ -246,7 +236,7 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
246236

247237
decomp_table = _get_inductor_decomp_table()
248238
# run with DTensor to apply the collectives given the graph
249-
interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table)
239+
interp = ApplyShardingInterpreter(gm, sharding_placement)
250240

251241
# TODO: make_fx here is suspicious in case of dynamic shapes
252242
with fx_traceback.preserve_node_meta():

0 commit comments

Comments
 (0)