Skip to content

Commit 9e86bcc

Browse files
authored
integrate dI/dW graph pass into graph_pp_runner.py (#225) (#232)
* integrate dI/dW graph pass into graph_pp_runner.py [ghstack-poisoned] * Update on "integrate dI/dW graph pass into graph_pp_runner.py" [ghstack-poisoned] * Update on "integrate dI/dW graph pass into graph_pp_runner.py" [ghstack-poisoned]
1 parent c583870 commit 9e86bcc

File tree

4 files changed

+313
-26
lines changed

4 files changed

+313
-26
lines changed

autoparallel/_passes/split_di_dw_graph.py

Lines changed: 211 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,29 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import copy
7+
import itertools
8+
import operator
79

10+
import sympy
11+
import torch
812
import torch.fx as fx
9-
from functorch.compile import default_partition
13+
from torch._functorch.partitioners import (
14+
SavedForBackwardsAOTOutput,
15+
_extract_fwd_bwd_outputs,
16+
_extract_graph_with_inputs_outputs,
17+
_is_backward_state,
18+
_is_bwd_seed_offset,
19+
_is_fwd_seed_offset,
20+
_is_primal,
21+
_remove_by_name,
22+
find_symbol_binding_fx_nodes,
23+
free_symbols,
24+
is_sym_node,
25+
is_symbol_binding_fx_node,
26+
)
27+
from torch.utils._ordered_set import OrderedSet
28+
29+
from autoparallel.apply_sharding import rename_placeholder_node
1030

1131
# we are running the default partitioner on the bw graph, which requires AC tags being removed.
1232
# At this stage we have already finished running AC anyway, since we have a bw graph
@@ -44,21 +64,203 @@ def reorder_output_grads(bw_gm, num_weight_gradients):
4464
return len(grad_inputs)
4565

4666

47-
# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
67+
# This is a copy of the function used by the default partitioner,
68+
# which does *not* reorder symint activations.
69+
# This is reordering is needed by the custom autograd.Function in AOTDispatcher,
70+
# but isn't needed in our dI/dW splitting since there is no autograd in the loop.
71+
# TODO: provide a way to gt this behavior automatically out of the default partitioner
72+
def _extract_fwd_bwd_modules(
73+
joint_module: fx.GraphModule,
74+
saved_values: list[fx.Node],
75+
saved_sym_nodes: list[fx.Node],
76+
*,
77+
num_fwd_outputs: int,
78+
) -> tuple[fx.GraphModule, fx.GraphModule]:
79+
(
80+
fwd_outputs,
81+
bwd_outputs,
82+
fwd_outputs_descs,
83+
bwd_outputs_descs,
84+
) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
85+
placeholders = joint_module.graph.find_nodes(op="placeholder")
86+
primal_inputs = [*filter(_is_primal, placeholders)]
87+
fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
88+
bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
89+
backward_state_inputs = [*filter(_is_backward_state, placeholders)]
90+
91+
bwd_graph = _extract_graph_with_inputs_outputs(
92+
joint_module.graph,
93+
saved_values + saved_sym_nodes + bwd_seed_offset_inputs,
94+
bwd_outputs,
95+
bwd_outputs_descs,
96+
"backward",
97+
ignore_must_be_in_fw_bw=True,
98+
)
99+
100+
distributed_enabled = torch.distributed.is_available()
101+
102+
for node in bwd_graph.find_nodes(op="placeholder"):
103+
# This is to filter out saved values that don't actually end up being used by the backwards pass
104+
if not node.users:
105+
_remove_by_name(saved_values, node.name)
106+
_remove_by_name(saved_sym_nodes, node.name)
107+
# wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw,
108+
# but this dead activation is actually a collective,
109+
# then the collective will generally by followed by a wait_tensor() call.
110+
# we need to peak one node further to see if this wait_tensor is dead as well.
111+
elif distributed_enabled and all(
112+
n.target is torch.ops._c10d_functional.wait_tensor.default
113+
and len(n.users) == 0
114+
for n in node.users
115+
):
116+
_remove_by_name(saved_values, node.name)
117+
_remove_by_name(saved_sym_nodes, node.name)
118+
elif _is_backward_state(node):
119+
# BackwardState is saved directly
120+
_remove_by_name(saved_values, node.name)
121+
assert backward_state_inputs
122+
123+
# Now that we have the finalized list of saved values, we need to ensure
124+
# we propagate all symbols which are referenced by backwards inputs.
125+
# These are not directly used in the graph but are required for downstream
126+
# sizevar assignment
127+
saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
128+
saved_sym_nodes_binding = []
129+
saved_sym_nodes_derived = []
130+
131+
# Some symbols may already be bound in the directly saved_sym_nodes,
132+
# keep track of them so we don't re-bind them
133+
for node in saved_sym_nodes:
134+
symbol = is_symbol_binding_fx_node(node)
135+
if symbol:
136+
saved_symbols.add(symbol)
137+
saved_sym_nodes_binding.append(node)
138+
else:
139+
saved_sym_nodes_derived.append(node)
140+
141+
# Now go through all of the prospective backward inputs and track any
142+
# other symbols we need to bind
143+
symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
144+
for node in itertools.chain(saved_sym_nodes_derived, saved_values):
145+
if "val" not in node.meta:
146+
continue
147+
new_symbols = free_symbols(node.meta["val"]) - saved_symbols
148+
# NB: Deterministic order please!
149+
for s in sorted(new_symbols, key=lambda s: s.name):
150+
# NB: For well formed graphs, the symbol should always be present,
151+
# but we also have ways to produce ill-formed graphs, e.g., direct
152+
# make_fx usages, so don't choke in this case
153+
if s not in symbol_bindings:
154+
continue
155+
saved_sym_nodes_binding.append(symbol_bindings[s])
156+
saved_symbols |= new_symbols
157+
158+
# Update saved_sym_nodes that are now reordered to have all bindings at
159+
# front. This can also be used later on to figure out the position of saved
160+
# sym nodes in the output of fwd graph.
161+
saved_sym_nodes.clear()
162+
saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
48163

164+
# Now, we re-generate the fwd/bwd graphs.
165+
# NB: This might increase compilation time, but I doubt it matters
166+
fwd_graph = _extract_graph_with_inputs_outputs(
167+
joint_module.graph,
168+
primal_inputs + fwd_seed_offset_inputs,
169+
fwd_outputs + saved_values + saved_sym_nodes,
170+
fwd_outputs_descs
171+
+ [
172+
SavedForBackwardsAOTOutput(i)
173+
for i in range(len(saved_values) + len(saved_sym_nodes))
174+
],
175+
"forward",
176+
ignore_must_be_in_fw_bw=True,
177+
)
178+
bwd_graph = _extract_graph_with_inputs_outputs(
179+
joint_module.graph,
180+
saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs,
181+
bwd_outputs,
182+
bwd_outputs_descs,
183+
"backward",
184+
ignore_must_be_in_fw_bw=True,
185+
)
186+
187+
fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
188+
bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
189+
return fwd_module, bwd_module
49190

191+
192+
# TODO: in theory we can infer num_weight_gradients from the graph metadata directly
50193
def split_di_dw_graph(
51-
bw_gm: fx.GraphModule, *, num_weight_gradients
52-
) -> tuple[fx.GraphModule, fx.GraphModule]:
194+
bw_gm_old: fx.GraphModule, *, num_weight_gradients
195+
) -> tuple[fx.GraphModule, fx.GraphModule, int]:
53196
# we could consider doing this is a non-mutating way
54-
bw_gm = copy.deepcopy(bw_gm)
197+
bw_gm = copy.deepcopy(bw_gm_old)
198+
placeholders = bw_gm.graph.find_nodes(op="placeholder")
199+
for p in placeholders:
200+
if p.name.startswith("tangent"):
201+
name_suffix = p.name[8:]
202+
rename_placeholder_node(bw_gm, p, f"not_tngnt{name_suffix}")
203+
55204
remove_recompute_tags(bw_gm)
56205
num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients)
57206
bw_gm.recompile()
58207

