Skip to content

Commit ff637c1

Browse files
authored
Support fx_traceback.annotate w/ local_map and add tests (#203)
stack-info: PR: #203, branch: xmfan/stack/13
1 parent 5ab4e0b commit ff637c1

File tree

3 files changed

+247
-29
lines changed

3 files changed

+247
-29
lines changed

autoparallel/activation_checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None:
184184
for node in must_save_nodes:
185185
if (
186186
node.meta.get("recompute", None) is not None
187-
and node.meta["ac_graph_id"] != AP_AC_GRAPH_ID
187+
and node.meta.get("ac_graph_id", -1) != AP_AC_GRAPH_ID
188188
):
189189
# Let user annotations take precedence
190190
skipped_nodes[node] = node.meta["recompute"]

examples/example_local_map.py

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import functools
77

88
import torch
9+
import torch.fx.traceback as fx_traceback
910
from torch import nn
1011
from torch.distributed._tensor.experimental import local_map
1112
from torch.distributed.fsdp import MixedPrecisionPolicy
@@ -57,7 +58,8 @@ def policy_fn(ctx, op, *args, **kwargs):
5758
device_mesh=mesh,
5859
)
5960
def replicate_linear(w, x):
60-
return torch.matmul(x, w.t())
61+
with fx_traceback.annotate({"inside_local_map": 1}):
62+
return torch.matmul(x, w.t())
6163

6264

6365
@local_map(
@@ -68,7 +70,8 @@ def replicate_linear(w, x):
6870
device_mesh=mesh,
6971
)
7072
def sharded_pointwise(x):
71-
return x + 10
73+
with fx_traceback.annotate({"inside_local_map": 0}):
74+
return x + 10
7275

7376

7477
@local_map(
@@ -83,10 +86,11 @@ def sharded_pointwise(x):
8386
device_mesh=mesh,
8487
)
8588
def context_parallel_attention(query, key, value):
86-
out = nn.functional.scaled_dot_product_attention(
87-
query=query, key=key, value=value, is_causal=False
88-
)
89-
return out
89+
with fx_traceback.annotate({"inside_local_map": 2}):
90+
out = nn.functional.scaled_dot_product_attention(
91+
query=query, key=key, value=value, is_causal=False
92+
)
93+
return out
9094

9195

9296
class Block(nn.Module):
@@ -108,35 +112,37 @@ def init_weights(self):
108112
torch.nn.init.normal_(lin.bias)
109113

110114
def _compute_attention(self, x):
111-
boosted_weight = sharded_pointwise(self.wq.weight)
112-
q = replicate_linear(boosted_weight, x)
113-
k = self.wk(x)
114-
v = self.wv(x)
115+
with fx_traceback.annotate({"inside_checkpoint": 0}):
116+
boosted_weight = sharded_pointwise(self.wq.weight)
117+
q = replicate_linear(boosted_weight, x)
118+
k = self.wk(x)
119+
v = self.wv(x)
115120

116-
q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
117-
k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
118-
v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
121+
q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
122+
k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
123+
v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
119124

120-
o = context_parallel_attention(q, k, v)
121-
o = o.permute(0, 2, 1, 3).flatten(-2)
125+
o = context_parallel_attention(q, k, v)
126+
o = o.permute(0, 2, 1, 3).flatten(-2)
122127

123-
o = self.wo(o)
124-
return o
128+
o = self.wo(o)
129+
return o
125130

126131
def forward(self, x):
127-
o = torch.utils.checkpoint.checkpoint(
128-
self._compute_attention, x, use_reentrant=False, context_fn=context_fn
129-
)
132+
with fx_traceback.annotate({"outside_checkpoint": 0}):
133+
o = torch.utils.checkpoint.checkpoint(
134+
self._compute_attention, x, use_reentrant=False, context_fn=context_fn
135+
)
130136

131-
o0 = o + x
137+
o0 = o + x
132138

133-
o = self.w1(o0)
134-
o = torch.nn.functional.relu(o)
135-
o = self.w2(o)
139+
o = self.w1(o0)
140+
o = torch.nn.functional.relu(o)
141+
o = self.w2(o)
136142

137-
o = o0 + o
143+
o = o0 + o
138144

139-
return o
145+
return o
140146

141147

142148
bs = 8 * mesh.shape[0]
@@ -160,7 +166,9 @@ def input_fn():
160166
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
161167
# mp_policy = None
162168

