|
34 | 34 |
|
35 | 35 |
|
36 | 36 | class ApplyShardingInterpreter(torch.fx.Interpreter): |
37 | | - def __init__(self, module, sharding_placement, decomp_table=None): |
| 37 | + def __init__(self, module, sharding_placement): |
38 | 38 | super().__init__(module, garbage_collect_values=True, graph=None) |
39 | 39 | self.sharding_placement = sharding_placement |
40 | | - if decomp_table is None: |
41 | | - decomp_table = {} |
42 | | - self.decomp_table = decomp_table |
43 | 40 | param_placement_order = {} |
44 | 41 | if _ENABLE_ORDERED_SHARDING_OPTIMIZATION: |
45 | 42 | param_placement_order = compute_optimal_placement_order_for_parameters( |
@@ -170,36 +167,11 @@ def call_function(self, target, args, kwargs): |
170 | 167 | # TODO: see if we can remove this contiguous properly |
171 | 168 | new_args[0] = new_args[0].contiguous() |
172 | 169 |
|
173 | | - if target in self.decomp_table: |
174 | | - new_target = self.decomp_table[target] |
175 | | - out = super().call_function(new_target, tuple(new_args), kwargs) |
176 | | - # NOTE: is there a canonical way of handling this? |
177 | | - if out is not NotImplemented: |
178 | | - out = tree_map_only(DTensor, lambda x: x.to_local(), out) |
179 | | - return out |
180 | 170 | out = super().call_function(target, tuple(new_args), kwargs) |
181 | 171 | out = tree_map_only(DTensor, lambda x: x.to_local(), out) |
182 | 172 | return out |
183 | 173 |
|
184 | 174 |
|
185 | | -class ApplyDecompInterpreter(torch.fx.Interpreter): |
186 | | - def __init__(self, module, decomp_table=None): |
187 | | - super().__init__(module, garbage_collect_values=True, graph=None) |
188 | | - if decomp_table is None: |
189 | | - decomp_table = {} |
190 | | - self.decomp_table = decomp_table |
191 | | - |
192 | | - def call_function(self, target, args, kwargs): |
193 | | - if target in self.decomp_table: |
194 | | - new_target = self.decomp_table[target] |
195 | | - out = super().call_function(new_target, args, kwargs) |
196 | | - # NOTE: is there a canonical way of handling this? |
197 | | - if out is not NotImplemented: |
198 | | - return out |
199 | | - out = super().call_function(target, args, kwargs) |
200 | | - return out |
201 | | - |
202 | | - |
203 | 175 | def shard_node_given_placements(node, sharding_placement, *, meta: bool): |
204 | 176 | # TODO: not sure if we actually guarantee sharding_placement has ever |
205 | 177 | # input node lol |
@@ -264,16 +236,18 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): |
264 | 236 |
|
265 | 237 | decomp_table = _get_inductor_decomp_table() |
266 | 238 | # run with DTensor to apply the collectives given the graph |
267 | | - interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table) |
| 239 | + interp = ApplyShardingInterpreter(gm, sharding_placement) |
268 | 240 |
|
269 | 241 | # TODO: make_fx here is suspicious in case of dynamic shapes |
270 | 242 | with fx_traceback.preserve_node_meta(): |
271 | 243 | parallel_gm0 = make_fx(interp.run)(*local_args) |
272 | 244 |
|
273 | 245 | cleanup_graph(parallel_gm0) |
274 | | - interp2 = ApplyDecompInterpreter(parallel_gm0, decomp_table) |
| 246 | + interp2 = torch.fx.Interpreter(parallel_gm0) |
275 | 247 | with fx_traceback.preserve_node_meta(): |
276 | | - parallel_gm = make_fx(interp2.run)(*local_args) |
| 248 | + parallel_gm = make_fx(interp2.run, decomposition_table=decomp_table)( |
| 249 | + *local_args |
| 250 | + ) |
277 | 251 | cleanup_graph(parallel_gm) |
278 | 252 |
|
279 | 253 | # Copy descriptors over to new graph |
|
0 commit comments