59-
args = [x.meta["val"] for x in bw_gm.graph.find_nodes(op="placeholder")]
208+
args = list(bw_gm.graph.find_nodes(op="placeholder"))
209+
210+
# bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients)
211+
# return bw_inputs, bw_weights, num_input_gradients
212+
213+
(
214+
grad_inps,
215+
grad_weights,
216+
grad_inp_descs,
217+
grad_weight_descs,
218+
) = _extract_fwd_bwd_outputs(bw_gm, num_fwd_outputs=num_input_gradients)
219+
bw_inputs_gm = _extract_graph_with_inputs_outputs(
220+
bw_gm.graph,
221+
args,
222+
grad_inps,
223+
grad_inp_descs,
224+
"forward",
225+
ignore_must_be_in_fw_bw=True,
226+
)
227+
bw_inputs_gm_node_names = OrderedSet(
228+
node.name for node in bw_inputs_gm.nodes if node.op != "output"
229+
)
230+
saved_values = []
231+
saved_sym_nodes = []
60232

61-
bw_inputs, bw_weights = default_partition(
62-
bw_gm, args, num_fwd_outputs=num_input_gradients
233+
for node in bw_gm.graph.nodes:
234+
if node.name not in bw_inputs_gm_node_names:
235+
# Not handling mutations for now,
236+
# we can try to re-use more of and/or consolidate with default partitioner
237+
continue
238+
if is_sym_node(node):
239+
saved_sym_nodes.append(node)
240+
elif (
241+
"tensor_meta" not in node.meta
242+
and node.op == "call_function"
243+
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
244+
):
245+
users = node.users
246+
assert all(user.target == operator.getitem for user in users)
247+
saved_values.extend(users)
248+
else:
249+
backward_usages = [
250+
n for n in node.users if n.name not in bw_inputs_gm_node_names
251+
]
252+
if "tensor_meta" in node.meta and all(
253+
is_sym_node(n) for n in backward_usages
254+
):
255+
saved_sym_nodes.extend(backward_usages)
256+
else:
257+
saved_values.append(node)
258+
saved_values = list(dict.fromkeys(saved_values).keys())
259+
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
260+
bw_inputs, bw_weights = _extract_fwd_bwd_modules(
261+
bw_gm,
262+
saved_values,
263+
saved_sym_nodes=saved_sym_nodes,
264+
num_fwd_outputs=num_input_gradients,
63265
)
64-
return bw_inputs, bw_weights
266+
return bw_inputs, bw_weights, num_input_gradients

autoparallel/api.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,9 @@ def forward(self, *args):
587587

588588

589589
class AutoParallelPP(AutoParallel):
590-
def apply_placement_pp(self, sharding_placement=None) -> dict[str, Any]:
590+
def apply_placement_pp(
591+
self, sharding_placement=None, generate_di_dw_split_graphs=False
592+
) -> dict[str, Any]:
591593
sharded_param_dict, sharded_buffer_dict = self._apply_placement_common(
592594
sharding_placement
593595
)
@@ -629,19 +631,60 @@ def apply_placement_pp(self, sharding_placement=None) -> dict[str, Any]:
629631
print_output=False, include_stride=True, include_device=True
630632
),
631633
)
634+
if generate_di_dw_split_graphs:
635+
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph
636+
637+
num_weight_gradients = (
638+
self.joint_with_descriptors._aot_state.aot_config.num_params_buffers
639+
)
640+
bw_dI_module, bw_dW_module, num_input_grads = split_di_dw_graph(
641+
bw_module, num_weight_gradients=num_weight_gradients
642+
)
643+
trace_structured(
644+
"artifact",
645+
metadata_fn=lambda: {
646+
"name": "autoparallel_pp_bw_dI_graph",
647+
"encoding": "string",
648+
},
649+
payload_fn=lambda: bw_dI_module.print_readable(
650+
print_output=False, include_stride=True, include_device=True
651+
),
652+
)
653+
trace_structured(
654+
"artifact",
655+
metadata_fn=lambda: {
656+
"name": "autoparallel_pp_bw_dW_graph",
657+
"encoding": "string",
658+
},
659+
payload_fn=lambda: bw_dW_module.print_readable(
660+
print_output=False, include_stride=True, include_device=True
661+
),
662+
)
663+
if all(
664+
x is None
665+
for x in bw_dI_module.graph.find_nodes(op="output")[0].args[0][
666+
:num_input_grads
667+
]
668+
):
669+
raise RuntimeError(
670+
"attempted to run split dI/dW pass on a graph that has no input gradients"
671+
)
672+
else:
673+
bw_dI_module, bw_dW_module, num_input_grads = None, None, -1
632674

