diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 1e22504c6..d708c890f 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -283,6 +283,8 @@ def _clear_types_and_shapes_recursive( """Recursively clear type/shape information for a graph and all its subgraphs. This is necessary for control flow operators (Scan, If, Loop) which have subgraphs. + For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph). + For main graph, clear all value_info. Args: graph: The ONNX graph to clear types and shapes for. @@ -303,13 +305,27 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> if d.dim_value: inp.type.tensor_type.shape.dim[idx].dim_param = "unk" - # Clear type/shape information for intermediates and outputs - for vi in g.value_info: - vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(vi.type.tensor_type.shape.dim): - if d.dim_value: - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + if is_sub: + # Identify which tensors are produced by nodes in this subgraph + subgraph_outputs = set() + for node in g.node: + subgraph_outputs.update(node.output) + + # Clear value_info only for intermediates produced by nodes in this subgraph + for vi in g.value_info: + if vi.name in subgraph_outputs: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + for idx, d in enumerate(vi.type.tensor_type.shape.dim): + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + else: + for vi in g.value_info: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + for idx, d in enumerate(vi.type.tensor_type.shape.dim): + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + # Clear outputs for both main graph and subgraphs for out in g.output: out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED for idx, d in enumerate(out.type.tensor_type.shape.dim): diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index bf87d8058..4fb02a230 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1520,3 +1520,79 @@ def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precis # Verify a cast was inserted between If output and Add input cast_nodes = [n for n in converted_model.graph.node if n.op_type == "Cast"] assert len(cast_nodes) > 0, "Should have cast nodes for mixed precision" + + +@pytest.fixture +def model_with_if_outer_scope_reference(): + """Create a minimal model where If subgraphs reference outer scope variables. + + This tests that subgraph value_info for outer scope variables is preserved during type clearing. + """ + # Main graph inputs/outputs + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + condition = helper.make_tensor_value_info("condition", TensorProto.BOOL, []) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + # Create "then" branch: Identity on X from outer scope + then_y = helper.make_tensor_value_info("then_y", TensorProto.FLOAT, [2, 4]) + then_identity = helper.make_node("Identity", ["X"], ["then_y"], name="then_identity") + then_graph = helper.make_graph([then_identity], "then_branch", [], [then_y]) + # Add X to value_info - this is what needs to be preserved + then_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])]) + + # Create "else" branch: Identity on X from outer scope + else_y = helper.make_tensor_value_info("else_y", TensorProto.FLOAT, [2, 4]) + else_identity = helper.make_node("Identity", ["X"], ["else_y"], name="else_identity") + else_graph = helper.make_graph([else_identity], "else_branch", [], [else_y]) + else_graph.value_info.extend([helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4])]) + + # Create If node and main graph + if_node = helper.make_node( + "If", ["condition"], ["Y"], name="if_node", then_branch=then_graph, else_branch=else_graph + ) + main_graph = helper.make_graph([if_node], "model_with_outer_scope", [x, condition], [y]) + + model = helper.make_model(main_graph, producer_name="model_with_outer_scope") + model.opset_import[0].version = 20 + onnx.checker.check_model(model) + + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + return model, value_info_map, initializer_map, node_to_init_map + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_if_subgraph_outer_scope_type_preservation( + model_with_if_outer_scope_reference, low_precision_type +): + """Test that outer scope variable types are preserved in If subgraphs during conversion. + + Without preserving X's value_info in subgraphs, shape inference fails with + "Element type of input 0 unknown". + """ + model, value_info_map, initializer_map, node_to_init_map = model_with_if_outer_scope_reference + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + ) + + converted_model = converter.convert(high_precision_nodes=["if_node"], low_precision_nodes=[]) + onnx.checker.check_model(converted_model) + + # Verify X's value_info is preserved in both subgraphs + if_node = next(n for n in converted_model.graph.node if n.op_type == "If") + then_branch = next(attr.g for attr in if_node.attribute if attr.name == "then_branch") + else_branch = next(attr.g for attr in if_node.attribute if attr.name == "else_branch") + + then_x_info = [vi for vi in then_branch.value_info if vi.name == "X"] + else_x_info = [vi for vi in else_branch.value_info if vi.name == "X"] + + assert len(then_x_info) > 0, "X value_info should be preserved in then branch" + assert len(else_x_info) > 0, "X value_info should be preserved in else branch" + assert then_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED + assert else_x_info[0].type.tensor_type.elem_type != onnx.TensorProto.UNDEFINED