|
34 | 34 |
|
35 | 35 |
|
36 | 36 | class ApplyShardingInterpreter(torch.fx.Interpreter): |
37 | | - def __init__(self, module, sharding_placement, decomp_table=None): |
| 37 | + def __init__(self, module, sharding_placement): |
38 | 38 | super().__init__(module, garbage_collect_values=True, graph=None) |
39 | 39 | self.sharding_placement = sharding_placement |
40 | | - if decomp_table is None: |
41 | | - decomp_table = {} |
42 | | - self.decomp_table = decomp_table |
43 | 40 | param_placement_order = {} |
44 | 41 | if _ENABLE_ORDERED_SHARDING_OPTIMIZATION: |
45 | 42 | param_placement_order = compute_optimal_placement_order_for_parameters( |
@@ -170,13 +167,6 @@ def call_function(self, target, args, kwargs): |
170 | 167 | # TODO: see if we can remove this contiguous properly |
171 | 168 | new_args[0] = new_args[0].contiguous() |
172 | 169 |
|
173 | | - if target in self.decomp_table: |
174 | | - new_target = self.decomp_table[target] |
175 | | - out = super().call_function(new_target, tuple(new_args), kwargs) |
176 | | - # NOTE: is there a canonical way of handling this? |
177 | | - if out is not NotImplemented: |
178 | | - out = tree_map_only(DTensor, lambda x: x.to_local(), out) |
179 | | - return out |
180 | 170 | out = super().call_function(target, tuple(new_args), kwargs) |
181 | 171 | out = tree_map_only(DTensor, lambda x: x.to_local(), out) |
182 | 172 | return out |
@@ -246,7 +236,7 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): |
246 | 236 |
|
247 | 237 | decomp_table = _get_inductor_decomp_table() |
248 | 238 | # run with DTensor to apply the collectives given the graph |
249 | | - interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table) |
| 239 | + interp = ApplyShardingInterpreter(gm, sharding_placement) |
250 | 240 |
|
251 | 241 | # TODO: make_fx here is suspicious in case of dynamic shapes |
252 | 242 | with fx_traceback.preserve_node_meta(): |
|
0 commit comments