Skip to content

Commit a803a7a

Browse files
committed
Add unit test for existing ac behavior
ghstack-source-id: 6bb849b Pull-Request: #244
1 parent b1c4909 commit a803a7a

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Tests for activation checkpointing functionality.
8+
"""
9+
10+
import pytest
11+
import torch
12+
from torch.utils.checkpoint import CheckpointPolicy
13+
14+
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
15+
from autoparallel.activation_checkpointing import _apply_ac_policy
16+
17+
18+
@pytest.fixture(scope="module")
19+
def llama3_model():
20+
"""Create a small Llama3 model for testing."""
21+
torch.manual_seed(1999)
22+
model_args = TransformerModelArgs(
23+
dim=64, n_layers=2, n_heads=4, vocab_size=256, rope_theta=500000
24+
)
25+
return Transformer(model_args)
26+
27+
28+
def create_joint_graph_from_model(model, input_args):
29+
"""Create a joint graph from a model for testing activation checkpointing functions."""
30+
from torch._subclasses.fake_tensor import FakeTensorMode
31+
from torch.fx.experimental.proxy_tensor import make_fx
32+
33+
def simple_fwd_fn(*inputs):
34+
return model(*inputs)
35+
36+
# Create fake tensor mode with consistent device handling
37+
with FakeTensorMode(allow_non_fake_inputs=True) as fake_mode:
38+
# Create fake inputs that match the input structure
39+
fake_input_args = tuple(fake_mode.from_tensor(arg) for arg in input_args)
40+
41+
# Create a simple forward graph first
42+
fwd_graph = make_fx(simple_fwd_fn)(*fake_input_args)
43+
44+
# Create a mock joint graph with forward and backward sections
45+
joint_graph = torch.fx.Graph()
46+
47+
# Copy forward nodes
48+
value_remap = {}
49+
for node in fwd_graph.graph.nodes:
50+
if node.op == "placeholder":
51+
new_node = joint_graph.placeholder(node.target)
52+
new_node.meta.update(node.meta)
53+
value_remap[node] = new_node
54+
elif node.op == "call_function":
55+
new_args = tuple(value_remap.get(arg, arg) for arg in node.args)
56+
new_node = joint_graph.call_function(node.target, new_args, node.kwargs)
57+
new_node.meta.update(node.meta)
58+
value_remap[node] = new_node
59+
elif node.op == "output":
60+
# Add backward nodes just manually for testing purpose(marked as backward)
61+
output_node = value_remap[node.args[0]]
62+
63+
# Add a sum operation for loss
64+
sum_node = joint_graph.call_function(
65+
torch.ops.aten.sum.default, (output_node,)
66+
)
67+
sum_node.meta["val"] = torch.tensor(1.0)
68+
69+
# Add backward nodes
70+
bw_node = joint_graph.call_function(
71+
torch.ops.aten.mul.Tensor, (sum_node, 1.0)
72+
)
73+
bw_node.meta["partitioner_tag"] = "is_backward"
74+
bw_node.meta["val"] = torch.tensor(1.0)
75+
76+
# Add tangent placeholder
77+
tangent_node = joint_graph.placeholder("tangents_1")
78+
tangent_node.meta["val"] = output_node.meta.get(
79+
"val", torch.randn(2, 8, 64)
80+
)
81+
82+
# Create output
83+
joint_graph.output([output_node, bw_node])
84+
break
85+
86+
return joint_graph
87+
88+
89+
def create_joint_graph_llama3(llama3_model):
90+
"""Create a joint graph from Llama3 model."""
91+
batch_size = 2
92+
seq_len = 8
93+
vocab_size = llama3_model.model_args.vocab_size
94+
95+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
96+
return create_joint_graph_from_model(llama3_model, (input_ids,))
97+
98+
99+
class TestACPolicy:
100+
"""Test AC policy application."""
101+
102+
def test_apply_ac_policy(self, llama3_model):
103+
"""Test _apply_ac_policy function."""
104+
graph = create_joint_graph_llama3(llama3_model)
105+
106+
# Define save list with operations that might be in the graph
107+
save_list = {
108+
torch.ops.aten.mm.default,
109+
torch.ops.aten.addmm.default,
110+
}
111+
112+
_apply_ac_policy(graph, save_list)
113+
114+
marked_nodes_to_save = [
115+
node
116+
for node in graph.nodes
117+
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE
118+
]
119+
120+
# Count total mm.default nodes in the graph to verify every-other-node policy
121+
total_mm_nodes = len(
122+
[node for node in graph.nodes if node.target == torch.ops.aten.mm.default]
123+
)
124+
125+
# The policy should save every other mm.default node
126+
expected_saved_nodes = (
127+
total_mm_nodes + 1
128+
) // 2 # ceiling division for odd counts
129+
130+
# Verify the every-other-node policy is working correctly
131+
assert (
132+
len(marked_nodes_to_save) == expected_saved_nodes
133+
), f"Expected {expected_saved_nodes} nodes to be saved, but got {len(marked_nodes_to_save)}"

0 commit comments

Comments
 (0)