163-
with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
169+
with torch.fx.traceback.preserve_node_meta(), AutoParallel(
170+
model, input_fn, mesh, mp_policy, compile=True
171+
) as autop:
164172
assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes)
165173
assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes)
166174
autop.add_parameter_memory_constraint(low=None, high=None)
@@ -208,4 +216,23 @@ def input_fn():
208216
op="call_function", target=torch.ops.aten.mm.default
209217
)
210218

219+
metas = [n.meta.get("custom", None) for n in autop.parallel_gm.graph.nodes]
220+
fwd_sdpa, bwd_sdpa = [
221+
n
222+
for n in autop.parallel_gm.graph.nodes
223+
if "_scaled_dot_product_flash_attention" in n.name
224+
]
225+
# TODO: Dynamo HOP body is not preserving the fx_traceback.annotate
226+
# We should expect to also see the "inside_local_map" annotation
227+
assert fwd_sdpa.meta["custom"] == {
228+
"inside_checkpoint": 0,
229+
"inside_local_map": 2,
230+
"outside_checkpoint": 0,
231+
}
232+
assert bwd_sdpa.meta["custom"] == {
233+
"inside_checkpoint": 0,
234+
"inside_local_map": 2,
235+
"outside_checkpoint": 0,
236+
}
237+
211238
print("All good!")

tests/test_api.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import pytest
77
import torch
8+
import torch.fx.traceback as fx_traceback
89
from torch import nn
9-
from torch.distributed.tensor.placement_types import Shard
10+
from torch.distributed.tensor.placement_types import Replicate, Shard
1011
from torch.testing._internal.distributed.fake_pg import FakeStore
1112

1213
from autoparallel.api import AutoParallel
@@ -114,3 +115,193 @@ def input_fn():
114115
assert torch.equal(
115116
parallel_mod.get_buffer("buf").full_tensor(), torch.arange(dim, device="cuda")
116117
)
118+
119+
120+
def test_fx_graph_annotate(device_mesh_1d):
121+
dim = 128
122+
123+
class Model(nn.Module):
124+
def __init__(self, dim):
125+
super().__init__()
126+
self.a = nn.Linear(dim, dim, bias=False)
127+
self.b = nn.Linear(dim, dim, bias=False)
128+
self.c = nn.Linear(dim, dim, bias=False)
129+
self.d = nn.Linear(dim, dim, bias=False)
130+
131+
def forward(self, x):
132+
with fx_traceback.annotate({"outer": 0}):
133+
with fx_traceback.annotate({"inner": 0}):
134+
a = self.a(x)
135+
with fx_traceback.annotate({"inner": 1}):
136+
b = self.b(a)
137+
with fx_traceback.annotate({"inner": 2}):
138+
c = self.c(b)
139+
with fx_traceback.annotate({"inner": 3}):
140+
d = self.d(c)
141+
return d
142+
143+
def input_fn():
144+
b = 512
145+
inputs = (torch.rand(b, dim, device="cuda"),)
146+
return inputs
147+
148+
with torch.device("meta"):
149+
model = Model(dim)
150+
151+
with fx_traceback.preserve_node_meta(), AutoParallel(
152+
model,
153+
input_fn,
154+
device_mesh_1d,
155+
) as autop:
156+
x_sharding = (Shard(0),)
157+
autop.add_input_constraints([x_sharding])
158+
sharding_placement = autop.optimize_placement()
159+
160+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
161+
_ = autop.apply_placement(sharding_placement)
162+
163+
graph = autop.parallel_gm.graph
164+
165+
# 4 linear -> 4 mm ops
166+
fw_seen_annotations = set()
167+
bw_seen_annotations = set()
168+
for mm in [n for n in graph.nodes if "mm" in n.name]:
169+
assert mm.meta["custom"]["outer"] == 0
170+
assert "inner" in mm.meta["custom"]
171+
if mm.meta.get("partitioner_tag", "") == "is_backward":
172+
bw_seen_annotations.add(mm.meta["custom"]["inner"])
173+
else:
174+
fw_seen_annotations.add(mm.meta["custom"]["inner"])
175+
assert fw_seen_annotations == bw_seen_annotations == {0, 1, 2, 3}
176+
177+
for ph in graph.find_nodes(op="placeholder"):
178+
assert (
179+
"custom" not in ph.meta
180+
), "Placeholders didn't have have custom metadata before"
181+
for out in graph.find_nodes(op="output"):
182+
assert (
183+
"custom" not in out.meta
184+
), "Output didn't have have custom metadata before"
185+
186+
# NOTE: The tests below are just to prevent semantics from changing silently.
187+
# Currently, custom metadata is not set for:
188+
# - graph inputs
189+
# - graph outputs
190+
# - collectives/waits added by AP
191+
for node in graph.nodes:
192+
if node.meta.get("custom", None) is None:
193+
assert (
194+
node.op == "placeholder"
195+
or node.op == "output"
196+
or node.target.namespace == "_c10d_functional"
197+
)
198+
199+
200+
def test_fx_graph_annotate_overlap_pass(device_mesh_1d):
201+
class DummyOp(torch.autograd.Function):
202+
@staticmethod
203+
def forward(ctx, x, scalar):
204+
ctx.save_for_backward(x)
205+
return x + scalar
206+
207+
@staticmethod
208+
def backward(ctx, grad_out):
209+
return grad_out, None
210+
211+
def mock_fw_compute(x):
212+
with fx_traceback.annotate({"compute": 0}):
213+
return DummyOp.apply(x, 10)
214+
215+
def mock_bw_comm(x):
216+
with fx_traceback.annotate({"comm": 0}):
217+
return DummyOp.apply(x, 20)
218+
219+
def mock_bw_compute(x):
220+
return DummyOp.apply(x, 30)
221+
222+
class Model(nn.Module):
223+
def forward(self, fw_in, bw_in):
224+
fw_out = mock_fw_compute(fw_in)
225+
# bw_in blocks bw_out
226+
bw_in = mock_bw_comm(bw_in)
227+
bw_out = mock_bw_compute(bw_in)
228+
return fw_out, bw_out
229+
230+
def input_fn():
231+
inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),)
232+
grad_ins = (torch.rand(2, 128, device="cuda"),)
233+
return (
234+
*inputs,
235+
*grad_ins,
236+
)
237+
238+
with torch.device("meta"):
239+
model = Model()
240+
241+
with fx_traceback.preserve_node_meta(), AutoParallel(
242+
model,
243+
input_fn,
244+
device_mesh_1d,
245+
) as autop:
246+
autop.add_input_constraints(
247+
[
248+
(Replicate(),),
249+
(Replicate(),),
250+
]
251+
)
252+
autop.add_output_constraints(
253+
[
254+
(Replicate(),),
255+
(Replicate(),),
256+
]
257+
)
258+
sharding_placement = autop.optimize_placement()
259+
260+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
261+
_ = autop.apply_placement(sharding_placement)
262+
263+
graph = autop.parallel_gm.graph
264+
265+
# At this point, the graph looks like:
266+
# graph():
267+
# %primals_1 : [num_users=1] = placeholder[target=primals_1]
268+
# %primals_2 : [num_users=1] = placeholder[target=primals_2]
269+
# %tangents_1 : [num_users=1] = placeholder[target=tangents_1]
270+
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
271+
# %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {})
272+
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
273+
# return ((add, add_2), (tangents_1, None))
274+
275+
compute_nodes = {
276+
n for n in graph.nodes if n.meta.get("custom", {}).get("compute", None) == 0
277+
}
278+
comm_nodes = [
279+
n for n in graph.nodes if n.meta.get("custom", {}).get("comm", None) == 0
280+
]
281+
assert len(compute_nodes) == 1
282+
assert len(comm_nodes) == 1
283+
284+
# move comm nodes before compute nodes
285+
first_compute_node = None
286+
for n in graph.nodes:
287+
if n in compute_nodes:
288+
first_compute_node = n
289+
break
290+
291+
assert first_compute_node is not None
292+
for node in reversed(comm_nodes):
293+
first_compute_node.prepend(node)
294+
295+
# After pass, add_1 (comm) should be before add (compute)
296+
node_names = [n.name for n in graph.nodes]
297+
assert node_names.index("add_1") == node_names.index("add") - 1
298+
299+
# The graph looks like:
300+
# graph():
301+
# %primals_1 : [num_users=1] = placeholder[target=primals_1]
302+
# %primals_2 : [num_users=1] = placeholder[target=primals_2]
303+
# %tangents_1 : [num_users=1] = placeholder[target=tangents_1]
304+
# %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {})
305+
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
306+
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
307+
# return ((add, add_2), (tangents_1, None))

0 commit comments

Comments
 (0)