633675
graph_meta: dict[str, int] = {
634676
"num_mutate_inputs": num_mutate_inputs,
635677
"num_user_outputs": num_user_outputs,
636678
"num_symints_saved_for_bw": num_symints_saved_for_bw,
637679
"num_weight_buffer_grads": len(sharded_param_dict)
638680
+ len(sharded_buffer_dict),
681+
"num_input_grads": num_input_grads,
639682
}
640683
graph_modules: dict[str, Optional[torch.fx.GraphModule]] = {
641684
"fw": fw_module,
642685
"full_bw": bw_module,
643-
"bw_dI": None,
644-
"bw_dW": None,
686+
"bw_dI": bw_dI_module,
687+
"bw_dW": bw_dW_module,
645688
"unshard": None,
646689
"reduce_grad": None,
647690
}

autoparallel/graph_pp_runner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class GraphMeta:
3737
num_user_outputs: int
3838
num_symints_saved_for_bw: int
3939
num_weight_buffer_grads: int
40+
num_input_grads: int
4041

4142

4243
class GraphPipelineStage(PipelineStage):
@@ -107,6 +108,20 @@ def _run_full_bw_module(
107108
return input_grads, param_buffer_grads
108109

109110

111+
def _run_split_bw_module(
112+
bw_dI_gm: fx.GraphModule, bw_dW_gm: fx.GraphModule, graph_meta: GraphMeta, bw_args
113+
) -> tuple[Any, list[Any]]:
114+
assert len([n for n in bw_dI_gm.graph.nodes if n.op == "placeholder"]) == len(
115+
bw_args
116+
), "Mismatched number of inputs to bwd"
117+
inp_grads_and_activations = torch.fx.Interpreter(bw_dI_gm).boxed_run(bw_args)
118+
inp_grads, activations = inp_grads_and_activations[
119+
: graph_meta.num_input_grads
120+
], list(inp_grads_and_activations[graph_meta.num_input_grads :])
121+
weight_grads = torch.fx.Interpreter(bw_dW_gm).boxed_run(activations)
122+
return inp_grads, weight_grads
123+
124+
110125
def _run_forward_microbatch(stage: GraphPipelineStage, *args) -> tuple[Any, Any]:
111126
fw_args = [
112127
*stage.state["unsharded_params"],

0 commit comments

Comments
 (0)