Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def _clear_types_and_shapes_recursive(
"""Recursively clear type/shape information for a graph and all its subgraphs.

This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph).
For main graph, clear all value_info.

Args:
graph: The ONNX graph to clear types and shapes for.
Expand All @@ -303,13 +305,27 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
if d.dim_value:
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"

# Clear type/shape information for intermediates and outputs
for vi in g.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
if is_sub:
# Identify which tensors are produced by nodes in this subgraph
subgraph_outputs = set()
for node in g.node:
subgraph_outputs.update(node.output)

# Clear value_info only for intermediates produced by nodes in this subgraph
for vi in g.value_info:
if vi.name in subgraph_outputs:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
else:
for vi in g.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"

# Clear outputs for both main graph and subgraphs
for out in g.output:
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(out.type.tensor_type.shape.dim):
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/onnx/autocast/test_precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,3 +1520,79 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis
# Verify a cast was inserted between If output and Add input
cast_nodes = [n for n in converted_model.graph.node if n.op_type == "Cast"]
assert len(cast_nodes) > 0, "Should have cast nodes for mixed precision"


@pytest.fixture
def model_with_if_outer_scope_reference():
"""Create a minimal model where If subgraphs reference outer scope variables.

This tests that subgraph value_info for outer scope variables is preserved during type clearing.
"""
# Main graph inputs/outputs
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])
condition = helper.make_tensor_value_info("condition", TensorProto.BOOL, [])
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])

# Create "then" branch: Identity on X from outer scope
then_y = helper.make_tensor_value_info("then_y", TensorProto.FLOAT, [2, 4])
then_identity = helper.make_node("Identity", ["X"], ["then_y"], name="then_identity")
then_graph = helper.make_graph([then_identity], "then_branch", [], [then_y])
# Add X to value_info - this is what needs to be preserved
then_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])])

# Create "else" branch: Identity on X from outer scope
else_y = helper.make_tensor_value_info("else_y", TensorProto.FLOAT, [2, 4])
else_identity = helper.make_node("Identity", ["X"], ["else_y"], name="else_identity")
else_graph = helper.make_graph([else_identity], "else_branch", [], [else_y])
else_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])])

# Create If node and main graph
if_node = helper.make_node(
"If", ["condition"], ["Y"], name="if_node", then_branch=then_graph, else_branch=else_graph
)
main_graph = helper.make_graph([if_node], "model_with_outer_scope", [x, condition], [y])

model = helper.make_model(main_graph, producer_name="model_with_outer_scope")
model.opset_import[0].version = 20
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
return model, value_info_map, initializer_map, node_to_init_map


@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
def test_if_subgraph_outer_scope_type_preservation(
model_with_if_outer_scope_reference, low_precision_type
):
"""Test that outer scope variable types are preserved in If subgraphs during conversion.

Without preserving X's value_info in subgraphs, shape inference fails with
"Element type of input 0 unknown".
"""
model, value_info_map, initializer_map, node_to_init_map = model_with_if_outer_scope_reference

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=True,
low_precision_type=low_precision_type,
)

converted_model = converter.convert(high_precision_nodes=["if_node"], low_precision_nodes=[])
onnx.checker.check_model(converted_model)

# Verify X's value_info is preserved in both subgraphs
if_node = next(n for n in converted_model.graph.node if n.op_type == "If")
then_branch = next(attr.g for attr in if_node.attribute if attr.name == "then_branch")
else_branch = next(attr.g for attr in if_node.attribute if attr.name == "else_branch")

then_x_info = [vi for vi in then_branch.value_info if vi.name == "X"]
else_x_info = [vi for vi in else_branch.value_info if vi.name == "X"]

assert len(then_x_info) > 0, "X value_info should be preserved in then branch"
assert len(else_x_info) > 0, "X value_info should be preserved in else branch"
assert then_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED
assert else_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED