44# LICENSE file in the root directory of this source tree.
55
66
7+ import functools
8+
79import torch
810from torch import nn
911from torch .distributed .fsdp import MixedPrecisionPolicy
1012from torch .distributed .tensor .placement_types import Replicate , Shard
1113from torch .testing ._internal .distributed .fake_pg import FakeStore
14+ from torch .utils .checkpoint import create_selective_checkpoint_contexts
1215
1316from autoparallel .api import AutoParallel
1417
1518
19+ def policy_fn (ctx , op , * args , ** kwargs ):
20+ if (
21+ op == torch .ops .aten ._scaled_dot_product_flash_attention .default
22+ or op == torch .ops .aten ._scaled_dot_product_efficient_attention .default
23+ ):
24+ # NOTE: we can't save nondeterministic_seeded ops, the run with rng wrapper is not traceable yet
25+ return torch .utils .checkpoint .CheckpointPolicy .PREFER_SAVE
26+ return torch .utils .checkpoint .CheckpointPolicy .PREFER_RECOMPUTE
27+
28+
29+ context_fn = functools .partial (create_selective_checkpoint_contexts , policy_fn )
30+
31+
1632class Block (nn .Module ):
1733 def __init__ (self , nheads , dim1 , dim2 ):
1834 super ().__init__ ()
@@ -48,7 +64,7 @@ def _compute_attention(self, x):
4864
4965 def forward (self , x ):
5066 o = torch .utils .checkpoint .checkpoint (
51- self ._compute_attention , x , use_reentrant = False
67+ self ._compute_attention , x , use_reentrant = False , context_fn = context_fn
5268 )
5369
5470 o0 = o + x
@@ -103,7 +119,6 @@ def input_fn():
103119
104120mp_policy = MixedPrecisionPolicy (param_dtype = torch .bfloat16 , reduce_dtype = torch .float32 )
105121# mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
106- # mp_policy = None
107122
108123with AutoParallel (model , input_fn , mesh , mp_policy , compile = True ) as autop :
109124 assert any (n .meta .get ("nn_module_stack" ) for n in autop .gm .graph .nodes )
@@ -128,16 +143,34 @@ def input_fn():
128143out = parallel_mod (* x )
129144out .backward (torch .randn_like (out ))
130145
131- print ("All good!" )
146+ # Validate
147+ seqs = set ()
148+ for n in autop .gm .graph .nodes :
149+ if "checkpoint" in n .meta .get (
150+ "stack_trace" , ""
151+ ): # placeholders don't have stack trace
152+ is_bwd = n .meta .get ("partitioner_tag" , "" ) == "is_backward"
153+ if not is_bwd :
154+ if "getitem" in str (n .target ):
155+ # getitem nodes are tagged same as their parent
156+ expected = policy_fn (None , n .args [0 ].target , (), ())
157+ else :
158+ expected = policy_fn (None , n .target , (), ())
159+ actual = n .meta .get ("recompute" )
160+ # NOTE: this assert only supports policy_fns on op alone
161+ assert actual == expected
162+ seqs .add (n .meta ["seq_nr" ])
163+ else :
164+ # fwd counterpart should have already populated seqs
165+ assert n .meta ["seq_nr" ] in seqs
132166
133167mm_nodes = autop .gm .graph .find_nodes (
134168 op = "call_function" , target = torch .ops .aten .mm .default
135169)
136170
137- # assert (
138- # mm_nodes[0].meta.get("recompute")
139- # == torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
140- # )
171+ assert (
172+ mm_nodes [0 ].meta .get ("recompute" )
173+ == torch .utils .checkpoint .CheckpointPolicy .PREFER_RECOMPUTE
174+ )
141175
142- # TODO: change this assert once we fix AC
143- assert mm_nodes [0 ].meta .get ("recompute" ) is None
176+ print ("All good!" )
0 commit comments