Skip to content

Commit 7d9ff4e

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents 403795c + 986c922 commit 7d9ff4e

File tree

2 files changed

+10
-32
lines changed

2 files changed

+10
-32
lines changed

autoparallel/apply_sharding.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,9 @@
3434

3535

3636
class ApplyShardingInterpreter(torch.fx.Interpreter):
37-
def __init__(self, module, sharding_placement, decomp_table=None):
37+
def __init__(self, module, sharding_placement):
3838
super().__init__(module, garbage_collect_values=True, graph=None)
3939
self.sharding_placement = sharding_placement
40-
if decomp_table is None:
41-
decomp_table = {}
42-
self.decomp_table = decomp_table
4340
param_placement_order = {}
4441
if _ENABLE_ORDERED_SHARDING_OPTIMIZATION:
4542
param_placement_order = compute_optimal_placement_order_for_parameters(
@@ -170,36 +167,11 @@ def call_function(self, target, args, kwargs):
170167
# TODO: see if we can remove this contiguous properly
171168
new_args[0] = new_args[0].contiguous()
172169

173-
if target in self.decomp_table:
174-
new_target = self.decomp_table[target]
175-
out = super().call_function(new_target, tuple(new_args), kwargs)
176-
# NOTE: is there a canonical way of handling this?
177-
if out is not NotImplemented:
178-
out = tree_map_only(DTensor, lambda x: x.to_local(), out)
179-
return out
180170
out = super().call_function(target, tuple(new_args), kwargs)
181171
out = tree_map_only(DTensor, lambda x: x.to_local(), out)
182172
return out
183173

184174

185-
class ApplyDecompInterpreter(torch.fx.Interpreter):
186-
def __init__(self, module, decomp_table=None):
187-
super().__init__(module, garbage_collect_values=True, graph=None)
188-
if decomp_table is None:
189-
decomp_table = {}
190-
self.decomp_table = decomp_table
191-
192-
def call_function(self, target, args, kwargs):
193-
if target in self.decomp_table:
194-
new_target = self.decomp_table[target]
195-
out = super().call_function(new_target, args, kwargs)
196-
# NOTE: is there a canonical way of handling this?
197-
if out is not NotImplemented:
198-
return out
199-
out = super().call_function(target, args, kwargs)
200-
return out
201-
202-
203175
def shard_node_given_placements(node, sharding_placement, *, meta: bool):
204176
# TODO: not sure if we actually guarantee sharding_placement has ever
205177
# input node lol
@@ -264,16 +236,18 @@ def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec):
264236

265237
decomp_table = _get_inductor_decomp_table()
266238
# run with DTensor to apply the collectives given the graph
267-
interp = ApplyShardingInterpreter(gm, sharding_placement, decomp_table)
239+
interp = ApplyShardingInterpreter(gm, sharding_placement)
268240

269241
# TODO: make_fx here is suspicious in case of dynamic shapes
270242
with fx_traceback.preserve_node_meta():
271243
parallel_gm0 = make_fx(interp.run)(*local_args)
272244

273245
cleanup_graph(parallel_gm0)
274-
interp2 = ApplyDecompInterpreter(parallel_gm0, decomp_table)
246+
interp2 = torch.fx.Interpreter(parallel_gm0)
275247
with fx_traceback.preserve_node_meta():
276-
parallel_gm = make_fx(interp2.run)(*local_args)
248+
parallel_gm = make_fx(interp2.run, decomposition_table=decomp_table)(
249+
*local_args
250+
)
277251
cleanup_graph(parallel_gm)
278252

279253
# Copy descriptors over to new graph

autoparallel/autobucketing_util/estimation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def benchmark_and_sync_runtime(
7575
_, memories_at_nodes = memory.estimate_peak_memory(
7676
snodes, name_to_freeable_input_buf, graph_outputs
7777
)
78+
# ensure memory offset is always positive
79+
if min(memories_at_nodes) < 0:
80+
shift_value = abs(min(memories_at_nodes))
81+
memories_at_nodes = [x + shift_value for x in memories_at_nodes]
7882

7983
for idx, snode in enumerate(snodes):
8084
if is_collective(

0 commit comments

Comments
 (0)