Skip to content

Commit bf39515

Browse files
authored
Migrate to strict mode export (dynamo) to support AC tags and HOPs (#93)
* Directly use _export_to_torch_ir strict_mode to support AC tags * clean * rebase * lint * works with pytorch wip prs * monkey patch strict mode verifier to accept dtype cast ops * mp fp32 * comments on verifier patch
1 parent 2cbfdc4 commit bf39515

File tree

2 files changed

+84
-11
lines changed

2 files changed

+84
-11
lines changed

autoparallel/api.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import copy
77
import itertools
8-
from contextlib import ExitStack
8+
import warnings
9+
from contextlib import ExitStack, contextmanager
910
from types import MethodType
1011
from typing import Optional, Union
1112

@@ -106,6 +107,42 @@ def _move_to_fake(module, k, device, parameter=True):
106107
return model
107108

108109

110+
# Export runs some asserts on the exported program to ensure that it is serializable,
111+
# and some safety checks e.g. whether the graph metadata is consistent with what's been traced.
112+
#
113+
# In autoparallel, we don't care about the serializability of this initial
114+
# trace, but we do want those same safety checks. In the short term, we
115+
# can patch the verification logic.
116+
@contextmanager
117+
def monkey_patch_export_verifier():
118+
from torch._export.verifier import SpecViolationError, Verifier, final
119+
120+
prior = Verifier._check_graph_module
121+
122+
def expected_error(e: Exception):
123+
okay = ["Operator 'autoparallel.dtype_cast' is not an allowed operator type"]
124+
e_str = str(e)
125+
for msg in okay:
126+
if msg in e_str:
127+
return True
128+
return False
129+
130+
@final
131+
def _try_check_graph_module(self: Verifier, gm: torch.fx.GraphModule) -> None:
132+
try:
133+
return prior(self, gm)
134+
except SpecViolationError as e:
135+
if not expected_error(e):
136+
raise
137+
warnings.warn(f"Ignoring strict-mode export verifier error: {e}")
138+
139+
try:
140+
Verifier._check_graph_module = _try_check_graph_module
141+
yield
142+
finally:
143+
Verifier._check_graph_module = prior
144+
145+
109146
class AutoParallel:
110147
"""
111148
Args:
@@ -220,7 +257,10 @@ def build_model_graph(self):
220257
inputs = (inputs,)
221258

222259
with set_dtype_cast(True):
223-
ep = torch.export.export(self.model, inputs)
260+
with torch._dynamo.config.patch(
261+
install_free_tensors=True
262+
), monkey_patch_export_verifier():
263+
ep = torch.export.export(self.model, inputs, strict=True)
224264
self.joint_with_descriptors = aot_export_joint_with_descriptors(
225265
self.stack,
226266
ep.module(),

examples/example_autoparallel.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,31 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
import functools
8+
79
import torch
810
from torch import nn
911
from torch.distributed.fsdp import MixedPrecisionPolicy
1012
from torch.distributed.tensor.placement_types import Replicate, Shard
1113
from torch.testing._internal.distributed.fake_pg import FakeStore
14+
from torch.utils.checkpoint import create_selective_checkpoint_contexts
1215

1316
from autoparallel.api import AutoParallel
1417

1518

19+
def policy_fn(ctx, op, *args, **kwargs):
20+
if (
21+
op == torch.ops.aten._scaled_dot_product_flash_attention.default
22+
or op == torch.ops.aten._scaled_dot_product_efficient_attention.default
23+
):
24+
# NOTE: we can't save nondeterministic_seeded ops, the run with rng wrapper is not traceable yet
25+
return torch.utils.checkpoint.CheckpointPolicy.PREFER_SAVE
26+
return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
27+
28+
29+
context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
30+
31+
1632
class Block(nn.Module):
1733
def __init__(self, nheads, dim1, dim2):
1834
super().__init__()
@@ -48,7 +64,7 @@ def _compute_attention(self, x):
4864

4965
def forward(self, x):
5066
o = torch.utils.checkpoint.checkpoint(
51-
self._compute_attention, x, use_reentrant=False
67+
self._compute_attention, x, use_reentrant=False, context_fn=context_fn
5268
)
5369

5470
o0 = o + x
@@ -103,7 +119,6 @@ def input_fn():
103119

104120
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
105121
# mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
106-
# mp_policy = None
107122

108123
with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
109124
assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes)
@@ -128,16 +143,34 @@ def input_fn():
128143
out = parallel_mod(*x)
129144
out.backward(torch.randn_like(out))
130145

131-
print("All good!")
146+
# Validate
147+
seqs = set()
148+
for n in autop.gm.graph.nodes:
149+
if "checkpoint" in n.meta.get(
150+
"stack_trace", ""
151+
): # placeholders don't have stack trace
152+
is_bwd = n.meta.get("partitioner_tag", "") == "is_backward"
153+
if not is_bwd:
154+
if "getitem" in str(n.target):
155+
# getitem nodes are tagged same as their parent
156+
expected = policy_fn(None, n.args[0].target, (), ())
157+
else:
158+
expected = policy_fn(None, n.target, (), ())
159+
actual = n.meta.get("recompute")
160+
# NOTE: this assert only supports policy_fns on op alone
161+
assert actual == expected
162+
seqs.add(n.meta["seq_nr"])
163+
else:
164+
# fwd counterpart should have already populated seqs
165+
assert n.meta["seq_nr"] in seqs
132166

133167
mm_nodes = autop.gm.graph.find_nodes(
134168
op="call_function", target=torch.ops.aten.mm.default
135169
)
136170

137-
# assert (
138-
# mm_nodes[0].meta.get("recompute")
139-
# == torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
140-
# )
171+
assert (
172+
mm_nodes[0].meta.get("recompute")
173+
== torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
174+
)
141175

142-
# TODO: change this assert once we fix AC
143-
assert mm_nodes[0].meta.get("recompute") is None
176+
print("All good!")

0 commit comments

Comments
 (0)