Skip to content

Commit bf4c477

Browse files
kylesayrsdsikka
andauthored
[Autowrapper] Support Gemma3n, autowrapper improvements (#1693)
## Purpose ## * Support Gemma3 and models which have similar implementations in their model definitions * Resolves #1711 * Add better debugging messages when failures happen inside autowrapped code ``` File "/home/kylesayrs/llm-compressor/src/llmcompressor/pipelines/sequential/ast_helpers.py", line 105, in append_autowrap_source_on_fail raise RuntimeError(message) from exception RuntimeError: name 'explode_the_bomb' is not defined --- <Autowrapped LlamaModel 140503462592384>:43 --- ... @torch.fx.wrap def wrapped_1(input_ids, inputs_embeds): if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) return (inputs_embeds,) @torch.fx.wrap def wrapped_0(input_ids, inputs_embeds): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError('You must specify exactly one of input_ids or inputs_embeds') return () def forward(self, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, position_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Cache]=None, inputs_embeds: Optional[torch.FloatTensor]=None, cache_position: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, **kwargs: Unpack[TransformersKwargs]) -> BaseModelOutputWithPast: () = wrapped_0(input_ids, inputs_embeds) inputs_embeds, = wrapped_1(input_ids, inputs_embeds) past_key_values, = wrapped_2(past_key_values, use_cache) cache_position, past_seen_tokens = wrapped_3(cache_position, inputs_embeds, past_key_values) position_ids, = wrapped_4(cache_position, position_ids) > explode_the_bomb() causal_mask = wrapped_5(attention_mask, cache_position, inputs_embeds, past_key_values, position_ids) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[:self.config.num_hidden_layers]: hidden_states = decoder_layer(hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values) ``` ## Changes ## * Handle source code which uses the walrus operator within if statements * Do not attempt to statically evaluate conditions containing walrus operators, as this is the only case where an if statement has a necessary side-effect (the declaration of a new variable) * Account for walrus operators when tracking names, similar to #1616 * Previously, all function decorators were removed from definitions. However, HF uses some function wrappers which have very subtle effects on kwarg passing such as `can_return_tuple`. These decorators must be kept * Improve debugging for autowrapped code * Fix tracebacks, now failures within autowrapped code will include which line they failed on * Add util function `append_autowrap_source_on_fail` which appends the entire autowrapped source upon fail * Add `project_per_layer_inputs` to list of ignored functions for tracing ## Testing ## * Added gemma3n text and vision tests * Added walrus operator test to autowrapping tests --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 2507b87 commit bf4c477

File tree

8 files changed

+112
-33
lines changed

8 files changed

+112
-33
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class DatasetArguments(CustomDatasetArguments):
204204
"_prepare_fsmt_decoder_inputs",
205205
"_prepare_4d_causal_attention_mask_with_cache_position",
206206
"_update_linear_attn_mask",
207+
"project_per_layer_inputs",
207208
],
208209
metadata={
209210
"help": "List of functions to ignore during tracing, either "

src/llmcompressor/args/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def parse_args(
7474

7575
# raise depreciation warnings
7676
if dataset_args.remove_columns is not None:
77-
logger.warn(
77+
logger.warning(
7878
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
7979
"columns which are invalid inputs the tokenizer will be removed",
8080
DeprecationWarning,

src/llmcompressor/pipelines/sequential/ast_helpers.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import linecache
55
import sys
66
import textwrap
7+
import traceback
78
from typing import List
89

910
import torch
1011

1112
from llmcompressor.pipelines.sequential.ast_utils.auto_wrapper import AutoWrapper
1213
from llmcompressor.utils import patch_attr
1314

14-
__all__ = ["autowrap_forwards"]
15+
__all__ = ["autowrap_forwards", "append_autowrap_source_on_fail"]
1516

1617

1718
@contextlib.contextmanager
@@ -58,22 +59,49 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
5859
# autowrap untraceable code
5960
auto_wrapper = AutoWrapper(namespace, ignore)
6061
tree = auto_wrapper.auto_wrap(tree)
62+
source = ast.unparse(tree)
6163

6264
# compile new forward function from autowrapped code
63-
filename = f"{module.__class__.__name__}_{hash(module)}_autowrapped"
64-
code = compile(tree, filename=filename, mode="exec")
65+
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
66+
code = compile(source, filename=filename, mode="exec")
6567
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap
6668

6769
# enable better tracebacks if autowrapped code fails
68-
source_str = ast.unparse(tree)
6970
linecache.cache[filename] = (
70-
len(source_str),
71+
len(source),
7172
None,
72-
[line + "\n" for line in source_str.splitlines()],
73+
[line + "\n" for line in source.splitlines()],
7374
filename,
7475
)
7576

7677
# patch forward with autowrapped forward
7778
new_forward = namespace["forward"].__get__(module)
7879
with patch_attr(module, "forward", new_forward):
7980
yield
81+
82+
83+
@contextlib.contextmanager
84+
def append_autowrap_source_on_fail():
85+
try:
86+
yield
87+
except Exception as exception:
88+
_exc_type, _exc_value, exc_tb = sys.exc_info()
89+
tb_list = traceback.extract_tb(exc_tb)
90+
91+
for frame in reversed(tb_list):
92+
if "Autowrapped" in frame.filename:
93+
source_lines = linecache.getlines(frame.filename)
94+
lineno = frame.lineno
95+
96+
# annotate failing line
97+
source_lines = [
98+
("> " if i + 1 == lineno else " ") + line
99+
for i, line in enumerate(source_lines)
100+
]
101+
102+
message = f"{exception}\n\n"
103+
message += f"\n--- {frame.filename}:{lineno} ---\n"
104+
message += "".join(source_lines)
105+
raise RuntimeError(message) from exception
106+
107+
raise exception

src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
5353
:param node: function definition whose decorators will be stripped
5454
:return: function definition without decorators
5555
"""
56-
node.decorator_list = []
56+
node.decorator_list = [
57+
decorator_name
58+
for decorator_name in node.decorator_list
59+
if isinstance(decorator_name, ast.Name)
60+
and decorator_name.id in ("can_return_tuple",) # modifies func signature
61+
]
62+
5763
if node.name == "forward":
5864
for arg in node.args.args:
5965
self._local_names.add(arg.arg)
@@ -104,6 +110,11 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]:
104110
try:
105111
value = bool(self._eval_expr(node.test))
106112

113+
# force a wrap if any assignments occur within the if statement
114+
for expr in ast.walk(node):
115+
if isinstance(expr, ast.NamedExpr):
116+
raise Exception("If statement contains assignment")
117+
107118
except Exception:
108119
return self._wrap_if_possible(node)
109120

@@ -165,8 +176,7 @@ def _can_wrap(self, node: ast.AST) -> bool:
165176
without its original context. In the future, we can add more checks for module
166177
calls (see `visit_If`)
167178
"""
168-
analyzer = ControlFlowAnalyzer()
169-
return analyzer.is_valid(node)
179+
return ControlFlowAnalyzer().is_valid(node)
170180

171181
def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Call]:
172182
"""

src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def visit_Assign(self, node: ast.Assign):
7474
for target in node.targets:
7575
self.visit(target)
7676

77+
def visit_NamedExpr(self, node: ast.NamedExpr):
78+
# Visit the right side of the assignment first
79+
self.visit(node.value)
80+
81+
# Now visit the left side of the assignment
82+
self.visit(node.target)
83+
7784
def visit_If(self, node: ast.If):
7885
self.visit(node.test)
7986

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
from collections import deque
44
from dataclasses import dataclass
5+
from types import FunctionType, MethodType
56
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
67

78
import torch
@@ -26,7 +27,7 @@
2627
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
2728
from llmcompressor.utils.pytorch.module import get_no_split_params
2829

29-
from .ast_helpers import autowrap_forwards
30+
from .ast_helpers import append_autowrap_source_on_fail, autowrap_forwards
3031

3132
if TYPE_CHECKING:
3233
from llmcompressor.args.dataset_arguments import DatasetArguments
@@ -69,15 +70,8 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]:
6970

7071
forward_fn = self._code.globals.get("forward")
7172

72-
try:
73-
outputs = forward_fn(*args, **kwargs)
74-
except Exception as exception:
75-
raise RuntimeError(
76-
"Raised an exception during execution of the following code:\n"
77-
f"```\n{add_line_numbers(self._code.src)}\n```"
78-
) from exception
79-
80-
return outputs
73+
with append_autowrap_source_on_fail():
74+
return forward_fn(*args, **kwargs)
8175

8276
def submodules(self, model: Module, recurse: bool = False) -> Set[Module]:
8377
nodes = self.graph.find_nodes(op="call_module")
@@ -126,19 +120,26 @@ def trace_subgraphs(
126120

127121
# autowrap forwards
128122
stack.enter_context(autowrap_forwards(ancestors, ignore))
129-
stack.enter_context(patch_attr(type(model), "forward", model.forward.__func__))
130123

131-
graph = GraphModule(
132-
model,
133-
tracer.trace(
124+
# avoid bug where pytorch cannot handle wrapped root functions
125+
unwrapped = inspect.unwrap(model.forward).__get__(model)
126+
stack.enter_context(patch_attr(model, "forward", unwrapped))
127+
stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__))
128+
assert isinstance(model.forward, MethodType)
129+
assert isinstance(type(model).forward, FunctionType)
130+
131+
with append_autowrap_source_on_fail():
132+
graph = GraphModule(
134133
model,
135-
dummy_inputs=sample_input,
136-
concrete_args=concrete_args,
137-
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
138-
# bug in trace throws an error for variadic
139-
# args and kwargs in function signature
140-
),
141-
)
134+
tracer.trace(
135+
model,
136+
dummy_inputs=sample_input,
137+
concrete_args=concrete_args,
138+
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
139+
# bug in trace throws an error for variadic
140+
# args and kwargs in function signature
141+
),
142+
)
142143

143144
# copy metadata
144145
graph.config = model.config

tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# flake8: noqa
12
import ast
23
import textwrap
34
from types import SimpleNamespace
@@ -21,13 +22,14 @@ def check_wrapping(
2122

2223
wrapped_lines = ast.unparse(wrapped).splitlines()
2324
output_lines = textwrap.dedent(output).splitlines()[1:]
25+
lines = ("\n".join(wrapped_lines), "\n".join(output_lines))
2426

25-
assert len(wrapped_lines) == len(output_lines)
27+
assert len(wrapped_lines) == len(output_lines), lines
2628
for wrapped_line, output_line in zip(wrapped_lines, output_lines):
2729
if "# skip" in output:
2830
continue
2931

30-
assert wrapped_line == output_line
32+
assert wrapped_line == output_line, lines
3133

3234

3335
def test_static_if():
@@ -189,3 +191,24 @@ def forward(a, *b, c=5, **d):
189191
() = wrapped_0(a, b, c, d)
190192
"""
191193
check_wrapping(source, output)
194+
195+
196+
def test_walrus():
197+
"""Checks for handling variadic names created via function def"""
198+
199+
source = """
200+
def forward():
201+
if (x := (1 + 2)):
202+
pass
203+
"""
204+
output = """
205+
@torch.fx.wrap
206+
def wrapped_0():
207+
if (x := (1 + 2)):
208+
pass
209+
return (x,)
210+
211+
def forward():
212+
(x,) = wrapped_0() # skip: some envs use "(x,)" -> "x,"
213+
"""
214+
check_wrapping(source, output)

tests/llmcompressor/transformers/tracing/test_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from transformers import (
55
AutoModelForCausalLM,
66
Gemma3ForConditionalGeneration,
7+
Gemma3nForConditionalGeneration,
78
Idefics3ForConditionalGeneration,
89
Llama4ForConditionalGeneration,
910
LlavaForConditionalGeneration,
@@ -49,6 +50,7 @@
4950
"text",
5051
[],
5152
),
53+
("google/gemma-3n-E2B-it", AutoModelForCausalLM, None, "text", ["timm"]),
5254
("unsloth/DeepSeek-R1-0528-BF16", AutoModelForCausalLM, None, "text", []),
5355
# --- vision ---
5456
(
@@ -122,6 +124,13 @@
122124
"vision",
123125
[],
124126
),
127+
(
128+
"google/gemma-3n-E2B-it",
129+
Gemma3nForConditionalGeneration,
130+
None,
131+
"vision",
132+
["timm"],
133+
),
125134
# --- audio ---
126135
(
127136
"openai/whisper-large-v3",

0 commit comments

Comments
 (0)