-
Notifications
You must be signed in to change notification settings - Fork 9
Add unit test for existing ac API behavior #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ailzhang
wants to merge
1
commit into
gh/ailzhang/3/base
Choose a base branch
from
gh/ailzhang/3/head
base: gh/ailzhang/3/base
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Tests for activation checkpointing functionality. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch.utils.checkpoint import CheckpointPolicy | ||
|
|
||
| from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs | ||
| from autoparallel.activation_checkpointing import _apply_ac_policy | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def llama3_model(): | ||
| """Create a small Llama3 model for testing.""" | ||
| torch.manual_seed(1999) | ||
| model_args = TransformerModelArgs( | ||
| dim=64, n_layers=2, n_heads=4, vocab_size=256, rope_theta=500000 | ||
| ) | ||
| return Transformer(model_args) | ||
|
|
||
|
|
||
| def create_joint_graph_from_model(model, input_args): | ||
| """Create a joint graph from a model for testing activation checkpointing functions.""" | ||
| from torch._subclasses.fake_tensor import FakeTensorMode | ||
| from torch.fx.experimental.proxy_tensor import make_fx | ||
|
|
||
| def simple_fwd_fn(*inputs): | ||
| return model(*inputs) | ||
|
|
||
| # Create fake tensor mode with consistent device handling | ||
| with FakeTensorMode(allow_non_fake_inputs=True) as fake_mode: | ||
| # Create fake inputs that match the input structure | ||
| fake_input_args = tuple(fake_mode.from_tensor(arg) for arg in input_args) | ||
|
|
||
| # Create a simple forward graph first | ||
| fwd_graph = make_fx(simple_fwd_fn)(*fake_input_args) | ||
|
|
||
| # Create a mock joint graph with forward and backward sections | ||
| joint_graph = torch.fx.Graph() | ||
|
|
||
| # Copy forward nodes | ||
| value_remap = {} | ||
| for node in fwd_graph.graph.nodes: | ||
| if node.op == "placeholder": | ||
| new_node = joint_graph.placeholder(node.target) | ||
| new_node.meta.update(node.meta) | ||
| value_remap[node] = new_node | ||
| elif node.op == "call_function": | ||
| new_args = tuple(value_remap.get(arg, arg) for arg in node.args) | ||
| new_node = joint_graph.call_function(node.target, new_args, node.kwargs) | ||
| new_node.meta.update(node.meta) | ||
| value_remap[node] = new_node | ||
| elif node.op == "output": | ||
| # Add backward nodes just manually for testing purpose(marked as backward) | ||
| output_node = value_remap[node.args[0]] | ||
|
|
||
| # Add a sum operation for loss | ||
| sum_node = joint_graph.call_function( | ||
| torch.ops.aten.sum.default, (output_node,) | ||
| ) | ||
| sum_node.meta["val"] = torch.tensor(1.0) | ||
|
|
||
| # Add backward nodes | ||
| bw_node = joint_graph.call_function( | ||
| torch.ops.aten.mul.Tensor, (sum_node, 1.0) | ||
| ) | ||
| bw_node.meta["partitioner_tag"] = "is_backward" | ||
| bw_node.meta["val"] = torch.tensor(1.0) | ||
|
|
||
| # Add tangent placeholder | ||
| tangent_node = joint_graph.placeholder("tangents_1") | ||
| tangent_node.meta["val"] = output_node.meta.get( | ||
| "val", torch.randn(2, 8, 64) | ||
| ) | ||
|
|
||
| # Create output | ||
| joint_graph.output([output_node, bw_node]) | ||
| break | ||
|
|
||
| return joint_graph | ||
|
|
||
|
|
||
| def create_joint_graph_llama3(llama3_model): | ||
| """Create a joint graph from Llama3 model.""" | ||
| batch_size = 2 | ||
| seq_len = 8 | ||
| vocab_size = llama3_model.model_args.vocab_size | ||
|
|
||
| input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long) | ||
| return create_joint_graph_from_model(llama3_model, (input_ids,)) | ||
|
|
||
|
|
||
| class TestACPolicy: | ||
| """Test AC policy application.""" | ||
|
|
||
| def test_apply_ac_policy(self, llama3_model): | ||
| """Test _apply_ac_policy function.""" | ||
| graph = create_joint_graph_llama3(llama3_model) | ||
|
|
||
| # Define save list with operations that might be in the graph | ||
| save_list = { | ||
| torch.ops.aten.mm.default, | ||
| torch.ops.aten.addmm.default, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this always gets decomposed away. and if it doesn't, it will mess up your saved node count lol |
||
| } | ||
|
|
||
| _apply_ac_policy(graph, save_list) | ||
|
|
||
| marked_nodes_to_save = [ | ||
| node | ||
| for node in graph.nodes | ||
| if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE | ||
| ] | ||
|
|
||
| # Count total mm.default nodes in the graph to verify every-other-node policy | ||
| total_mm_nodes = len( | ||
| [node for node in graph.nodes if node.target == torch.ops.aten.mm.default] | ||
| ) | ||
|
|
||
| # The policy should save every other mm.default node | ||
| expected_saved_nodes = ( | ||
| total_mm_nodes + 1 | ||
| ) // 2 # ceiling division for odd counts | ||
|
|
||
| # Verify the every-other-node policy is working correctly | ||
| assert ( | ||
| len(marked_nodes_to_save) == expected_saved_nodes | ||
| ), f"Expected {expected_saved_nodes} nodes to be saved, but got {len(marked_nodes_to_save)}" | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this implementation is strange, any issues just using the same joint graph capture frontend as the rest of the repo?
autoparallel/autoparallel/api.py
Lines 301 to 310 in b1c4909