diff --git a/modelopt/onnx/graph_surgery/__init__.py b/modelopt/onnx/graph_surgery/__init__.py new file mode 100644 index 000000000..06ac87c0b --- /dev/null +++ b/modelopt/onnx/graph_surgery/__init__.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graph surgery module for post-processing ONNX models. + +This module provides utilities for performing graph-level transformations on ONNX models +after export. Common use cases include: + +- Replacing standard attention patterns with GroupQueryAttention (GQA) for LLMs +- Adding cross-attention KV cache outputs to encoder models +- Converting model precision (e.g., FP16 to BF16) +- Transposing DequantizeLinear weights for column-major storage optimization +- Graph cleanup and optimization + +Example usage: + >>> from modelopt.onnx.graph_surgery import ( + ... replace_attention_with_gqa, + ... convert_fp16_to_bf16, + ... transpose_dequantize_linear_weights, + ... add_cross_kv_to_encoder, + ... ) + >>> # Replace attention with GQA for LLMs (FP16 model) + >>> replace_attention_with_gqa( + ... model_path="model_fp16.onnx", + ... output_path="model_gqa.onnx", + ... hf_model_id="meta-llama/Llama-2-7b-hf", + ... io_dtype="float16", + ... ) + >>> # Replace attention with GQA and convert to BF16 in one step + >>> replace_attention_with_gqa( + ... model_path="model_fp16.onnx", + ... output_path="model_gqa_bf16.onnx", + ... hf_model_id="meta-llama/Llama-2-7b-hf", + ... io_dtype="bfloat16", # Automatically converts FP16 to BF16 + ... ) + >>> # Add cross-attention KV cache outputs to encoder (GenAI compatible) + >>> add_cross_kv_to_encoder( + ... encoder_path="encoder_model.onnx", + ... output_path="encoder_with_kv.onnx", + ... hf_model_id="openai/whisper-large-v3-turbo", + ... ) + >>> # Standalone FP16 to BF16 conversion + >>> convert_fp16_to_bf16( + ... input_path="model_fp16.onnx", + ... output_path="model_bf16.onnx", + ... ) + >>> + >>> # Transpose DequantizeLinear weights for column-major storage + >>> transpose_dequantize_linear_weights( + ... model_path="model_quantized.onnx", + ... output_path="model_quantized_transposed.onnx", + ... ) +""" + +from .dq_transpose import transpose_dequantize_linear_weights +from .encoder_cross_kv import add_cross_kv_to_encoder +from .gqa_replacement import replace_attention_with_gqa +from .utils.dtype_conversion import convert_fp16_to_bf16 + +__all__ = [ + "add_cross_kv_to_encoder", + "convert_fp16_to_bf16", + "replace_attention_with_gqa", + "transpose_dequantize_linear_weights", +] diff --git a/modelopt/onnx/graph_surgery/__main__.py b/modelopt/onnx/graph_surgery/__main__.py new file mode 100644 index 000000000..573f42a86 --- /dev/null +++ b/modelopt/onnx/graph_surgery/__main__.py @@ -0,0 +1,315 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Command-line interface for graph surgery operations. + +This module provides CLI access to graph surgery tools: + +Replace attention with GQA (for FP16/BF16 LLMs):: + + python -m modelopt.onnx.graph_surgery replace-gqa \ + --input model.onnx \ + --output model_gqa.onnx \ + --model-id meta-llama/Llama-2-7b-hf + +Replace attention with GQA (for INT4/AWQ quantized LLMs):: + + python -m modelopt.onnx.graph_surgery replace-gqa \ + --input model.onnx \ + --output model_gqa.onnx \ + --model-id meta-llama/Llama-3.1-8B + +Add cross-attention KV cache to encoder:: + + python -m modelopt.onnx.graph_surgery add-cross-kv \ + --input encoder_model.onnx \ + --output encoder_with_kv.onnx \ + --model-id openai/whisper-large-v3-turbo + +Convert FP16 to BF16:: + + python -m modelopt.onnx.graph_surgery convert-bf16 \ + --input model_fp16.onnx \ + --output model_bf16.onnx + +Transpose DequantizeLinear weights (column-major optimization):: + + python -m modelopt.onnx.graph_surgery transpose-dq \ + --input model_quantized.onnx \ + --output model_quantized_transposed.onnx + +Analyze attention pattern:: + + python -m modelopt.onnx.graph_surgery analyze \ + --input model.onnx \ + --layer 0 +""" + +import argparse +import sys + + +def main(): + """Main entry point for graph surgery CLI.""" + parser = argparse.ArgumentParser( + description="ONNX Graph Surgery Tools", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + Replace attention with GQA (FP16/BF16 LLMs): + python -m modelopt.onnx.graph_surgery replace-gqa -i model.onnx -o model_gqa.onnx -m meta-llama/Llama-2-7b-hf + + Replace attention with GQA (INT4/AWQ quantized LLMs): + python -m modelopt.onnx.graph_surgery replace-gqa -i model.onnx -o model_gqa.onnx -m meta-llama/Llama-3.1-8B + + Add cross-attention KV to encoder: + python -m modelopt.onnx.graph_surgery add-cross-kv \\ + -i encoder.onnx -o encoder_kv.onnx -m openai/whisper-large-v3-turbo + + Convert FP16 to BF16: + python -m modelopt.onnx.graph_surgery convert-bf16 -i model_fp16.onnx -o model_bf16.onnx + + Transpose DequantizeLinear weights: + python -m modelopt.onnx.graph_surgery transpose-dq -i model_quantized.onnx -o model_transposed.onnx + + Analyze attention pattern: + python -m modelopt.onnx.graph_surgery analyze -i model.onnx --layer 0 + """, + ) + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Replace GQA subcommand + gqa_parser = subparsers.add_parser( + "replace-gqa", + help="Replace attention with GroupQueryAttention", + description="Replace standard attention subgraphs with GroupQueryAttention (GQA).", + ) + gqa_parser.add_argument("-i", "--input", required=True, help="Input ONNX model path") + gqa_parser.add_argument("-o", "--output", required=True, help="Output ONNX model path") + gqa_parser.add_argument( + "-m", "--model-id", required=True, help="HuggingFace model ID for config" + ) + gqa_parser.add_argument("--max-seq-len", type=int, default=4096, help="Maximum sequence length") + gqa_parser.add_argument( + "--dtype", + default="float16", + choices=["float16", "float32", "bfloat16"], + help="I/O data type", + ) + gqa_parser.add_argument( + "--no-external-data", + action="store_true", + help="Embed weights in the model file (disables external data)", + ) + gqa_parser.add_argument( + "--external-data-name", + type=str, + default=None, + help="Name for external data file (default: model.onnx_data)", + ) + gqa_parser.add_argument( + "--ir-version", + type=int, + default=None, + help="Set ONNX IR version for compatibility (e.g., 9 for older ORT versions)", + ) + gqa_parser.add_argument( + "--pack-qkv", + action="store_true", + help=( + "For quantized models: concatenate Q/K/V outputs into a single packed" + " QKV tensor for GQA input (default: separate Q/K/V inputs)" + ), + ) + gqa_parser.add_argument("-q", "--quiet", action="store_true", help="Suppress progress messages") + gqa_parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code in HuggingFace model config", + ) + + # Add cross-KV subcommand + cross_kv_parser = subparsers.add_parser( + "add-cross-kv", + help="Add cross-attention KV cache outputs to encoder", + description="Add cross-attention K/V projection outputs to encoder for GenAI compatibility.", + ) + cross_kv_parser.add_argument( + "-i", "--input", required=True, help="Input encoder ONNX model path" + ) + cross_kv_parser.add_argument("-o", "--output", required=True, help="Output ONNX model path") + cross_kv_parser.add_argument( + "-m", "--model-id", required=True, help="HuggingFace model ID for cross-attention weights" + ) + cross_kv_parser.add_argument( + "--hidden-state-name", + default="last_hidden_state", + help="Name of encoder hidden state output", + ) + cross_kv_parser.add_argument( + "--no-rename-input", + action="store_true", + help="Don't rename input_features to audio_features", + ) + cross_kv_parser.add_argument( + "--no-external-data", + action="store_true", + help="Don't save weights as external data", + ) + cross_kv_parser.add_argument( + "--decoder-filename", + default="decoder_with_past_model.onnx", + help="Decoder ONNX filename for genai_config.json (default: decoder_with_past_model.onnx)", + ) + cross_kv_parser.add_argument( + "--no-genai-config", + action="store_true", + help="Don't generate genai_config.json", + ) + cross_kv_parser.add_argument( + "--provider", + default="cuda", + choices=["cuda", "cpu", "NvTensorRtRtx"], + help="Execution provider for genai_config.json", + ) + cross_kv_parser.add_argument( + "-q", "--quiet", action="store_true", help="Suppress progress messages" + ) + cross_kv_parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code in HuggingFace model", + ) + + # Convert BF16 subcommand + bf16_parser = subparsers.add_parser( + "convert-bf16", + help="Convert FP16 model to BF16", + description="Convert an ONNX model from FP16 to BF16 precision.", + ) + bf16_parser.add_argument("-i", "--input", required=True, help="Input FP16 ONNX model path") + bf16_parser.add_argument("-o", "--output", required=True, help="Output BF16 ONNX model path") + bf16_parser.add_argument( + "--no-external-data", + action="store_true", + help="Don't save weights as external data", + ) + bf16_parser.add_argument( + "-q", "--quiet", action="store_true", help="Suppress progress messages" + ) + + # Transpose DQ subcommand + transpose_parser = subparsers.add_parser( + "transpose-dq", + help="Transpose DequantizeLinear weights for column-major storage", + description="Transpose weights in DequantizeLinear nodes for column-major storage optimization.", + ) + transpose_parser.add_argument( + "-i", "--input", required=True, help="Input quantized ONNX model path" + ) + transpose_parser.add_argument("-o", "--output", required=True, help="Output ONNX model path") + transpose_parser.add_argument( + "--no-external-data", + action="store_true", + help="Don't save weights as external data", + ) + transpose_parser.add_argument( + "--external-data-name", + type=str, + default=None, + help="Name for external data file", + ) + transpose_parser.add_argument( + "-q", "--quiet", action="store_true", help="Suppress progress messages" + ) + + # Analyze subcommand + analyze_parser = subparsers.add_parser( + "analyze", + help="Analyze attention pattern in model", + description="Analyze the attention pattern in an existing model for debugging.", + ) + analyze_parser.add_argument("-i", "--input", required=True, help="Input ONNX model path") + analyze_parser.add_argument("--layer", type=int, default=0, help="Layer to analyze") + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + sys.exit(1) + + if args.command == "replace-gqa": + from .gqa_replacement import replace_attention_with_gqa + + replace_attention_with_gqa( + model_path=args.input, + output_path=args.output, + hf_model_id=args.model_id, + max_seq_len=args.max_seq_len, + io_dtype=args.dtype, + use_external_data=not args.no_external_data, + external_data_name=args.external_data_name, + ir_version=args.ir_version, + pack_qkv=args.pack_qkv, + verbose=not args.quiet, + trust_remote_code=args.trust_remote_code, + ) + + elif args.command == "add-cross-kv": + from .encoder_cross_kv import add_cross_kv_to_encoder + + add_cross_kv_to_encoder( + encoder_path=args.input, + output_path=args.output, + hf_model_id=args.model_id, + hidden_state_output_name=args.hidden_state_name, + rename_input_features=not args.no_rename_input, + use_external_data=not args.no_external_data, + decoder_filename=args.decoder_filename, + generate_genai_config=not args.no_genai_config, + provider=args.provider, + verbose=not args.quiet, + trust_remote_code=args.trust_remote_code, + ) + + elif args.command == "convert-bf16": + from .utils.dtype_conversion import convert_fp16_to_bf16 + + convert_fp16_to_bf16( + input_path=args.input, + output_path=args.output, + external_data=not args.no_external_data, + verbose=not args.quiet, + ) + + elif args.command == "transpose-dq": + from .dq_transpose import transpose_dequantize_linear_weights + + transpose_dequantize_linear_weights( + model_path=args.input, + output_path=args.output, + use_external_data=not args.no_external_data, + external_data_name=args.external_data_name, + verbose=not args.quiet, + ) + + elif args.command == "analyze": + from .gqa_replacement import analyze_attention_pattern + + analyze_attention_pattern(args.input, args.layer) + + +if __name__ == "__main__": + main() diff --git a/modelopt/onnx/graph_surgery/dq_transpose.py b/modelopt/onnx/graph_surgery/dq_transpose.py new file mode 100644 index 000000000..9ae9029b0 --- /dev/null +++ b/modelopt/onnx/graph_surgery/dq_transpose.py @@ -0,0 +1,387 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transpose DequantizeLinear weights for column-major storage optimization. + +This module provides functionality to transform quantized ONNX models by +transposing the weights and scales in DequantizeLinear nodes and adding +corresponding Transpose nodes after them. This is useful for optimizing +inference with backends that prefer column-major weight storage (e.g., NvTensorRtRtx). + +The transformation: +1. For each DequantizeLinear node feeding into MatMul/Gemm: + - Transpose the quantized weights (input 0) + - Transpose the scales (input 1) + - Transpose zero points if present (input 2) + - Update the axis attribute (0 -> 1 for 2D tensors) + - Add a Transpose node after DequantizeLinear to recover original shape +""" + +import os + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + +from ..logging_config import logger + + +def _is_int4_type(data_type: int) -> bool: + """Check if the data type is INT4 or UINT4.""" + return data_type in [TensorProto.INT4, TensorProto.UINT4] + + +def _unpack_int4(packed_data: bytes, shape: tuple, signed: bool = True) -> np.ndarray: + """Unpack INT4/UINT4 packed bytes to int8 array. + + INT4 is packed with 2 values per byte (low nibble first). + """ + packed_arr = np.frombuffer(packed_data, dtype=np.uint8) + + # Extract low and high nibbles + low_nibbles = packed_arr & 0x0F + high_nibbles = (packed_arr >> 4) & 0x0F + + # Interleave: low nibble comes first + unpacked = np.empty(len(packed_arr) * 2, dtype=np.int8) + unpacked[0::2] = low_nibbles + unpacked[1::2] = high_nibbles + + # Handle sign extension for INT4 + if signed: + # Values >= 8 are negative (two's complement) + unpacked = np.where(unpacked >= 8, unpacked - 16, unpacked).astype(np.int8) + + # Reshape to original shape + total_elements = np.prod(shape) + return unpacked[:total_elements].reshape(shape) + + +def _pack_int4(arr: np.ndarray, signed: bool = True) -> bytes: + """Pack int8 array back to INT4/UINT4 packed bytes. + + INT4 is packed with 2 values per byte (low nibble first). + """ + flat = arr.flatten().astype(np.int8) + + # Handle negative values for signed INT4 + if signed: + flat = np.where(flat < 0, flat + 16, flat).astype(np.uint8) + else: + flat = flat.astype(np.uint8) + + # Ensure we have even number of elements (pad if needed) + if len(flat) % 2 != 0: + flat = np.append(flat, np.uint8(0)) + + # Pack: low nibble first, then high nibble + low_nibbles = flat[0::2] & 0x0F + high_nibbles = flat[1::2] & 0x0F + packed = low_nibbles | (high_nibbles << 4) + + return packed.astype(np.uint8).tobytes() + + +def _transpose_tensor_proto(tensor: onnx.TensorProto, perm: list[int]) -> onnx.TensorProto: + """Transpose an ONNX TensorProto, handling INT4 packed format.""" + original_shape = list(tensor.dims) + data_type = tensor.data_type + + if _is_int4_type(data_type): + # Handle INT4/UINT4 specially + signed = data_type == TensorProto.INT4 + unpacked = _unpack_int4(tensor.raw_data, tuple(original_shape), signed=signed) + transposed = np.transpose(unpacked, perm) + packed_data = _pack_int4(transposed, signed=signed) + + new_shape = [original_shape[p] for p in perm] + + new_tensor = onnx.TensorProto() + new_tensor.name = tensor.name + "_transposed" + new_tensor.data_type = data_type + new_tensor.dims.extend(new_shape) + new_tensor.raw_data = packed_data + return new_tensor + else: + # Standard handling for other types + arr = numpy_helper.to_array(tensor) + transposed = np.transpose(arr, perm) + new_tensor = numpy_helper.from_array(transposed, name=tensor.name + "_transposed") + return new_tensor + + +def _find_initializer(model: onnx.ModelProto, name: str) -> onnx.TensorProto | None: + """Find initializer by name.""" + for init in model.graph.initializer: + if init.name == name: + return init + return None + + +def _find_node_by_output(model: onnx.ModelProto, output_name: str) -> onnx.NodeProto | None: + """Find node that produces the given output.""" + for node in model.graph.node: + if output_name in node.output: + return node + return None + + +def _get_consumers(model: onnx.ModelProto, tensor_name: str) -> list[onnx.NodeProto]: + """Find all nodes that consume the given tensor.""" + return [node for node in model.graph.node if tensor_name in node.input] + + +def transpose_dequantize_linear_weights( + model_path: str, + output_path: str, + use_external_data: bool = True, + external_data_name: str | None = None, + verbose: bool = True, +) -> onnx.ModelProto: + """Transpose weights in DequantizeLinear nodes for column-major storage. + + This function transforms a quantized ONNX model by: + 1. Finding all DequantizeLinear nodes that feed into MatMul/Gemm + 2. Transposing the quantized weights, scales, and zero points + 3. Updating the axis attribute (0 -> 1) + 4. Adding Transpose nodes after DequantizeLinear to recover original shape + + This optimization is useful for backends that prefer column-major weight + storage, such as NvTensorRtRtx. + + Args: + model_path: Path to input quantized ONNX model. + output_path: Path to save modified model. + use_external_data: Whether to save weights as external data. + external_data_name: Name for external data file. + verbose: Whether to print progress messages. + + Returns: + Modified ONNX model. + + Example: + >>> from modelopt.onnx.graph_surgery import transpose_dequantize_linear_weights + >>> model = transpose_dequantize_linear_weights( + ... model_path="model_quantized.onnx", + ... output_path="model_quantized_transposed.onnx", + ... ) + """ + if verbose: + logger.info(f"Loading model from: {model_path}") + + model = onnx.load(model_path, load_external_data=True) + graph = model.graph + + # Statistics + stats = { + "dq_nodes_processed": 0, + "transpose_nodes_added": 0, + "weights_transposed": 0, + "scales_transposed": 0, + "zero_points_transposed": 0, + } + + # Find all DequantizeLinear nodes + dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"] + + if verbose: + logger.info(f"Found {len(dq_nodes)} DequantizeLinear nodes") + + # Track which DQ nodes feed into MatMul/Gemm as weight input + dq_nodes_to_process = [] + for dq_node in dq_nodes: + if len(dq_node.output) == 0: + continue + + dq_output = dq_node.output[0] + consumers = _get_consumers(model, dq_output) + + for consumer in consumers: + if consumer.op_type in ["MatMul", "Gemm"]: + # Check if DQ output is the weight input (input[1] for MatMul/Gemm) + if len(consumer.input) > 1 and consumer.input[1] == dq_output: + dq_nodes_to_process.append((dq_node, consumer)) + break + + if verbose: + logger.info( + f"Found {len(dq_nodes_to_process)} DequantizeLinear nodes feeding into MatMul/Gemm" + ) + + # Track initializers to add/remove + initializers_to_remove = [] + initializers_to_add = [] + nodes_to_add = [] + processed_dq_names = set() + + for dq_node, consumer_node in dq_nodes_to_process: + if dq_node.name in processed_dq_names: + continue + processed_dq_names.add(dq_node.name) + + if len(dq_node.input) < 2: + if verbose: + logger.warning(f"Skipping {dq_node.name}: insufficient inputs") + continue + + weight_name = dq_node.input[0] + scale_name = dq_node.input[1] + + # Find initializers + weight_init = _find_initializer(model, weight_name) + scale_init = _find_initializer(model, scale_name) + + if weight_init is None or scale_init is None: + if verbose: + logger.debug(f"Skipping {dq_node.name}: weights or scale not constant") + continue + + # Check if 2D + if len(weight_init.dims) != 2: + if verbose: + logger.debug( + f"Skipping {dq_node.name}: weights not 2D (shape: {list(weight_init.dims)})" + ) + continue + + original_shape = list(weight_init.dims) + + if verbose: + is_int4 = _is_int4_type(weight_init.data_type) + logger.debug(f"Processing {dq_node.name}: shape={original_shape}, INT4={is_int4}") + + # Transpose weights + transposed_weight = _transpose_tensor_proto(weight_init, [1, 0]) + initializers_to_remove.append(weight_init) + initializers_to_add.append(transposed_weight) + + # Update DQ node input to use transposed weight + for i, inp in enumerate(dq_node.input): + if inp == weight_name: + dq_node.input[i] = transposed_weight.name + break + stats["weights_transposed"] += 1 + + # Transpose scale if 2D + if len(scale_init.dims) == 2: + transposed_scale = _transpose_tensor_proto(scale_init, [1, 0]) + initializers_to_remove.append(scale_init) + initializers_to_add.append(transposed_scale) + + for i, inp in enumerate(dq_node.input): + if inp == scale_name: + dq_node.input[i] = transposed_scale.name + break + stats["scales_transposed"] += 1 + + # Transpose zero point if present and 2D + if len(dq_node.input) > 2: + zp_name = dq_node.input[2] + zp_init = _find_initializer(model, zp_name) + if zp_init is not None and len(zp_init.dims) == 2: + transposed_zp = _transpose_tensor_proto(zp_init, [1, 0]) + initializers_to_remove.append(zp_init) + initializers_to_add.append(transposed_zp) + + for i, inp in enumerate(dq_node.input): + if inp == zp_name: + dq_node.input[i] = transposed_zp.name + break + stats["zero_points_transposed"] += 1 + + # Update axis attribute (0 -> 1 after transpose) + for attr in dq_node.attribute: + if attr.name == "axis": + old_axis = attr.i + if old_axis == 0: + attr.i = 1 + elif old_axis == 1: + attr.i = 0 + if verbose: + logger.debug(f" Updated axis: {old_axis} -> {attr.i}") + break + + # Create intermediate output for DQ (transposed shape) + dq_output_name = dq_node.output[0] + new_dq_output_name = f"{dq_output_name}_before_transpose" + + # Update DQ node output + dq_node.output[0] = new_dq_output_name + + # Create Transpose node to convert back to original shape + transpose_node = helper.make_node( + "Transpose", + inputs=[new_dq_output_name], + outputs=[dq_output_name], # Use original output name + name=f"{dq_node.name}_transpose_back", + perm=[1, 0], + ) + nodes_to_add.append(transpose_node) + stats["transpose_nodes_added"] += 1 + stats["dq_nodes_processed"] += 1 + + if verbose: + transposed_shape = [original_shape[1], original_shape[0]] + logger.debug(f" Transposed: {original_shape} -> {transposed_shape}") + + # Apply changes to graph + # Remove old initializers + for init in initializers_to_remove: + graph.initializer.remove(init) + + # Add new initializers + graph.initializer.extend(initializers_to_add) + + # Add transpose nodes + graph.node.extend(nodes_to_add) + + if verbose: + logger.info("\nTransformation statistics:") + logger.info(f" DequantizeLinear nodes processed: {stats['dq_nodes_processed']}") + logger.info(f" Transpose nodes added: {stats['transpose_nodes_added']}") + logger.info(f" Weights transposed: {stats['weights_transposed']}") + logger.info(f" Scales transposed: {stats['scales_transposed']}") + logger.info(f" Zero points transposed: {stats['zero_points_transposed']}") + + # Save model + if verbose: + logger.info(f"\nSaving modified model to: {output_path}") + + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + if use_external_data: + if external_data_name is None: + external_data_name = os.path.basename(output_path) + "_data" + + if verbose: + logger.info(f" Saving weights to external file: {external_data_name}") + + onnx.save( + model, + output_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_data_name, + size_threshold=1024, + ) + else: + onnx.save(model, output_path) + + if verbose: + logger.info("Done!") + + return model diff --git a/modelopt/onnx/graph_surgery/encoder_cross_kv.py b/modelopt/onnx/graph_surgery/encoder_cross_kv.py new file mode 100644 index 000000000..32be99185 --- /dev/null +++ b/modelopt/onnx/graph_surgery/encoder_cross_kv.py @@ -0,0 +1,495 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Add cross-attention KV cache outputs to encoder model. + +This module provides functionality to transform Optimum-exported encoder models +(e.g., Whisper encoder) by adding cross-attention Key/Value projection outputs. +This is required for ONNX Runtime GenAI compatibility where the decoder expects +pre-computed encoder K/V caches. + +The transformation: +1. Loads cross-attention K/V projection weights from HuggingFace model +2. Adds MatMul -> Reshape -> Transpose nodes to encoder graph +3. Adds new outputs: present_key_cross_0, present_value_cross_0, etc. +""" + +import os +from pathlib import Path + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + +from ..logging_config import logger +from .utils.graph_utils import detect_model_dtype + + +def _get_cross_attn_weights_from_hf( + model_id: str, trust_remote_code: bool = False +) -> tuple[dict, int, int, int]: + """Extract cross-attention K and V projection weights from HuggingFace model. + + Args: + model_id: HuggingFace model ID (e.g., "openai/whisper-large-v3-turbo"). + trust_remote_code: Whether to trust remote code in HuggingFace model. + + Returns: + Tuple of (weights_dict, num_heads, head_size, num_layers). + """ + from transformers import WhisperForConditionalGeneration + + logger.info(f"Loading PyTorch model: {model_id}") + model = WhisperForConditionalGeneration.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + + weights = {} + num_layers = model.config.decoder_layers + num_heads = model.config.decoder_attention_heads + hidden_size = model.config.d_model + head_size = hidden_size // num_heads + + logger.info(f" num_layers: {num_layers}") + logger.info(f" num_heads: {num_heads}") + logger.info(f" hidden_size: {hidden_size}") + logger.info(f" head_size: {head_size}") + + for i in range(num_layers): + layer = model.model.decoder.layers[i] + + # encoder_attn is the cross-attention layer + k_proj = layer.encoder_attn.k_proj + v_proj = layer.encoder_attn.v_proj + + weights[i] = { + "k_weight": k_proj.weight.detach().cpu().numpy(), # (out_features, in_features) + "v_weight": v_proj.weight.detach().cpu().numpy(), + "k_bias": k_proj.bias.detach().cpu().numpy() if k_proj.bias is not None else None, + "v_bias": v_proj.bias.detach().cpu().numpy() if v_proj.bias is not None else None, + } + + logger.debug( + f" Layer {i}: K weight {weights[i]['k_weight'].shape}, " + f"V weight {weights[i]['v_weight'].shape}" + ) + + return weights, num_heads, head_size, num_layers + + +def _rename_input(graph: onnx.GraphProto, old_name: str, new_name: str) -> None: + """Rename an input in the graph.""" + for inp in graph.input: + if inp.name == old_name: + inp.name = new_name + logger.debug(f" Renamed input: {old_name} -> {new_name}") + break + + # Rename in all nodes that use this input + for node in graph.node: + for i, inp in enumerate(node.input): + if inp == old_name: + node.input[i] = new_name + + +def _add_cross_kv_outputs( + encoder_model: onnx.ModelProto, + cross_attn_weights: dict, + hidden_state_output_name: str, + num_heads: int, + head_size: int, + num_layers: int, + rename_input_features: bool = True, + onnx_dtype: int = TensorProto.FLOAT, + np_dtype: np.dtype = np.float32, +) -> onnx.ModelProto: + """Add cross-attention KV cache computation to encoder model. + + Args: + encoder_model: ONNX encoder model to modify. + cross_attn_weights: Dictionary of cross-attention K/V weights per layer. + hidden_state_output_name: Name of the encoder hidden state output. + num_heads: Number of attention heads. + head_size: Size of each attention head. + num_layers: Number of decoder layers. + rename_input_features: Whether to rename input_features to audio_features. + onnx_dtype: ONNX tensor data type (TensorProto.FLOAT or TensorProto.FLOAT16). + np_dtype: NumPy dtype for weight conversion. + + Returns: + Modified ONNX model with cross-attention KV outputs. + """ + logger.info(f"Adding cross KV outputs with dtype: {np_dtype}") + graph = encoder_model.graph + + # Rename input_features to audio_features if requested + if rename_input_features: + _rename_input(graph, "input_features", "audio_features") + + # Rename output to encoder_hidden_states + encoder_hidden_states_name = "encoder_hidden_states" + + # Find the node that produces this output and rename it directly + for node in graph.node: + for i, out in enumerate(node.output): + if out == hidden_state_output_name: + node.output[i] = encoder_hidden_states_name + logger.debug( + f" Renamed output: {hidden_state_output_name} -> {encoder_hidden_states_name}" + ) + break + + # Update the graph output + for output in list(graph.output): + if output.name == hidden_state_output_name: + dims = [ + d.dim_param if d.dim_param else d.dim_value + for d in output.type.tensor_type.shape.dim + ] + new_output = helper.make_tensor_value_info( + encoder_hidden_states_name, + output.type.tensor_type.elem_type, + dims, + ) + graph.output.remove(output) + graph.output.append(new_output) + break + + logger.info(f"Adding cross KV cache outputs for {num_layers} layers") + + new_nodes = [] + new_outputs = [] + new_initializers = [] + + # Shape constant for reshape: (batch, seq, num_heads, head_size) + reshape_shape_name = "cross_kv_reshape_shape" + reshape_shape = np.array([0, -1, num_heads, head_size], dtype=np.int64) + new_initializers.append(numpy_helper.from_array(reshape_shape, name=reshape_shape_name)) + + for layer_idx in range(num_layers): + layer_weights = cross_attn_weights[layer_idx] + + k_weight = layer_weights["k_weight"] + v_weight = layer_weights["v_weight"] + k_bias = layer_weights["k_bias"] + v_bias = layer_weights["v_bias"] + + # Transpose weights for ONNX MatMul + # Use detected model dtype for weight conversion + k_weight_t = k_weight.T.astype(np_dtype) + v_weight_t = v_weight.T.astype(np_dtype) + + # Add weight initializers + k_weight_name = f"encoder.cross_attn_k_weight.{layer_idx}" + v_weight_name = f"encoder.cross_attn_v_weight.{layer_idx}" + + new_initializers.append(numpy_helper.from_array(k_weight_t, name=k_weight_name)) + new_initializers.append(numpy_helper.from_array(v_weight_t, name=v_weight_name)) + + # MatMul: encoder_hidden_states @ k_weight + k_matmul_out = f"cross_k_matmul_{layer_idx}" + new_nodes.append( + helper.make_node( + "MatMul", + inputs=[encoder_hidden_states_name, k_weight_name], + outputs=[k_matmul_out], + name=f"CrossK_MatMul_{layer_idx}", + ) + ) + + v_matmul_out = f"cross_v_matmul_{layer_idx}" + new_nodes.append( + helper.make_node( + "MatMul", + inputs=[encoder_hidden_states_name, v_weight_name], + outputs=[v_matmul_out], + name=f"CrossV_MatMul_{layer_idx}", + ) + ) + + # Add bias if present + if k_bias is not None: + k_bias_name = f"encoder.cross_attn_k_bias.{layer_idx}" + new_initializers.append( + numpy_helper.from_array(k_bias.astype(np_dtype), name=k_bias_name) + ) + k_add_out = f"cross_k_add_{layer_idx}" + new_nodes.append( + helper.make_node( + "Add", + inputs=[k_matmul_out, k_bias_name], + outputs=[k_add_out], + name=f"CrossK_Add_{layer_idx}", + ) + ) + k_matmul_out = k_add_out + + if v_bias is not None: + v_bias_name = f"encoder.cross_attn_v_bias.{layer_idx}" + new_initializers.append( + numpy_helper.from_array(v_bias.astype(np_dtype), name=v_bias_name) + ) + v_add_out = f"cross_v_add_{layer_idx}" + new_nodes.append( + helper.make_node( + "Add", + inputs=[v_matmul_out, v_bias_name], + outputs=[v_add_out], + name=f"CrossV_Add_{layer_idx}", + ) + ) + v_matmul_out = v_add_out + + # Reshape: (batch, seq, hidden) -> (batch, seq, num_heads, head_size) + k_reshape_out = f"cross_k_reshape_{layer_idx}" + new_nodes.append( + helper.make_node( + "Reshape", + inputs=[k_matmul_out, reshape_shape_name], + outputs=[k_reshape_out], + name=f"CrossK_Reshape_{layer_idx}", + ) + ) + + v_reshape_out = f"cross_v_reshape_{layer_idx}" + new_nodes.append( + helper.make_node( + "Reshape", + inputs=[v_matmul_out, reshape_shape_name], + outputs=[v_reshape_out], + name=f"CrossV_Reshape_{layer_idx}", + ) + ) + + # Transpose: (batch, seq, num_heads, head_size) -> (batch, num_heads, seq, head_size) + k_output_name = f"present_key_cross_{layer_idx}" + new_nodes.append( + helper.make_node( + "Transpose", + inputs=[k_reshape_out], + outputs=[k_output_name], + perm=[0, 2, 1, 3], + name=f"CrossK_Transpose_{layer_idx}", + ) + ) + + v_output_name = f"present_value_cross_{layer_idx}" + new_nodes.append( + helper.make_node( + "Transpose", + inputs=[v_reshape_out], + outputs=[v_output_name], + perm=[0, 2, 1, 3], + name=f"CrossV_Transpose_{layer_idx}", + ) + ) + + # Add outputs with shape: (batch_size, num_heads, seq_len, head_size) + # Use detected model dtype for output tensor type + k_output = helper.make_tensor_value_info( + k_output_name, + onnx_dtype, + ["batch_size", num_heads, "encoder_sequence_length", head_size], + ) + v_output = helper.make_tensor_value_info( + v_output_name, + onnx_dtype, + ["batch_size", num_heads, "encoder_sequence_length", head_size], + ) + new_outputs.append(k_output) + new_outputs.append(v_output) + + # Add new nodes, initializers, and outputs + graph.node.extend(new_nodes) + graph.initializer.extend(new_initializers) + graph.output.extend(new_outputs) + + return encoder_model + + +def add_cross_kv_to_encoder( + encoder_path: str, + output_path: str, + hf_model_id: str, + hidden_state_output_name: str = "last_hidden_state", + rename_input_features: bool = True, + use_external_data: bool = True, + external_data_name: str | None = None, + decoder_filename: str = "decoder_with_past_model.onnx", + generate_genai_config: bool = True, + provider: str = "cuda", + verbose: bool = True, + trust_remote_code: bool = False, +) -> onnx.ModelProto: + """Add cross-attention KV cache outputs to encoder model. + + This function transforms an Optimum-exported encoder model by adding + cross-attention Key/Value projection outputs. This is required for + ONNX Runtime GenAI compatibility where the decoder expects pre-computed + encoder K/V caches. + + The transformation: + 1. Renames input_features -> audio_features (optional) + 2. Renames last_hidden_state -> encoder_hidden_states + 3. Adds K/V projection weights from HuggingFace model + 4. Adds MatMul -> Reshape -> Transpose subgraph for each layer + 5. Adds outputs: present_key_cross_X, present_value_cross_X + 6. Generates genai_config.json and audio_processor_config.json (optional) + + Args: + encoder_path: Path to encoder ONNX model. + output_path: Path to save modified encoder. + hf_model_id: HuggingFace model ID for loading cross-attention weights. + hidden_state_output_name: Name of encoder hidden state output. + rename_input_features: Whether to rename input_features to audio_features. + use_external_data: Whether to save weights as external data. + external_data_name: Name for external data file. + decoder_filename: Filename for decoder model in genai_config.json. + Default is "decoder_with_past_model.onnx". + generate_genai_config: Whether to generate genai_config.json. + provider: Execution provider for genai_config.json ("cuda", "cpu", "NvTensorRtRtx"). + verbose: Whether to print progress messages. + trust_remote_code: Whether to trust remote code in HuggingFace model. + + Returns: + Modified encoder model with cross-attention KV cache outputs. + + Example: + >>> from modelopt.onnx.graph_surgery import add_cross_kv_to_encoder + >>> model = add_cross_kv_to_encoder( + ... encoder_path="encoder_model.onnx", + ... output_path="encoder_model_with_kv.onnx", + ... hf_model_id="openai/whisper-large-v3-turbo", + ... ) + """ + # Load cross-attention weights from HuggingFace model + cross_attn_weights, num_heads, head_size, num_layers = _get_cross_attn_weights_from_hf( + hf_model_id, trust_remote_code=trust_remote_code + ) + + if verbose: + logger.info(f"Loading encoder model from: {encoder_path}") + + encoder_model = onnx.load(encoder_path, load_external_data=True) + + # Detect model dtype + onnx_dtype, np_dtype = detect_model_dtype(encoder_model) + if verbose: + dtype_names = { + TensorProto.FLOAT: "FP32", + TensorProto.FLOAT16: "FP16", + TensorProto.BFLOAT16: "BF16", + } + logger.info(f"Detected model dtype: {dtype_names.get(onnx_dtype, 'unknown')}") + + if verbose: + logger.info("Adding cross KV cache outputs to encoder...") + + modified_encoder = _add_cross_kv_outputs( + encoder_model, + cross_attn_weights, + hidden_state_output_name, + num_heads, + head_size, + num_layers, + rename_input_features, + onnx_dtype=onnx_dtype, + np_dtype=np_dtype, + ) + + # Save model + if verbose: + logger.info(f"Saving modified encoder to: {output_path}") + + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + if use_external_data: + if external_data_name is None: + external_data_name = Path(output_path).name.replace(".onnx", ".onnx_data") + + if verbose: + logger.info(f" Saving weights to external file: {external_data_name}") + + onnx.save_model( + modified_encoder, + output_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_data_name, + size_threshold=1024, + convert_attribute=False, + ) + else: + onnx.save(modified_encoder, output_path) + + if verbose: + logger.info("Done!") + logger.info("\nEncoder inputs:") + for inp in modified_encoder.graph.input: + logger.info(f" {inp.name}") + logger.info("\nEncoder outputs:") + for output in modified_encoder.graph.output: + logger.info(f" {output.name}") + + logger.info("\n" + "=" * 60) + logger.info("UPDATE genai_config.json with:") + logger.info("=" * 60) + logger.info( + """ +"encoder": { + "filename": ".onnx", + "inputs": { + "audio_features": "audio_features" + }, + "outputs": { + "encoder_hidden_states": "encoder_hidden_states", + "cross_present_key_names": "present_key_cross_%d", + "cross_present_value_names": "present_value_cross_%d" + } +} +""" + ) + + # Generate config files if output directory exists + output_dir = os.path.dirname(output_path) + if output_dir: + # Save audio processor config + from .utils.whisper_utils import save_audio_processor_config + + save_audio_processor_config( + output_dir, + hf_model_id=hf_model_id, + overwrite=False, + trust_remote_code=trust_remote_code, + ) + + # Generate genai_config.json with encoder pointing to this output + if generate_genai_config: + from .utils.whisper_utils import save_genai_config as _save_genai_config + + encoder_filename = os.path.basename(output_path) + _save_genai_config( + output_dir=output_dir, + encoder_filename=encoder_filename, + decoder_filename=decoder_filename, + hf_model_id=hf_model_id, + provider=provider, + trust_remote_code=trust_remote_code, + overwrite=False, # Don't overwrite if exists + ) + + return modified_encoder diff --git a/modelopt/onnx/graph_surgery/gqa_replacement.py b/modelopt/onnx/graph_surgery/gqa_replacement.py new file mode 100644 index 000000000..bc30b85cb --- /dev/null +++ b/modelopt/onnx/graph_surgery/gqa_replacement.py @@ -0,0 +1,1513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Replace standard attention subgraph with GroupQueryAttention (GQA). + +This module provides functionality to transform ONNX models exported from +HuggingFace/Optimum to use Microsoft's GroupQueryAttention operator, +which is optimized for inference with ONNX Runtime. + +The transformation includes: +1. Converting weights to target dtype (FP16/BF16) +2. Removing unnecessary Cast nodes in layers +3. Adding Gemma-specific casts if needed +4. Computing and adding RoPE cos/sin caches +5. Adding attention mask reformatting subgraph +6. Replacing attention pattern with GQA for all layers +7. Fusing Q/K/V projections into single MatMul +8. Adding past/present KV cache inputs/outputs +""" + +import contextlib +import os +import re + +import numpy as np +import onnx +from onnx import TensorProto, helper +from onnx.external_data_helper import convert_model_to_external_data + +from ..logging_config import logger +from .utils.graph_utils import ( + add_initializer, + array_to_initializer, + cleanup_unused_ios, + convert_initializers_to_dtype, + convert_model_fp16_to_bf16, + find_initializer, + get_onnx_dtype, + initializer_to_array, +) +from .utils.rope_cache import get_rope_caches + + +def _remove_layer_cast_nodes(graph: onnx.GraphProto, verbose: bool = True) -> int: + """Remove unnecessary /model/layers.{i}/Cast and /Cast_1 nodes. + + Args: + graph: ONNX graph to modify. + verbose: Whether to print progress. + + Returns: + Number of Cast nodes removed. + """ + cast_pattern = re.compile(r"^/model/layers\.(\d+)/Cast(_1)?$") + cast_nodes_removed = 0 + + cast_nodes_to_remove = [ + node for node in graph.node if node.op_type == "Cast" and cast_pattern.match(node.name) + ] + + for cast_node in cast_nodes_to_remove: + if len(cast_node.input) == 1 and len(cast_node.output) == 1: + cast_input = cast_node.input[0] + cast_output = cast_node.output[0] + + # Rewire: replace all uses of cast_output with cast_input + for node in graph.node: + for i, inp in enumerate(node.input): + if inp == cast_output: + node.input[i] = cast_input + + # Also check graph outputs + for out in graph.output: + if out.name == cast_output: + out.name = cast_input + + graph.node.remove(cast_node) + cast_nodes_removed += 1 + if verbose: + logger.info(f" Removed: {cast_node.name}") + + return cast_nodes_removed + + +def _add_gemma_cast_nodes( + graph: onnx.GraphProto, + hf_model_id: str, + io_dtype: str, + onnx_dtype: int, + verbose: bool = True, + trust_remote_code: bool = False, +) -> int: + """Add Cast to target dtype after layernorm Mul nodes for Gemma models. + + Args: + graph: ONNX graph to modify. + hf_model_id: HuggingFace model ID. + io_dtype: Target IO dtype string. + onnx_dtype: ONNX dtype constant. + verbose: Whether to print progress. + trust_remote_code: Whether to trust remote code in HuggingFace model config. + + Returns: + Number of Cast nodes added. + """ + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code) + num_layers = config.num_hidden_layers + cast_nodes_added = 0 + + # Build list of target Mul nodes for all layers + gemma_cast_targets = [] + for layer_id in range(num_layers): + gemma_cast_targets.append(f"/model/layers.{layer_id}/input_layernorm/Mul") + gemma_cast_targets.append(f"/model/layers.{layer_id}/post_attention_layernorm/Mul") + + # Also add the final norm/Mul + gemma_cast_targets.append("/model/norm/Mul") + + for target_mul_name in gemma_cast_targets: + # Find the target Mul node + target_mul_node = None + for node in graph.node: + if node.name == target_mul_name: + target_mul_node = node + break + + if target_mul_node and len(target_mul_node.output) > 0: + mul_output = target_mul_node.output[0] + + # Check if output is a graph output + is_graph_output = False + original_output_name = None + for out in graph.output: + if out.name == mul_output: + is_graph_output = True + original_output_name = mul_output + break + + if is_graph_output: + new_mul_output = f"{target_mul_name}_output_before_cast" + target_mul_node.output[0] = new_mul_output + cast_output = original_output_name + cast_input = new_mul_output + else: + cast_output = f"{target_mul_name}/Cast_to_fp16/output_0" + cast_input = mul_output + + # Create Cast node + dtype_suffix = ( + "fp16" if io_dtype == "float16" else ("bf16" if io_dtype == "bfloat16" else "fp32") + ) + cast_node = helper.make_node( + "Cast", + inputs=[cast_input], + outputs=[cast_output], + name=f"{target_mul_name}/Cast_to_{dtype_suffix}", + to=onnx_dtype, + ) + graph.node.append(cast_node) + + if not is_graph_output: + # Rewire consumers + for node in graph.node: + if node.name == cast_node.name: + continue + for i, inp in enumerate(node.input): + if inp == mul_output: + node.input[i] = cast_output + + # Update graph output type if applicable + for out in graph.output: + if out.name == cast_output: + out.type.tensor_type.elem_type = onnx_dtype + + # Add value_info + cast_info = helper.make_tensor_value_info( + cast_output, onnx_dtype, ["batch_size", "sequence_length", "hidden_size"] + ) + graph.value_info.append(cast_info) + + cast_nodes_added += 1 + + if verbose: + logger.info(f" Added {cast_nodes_added} Cast nodes for Gemma model") + + return cast_nodes_added + + +def _add_bf16_lm_head_cast(graph: onnx.GraphProto, verbose: bool = True) -> bool: + """Add Cast to FP32 after /lm_head/MatMul for bfloat16 TensorRT compatibility. + + Args: + graph: ONNX graph to modify. + verbose: Whether to print progress. + + Returns: + True if cast was added, False otherwise. + """ + lm_head_matmul_name = "/lm_head/MatMul" + lm_head_node = None + for node in graph.node: + if node.name == lm_head_matmul_name: + lm_head_node = node + break + + if lm_head_node and len(lm_head_node.output) > 0: + # Create new output name for MatMul + new_matmul_output = f"{lm_head_matmul_name}_output_bf16" + lm_head_node.output[0] = new_matmul_output + + # Create Cast node to FP32 + lm_head_cast_node = helper.make_node( + "Cast", + inputs=[new_matmul_output], + outputs=["logits"], + name=f"{lm_head_matmul_name}/Cast_to_fp32", + to=TensorProto.FLOAT, + ) + graph.node.append(lm_head_cast_node) + + # Add value_info for bf16 intermediate + lm_head_bf16_info = helper.make_tensor_value_info( + new_matmul_output, TensorProto.BFLOAT16, ["batch_size", "sequence_length", "vocab_size"] + ) + graph.value_info.append(lm_head_bf16_info) + + # Update logits output to FP32 + for out in graph.output: + if out.name == "logits": + out.type.tensor_type.elem_type = TensorProto.FLOAT + + if verbose: + logger.info(f" Added Cast to FP32 after {lm_head_matmul_name}") + return True + else: + if verbose: + logger.info(f" Warning: Could not find {lm_head_matmul_name} node") + return False + + +def _create_attention_mask_subgraph( + graph: onnx.GraphProto, + onnx_dtype: int, +) -> tuple[str, str]: + """Create attention mask reformatting subgraph for GQA. + + Args: + graph: ONNX graph to modify. + onnx_dtype: ONNX dtype constant. + + Returns: + Tuple of (seqlens_k_output, total_seq_len_output) tensor names. + """ + attn_mask_basename = "/model/attn_mask_reformat/attn_mask_subgraph" + + # ReduceSum: sum attention_mask along axis 1 + reduce_sum_node = helper.make_node( + "ReduceSum", + inputs=["attention_mask", "/model/constants/INT64/[1]"], + outputs=[f"{attn_mask_basename}/ReduceSum/output_0"], + name=f"{attn_mask_basename}/ReduceSum", + keepdims=0, + ) + graph.node.append(reduce_sum_node) + + # Sub: seqlens_k = ReduceSum - 1 + sub_node = helper.make_node( + "Sub", + inputs=[f"{attn_mask_basename}/ReduceSum/output_0", "/model/constants/INT64/1"], + outputs=[f"{attn_mask_basename}/Sub/output_0"], + name=f"{attn_mask_basename}/Sub", + ) + graph.node.append(sub_node) + + # Cast seqlens_k to int32 + cast_seqlens_node = helper.make_node( + "Cast", + inputs=[f"{attn_mask_basename}/Sub/output_0"], + outputs=[f"{attn_mask_basename}/Sub/Cast/output_0"], + name=f"{attn_mask_basename}/Sub/Cast", + to=TensorProto.INT32, + ) + graph.node.append(cast_seqlens_node) + + # Shape of attention_mask + shape_node = helper.make_node( + "Shape", + inputs=["attention_mask"], + outputs=[f"{attn_mask_basename}/Shape/output_0"], + name=f"{attn_mask_basename}/Shape", + ) + graph.node.append(shape_node) + + # Gather index 1 (sequence length dimension) + gather_node = helper.make_node( + "Gather", + inputs=[f"{attn_mask_basename}/Shape/output_0", "/model/constants/INT64/1"], + outputs=[f"{attn_mask_basename}/Gather/output_0"], + name=f"{attn_mask_basename}/Gather", + axis=0, + ) + graph.node.append(gather_node) + + # Cast total_seq_len to int32 + cast_total_node = helper.make_node( + "Cast", + inputs=[f"{attn_mask_basename}/Gather/output_0"], + outputs=[f"{attn_mask_basename}/Gather/Cast/output_0"], + name=f"{attn_mask_basename}/Gather/Cast", + to=TensorProto.INT32, + ) + graph.node.append(cast_total_node) + + seqlens_k_output = f"{attn_mask_basename}/Sub/Cast/output_0" + total_seq_len_output = f"{attn_mask_basename}/Gather/Cast/output_0" + + # Add value_info for mask subgraph outputs + seqlens_k_info = helper.make_tensor_value_info( + seqlens_k_output, TensorProto.INT32, ["batch_size"] + ) + total_seq_len_info = helper.make_tensor_value_info(total_seq_len_output, TensorProto.INT32, []) + graph.value_info.extend([seqlens_k_info, total_seq_len_info]) + + return seqlens_k_output, total_seq_len_output + + +def _fuse_qkv_and_create_gqa( + graph: onnx.GraphProto, + layer_id: int, + num_attention_heads: int, + num_kv_heads: int, + head_dim: int, + hidden_size: int, + seqlens_k_output: str, + total_seq_len_output: str, + onnx_dtype: int, + attn_prefix: str = "self_attn", + pack_qkv: bool = False, + verbose: bool = True, +) -> tuple[list[onnx.NodeProto], list[onnx.NodeProto], list[onnx.NodeProto]]: + """Fuse Q/K/V MatMuls and create GQA node for a single layer. + + Args: + graph: ONNX graph to modify. + layer_id: Layer index. + num_attention_heads: Number of attention heads. + num_kv_heads: Number of key-value heads. + head_dim: Dimension per head. + hidden_size: Hidden dimension. + seqlens_k_output: Name of seqlens_k tensor. + total_seq_len_output: Name of total_seq_len tensor. + onnx_dtype: ONNX dtype constant. + attn_prefix: Attention namespace prefix ("self_attn" or "attn"). + verbose: Whether to print progress. + + Returns: + Tuple of (qkv_matmul_nodes, gqa_nodes, nodes_to_remove). + """ + qkv_matmul_nodes = [] + gqa_nodes = [] + qkv_nodes_to_remove = [] + + # Helper to find node by name in graph + def _find_node(name: str) -> onnx.NodeProto | None: + for node in graph.node: + if node.name == name: + return node + return None + + # Helper to find node by name pattern (partial match), optionally filtered by op_type + def _find_node_by_pattern(pattern: str, op_type: str | None = None) -> onnx.NodeProto | None: + for node in graph.node: + if pattern in node.name: + if op_type is None or node.op_type == op_type: + return node + return None + + # Try separate Q, K, V MatMul nodes first + q_matmul_name = f"/model/layers.{layer_id}/{attn_prefix}/q_proj/MatMul" + k_matmul_name = f"/model/layers.{layer_id}/{attn_prefix}/k_proj/MatMul" + v_matmul_name = f"/model/layers.{layer_id}/{attn_prefix}/v_proj/MatMul" + + q_matmul = _find_node(q_matmul_name) + k_matmul = _find_node(k_matmul_name) + v_matmul = _find_node(v_matmul_name) + + if not q_matmul or not k_matmul or not v_matmul: + # Check for combined qkv_proj pattern (e.g. TinyLlama) + qkv_proj_pattern = f"/model/layers.{layer_id}/{attn_prefix}/qkv_proj/MatMul" + qkv_matmul = _find_node_by_pattern(qkv_proj_pattern, op_type="MatMul") + + if qkv_matmul is not None: + if verbose: + logger.info(f" Layer {layer_id}: Combined qkv_proj detected: {qkv_matmul.name}") + logger.info(" Using packed GQA mode (QKV already combined)") + + qkv_output = qkv_matmul.output[0] + + # Check for bias Add after combined qkv_proj MatMul + qkv_add = _find_node(f"/model/layers.{layer_id}/{attn_prefix}/qkv_proj/Add") + if qkv_add is not None: + qkv_output = qkv_add.output[0] + if verbose: + logger.info(f" Found qkv_proj/Add, using Add output: {qkv_output}") + + # Create GQA node with packed QKV input + past_key = f"past_key_values.{layer_id}.key" + past_value = f"past_key_values.{layer_id}.value" + present_key = f"present.{layer_id}.key" + present_value = f"present.{layer_id}.value" + gqa_output = f"/model/layers.{layer_id}/{attn_prefix}/GQA/output_0" + + gqa_node = helper.make_node( + "GroupQueryAttention", + inputs=[ + qkv_output, # packed QKV + "", # key (empty for packed mode) + "", # value (empty for packed mode) + past_key, + past_value, + seqlens_k_output, + total_seq_len_output, + "cos_cache", + "sin_cache", + "", # position_ids + "", # attention_bias + ], + outputs=[ + gqa_output, + present_key, + present_value, + ], + name=f"/model/layers.{layer_id}/{attn_prefix}/GQA", + domain="com.microsoft", + num_heads=num_attention_heads, + kv_num_heads=num_kv_heads, + scale=1.0 / (head_dim**0.5), + do_rotary=1, + rotary_interleaved=0, + local_window_size=-1, + ) + gqa_nodes.append(gqa_node) + + # Add value_info for GQA output + gqa_output_info = helper.make_tensor_value_info( + gqa_output, onnx_dtype, ["batch_size", "sequence_length", hidden_size] + ) + graph.value_info.append(gqa_output_info) + + if verbose: + logger.info(f" Created GQA with packed input: {qkv_output}") + + return qkv_matmul_nodes, gqa_nodes, qkv_nodes_to_remove + + if verbose: + logger.info(f" Warning: Could not find Q/K/V MatMul nodes for layer {layer_id}") + return [], [], [] + + assert q_matmul is not None + assert k_matmul is not None + assert v_matmul is not None + + # Get weight initializer names + q_weight_name = q_matmul.input[1] + k_weight_name = k_matmul.input[1] + v_weight_name = v_matmul.input[1] + + # Get the input to Q MatMul + qkv_input = q_matmul.input[0] + + # Find weight initializers + q_weight = find_initializer(graph, q_weight_name) + k_weight = find_initializer(graph, k_weight_name) + v_weight = find_initializer(graph, v_weight_name) + + if not all([q_weight, k_weight, v_weight]): + # Quantized model path: keep separate Q/K/V projections + if verbose: + logger.info( + f" Layer {layer_id}: Quantized model detected (weights behind DequantizeLinear)" + ) + + q_output = q_matmul.output[0] + k_output = k_matmul.output[0] + v_output = v_matmul.output[0] + + # Check for bias Add nodes after Q/K/V MatMuls + q_add = _find_node(f"/model/layers.{layer_id}/{attn_prefix}/q_proj/Add") + k_add = _find_node(f"/model/layers.{layer_id}/{attn_prefix}/k_proj/Add") + v_add = _find_node(f"/model/layers.{layer_id}/{attn_prefix}/v_proj/Add") + if q_add and k_add and v_add: + # Use Add outputs (after bias) instead of raw MatMul outputs + q_output = q_add.output[0] + k_output = k_add.output[0] + v_output = v_add.output[0] + if verbose: + logger.info(" Found bias Add nodes, using Add outputs for Q/K/V") + + past_key = f"past_key_values.{layer_id}.key" + past_value = f"past_key_values.{layer_id}.value" + present_key = f"present.{layer_id}.key" + present_value = f"present.{layer_id}.value" + gqa_output = f"/model/layers.{layer_id}/{attn_prefix}/GQA/output_0" + + if pack_qkv: + # Packed QKV mode: Concat Q, K, V horizontally (axis=-1) then feed as single input + if verbose: + logger.info(" Packing Q/K/V with Concat (axis=-1) for packed GQA mode") + + concat_output = f"/model/layers.{layer_id}/{attn_prefix}/qkv_concat/output_0" + + concat_node = helper.make_node( + "Concat", + inputs=[q_output, k_output, v_output], + outputs=[concat_output], + name=f"/model/layers.{layer_id}/{attn_prefix}/qkv_concat", + axis=-1, + ) + qkv_matmul_nodes.append(concat_node) + + # Add value_info for concat output + # Q dim = num_heads * head_dim, K dim = kv_heads * head_dim, V dim = kv_heads * head_dim + qkv_dim = (num_attention_heads + 2 * num_kv_heads) * head_dim + concat_info = helper.make_tensor_value_info( + concat_output, onnx_dtype, ["batch_size", "sequence_length", qkv_dim] + ) + graph.value_info.append(concat_info) + + gqa_node = helper.make_node( + "GroupQueryAttention", + inputs=[ + concat_output, # packed QKV + "", # key (empty for packed mode) + "", # value (empty for packed mode) + past_key, + past_value, + seqlens_k_output, + total_seq_len_output, + "cos_cache", + "sin_cache", + "", # position_ids + "", # attention_bias + ], + outputs=[ + gqa_output, + present_key, + present_value, + ], + name=f"/model/layers.{layer_id}/{attn_prefix}/GQA", + domain="com.microsoft", + num_heads=num_attention_heads, + kv_num_heads=num_kv_heads, + scale=1.0 / (head_dim**0.5), + do_rotary=1, + rotary_interleaved=0, + local_window_size=-1, + ) + else: + # Unpacked mode: separate Q/K/V inputs to GQA + if verbose: + logger.info( + " Keeping separate Q/K/V projections, creating GQA with unpacked inputs" + ) + + gqa_node = helper.make_node( + "GroupQueryAttention", + inputs=[ + q_output, # query + k_output, # key + v_output, # value + past_key, + past_value, + seqlens_k_output, + total_seq_len_output, + "cos_cache", + "sin_cache", + "", # position_ids + "", # attention_bias + ], + outputs=[ + gqa_output, + present_key, + present_value, + ], + name=f"/model/layers.{layer_id}/{attn_prefix}/GQA", + domain="com.microsoft", + num_heads=num_attention_heads, + kv_num_heads=num_kv_heads, + scale=1.0 / (head_dim**0.5), + do_rotary=1, + rotary_interleaved=0, + local_window_size=-1, + ) + + gqa_nodes.append(gqa_node) + + # Add value_info for GQA output + gqa_output_info = helper.make_tensor_value_info( + gqa_output, onnx_dtype, ["batch_size", "sequence_length", hidden_size] + ) + graph.value_info.append(gqa_output_info) + + # Don't remove any projection nodes, don't create fused MatMul + return qkv_matmul_nodes, gqa_nodes, qkv_nodes_to_remove + + # Convert weights to numpy arrays + q_arr, q_bf16 = initializer_to_array(q_weight) + k_arr, _k_bf16 = initializer_to_array(k_weight) + v_arr, _v_bf16 = initializer_to_array(v_weight) + is_bfloat16 = q_bf16 == "bfloat16" + + # Concatenate weights + qkv_weight_arr = np.concatenate([q_arr, k_arr, v_arr], axis=1) + + # Create fused QKV weight initializer + qkv_weight_name = f"/model/layers.{layer_id}/{attn_prefix}/qkv_proj/weight" + qkv_weight_tensor = array_to_initializer(qkv_weight_arr, qkv_weight_name, is_bfloat16) + graph.initializer.append(qkv_weight_tensor) + + # Create fused QKV MatMul node + qkv_matmul_name = f"/model/layers.{layer_id}/{attn_prefix}/qkv_proj/MatMul" + qkv_matmul_output = f"{qkv_matmul_name}_output_0" + + qkv_matmul_node = helper.make_node( + "MatMul", + inputs=[qkv_input, qkv_weight_name], + outputs=[qkv_matmul_output], + name=qkv_matmul_name, + ) + qkv_matmul_nodes.append(qkv_matmul_node) + + # Add value_info for fused QKV output + packed_qkv_dim = (num_attention_heads + 2 * num_kv_heads) * head_dim + qkv_output_info = helper.make_tensor_value_info( + qkv_matmul_output, onnx_dtype, ["batch_size", "sequence_length", packed_qkv_dim] + ) + graph.value_info.append(qkv_output_info) + + # Mark old Q/K/V MatMul nodes for removal + qkv_nodes_to_remove.extend([q_matmul, k_matmul, v_matmul]) + + # Remove old weight initializers + graph.initializer.remove(q_weight) + graph.initializer.remove(k_weight) + graph.initializer.remove(v_weight) + + # Check for bias Add nodes (e.g. Qwen models) + q_add_name = f"/model/layers.{layer_id}/{attn_prefix}/q_proj/Add" + k_add_name = f"/model/layers.{layer_id}/{attn_prefix}/k_proj/Add" + v_add_name = f"/model/layers.{layer_id}/{attn_prefix}/v_proj/Add" + + q_add = _find_node(q_add_name) + k_add = _find_node(k_add_name) + v_add = _find_node(v_add_name) + + gqa_input = qkv_matmul_output + + if q_add and k_add and v_add: + # Fuse bias Add operations + if verbose: + logger.info(f" Layer {layer_id}: Found bias Add nodes, fusing biases...") + + # Get bias initializer names + q_bias_name = q_add.input[1] if find_initializer(graph, q_add.input[1]) else q_add.input[0] + k_bias_name = k_add.input[1] if find_initializer(graph, k_add.input[1]) else k_add.input[0] + v_bias_name = v_add.input[1] if find_initializer(graph, v_add.input[1]) else v_add.input[0] + + q_bias = find_initializer(graph, q_bias_name) + k_bias = find_initializer(graph, k_bias_name) + v_bias = find_initializer(graph, v_bias_name) + + if all([q_bias, k_bias, v_bias]): + # Concatenate biases + q_bias_arr, qb_bf16 = initializer_to_array(q_bias) + k_bias_arr, _ = initializer_to_array(k_bias) + v_bias_arr, _ = initializer_to_array(v_bias) + bias_is_bfloat16 = qb_bf16 == "bfloat16" + + qkv_bias_arr = np.concatenate([q_bias_arr, k_bias_arr, v_bias_arr], axis=0) + + # Create fused bias initializer + qkv_bias_name = f"/model/layers.{layer_id}/{attn_prefix}/qkv_proj/bias" + qkv_bias_tensor = array_to_initializer(qkv_bias_arr, qkv_bias_name, bias_is_bfloat16) + graph.initializer.append(qkv_bias_tensor) + + # Create fused Add node + qkv_add_name = f"/model/layers.{layer_id}/{attn_prefix}/qkv_proj/Add" + qkv_add_output = f"{qkv_add_name}_output_0" + + qkv_add_node = helper.make_node( + "Add", + inputs=[qkv_matmul_output, qkv_bias_name], + outputs=[qkv_add_output], + name=qkv_add_name, + ) + graph.node.append(qkv_add_node) + + # Add value_info + qkv_add_info = helper.make_tensor_value_info( + qkv_add_output, onnx_dtype, ["batch_size", "sequence_length", packed_qkv_dim] + ) + graph.value_info.append(qkv_add_info) + + # Update GQA input + gqa_input = qkv_add_output + + # Mark old Add nodes for removal + qkv_nodes_to_remove.extend([q_add, k_add, v_add]) + + # Remove old bias initializers + graph.initializer.remove(q_bias) + graph.initializer.remove(k_bias) + graph.initializer.remove(v_bias) + + if verbose: + logger.info( + f" Layer {layer_id}: Fused biases {q_bias_arr.shape} + " + f"{k_bias_arr.shape} + {v_bias_arr.shape} -> {qkv_bias_arr.shape}" + ) + + # Create GQA node + past_key = f"past_key_values.{layer_id}.key" + past_value = f"past_key_values.{layer_id}.value" + present_key = f"present.{layer_id}.key" + present_value = f"present.{layer_id}.value" + gqa_output = f"/model/layers.{layer_id}/{attn_prefix}/GQA/output_0" + + gqa_node = helper.make_node( + "GroupQueryAttention", + inputs=[ + gqa_input, + "", # key (empty for packed mode) + "", # value (empty for packed mode) + past_key, + past_value, + seqlens_k_output, + total_seq_len_output, + "cos_cache", + "sin_cache", + "", # position_ids + "", # attention_bias + ], + outputs=[ + gqa_output, + present_key, + present_value, + ], + name=f"/model/layers.{layer_id}/{attn_prefix}/GQA", + domain="com.microsoft", + num_heads=num_attention_heads, + kv_num_heads=num_kv_heads, + scale=1.0 / (head_dim**0.5), + do_rotary=1, + rotary_interleaved=0, + local_window_size=-1, + ) + gqa_nodes.append(gqa_node) + + # Add value_info for GQA output + gqa_output_info = helper.make_tensor_value_info( + gqa_output, onnx_dtype, ["batch_size", "sequence_length", hidden_size] + ) + graph.value_info.append(gqa_output_info) + + if verbose: + logger.info( + f" Layer {layer_id}: Fused Q/K/V weights {q_arr.shape} + " + f"{k_arr.shape} + {v_arr.shape} -> {qkv_weight_arr.shape}" + ) + + return qkv_matmul_nodes, gqa_nodes, qkv_nodes_to_remove + + +def replace_attention_with_gqa( + model_path: str, + output_path: str, + hf_model_id: str, + max_seq_len: int = 4096, + io_dtype: str = "float16", + use_external_data: bool = True, + external_data_name: str | None = None, + ir_version: int | None = None, + pack_qkv: bool = False, + verbose: bool = True, + trust_remote_code: bool = False, +) -> onnx.ModelProto: + """Replace attention subgraphs with GroupQueryAttention (GQA) in an ONNX model. + + This function transforms an ONNX model exported from HuggingFace/Optimum + to use Microsoft's GroupQueryAttention operator, which is optimized for + inference with ONNX Runtime. + + The transformation includes: + - Converting weights to target dtype (FP16/BF16) [non-quantized models only] + - Adding RoPE cos/sin caches + - Replacing attention patterns with GQA for all layers + - Fusing Q/K/V projections into single MatMul [non-quantized models only] + - Concatenating Q/K/V outputs for GQA [quantized models only] + - Adding past/present KV cache inputs/outputs + + Args: + model_path: Path to input ONNX model. + output_path: Path to save modified model. + hf_model_id: HuggingFace model ID for config. + max_seq_len: Maximum sequence length for caches. + io_dtype: Data type for I/O tensors ("float16", "float32", or "bfloat16"). + If the model has FP16 initializers and "bfloat16" is specified, + they are automatically converted to BF16. + use_external_data: Save weights as external data file. + external_data_name: Name for external data file (default: model.onnx_data). + ir_version: If specified, set the ONNX IR version to this value. Useful for + compatibility with older ONNX Runtime versions (e.g., set to 9 for ORT 1.16). + verbose: Whether to print progress messages. + trust_remote_code: Whether to trust remote code in HuggingFace model config. + + Returns: + Modified ONNX model. + + Example: + >>> from modelopt.onnx.graph_surgery import replace_attention_with_gqa + >>> model = replace_attention_with_gqa( + ... model_path="model_fp16.onnx", + ... output_path="model_gqa.onnx", + ... hf_model_id="meta-llama/Llama-2-7b-hf", + ... max_seq_len=4096, + ... io_dtype="float16", + ... ) + """ + if verbose: + logger.info(f"Loading model from: {model_path}") + model = onnx.load(model_path) + graph = model.graph + + onnx_dtype = get_onnx_dtype(io_dtype) + + # Early detection: check if model is quantized (has DequantizeLinear nodes) + has_dequantize = any(n.op_type == "DequantizeLinear" for n in graph.node) + if has_dequantize and verbose: + logger.info("Quantized model detected (DequantizeLinear nodes found)") + logger.info(" Skipping dtype conversion and Cast removal to preserve quantization graph") + + if not has_dequantize: + # Step 0: Convert float32 weights to target dtype (non-quantized models only) + if verbose: + logger.info(f"\nConverting float32 initializers to {io_dtype}...") + converted_count = convert_initializers_to_dtype(graph, io_dtype) + if verbose: + logger.info(f" Converted {converted_count} initializers to {io_dtype}") + + # Step 0.1: If target is bfloat16, also convert all FP16 elements to BF16 + if io_dtype == "bfloat16": + if verbose: + logger.info("\nConverting FP16 elements to BF16 (io_dtype=bfloat16)...") + convert_model_fp16_to_bf16(graph, verbose=verbose) + + # Step 0.5: Remove unnecessary Cast nodes in layers + if verbose: + logger.info("\nRemoving unnecessary /model/layers.{i}/Cast and /Cast_1 nodes...") + cast_nodes_removed = _remove_layer_cast_nodes(graph, verbose) + if verbose: + logger.info(f" Total Cast nodes removed: {cast_nodes_removed}") + + if not has_dequantize: + # Step 0.6: Gemma-specific casts + is_gemma = "gemma" in hf_model_id.lower() + if is_gemma: + if verbose: + logger.info( + "\nGemma model detected - adding Cast to fp16 after layernorm Mul nodes..." + ) + _add_gemma_cast_nodes( + graph, hf_model_id, io_dtype, onnx_dtype, verbose, trust_remote_code + ) + + # Step 0.7: BF16 lm_head cast for TensorRT compatibility + if io_dtype == "bfloat16": + if verbose: + logger.info( + "\nAdding Cast to FP32 after /lm_head/MatMul for bfloat16 TensorRT compatibility..." + ) + _add_bf16_lm_head_cast(graph, verbose) + + # Get config and compute caches + if verbose: + logger.info(f"\nComputing RoPE caches from: {hf_model_id}") + cos_cache, sin_cache, config = get_rope_caches( + hf_model_id, max_seq_len, io_dtype, trust_remote_code=trust_remote_code + ) + + num_layers = config.num_hidden_layers + num_attention_heads = config.num_attention_heads + num_kv_heads = getattr(config, "num_key_value_heads", num_attention_heads) + head_dim = config.hidden_size // num_attention_heads + hidden_size = config.hidden_size + + # Auto-detect attention namespace: "self_attn" vs "attn" + attn_prefix = "self_attn" + for node in graph.node: + if "/layers.0/attn/" in node.name and "/self_attn/" not in node.name: + attn_prefix = "attn" + break + + # Auto-detect combined QKV pattern + has_combined_qkv = any(f"/layers.0/{attn_prefix}/qkv_proj/" in node.name for node in graph.node) + + if verbose: + logger.info("Model config:") + logger.info(f" num_layers: {num_layers}") + logger.info(f" num_attention_heads: {num_attention_heads}") + logger.info(f" num_kv_heads: {num_kv_heads}") + logger.info(f" head_dim: {head_dim}") + logger.info(f" hidden_size: {hidden_size}") + logger.info(f" cos_cache shape: {cos_cache.shape}") + logger.info(f" sin_cache shape: {sin_cache.shape}") + logger.info(f" attn_prefix: {attn_prefix}") + logger.info(f" has_combined_qkv: {has_combined_qkv}") + + # Step 1: Add cos/sin cache initializers + if verbose: + logger.info("\nAdding cos/sin cache initializers...") + add_initializer(graph, "cos_cache", cos_cache, onnx_dtype) + add_initializer(graph, "sin_cache", sin_cache, onnx_dtype) + + # Add value_info for cos/sin caches + cos_cache_info = helper.make_tensor_value_info("cos_cache", onnx_dtype, list(cos_cache.shape)) + sin_cache_info = helper.make_tensor_value_info("sin_cache", onnx_dtype, list(sin_cache.shape)) + graph.value_info.extend([cos_cache_info, sin_cache_info]) + + # Step 2: Add constant initializers + if verbose: + logger.info("Adding constant initializers...") + add_initializer( + graph, "/model/constants/INT64/1", np.array(1, dtype=np.int64), TensorProto.INT64 + ) + add_initializer( + graph, "/model/constants/INT64/[1]", np.array([1], dtype=np.int64), TensorProto.INT64 + ) + + # Step 2.5: Rename suffixed I/O names to standard names + # Some models have suffixed names like input_ids_318, attention_mask_337 etc. + # Rename them to standard names for compatibility with genai_config. + io_renames = {} + for inp in graph.input: + if inp.name.startswith("input_ids") and inp.name != "input_ids": + io_renames[inp.name] = "input_ids" + elif inp.name.startswith("attention_mask") and inp.name != "attention_mask": + io_renames[inp.name] = "attention_mask" + for out in graph.output: + if out.name.startswith("logits") and out.name != "logits": + io_renames[out.name] = "logits" + + if io_renames and verbose: + logger.info("Renaming suffixed I/O names to standard names...") + + for old_name, new_name in io_renames.items(): + # Rename graph input or output + for inp in graph.input: + if inp.name == old_name: + inp.name = new_name + if verbose: + logger.info(f" Renamed input: {old_name} -> {new_name}") + break + for out in graph.output: + if out.name == old_name: + out.name = new_name + if verbose: + logger.info(f" Renamed output: {old_name} -> {new_name}") + break + # Rename in all node inputs/outputs + for node in graph.node: + for i, inp in enumerate(node.input): + if inp == old_name: + node.input[i] = new_name + for i, out in enumerate(node.output): + if out == old_name: + node.output[i] = new_name + # Rename in value_info + for vi in graph.value_info: + if vi.name == old_name: + vi.name = new_name + + # Step 3: Ensure attention_mask input exists with dynamic shape + existing_inputs = [inp.name for inp in graph.input] + if "attention_mask" not in existing_inputs: + if verbose: + logger.info("Adding attention_mask input...") + attn_mask_input = helper.make_tensor_value_info( + "attention_mask", TensorProto.INT64, ["batch_size", "total_sequence_length"] + ) + graph.input.append(attn_mask_input) + else: + # Update existing attention_mask to have dynamic shape + for inp in graph.input: + if inp.name == "attention_mask": + inp.CopyFrom( + helper.make_tensor_value_info( + "attention_mask", TensorProto.INT64, ["batch_size", "total_sequence_length"] + ) + ) + if verbose: + logger.info("Updated attention_mask input to dynamic shape") + break + + # Also ensure input_ids has dynamic shape + for inp in graph.input: + if inp.name == "input_ids": + inp.CopyFrom( + helper.make_tensor_value_info( + "input_ids", TensorProto.INT64, ["batch_size", "sequence_length"] + ) + ) + if verbose: + logger.info("Updated input_ids input to dynamic shape") + break + + # Step 4: Remove existing past_key_values inputs and present outputs + if verbose: + logger.info("Handling past_key_values inputs and present outputs...") + + # Remove existing past_key_values inputs + inputs_to_remove = [] + for inp in graph.input: + if "past_key_values" in inp.name or "past_key" in inp.name or "past_value" in inp.name: + inputs_to_remove.append(inp) + if verbose: + logger.info(f" Removing existing input: {inp.name}") + for inp in inputs_to_remove: + graph.input.remove(inp) + + # Remove existing present outputs + outputs_to_remove = [] + for out in graph.output: + if "present" in out.name: + outputs_to_remove.append(out) + if verbose: + logger.info(f" Removing existing output: {out.name}") + for out in outputs_to_remove: + graph.output.remove(out) + + # Clean up value_info and initializers + value_info_to_remove = [ + vi for vi in graph.value_info if "past_key_values" in vi.name or "present" in vi.name + ] + for vi in value_info_to_remove: + graph.value_info.remove(vi) + + initializers_to_remove = [ + init + for init in graph.initializer + if "past_key_values" in init.name or "present" in init.name + ] + for init in initializers_to_remove: + graph.initializer.remove(init) + + if verbose: + logger.info(f" Removed {len(inputs_to_remove)} existing past_key_values inputs") + logger.info(f" Removed {len(outputs_to_remove)} existing present outputs") + logger.info(" Adding new past_key_values inputs and present outputs...") + + # Add new past_key_values inputs and present outputs + kv_cache_shape = ["batch_size", num_kv_heads, "past_sequence_length", head_dim] + present_shape = ["batch_size", num_kv_heads, "total_sequence_length", head_dim] + + for layer_id in range(num_layers): + # Past key/value inputs + past_key_name = f"past_key_values.{layer_id}.key" + past_value_name = f"past_key_values.{layer_id}.value" + graph.input.append(helper.make_tensor_value_info(past_key_name, onnx_dtype, kv_cache_shape)) + graph.input.append( + helper.make_tensor_value_info(past_value_name, onnx_dtype, kv_cache_shape) + ) + + # Present key/value outputs + present_key_name = f"present.{layer_id}.key" + present_value_name = f"present.{layer_id}.value" + graph.output.append( + helper.make_tensor_value_info(present_key_name, onnx_dtype, present_shape) + ) + graph.output.append( + helper.make_tensor_value_info(present_value_name, onnx_dtype, present_shape) + ) + + # Step 5: Create attention mask reformatting subgraph + if verbose: + logger.info("Creating attention mask reformatting subgraph...") + seqlens_k_output, total_seq_len_output = _create_attention_mask_subgraph(graph, onnx_dtype) + + # Step 6: Process Q/K/V projections and create GQA nodes + all_qkv_matmul_nodes = [] + all_gqa_nodes = [] + all_qkv_nodes_to_remove = [] + + # Fuse Q/K/V weights (or keep separate for quantized models) + if verbose: + logger.info("Processing Q/K/V projections and creating GQA nodes for each layer...") + + for layer_id in range(num_layers): + qkv_matmul_nodes, gqa_nodes, qkv_nodes_to_remove = _fuse_qkv_and_create_gqa( + graph, + layer_id, + num_attention_heads, + num_kv_heads, + head_dim, + hidden_size, + seqlens_k_output, + total_seq_len_output, + onnx_dtype, + attn_prefix=attn_prefix, + pack_qkv=pack_qkv, + verbose=verbose, + ) + all_qkv_matmul_nodes.extend(qkv_matmul_nodes) + all_gqa_nodes.extend(gqa_nodes) + all_qkv_nodes_to_remove.extend(qkv_nodes_to_remove) + + # Detect if model is quantized: use the early detection flag (has_dequantize) + # We can't rely on all_qkv_matmul_nodes being empty because pack_qkv adds Concat nodes there + is_quantized = has_dequantize + + # Step 7: Identify attention nodes to remove + if verbose: + logger.info( + f"\nIdentifying attention subgraphs to replace " + f"(quantized={is_quantized}, combined_qkv={has_combined_qkv}, " + f"attn_prefix={attn_prefix})..." + ) + + nodes_to_remove = [] + for layer_id in range(num_layers): + layer_prefix = f"/model/layers.{layer_id}/{attn_prefix}" + + for node in graph.node: + if layer_prefix in node.name: + if has_combined_qkv: + # Combined QKV model: keep qkv_proj/ and o_proj/ chains + if any(x in node.name for x in ["/qkv_proj/", "/o_proj/"]): + continue + elif is_quantized: + # Quantized model: keep entire q_proj/, k_proj/, v_proj/, o_proj/ chains + # Also keep AWQ pre_quant_scale nodes (activation scaling before o_proj) + if any( + x in node.name + for x in ["/q_proj/", "/k_proj/", "/v_proj/", "/o_proj/", "pre_quant_scale"] + ): + continue + # Non-quantized: only keep the 4 projection MatMul nodes + elif any( + x in node.name + for x in [ + "/q_proj/MatMul", + "/k_proj/MatMul", + "/v_proj/MatMul", + "/o_proj/MatMul", + ] + ): + continue + nodes_to_remove.append(node) + + if verbose: + logger.info(f"Found {len(nodes_to_remove)} nodes to remove") + + # Remove old Q/K/V nodes + if verbose: + logger.info(f"Removing {len(all_qkv_nodes_to_remove)} old Q/K/V MatMul nodes...") + for node in all_qkv_nodes_to_remove: + with contextlib.suppress(ValueError): + graph.node.remove(node) + + # Add fused QKV MatMul nodes + if verbose: + logger.info(f"Adding {len(all_qkv_matmul_nodes)} fused QKV MatMul nodes...") + for node in all_qkv_matmul_nodes: + graph.node.append(node) + + # Step 8: Remove old attention nodes + if verbose: + logger.info("Removing old attention nodes...") + for node in nodes_to_remove: + with contextlib.suppress(ValueError): + graph.node.remove(node) + + # Step 9: Add GQA nodes + if verbose: + logger.info("Adding GQA nodes...") + for gqa_node in all_gqa_nodes: + graph.node.append(gqa_node) + + # Step 10: Reconnect o_proj to GQA output + if verbose: + logger.info( + f"Reconnecting o_proj inputs to GQA outputs " + f"(quantized={is_quantized}, combined_qkv={has_combined_qkv})..." + ) + + for layer_id in range(num_layers): + gqa_output = f"/model/layers.{layer_id}/{attn_prefix}/GQA/output_0" + + if has_combined_qkv: + # Combined QKV model: find o_proj MatMul by pattern match + o_proj_pattern = f"/model/layers.{layer_id}/{attn_prefix}/o_proj/MatMul" + connected = False + for node in graph.node: + if o_proj_pattern in node.name and node.op_type == "MatMul": + node.input[0] = gqa_output + if verbose: + logger.info( + f" Layer {layer_id}: Connected {node.name} input[0] to {gqa_output}" + ) + connected = True + break + if not connected and verbose: + logger.info(f" Warning: Could not find o_proj MatMul for layer {layer_id}") + elif is_quantized: + # Quantized model: connect GQA output to the first node in the o_proj quantization chain. + # AWQ pattern: [attn output] -> pre_quant_scale_mul -> o_proj/MatMul + # INT8 pattern: [attn output] -> o_proj/input_quantizer/Mul -> Cast -> o_proj/MatMul + # FP4 pattern: [attn output] -> o_proj/input_quantizer/Cast -> + # TRT_FP4DynamicQuantize -> DQL -> DQL_1 -> Cast(_f16) -> o_proj/MatMul + connected = False + + # Try AWQ pattern first: pre_quant_scale_mul before o_proj + layer_prefix_local = f"/model/layers.{layer_id}/{attn_prefix}" + for node in graph.node: + if ( + layer_prefix_local in node.name + and "pre_quant_scale" in node.name + and node.op_type == "Mul" + ): + node.input[0] = gqa_output + if verbose: + logger.info( + f" Layer {layer_id}: Connected {node.name} input[0] to {gqa_output} (AWQ)" + ) + connected = True + break + + # Try INT8 pattern: o_proj/input_quantizer/Mul + if not connected: + o_proj_quant_mul = ( + f"/model/layers.{layer_id}/{attn_prefix}/o_proj/input_quantizer/Mul" + ) + for node in graph.node: + if node.name == o_proj_quant_mul: + node.input[0] = gqa_output + if verbose: + logger.info( + f" Layer {layer_id}: Connected {o_proj_quant_mul} input[0] to {gqa_output}" + ) + connected = True + break + + # Try FP4 pattern: o_proj/input_quantizer/Cast + if not connected: + o_proj_quant_cast = ( + f"/model/layers.{layer_id}/{attn_prefix}/o_proj/input_quantizer/Cast" + ) + for node in graph.node: + if node.name == o_proj_quant_cast: + node.input[0] = gqa_output + if verbose: + logger.info( + f" Layer {layer_id}: Connected {o_proj_quant_cast} input[0] to {gqa_output}" + ) + connected = True + break + + # Fallback: connect directly to o_proj/MatMul + if not connected: + o_proj_name = f"/model/layers.{layer_id}/{attn_prefix}/o_proj/MatMul" + for node in graph.node: + if node.name == o_proj_name: + node.input[0] = gqa_output + if verbose: + logger.info( + f" Layer {layer_id}: Connected {o_proj_name} to {gqa_output} (fallback)" + ) + connected = True + break + if not connected and verbose: + logger.info(f" Warning: Could not connect o_proj for layer {layer_id}") + else: + # Non-quantized: connect directly to o_proj/MatMul + o_proj_name = f"/model/layers.{layer_id}/{attn_prefix}/o_proj/MatMul" + for node in graph.node: + if node.name == o_proj_name: + node.input[0] = gqa_output + if verbose: + logger.info(f" Layer {layer_id}: Connected {o_proj_name} to {gqa_output}") + break + + # Step 10.5: Fix o_proj input_quantizer dtype for quantized models + # GQA outputs float16, but the input_quantizer/Mul scale is float32 and there's + # a redundant Cast node. Convert scale to fp16, remove Cast, rewire directly. + if is_quantized and not has_combined_qkv: + if verbose: + logger.info("Fixing o_proj input_quantizer dtypes for quantized model...") + for layer_id in range(num_layers): + quant_mul_name = f"/model/layers.{layer_id}/{attn_prefix}/o_proj/input_quantizer/Mul" + cast_name = f"/model/layers.{layer_id}/{attn_prefix}/o_proj/MatMul_act_cast_fp16" + + # 1) Convert scale initializer to float16 + for node in graph.node: + if node.name == quant_mul_name: + scale_name = node.input[1] + for init in graph.initializer: + if init.name == scale_name and init.data_type == TensorProto.FLOAT: + from onnx import numpy_helper + + arr_fp16 = numpy_helper.to_array(init).astype(np.float16) + converted_init = numpy_helper.from_array(arr_fp16, name=init.name) + init.CopyFrom(converted_init) + if verbose: + logger.info( + f" Layer {layer_id}: Converted {scale_name} to float16" + ) + break + break + + # 2) Remove Cast node and rewire: Mul output goes directly to MatMul + cast_node = None + for node in graph.node: + if node.name == cast_name: + cast_node = node + break + + if cast_node is not None: + cast_input = cast_node.input[0] # input_quantizer/Mul_output_0 + cast_output = cast_node.output[0] # input_quantizer/Mul_output_0_cast_fp16 + # Rewire all consumers of cast_output to use cast_input + for node in graph.node: + for i, inp in enumerate(node.input): + if inp == cast_output: + node.input[i] = cast_input + graph.node.remove(cast_node) + # Update value_info for the Mul output to float16 + for vi in graph.value_info: + if vi.name == cast_input: + vi.type.tensor_type.elem_type = TensorProto.FLOAT16 + if verbose: + logger.info( + f" Layer {layer_id}: Updated {cast_input} value_info to float16" + ) + break + if verbose: + logger.info(f" Layer {layer_id}: Removed {cast_name}") + + # Step 11: Add opset import for com.microsoft domain + if verbose: + logger.info("Adding com.microsoft opset import...") + has_ms_domain = any(opset.domain == "com.microsoft" for opset in model.opset_import) + + if not has_ms_domain: + ms_opset = helper.make_opsetid("com.microsoft", 1) + model.opset_import.append(ms_opset) + + # Step 12: Clean up unused I/Os + if verbose: + logger.info("\nCleaning up unused I/Os and orphaned nodes...") + cleanup_stats = cleanup_unused_ios(graph) + if verbose: + logger.info(f" Removed {cleanup_stats['nodes_removed']} orphaned nodes") + logger.info(f" Removed {cleanup_stats['inputs_removed']} unused inputs") + logger.info(f" Removed {cleanup_stats['outputs_removed']} unused outputs") + logger.info(f" Removed {cleanup_stats['initializers_removed']} unused initializers") + logger.info(f" Removed {cleanup_stats['value_info_removed']} unused value_info entries") + + # Step 12.5: Adjust IR version if specified + if ir_version is not None: + # Check if opset 21 is being used (requires IR version >= 10) + current_opset = 0 + for opset in model.opset_import: + if opset.domain in {"", "ai.onnx"}: + current_opset = opset.version + break + + if current_opset >= 21 and ir_version < 10: + if verbose: + logger.info( + f"\nWarning: opset {current_opset} requires IR version >= 10," + f" but --ir-version {ir_version} was requested." + ) + logger.info(" Setting IR version to 10 (minimum for opset 21)") + ir_version = 10 + + old_ir = model.ir_version + model.ir_version = ir_version + if verbose: + logger.info(f"\nSetting IR version: {old_ir} -> {ir_version}") + + # Step 13: Save the modified model + if verbose: + logger.info(f"\nSaving modified model to: {output_path}") + + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + if external_data_name is None: + external_data_name = os.path.basename(output_path) + "_data" + + if use_external_data: + if verbose: + logger.info(f" Saving weights to external file: {external_data_name}") + + convert_model_to_external_data( + model, + all_tensors_to_one_file=True, + location=external_data_name, + size_threshold=1024, + convert_attribute=False, + ) + + onnx.save( + model, + output_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_data_name, + size_threshold=1024, + ) + else: + onnx.save(model, output_path) + + # Run shape inference (file-to-file, works with external data) + if verbose: + logger.info("\nRunning shape inference (file-to-file)...") + try: + onnx.shape_inference.infer_shapes_path( + output_path, output_path, check_type=False, strict_mode=False, data_prop=False + ) + if verbose: + logger.info(" Shape inference completed") + except Exception as e: + if verbose: + logger.info(f" Shape inference failed (non-fatal, model already saved): {e}") + + if verbose: + logger.info("\n" + "=" * 60) + logger.info("DONE! Model has been modified with GQA attention.") + logger.info("=" * 60) + logger.info("\nSummary:") + if is_quantized: + logger.info(" - Mode: Quantized (INT4/INT8) - separate Q/K/V projections") + else: + logger.info(" - Mode: Standard (FP16/BF16) - fused Q/K/V weights") + logger.info(f" - Added {len(all_qkv_matmul_nodes)} fused QKV MatMul nodes") + logger.info(f" - Replaced {num_layers} attention subgraphs with GQA") + logger.info(f" - Added cos_cache shape: {cos_cache.shape}") + logger.info(f" - Added sin_cache shape: {sin_cache.shape}") + logger.info(f" - Added {num_layers * 2} past_key_values inputs") + logger.info(f" - Added {num_layers * 2} present outputs") + logger.info(" - Added attention mask reformatting subgraph") + logger.info(f" - Cleaned up {sum(cleanup_stats.values())} unused graph elements") + if use_external_data: + logger.info(f" - Weights saved to: {external_data_name}") + + return model + + +def analyze_attention_pattern(model_path: str, layer_id: int = 0) -> list[onnx.NodeProto]: + """Analyze the attention pattern in an existing model. + + This is useful for debugging before running the full replacement. + + Args: + model_path: Path to ONNX model. + layer_id: Layer to analyze. + + Returns: + List of attention nodes in the specified layer. + """ + logger.info(f"Analyzing attention pattern for layer {layer_id}...") + model = onnx.load(model_path) + graph = model.graph + + layer_prefix = f"/model/layers.{layer_id}/self_attn" + + logger.info(f"\nNodes in {layer_prefix}:") + logger.info("-" * 80) + + attn_nodes = [] + for node in graph.node: + if layer_prefix in node.name: + attn_nodes.append(node) + logger.info(f" {node.op_type:20} | {node.name}") + logger.info(f" inputs: {list(node.input)}") + logger.info(f" outputs: {list(node.output)}") + + logger.info(f"Total attention nodes in layer {layer_id}: {len(attn_nodes)}") + + # Find layer norm output + layernorm_pattern = f"/model/layers.{layer_id}/input_layernorm" + logger.info(f"\nLayer norm nodes ({layernorm_pattern}):") + for node in graph.node: + if layernorm_pattern in node.name: + logger.info(f" {node.op_type:20} | {node.name}") + logger.info(f" outputs: {list(node.output)}") + + return attn_nodes diff --git a/modelopt/onnx/graph_surgery/utils/__init__.py b/modelopt/onnx/graph_surgery/utils/__init__.py new file mode 100644 index 000000000..20c0ff020 --- /dev/null +++ b/modelopt/onnx/graph_surgery/utils/__init__.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for graph surgery operations.""" + +from .dtype_conversion import convert_fp16_to_bf16, fp16_to_bf16_array +from .graph_utils import ( + add_initializer, + array_to_initializer, + cleanup_unused_ios, + convert_model_fp16_to_bf16, + detect_model_dtype, + find_initializer, + find_node_by_name, + find_node_by_output, + find_nodes_by_pattern, + get_all_tensors_used, + get_consumers, + get_onnx_dtype, + initializer_to_array, + remove_node, + topological_sort_nodes, + uses_external_data, +) +from .rope_cache import get_rope_caches +from .whisper_utils import ( + generate_audio_processor_config, + generate_genai_config, + save_audio_processor_config, + save_genai_config, + update_genai_config_decoder, + update_genai_config_encoder, +) + +__all__ = [ + "add_initializer", + "array_to_initializer", + "cleanup_unused_ios", + "convert_fp16_to_bf16", + "convert_model_fp16_to_bf16", + "detect_model_dtype", + "find_initializer", + "find_node_by_name", + "find_node_by_output", + "find_nodes_by_pattern", + "fp16_to_bf16_array", + "generate_audio_processor_config", + "generate_genai_config", + "get_all_tensors_used", + "get_consumers", + "get_onnx_dtype", + "get_rope_caches", + "initializer_to_array", + "remove_node", + "save_audio_processor_config", + "save_genai_config", + "topological_sort_nodes", + "update_genai_config_decoder", + "update_genai_config_encoder", + "uses_external_data", +] diff --git a/modelopt/onnx/graph_surgery/utils/dtype_conversion.py b/modelopt/onnx/graph_surgery/utils/dtype_conversion.py new file mode 100644 index 000000000..920678c1c --- /dev/null +++ b/modelopt/onnx/graph_surgery/utils/dtype_conversion.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data type conversion utilities for ONNX models. + +This module provides functions for converting ONNX models between different +floating-point precisions, particularly FP16 to BF16 conversion. +""" + +import os + +import numpy as np +import onnx +from onnx import TensorProto, numpy_helper + +from ...logging_config import logger + + +def fp16_to_bf16_array(fp16_array: np.ndarray) -> np.ndarray: + """Convert a float16 numpy array to bfloat16. + + BF16 has 1 sign bit, 8 exponent bits, 7 mantissa bits. + FP16 has 1 sign bit, 5 exponent bits, 10 mantissa bits. + + We go FP16 -> FP32 -> BF16 to avoid precision loss. + + Args: + fp16_array: Input float16 numpy array. + + Returns: + BF16 data as uint16 numpy array (raw bit representation). + """ + # Convert FP16 to FP32 + fp32_array = fp16_array.astype(np.float32) + + # View as uint32 to manipulate bits + uint32_view = fp32_array.view(np.uint32) + + # BF16 is just the upper 16 bits of FP32 + # Round to nearest even + rounding = (uint32_view >> 16) & 1 + uint32_view = uint32_view + 0x7FFF + rounding + + # Shift right by 16 to get BF16 as uint16 + bf16_uint16 = (uint32_view >> 16).astype(np.uint16) + + return bf16_uint16 + + +def _convert_initializer_to_bf16(initializer: onnx.TensorProto) -> onnx.TensorProto: + """Convert an FP16 initializer to BF16. + + Args: + initializer: ONNX TensorProto initializer. + + Returns: + New initializer with BF16 data type. + """ + if initializer.data_type != TensorProto.FLOAT16: + return initializer + + # Get the FP16 data + fp16_array = numpy_helper.to_array(initializer) + + # Convert to BF16 (stored as uint16) + bf16_uint16 = fp16_to_bf16_array(fp16_array) + + # Create new initializer with BF16 type + new_initializer = onnx.TensorProto() + new_initializer.name = initializer.name + new_initializer.data_type = TensorProto.BFLOAT16 + new_initializer.dims.extend(initializer.dims) + + # Store BF16 data as raw bytes + new_initializer.raw_data = bf16_uint16.tobytes() + + return new_initializer + + +def _convert_constant_node_to_bf16(node: onnx.NodeProto) -> bool: + """Convert a Constant node's FP16 value to BF16. + + Args: + node: ONNX NodeProto to convert. + + Returns: + True if conversion was performed, False otherwise. + """ + if node.op_type != "Constant": + return False + + for attr in node.attribute: + if attr.name == "value" and attr.t.data_type == TensorProto.FLOAT16: + # Get the FP16 tensor + fp16_array = numpy_helper.to_array(attr.t) + + # Convert to BF16 + bf16_uint16 = fp16_to_bf16_array(fp16_array) + + # Update the tensor in place + attr.t.data_type = TensorProto.BFLOAT16 + attr.t.ClearField("raw_data") + attr.t.ClearField("float_data") + attr.t.ClearField("int32_data") + attr.t.raw_data = bf16_uint16.tobytes() + + return True + + # Handle value_float attribute + if attr.name == "value_float": + fp32_val = np.array([attr.f], dtype=np.float32) + bf16_uint16 = fp16_to_bf16_array(fp32_val.astype(np.float16)) + new_tensor = onnx.TensorProto() + new_tensor.data_type = TensorProto.BFLOAT16 + new_tensor.raw_data = bf16_uint16.tobytes() + attr.t.CopyFrom(new_tensor) + attr.name = "value" + return True + + return False + + +def convert_fp16_to_bf16( + input_path: str, + output_path: str, + external_data: bool = True, + verbose: bool = True, +) -> dict[str, int]: + """Convert an ONNX model from FP16 to BF16. + + This function converts: + 1. All FP16 initializers (weights) to BF16 + 2. All FP16 value_info (intermediate tensors) to BF16 + 3. All FP16 graph inputs/outputs to BF16 + 4. All Cast nodes that target FP16 to target BF16 + + Args: + input_path: Path to input FP16 ONNX model. + output_path: Path to output BF16 ONNX model. + external_data: Whether to save weights as external data. + verbose: Whether to print progress messages. + + Returns: + Dictionary with conversion statistics. + + Example: + >>> stats = convert_fp16_to_bf16( + ... input_path="model_fp16.onnx", + ... output_path="model_bf16.onnx", + ... ) + >>> logger.info(f"Converted {stats['initializers_converted']} initializers") + """ + if verbose: + logger.info(f"Loading model from: {input_path}") + + # Load model with external data + model = onnx.load(input_path, load_external_data=True) + graph = model.graph + + # Statistics + stats = { + "initializers_converted": 0, + "constants_converted": 0, + "casts_converted": 0, + "value_info_converted": 0, + "inputs_converted": 0, + "outputs_converted": 0, + } + + # 1. Convert initializers (weights) + if verbose: + logger.info("Converting initializers...") + new_initializers = [] + for init in graph.initializer: + if init.data_type == TensorProto.FLOAT16: + new_init = _convert_initializer_to_bf16(init) + new_initializers.append(new_init) + stats["initializers_converted"] += 1 + else: + new_initializers.append(init) + + # Clear and replace initializers + while len(graph.initializer) > 0: + graph.initializer.pop() + graph.initializer.extend(new_initializers) + + # 2. Convert Constant nodes and Cast nodes + if verbose: + logger.info("Converting Constant nodes and Cast nodes...") + for node in graph.node: + if _convert_constant_node_to_bf16(node): + stats["constants_converted"] += 1 + + # Convert Cast nodes that cast to FP16 + if node.op_type == "Cast": + for attr in node.attribute: + if attr.name == "to" and attr.i == TensorProto.FLOAT16: + attr.i = TensorProto.BFLOAT16 + stats["casts_converted"] += 1 + + # 3. Convert value_info (intermediate tensors) + if verbose: + logger.info("Converting value_info...") + for vi in graph.value_info: + if vi.type.HasField("tensor_type"): + if vi.type.tensor_type.elem_type == TensorProto.FLOAT16: + vi.type.tensor_type.elem_type = TensorProto.BFLOAT16 + stats["value_info_converted"] += 1 + + # 4. Convert graph inputs + if verbose: + logger.info("Converting graph inputs...") + for inp in graph.input: + if inp.type.HasField("tensor_type"): + if inp.type.tensor_type.elem_type == TensorProto.FLOAT16: + inp.type.tensor_type.elem_type = TensorProto.BFLOAT16 + stats["inputs_converted"] += 1 + + # 5. Convert graph outputs + if verbose: + logger.info("Converting graph outputs...") + for out in graph.output: + if out.type.HasField("tensor_type"): + if out.type.tensor_type.elem_type == TensorProto.FLOAT16: + out.type.tensor_type.elem_type = TensorProto.BFLOAT16 + stats["outputs_converted"] += 1 + + # Print statistics + if verbose: + logger.info("\nConversion statistics:") + logger.info(f" Initializers converted: {stats['initializers_converted']}") + logger.info(f" Constants converted: {stats['constants_converted']}") + logger.info(f" Cast nodes converted: {stats['casts_converted']}") + logger.info(f" Value_info converted: {stats['value_info_converted']}") + logger.info(f" Inputs converted: {stats['inputs_converted']}") + logger.info(f" Outputs converted: {stats['outputs_converted']}") + + # Save model + if verbose: + logger.info(f"\nSaving model to: {output_path}") + + # Ensure output directory exists + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + if external_data: + external_data_path = os.path.basename(output_path) + ".data" + onnx.save( + model, + output_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_data_path, + size_threshold=0, + ) + if verbose: + logger.info(f" External data saved to: {external_data_path}") + else: + onnx.save(model, output_path) + + if verbose: + logger.info("Done!") + + return stats diff --git a/modelopt/onnx/graph_surgery/utils/graph_utils.py b/modelopt/onnx/graph_surgery/utils/graph_utils.py new file mode 100644 index 000000000..3e2c65a50 --- /dev/null +++ b/modelopt/onnx/graph_surgery/utils/graph_utils.py @@ -0,0 +1,615 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common graph manipulation utilities for ONNX models.""" + +import os +from collections import defaultdict, deque + +import numpy as np +import onnx +import torch +from onnx import TensorProto, helper, numpy_helper + +from ...logging_config import logger +from .dtype_conversion import fp16_to_bf16_array + + +def uses_external_data(model_path: str) -> bool: + """Check if an ONNX model uses external data files. + + Args: + model_path: Path to the ONNX model file. + + Returns: + True if external data files exist for the model. + """ + model_dir = os.path.dirname(model_path) or "." + model_name = os.path.basename(model_path) + + # Common external data naming patterns + external_patterns = [ + model_name + "_data", + model_name + ".data", + model_name.replace(".onnx", ".onnx_data"), + model_name.replace(".onnx", "_data"), + ] + + return any(os.path.exists(os.path.join(model_dir, pattern)) for pattern in external_patterns) + + +def topological_sort_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]: + """Topologically sort nodes in the graph for clean ordering. + + Uses Kahn's algorithm to produce a valid topological ordering + respecting data dependencies between nodes. + + Args: + graph: ONNX graph to sort. + + Returns: + List of nodes in topological order. + """ + # Build dependency graph: output_name -> node that produces it + output_to_node: dict[str, onnx.NodeProto] = {} + for node in graph.node: + for out in node.output: + output_to_node[out] = node + + # Get all initializer and input names (available from the start) + available: set[str] = set() + for init in graph.initializer: + available.add(init.name) + for inp in graph.input: + available.add(inp.name) + + # Build adjacency list and in-degree count + node_to_idx: dict[str, int] = {node.name: i for i, node in enumerate(graph.node)} + in_degree: dict[int, int] = dict.fromkeys(range(len(graph.node)), 0) + adj: dict[int, list[int]] = defaultdict(list) + + for i, node in enumerate(graph.node): + for inp in node.input: + if inp and inp not in available: + # This input comes from another node + if inp in output_to_node: + producer = output_to_node[inp] + producer_idx = node_to_idx.get(producer.name) + if producer_idx is not None and producer_idx != i: + adj[producer_idx].append(i) + in_degree[i] += 1 + + # Kahn's algorithm for topological sort + queue: deque[int] = deque() + for i in range(len(graph.node)): + if in_degree[i] == 0: + queue.append(i) + + sorted_indices: list[int] = [] + while queue: + idx = queue.popleft() + sorted_indices.append(idx) + for neighbor in adj[idx]: + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + # If we couldn't sort all nodes, there might be a cycle or disconnected nodes + # Just append any remaining nodes at the end + if len(sorted_indices) < len(graph.node): + remaining = set(range(len(graph.node))) - set(sorted_indices) + sorted_indices.extend(remaining) + + return [graph.node[i] for i in sorted_indices] + + +def detect_model_dtype(model: onnx.ModelProto) -> tuple[int, np.dtype]: + """Detect the primary floating-point dtype of the model. + + Analyzes initializers to find the most common floating-point dtype. + + Args: + model: ONNX model to analyze. + + Returns: + Tuple of (onnx_dtype, numpy_dtype). + """ + dtype_counts: dict[int, int] = {} + for init in model.graph.initializer: + if init.data_type in [TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.BFLOAT16]: + dtype_counts[init.data_type] = dtype_counts.get(init.data_type, 0) + 1 + + if not dtype_counts: + return TensorProto.FLOAT, np.float32 + + dominant_dtype = max(dtype_counts, key=lambda k: dtype_counts[k]) + + dtype_map = { + TensorProto.FLOAT: np.float32, + TensorProto.FLOAT16: np.float16, + TensorProto.BFLOAT16: np.float32, # numpy doesn't support bfloat16 + } + + return dominant_dtype, dtype_map.get(dominant_dtype, np.float32) + + +def get_onnx_dtype(io_dtype: str) -> int: + """Convert string dtype to ONNX TensorProto dtype. + + Args: + io_dtype: String representation of dtype ("float32", "float16", or "bfloat16"). + + Returns: + ONNX TensorProto data type constant. + """ + dtype_map = { + "float32": TensorProto.FLOAT, + "float16": TensorProto.FLOAT16, + "bfloat16": TensorProto.BFLOAT16, + } + return dtype_map.get(io_dtype, TensorProto.FLOAT16) + + +def find_nodes_by_pattern(graph: onnx.GraphProto, pattern: str) -> list[onnx.NodeProto]: + """Find nodes whose name contains the pattern. + + Args: + graph: ONNX graph to search. + pattern: Substring pattern to match in node names. + + Returns: + List of nodes matching the pattern. + """ + return [n for n in graph.node if pattern in n.name] + + +def find_node_by_output(graph: onnx.GraphProto, output_name: str) -> onnx.NodeProto | None: + """Find node that produces the given output. + + Args: + graph: ONNX graph to search. + output_name: Name of the output tensor. + + Returns: + Node that produces the output, or None if not found. + """ + for node in graph.node: + if output_name in node.output: + return node + return None + + +def find_node_by_name(graph: onnx.GraphProto, name: str) -> onnx.NodeProto | None: + """Find node by exact name. + + Args: + graph: ONNX graph to search. + name: Exact node name to find. + + Returns: + Node with the given name, or None if not found. + """ + for node in graph.node: + if node.name == name: + return node + return None + + +def find_initializer(graph: onnx.GraphProto, name: str) -> onnx.TensorProto | None: + """Find initializer by name. + + Args: + graph: ONNX graph to search. + name: Name of the initializer. + + Returns: + Initializer tensor, or None if not found. + """ + for init in graph.initializer: + if init.name == name: + return init + return None + + +def get_consumers(graph: onnx.GraphProto, tensor_name: str) -> list[onnx.NodeProto]: + """Find all nodes that consume the given tensor. + + Args: + graph: ONNX graph to search. + tensor_name: Name of the tensor. + + Returns: + List of nodes that use the tensor as input. + """ + return [node for node in graph.node if tensor_name in node.input] + + +def remove_node(graph: onnx.GraphProto, node: onnx.NodeProto) -> None: + """Remove a node from the graph. + + Args: + graph: ONNX graph to modify. + node: Node to remove. + """ + graph.node.remove(node) + + +def get_all_tensors_used(graph: onnx.GraphProto) -> set[str]: + """Get all tensor names that are actually used in the graph. + + Args: + graph: ONNX graph to analyze. + + Returns: + Set of tensor names that are used in the graph. + """ + used_tensors = set() + + # Tensors produced by nodes + for node in graph.node: + for inp in node.input: + if inp: # Skip empty strings + used_tensors.add(inp) + for out in node.output: + if out: + used_tensors.add(out) + + # Tensors that are graph outputs + for out in graph.output: + used_tensors.add(out.name) + + return used_tensors + + +def cleanup_unused_ios(graph: onnx.GraphProto) -> dict[str, int]: + """Remove unused inputs, outputs, initializers, value_info, and orphaned nodes. + + This function iteratively removes: + 1. Orphaned nodes (nodes whose outputs are not consumed) + 2. Unused graph inputs + 3. Unused graph outputs + 4. Unused initializers + 5. Unused value_info entries + + Args: + graph: ONNX graph to clean up. + + Returns: + Dictionary with counts of removed items. + """ + # First pass: Remove orphaned nodes + total_nodes_removed = 0 + + while True: + # Get all tensor names consumed by nodes + consumed_by_nodes = set() + for node in graph.node: + for inp in node.input: + if inp: + consumed_by_nodes.add(inp) + + # Also include graph outputs as "consumed" + for out in graph.output: + consumed_by_nodes.add(out.name) + + # Find nodes whose outputs are not consumed by anyone + nodes_to_remove = [] + for node in graph.node: + # Skip if node has no outputs (side-effect nodes) + if not node.output: + continue + + # Check if ANY of the node's outputs are consumed + any_output_consumed = False + for out in node.output: + if out and out in consumed_by_nodes: + any_output_consumed = True + break + + # If none of the outputs are consumed, mark for removal + if not any_output_consumed: + nodes_to_remove.append(node) + + if not nodes_to_remove: + break + + for node in nodes_to_remove: + graph.node.remove(node) + + total_nodes_removed += len(nodes_to_remove) + + # Rebuild consumed_by_nodes after node cleanup + consumed_by_nodes = set() + for node in graph.node: + for inp in node.input: + if inp: + consumed_by_nodes.add(inp) + + for out in graph.output: + consumed_by_nodes.add(out.name) + + # Get all tensors used + used_tensors = get_all_tensors_used(graph) + + # Get all tensor names produced by nodes + produced_by_nodes = set() + for node in graph.node: + for out in node.output: + if out: + produced_by_nodes.add(out) + + # Clean up inputs + inputs_to_remove = [inp for inp in graph.input if inp.name not in consumed_by_nodes] + + for inp in inputs_to_remove: + graph.input.remove(inp) + + # Clean up outputs + outputs_to_remove = [] + for out in graph.output: + input_names = {i.name for i in graph.input} + init_names = {i.name for i in graph.initializer} + if ( + out.name not in produced_by_nodes + and out.name not in input_names + and out.name not in init_names + ): + outputs_to_remove.append(out) + + for out in outputs_to_remove: + graph.output.remove(out) + + # Clean up initializers + initializers_to_remove = [ + init for init in graph.initializer if init.name not in consumed_by_nodes + ] + + for init in initializers_to_remove: + graph.initializer.remove(init) + + # Clean up value_info + value_info_to_remove = [vi for vi in graph.value_info if vi.name not in used_tensors] + + for vi in value_info_to_remove: + graph.value_info.remove(vi) + + return { + "nodes_removed": total_nodes_removed, + "inputs_removed": len(inputs_to_remove), + "outputs_removed": len(outputs_to_remove), + "initializers_removed": len(initializers_to_remove), + "value_info_removed": len(value_info_to_remove), + } + + +def initializer_to_array(init: onnx.TensorProto) -> tuple[np.ndarray, str | None]: + """Convert ONNX initializer to numpy array, handling bfloat16. + + Args: + init: ONNX initializer tensor. + + Returns: + Tuple of (numpy array, dtype string if bfloat16 else None). + """ + if init.data_type == TensorProto.BFLOAT16: + # bfloat16 stored as raw bytes - read as int16 for manipulation + arr = np.frombuffer(init.raw_data, dtype=np.int16).reshape(init.dims) + return arr, "bfloat16" + else: + return numpy_helper.to_array(init), None + + +def array_to_initializer(arr: np.ndarray, name: str, is_bfloat16: bool = False) -> onnx.TensorProto: + """Convert numpy array to ONNX initializer, handling bfloat16. + + Args: + arr: Numpy array to convert. + name: Name for the initializer. + is_bfloat16: Whether to store as bfloat16. + + Returns: + ONNX TensorProto initializer. + """ + if is_bfloat16: + tensor = onnx.TensorProto() + tensor.name = name + tensor.data_type = TensorProto.BFLOAT16 + tensor.dims.extend(arr.shape) + tensor.raw_data = arr.tobytes() + return tensor + else: + return numpy_helper.from_array(arr, name=name) + + +def add_initializer( + graph: onnx.GraphProto, + name: str, + data: np.ndarray, + dtype: int = TensorProto.FLOAT16, +) -> None: + """Add an initializer (constant tensor) to the graph. + + Args: + graph: ONNX graph to modify. + name: Name for the initializer. + data: Numpy array data. + dtype: ONNX data type for the tensor. + """ + if dtype == TensorProto.BFLOAT16: + # For bfloat16, data comes as int16 view - create raw tensor + tensor = onnx.TensorProto() + tensor.name = name + tensor.data_type = TensorProto.BFLOAT16 + tensor.dims.extend(data.shape) + tensor.raw_data = data.tobytes() + else: + tensor = numpy_helper.from_array(data, name=name) + graph.initializer.append(tensor) + # Also add to value_info for shape inference + value_info = helper.make_tensor_value_info(name, dtype, list(data.shape)) + graph.value_info.append(value_info) + + +def convert_initializers_to_dtype(graph: onnx.GraphProto, target_dtype_str: str = "float16") -> int: + """Convert all float32 initializers to the target dtype. + + This ensures weight matrices match the model's I/O precision. + + Args: + graph: ONNX graph to modify. + target_dtype_str: Target dtype ("float16", "float32", or "bfloat16"). + + Returns: + Count of converted initializers. + """ + if target_dtype_str == "float32": + return 0 + + converted_count = 0 + + # Iterate over a copy since we're modifying the list + initializers_to_convert = [ + init for init in graph.initializer if init.data_type == TensorProto.FLOAT + ] + + for init in initializers_to_convert: + arr = numpy_helper.to_array(init) + + if target_dtype_str == "float16": + arr_converted = arr.astype(np.float16) + new_init = numpy_helper.from_array(arr_converted, name=init.name) + elif target_dtype_str == "bfloat16": + tensor = torch.from_numpy(arr).to(torch.bfloat16) + new_init = onnx.TensorProto() + new_init.name = init.name + new_init.data_type = TensorProto.BFLOAT16 + new_init.dims.extend(arr.shape) + new_init.raw_data = tensor.view(torch.int16).numpy().tobytes() + else: + logger.warning(f"Unknown dtype {target_dtype_str}, skipping {init.name}") + continue + + graph.initializer.remove(init) + graph.initializer.append(new_init) + converted_count += 1 + + return converted_count + + +def convert_model_fp16_to_bf16(graph: onnx.GraphProto, verbose: bool = True) -> dict[str, int]: + """Convert all FP16 elements in the graph to BF16. + + This converts: + 1. FP16 initializers (weights) to BF16 + 2. FP16 Constant nodes to BF16 + 3. Cast nodes targeting FP16 to target BF16 + 4. FP16 value_info to BF16 + 5. FP16 graph inputs to BF16 + 6. FP16 graph outputs to BF16 + + Args: + graph: ONNX graph to modify in-place. + verbose: Whether to print progress messages. + + Returns: + Dictionary with conversion statistics. + """ + stats = { + "initializers_converted": 0, + "constants_converted": 0, + "casts_converted": 0, + "value_info_converted": 0, + "inputs_converted": 0, + "outputs_converted": 0, + } + + # 1. Convert FP16 initializers to BF16 + new_initializers = [] + for init in graph.initializer: + if init.data_type == TensorProto.FLOAT16: + fp16_array = numpy_helper.to_array(init) + bf16_uint16 = fp16_to_bf16_array(fp16_array) + + new_init = onnx.TensorProto() + new_init.name = init.name + new_init.data_type = TensorProto.BFLOAT16 + new_init.dims.extend(init.dims) + new_init.raw_data = bf16_uint16.tobytes() + + new_initializers.append(new_init) + stats["initializers_converted"] += 1 + else: + new_initializers.append(init) + + # Clear and replace initializers + while len(graph.initializer) > 0: + graph.initializer.pop() + graph.initializer.extend(new_initializers) + + # 2. Convert Constant nodes and Cast nodes + for node in graph.node: + # Convert Constant nodes with FP16 values + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == "value" and attr.t.data_type == TensorProto.FLOAT16: + fp16_array = numpy_helper.to_array(attr.t) + bf16_uint16 = fp16_to_bf16_array(fp16_array) + + attr.t.data_type = TensorProto.BFLOAT16 + attr.t.ClearField("raw_data") + attr.t.ClearField("float_data") + attr.t.ClearField("int32_data") + attr.t.raw_data = bf16_uint16.tobytes() + + stats["constants_converted"] += 1 + + # Convert Cast nodes targeting FP16 to target BF16 + if node.op_type == "Cast": + for attr in node.attribute: + if attr.name == "to" and attr.i == TensorProto.FLOAT16: + attr.i = TensorProto.BFLOAT16 + stats["casts_converted"] += 1 + + # 3. Convert value_info + for vi in graph.value_info: + if vi.type.HasField("tensor_type"): + if vi.type.tensor_type.elem_type == TensorProto.FLOAT16: + vi.type.tensor_type.elem_type = TensorProto.BFLOAT16 + stats["value_info_converted"] += 1 + + # 4. Convert graph inputs + for inp in graph.input: + if inp.type.HasField("tensor_type"): + if inp.type.tensor_type.elem_type == TensorProto.FLOAT16: + inp.type.tensor_type.elem_type = TensorProto.BFLOAT16 + stats["inputs_converted"] += 1 + + # 5. Convert graph outputs + for out in graph.output: + if out.type.HasField("tensor_type"): + if out.type.tensor_type.elem_type == TensorProto.FLOAT16: + out.type.tensor_type.elem_type = TensorProto.BFLOAT16 + stats["outputs_converted"] += 1 + + if verbose: + logger.info("FP16 to BF16 conversion statistics:") + logger.info(f" Initializers: {stats['initializers_converted']}") + logger.info(f" Constants: {stats['constants_converted']}") + logger.info(f" Cast nodes: {stats['casts_converted']}") + logger.info(f" Value_info: {stats['value_info_converted']}") + logger.info(f" Inputs: {stats['inputs_converted']}") + logger.info(f" Outputs: {stats['outputs_converted']}") + + return stats diff --git a/modelopt/onnx/graph_surgery/utils/rope_cache.py b/modelopt/onnx/graph_surgery/utils/rope_cache.py new file mode 100644 index 000000000..229193b81 --- /dev/null +++ b/modelopt/onnx/graph_surgery/utils/rope_cache.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RoPE (Rotary Position Embedding) cache computation utilities. + +This module provides functions for computing cosine and sine caches +for rotary position embeddings, matching the onnxruntime_genai builder format. +""" + +from typing import Any + +import numpy as np +import torch + + +def get_rope_caches( + model_id: str, + max_seq_len: int, + io_dtype: str = "float16", + trust_remote_code: bool = False, +) -> tuple[np.ndarray, np.ndarray, Any]: + """Compute cos/sin caches matching onnxruntime_genai builder. + + This function computes the rotary position embedding caches required + for GroupQueryAttention (GQA) nodes. The caches are computed based on + the model's configuration from HuggingFace. + + Args: + model_id: HuggingFace model ID or path to config. + max_seq_len: Maximum sequence length for the caches. + io_dtype: Data type for output ("float16", "float32", or "bfloat16"). + trust_remote_code: Whether to trust remote code in HuggingFace model config. + + Returns: + Tuple of (cos_cache, sin_cache, config) where: + - cos_cache: Cosine cache as numpy array with shape [max_seq_len, head_dim//2] + - sin_cache: Sine cache as numpy array with shape [max_seq_len, head_dim//2] + - config: HuggingFace model configuration + + Example: + >>> cos_cache, sin_cache, config = get_rope_caches( + ... model_id="meta-llama/Llama-2-7b-hf", + ... max_seq_len=4096, + ... io_dtype="float16", + ... ) + >>> print(f"Cache shape: {cos_cache.shape}") + """ + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + + theta = getattr(config, "rope_theta", 10000.0) + head_dim = config.hidden_size // config.num_attention_heads + partial_factor = getattr(config, "partial_rotary_factor", 1.0) + dim = int(head_dim * partial_factor) + + # Match builder: int64 -> float + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + t = torch.arange(max_seq_len, dtype=torch.int64).float() + + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + cos_cache = emb.cos() + sin_cache = emb.sin() + + # Cast to target dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + target_dtype = dtype_map.get(io_dtype, torch.float16) + + cos_cache = cos_cache.to(target_dtype) + sin_cache = sin_cache.to(target_dtype) + + # Slice to half (GQA expects this) + if cos_cache.shape[-1] == head_dim: + cos_cache = cos_cache[:, : head_dim // 2] + sin_cache = sin_cache[:, : head_dim // 2] + + # Convert to numpy - bfloat16 needs special handling + if io_dtype == "bfloat16": + # bfloat16 can't be converted directly to numpy + # Return as int16 view, add_initializer will handle proper ONNX storage + cos_np = cos_cache.view(torch.int16).numpy() + sin_np = sin_cache.view(torch.int16).numpy() + return cos_np, sin_np, config + else: + return cos_cache.numpy(), sin_cache.numpy(), config diff --git a/modelopt/onnx/graph_surgery/utils/whisper_utils.py b/modelopt/onnx/graph_surgery/utils/whisper_utils.py new file mode 100644 index 000000000..012355af3 --- /dev/null +++ b/modelopt/onnx/graph_surgery/utils/whisper_utils.py @@ -0,0 +1,470 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Whisper model configuration utilities for ONNX Runtime GenAI. + +This module provides utilities for generating configuration files required +by ONNX Runtime GenAI for Whisper models: + +- audio_processor_config.json: Defines the audio preprocessing pipeline + (AudioDecoder -> STFT -> LogMelSpectrum) +- genai_config.json: Specifies model architecture, I/O tensor names, + and inference settings for encoder-decoder models +""" + +import json +import os +from typing import Any + +from ...logging_config import logger + +# --------------------------------------------------------------------------- +# Audio processor config +# --------------------------------------------------------------------------- + + +def generate_audio_processor_config( + num_mel_bins: int = 128, + n_fft: int = 400, + hop_length: int = 160, + chunk_size: int = 30, +) -> dict: + """Generate audio processor configuration for Whisper. + + Args: + num_mel_bins: Number of mel frequency bins. + - 80 for whisper-tiny/base/small/medium/large/large-v2 + - 128 for whisper-large-v3 and whisper-large-v3-turbo + n_fft: FFT size (default 400 for 16kHz audio with 25ms window). + hop_length: Hop length for STFT (default 160 for 10ms hop). + chunk_size: Audio chunk size in seconds (default 30). + + Returns: + Audio processor configuration dictionary. + """ + return { + "feature_extraction": { + "sequence": [ + {"operation": {"name": "audio_decoder", "type": "AudioDecoder"}}, + { + "operation": { + "name": "STFT", + "type": "STFTNorm", + "attrs": { + "n_fft": n_fft, + "frame_length": n_fft, + "hop_length": hop_length, + }, + } + }, + { + "operation": { + "name": "log_mel_spectrogram", + "type": "LogMelSpectrum", + "attrs": { + "chunk_size": chunk_size, + "hop_length": hop_length, + "n_fft": n_fft, + "n_mel": num_mel_bins, + }, + } + }, + ] + } + } + + +def save_audio_processor_config( + output_dir: str, + hf_model_id: str | None = None, + num_mel_bins: int | None = None, + overwrite: bool = False, + trust_remote_code: bool = False, +) -> str: + """Save audio_processor_config.json to output directory. + + Args: + output_dir: Directory to save the config file. + hf_model_id: HuggingFace model ID to extract num_mel_bins from config. + If provided, num_mel_bins is extracted from the model config. + num_mel_bins: Number of mel bins. Used if hf_model_id is not provided. + Default is 128 (for whisper-large-v3 models). + overwrite: Whether to overwrite existing file. + + Returns: + Path to the saved config file. + """ + output_path = os.path.join(output_dir, "audio_processor_config.json") + + # Check if file already exists + if os.path.exists(output_path) and not overwrite: + logger.info(f"audio_processor_config.json already exists at {output_dir}") + return output_path + + # Determine num_mel_bins + if hf_model_id is not None: + from transformers import WhisperConfig + + config = WhisperConfig.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code) + num_mel_bins = config.num_mel_bins + logger.info(f"Extracted num_mel_bins={num_mel_bins} from {hf_model_id}") + elif num_mel_bins is None: + num_mel_bins = 128 # Default for whisper-large-v3 + logger.info(f"Using default num_mel_bins={num_mel_bins}") + + # Generate config + audio_processor_cfg = generate_audio_processor_config(num_mel_bins=num_mel_bins) + + # Save to file + os.makedirs(output_dir, exist_ok=True) + with open(output_path, "w") as f: + json.dump(audio_processor_cfg, f, indent=4) + + logger.info(f"Saved audio_processor_config.json to {output_dir}") + return output_path + + +# --------------------------------------------------------------------------- +# GenAI config +# --------------------------------------------------------------------------- + + +def generate_genai_config( + encoder_filename: str, + decoder_filename: str, + hf_model_id: str | None = None, + trust_remote_code: bool = False, + # Model config (auto-detected from HuggingFace if model_id provided) + num_decoder_layers: int | None = None, + num_encoder_layers: int | None = None, + num_attention_heads: int | None = None, + hidden_size: int | None = None, + vocab_size: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + pad_token_id: int | None = None, + decoder_start_token_id: int | None = None, + context_length: int | None = None, + # Provider options + provider: str = "cuda", + enable_cuda_graph: bool = False, + # Search options + num_beams: int = 1, + max_length: int | None = None, + past_present_share_buffer: bool = True, + # Decoder input/output naming patterns + decoder_past_key_pattern: str = "past_key_self_%d", + decoder_past_value_pattern: str = "past_value_self_%d", + decoder_cross_past_key_pattern: str = "past_key_cross_%d", + decoder_cross_past_value_pattern: str = "past_value_cross_%d", + decoder_present_key_pattern: str = "present_key_self_%d", + decoder_present_value_pattern: str = "present_value_self_%d", + # Encoder output naming patterns + encoder_cross_present_key_pattern: str = "present_key_cross_%d", + encoder_cross_present_value_pattern: str = "present_value_cross_%d", +) -> dict[str, Any]: + """Generate genai_config.json configuration for Whisper models. + + This config is required by ONNX Runtime GenAI for running encoder-decoder + models like Whisper. It specifies model architecture, I/O tensor names, + and inference settings. + + Args: + encoder_filename: Filename of the encoder ONNX model. + decoder_filename: Filename of the decoder ONNX model. + hf_model_id: HuggingFace model ID to auto-detect config parameters. + num_decoder_layers: Number of decoder layers. + num_encoder_layers: Number of encoder layers. + num_attention_heads: Number of attention heads. + hidden_size: Hidden dimension size. + vocab_size: Vocabulary size. + bos_token_id: Beginning of sequence token ID. + eos_token_id: End of sequence token ID. + pad_token_id: Padding token ID. + decoder_start_token_id: Decoder start token ID. + context_length: Maximum context/sequence length. + provider: Execution provider ("cuda", "cpu", "NvTensorRtRtx"). + enable_cuda_graph: Whether to enable CUDA graph optimization. + num_beams: Number of beams for beam search. + max_length: Maximum generation length. + past_present_share_buffer: Whether to share KV cache buffer. + decoder_past_key_pattern: Pattern for decoder past key input names. + decoder_past_value_pattern: Pattern for decoder past value input names. + decoder_cross_past_key_pattern: Pattern for cross-attention past key names. + decoder_cross_past_value_pattern: Pattern for cross-attention past value names. + decoder_present_key_pattern: Pattern for decoder present key output names. + decoder_present_value_pattern: Pattern for decoder present value output names. + encoder_cross_present_key_pattern: Pattern for encoder cross-attention key outputs. + encoder_cross_present_value_pattern: Pattern for encoder cross-attention value outputs. + + Returns: + Dictionary containing the complete genai_config.json structure. + """ + # Load config from HuggingFace if model_id provided + if hf_model_id is not None: + from transformers import WhisperConfig + + logger.info(f"Loading config from HuggingFace: {hf_model_id}") + config = WhisperConfig.from_pretrained(hf_model_id, trust_remote_code=trust_remote_code) + + # Extract values from HF config + if num_decoder_layers is None: + num_decoder_layers = config.decoder_layers + if num_encoder_layers is None: + num_encoder_layers = config.encoder_layers + if num_attention_heads is None: + num_attention_heads = config.decoder_attention_heads + if hidden_size is None: + hidden_size = config.d_model + if vocab_size is None: + vocab_size = config.vocab_size + if bos_token_id is None: + bos_token_id = config.bos_token_id + if eos_token_id is None: + eos_token_id = config.eos_token_id + if pad_token_id is None: + pad_token_id = config.pad_token_id + if decoder_start_token_id is None: + decoder_start_token_id = config.decoder_start_token_id + if context_length is None: + context_length = getattr(config, "max_target_positions", None) or config.max_length + if max_length is None: + max_length = getattr(config, "max_target_positions", None) or config.max_length + + # Compute head_size + head_size = hidden_size // num_attention_heads + + # Build provider options (lowercase provider names as expected by GenAI) + if provider == "cuda": + provider_options = [{"cuda": {}}] + elif provider == "cpu": + provider_options = [] + elif provider == "NvTensorRtRtx": + provider_options = [ + {"NvTensorRtRtx": {"enable_cuda_graph": "1" if enable_cuda_graph else "0"}} + ] + else: + provider_options = [{provider.lower(): {}}] + + # Build config + genai_config = { + "model": { + "bos_token_id": bos_token_id, + "context_length": context_length, + "decoder_start_token_id": decoder_start_token_id, + "speech": {"config_filename": "audio_processor_config.json"}, + "decoder": { + "session_options": { + "log_id": "onnxruntime-genai", + "provider_options": provider_options, + }, + "filename": decoder_filename, + "head_size": head_size, + "hidden_size": hidden_size, + "inputs": { + "input_ids": "input_ids", + "past_key_names": decoder_past_key_pattern, + "past_value_names": decoder_past_value_pattern, + "cross_past_key_names": decoder_cross_past_key_pattern, + "cross_past_value_names": decoder_cross_past_value_pattern, + "attention_mask": "attention_mask", + "position_ids": "position_ids", + }, + "outputs": { + "logits": "logits", + "present_key_names": decoder_present_key_pattern, + "present_value_names": decoder_present_value_pattern, + }, + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_decoder_layers, + "num_key_value_heads": num_attention_heads, # Whisper uses MHA, not GQA + }, + "encoder": { + "session_options": { + "log_id": "onnxruntime-genai", + "provider_options": provider_options, + }, + "filename": encoder_filename, + "head_size": head_size, + "hidden_size": hidden_size, + "inputs": {"audio_features": "audio_features"}, + "outputs": { + "encoder_outputs": "encoder_hidden_states", + "cross_present_key_names": encoder_cross_present_key_pattern, + "cross_present_value_names": encoder_cross_present_value_pattern, + }, + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_encoder_layers, + "num_key_value_heads": num_attention_heads, + }, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "type": "whisper", + "vocab_size": vocab_size, + }, + "search": { + "diversity_penalty": 0.0, + "do_sample": False, + "early_stopping": True, + "length_penalty": 1.0, + "max_length": max_length, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beams": num_beams, + "num_return_sequences": 1, + "past_present_share_buffer": past_present_share_buffer, + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_k": 1, + "top_p": 1.0, + }, + } + + return genai_config + + +def save_genai_config( + output_dir: str, + encoder_filename: str, + decoder_filename: str = "decoder_with_past_model.onnx", + hf_model_id: str | None = None, + overwrite: bool = False, + provider: str = "cuda", + trust_remote_code: bool = False, + **kwargs, +) -> str: + """Save genai_config.json to output directory. + + Args: + output_dir: Directory to save the config file. + encoder_filename: Filename of the encoder ONNX model. + decoder_filename: Filename of the decoder ONNX model. + Default is "decoder_with_past_model.onnx". + hf_model_id: HuggingFace model ID for auto-detecting config. + overwrite: Whether to overwrite existing file. + provider: Execution provider. + **kwargs: Additional arguments passed to generate_genai_config. + + Returns: + Path to the saved config file. + """ + output_path = os.path.join(output_dir, "genai_config.json") + + # Check if file already exists + if os.path.exists(output_path) and not overwrite: + logger.info(f"genai_config.json already exists at {output_dir}") + return output_path + + # Generate config + genai_cfg = generate_genai_config( + encoder_filename=encoder_filename, + decoder_filename=decoder_filename, + hf_model_id=hf_model_id, + trust_remote_code=trust_remote_code, + provider=provider, + **kwargs, + ) + + # Save to file + os.makedirs(output_dir, exist_ok=True) + with open(output_path, "w") as f: + json.dump(genai_cfg, f, indent=4) + + logger.info(f"Saved genai_config.json to {output_dir}") + return output_path + + +def update_genai_config_encoder( + config_path: str, + encoder_filename: str, + encoder_cross_present_key_pattern: str = "present_key_cross_%d", + encoder_cross_present_value_pattern: str = "present_value_cross_%d", +) -> dict[str, Any]: + """Update an existing genai_config.json with encoder settings. + + Useful when you already have a config from ONNX Runtime export and + want to update the encoder section after running encoder surgery. + + Args: + config_path: Path to existing genai_config.json. + encoder_filename: New encoder filename. + encoder_cross_present_key_pattern: Pattern for cross-attention key outputs. + encoder_cross_present_value_pattern: Pattern for cross-attention value outputs. + + Returns: + Updated configuration dictionary. + """ + with open(config_path) as f: + config = json.load(f) + + # Update encoder section + if "model" in config and "encoder" in config["model"]: + config["model"]["encoder"]["filename"] = encoder_filename + config["model"]["encoder"]["outputs"]["cross_present_key_names"] = ( + encoder_cross_present_key_pattern + ) + config["model"]["encoder"]["outputs"]["cross_present_value_names"] = ( + encoder_cross_present_value_pattern + ) + + # Save updated config + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + + logger.info(f"Updated encoder section in {config_path}") + return config + + +def update_genai_config_decoder( + config_path: str, + decoder_filename: str, + decoder_past_key_pattern: str = "past_key_values.%d.decoder.key", + decoder_past_value_pattern: str = "past_key_values.%d.decoder.value", + decoder_present_key_pattern: str = "present.%d.decoder.key", + decoder_present_value_pattern: str = "present.%d.decoder.value", +) -> dict[str, Any]: + """Update an existing genai_config.json with decoder settings. + + Useful when you've run decoder surgery (MHA/GQA fusion) and need + to update the decoder section. + + Args: + config_path: Path to existing genai_config.json. + decoder_filename: New decoder filename. + decoder_past_key_pattern: Pattern for past key input names. + decoder_past_value_pattern: Pattern for past value input names. + decoder_present_key_pattern: Pattern for present key output names. + decoder_present_value_pattern: Pattern for present value output names. + + Returns: + Updated configuration dictionary. + """ + with open(config_path) as f: + config = json.load(f) + + # Update decoder section + if "model" in config and "decoder" in config["model"]: + config["model"]["decoder"]["filename"] = decoder_filename + config["model"]["decoder"]["inputs"]["past_key_names"] = decoder_past_key_pattern + config["model"]["decoder"]["inputs"]["past_value_names"] = decoder_past_value_pattern + config["model"]["decoder"]["outputs"]["present_key_names"] = decoder_present_key_pattern + config["model"]["decoder"]["outputs"]["present_value_names"] = decoder_present_value_pattern + + # Save updated config + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + + logger.info(f"Updated decoder section in {config_path}") + return config diff --git a/tests/unit/onnx/test_dq_transpose_surgery.py b/tests/unit/onnx/test_dq_transpose_surgery.py new file mode 100644 index 000000000..68b28a847 --- /dev/null +++ b/tests/unit/onnx/test_dq_transpose_surgery.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DQ transpose surgery (transpose_dequantize_linear_weights).""" + +import os +import tempfile + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +from onnx import TensorProto, helper, numpy_helper + +from modelopt.onnx.graph_surgery.dq_transpose import transpose_dequantize_linear_weights + +IN_FEATURES = 16 +OUT_FEATURES = 32 +BATCH = 2 +SEQ = 4 + + +def _build_dq_matmul_model(): + """Build minimal model: input -> MatMul(input, DQ(weight, scale, zp)) -> output. + + Uses per-channel INT8 quantization on axis 0 of the weight [IN, OUT]. + """ + rng = np.random.RandomState(123) + + # Quantized weight [IN_FEATURES, OUT_FEATURES] int8 + w_int8 = rng.randint(-127, 127, size=(IN_FEATURES, OUT_FEATURES)).astype(np.int8) + # Per-channel scale along axis 0 -> shape [IN_FEATURES] (one per row) + scale = rng.rand(IN_FEATURES).astype(np.float32) * 0.01 + 0.001 + # Per-channel zero point + zp = np.zeros(IN_FEATURES, dtype=np.int8) + + w_init = numpy_helper.from_array(w_int8, name="weight_quantized") + s_init = numpy_helper.from_array(scale, name="weight_scale") + zp_init = numpy_helper.from_array(zp, name="weight_zp") + + dq_node = helper.make_node( + "DequantizeLinear", + inputs=["weight_quantized", "weight_scale", "weight_zp"], + outputs=["weight_dequantized"], + name="dq_weight", + axis=0, + ) + + matmul_node = helper.make_node( + "MatMul", + inputs=["input", "weight_dequantized"], + outputs=["output"], + name="matmul_0", + ) + + graph = helper.make_graph( + [dq_node, matmul_node], + "test_dq_transpose", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [BATCH, SEQ, IN_FEATURES]), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [BATCH, SEQ, OUT_FEATURES]), + ], + initializer=[w_init, s_init, zp_init], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 9 + onnx.checker.check_model(model) + return model + + +def _run_session(model_proto, feeds): + """Run inference on in-memory model.""" + opts = ort.SessionOptions() + opts.log_severity_level = 3 + sess = ort.InferenceSession( + model_proto.SerializeToString(), opts, providers=["CPUExecutionProvider"] + ) + output_names = [o.name for o in sess.get_outputs()] + return sess.run(output_names, feeds) + + +@pytest.fixture(scope="class") +def models(): + """Build original model, apply DQ transpose surgery, return both.""" + orig = _build_dq_matmul_model() + + with tempfile.TemporaryDirectory() as tmp: + orig_path = os.path.join(tmp, "original.onnx") + out_path = os.path.join(tmp, "transposed.onnx") + + onnx.save(orig, orig_path) + transpose_dequantize_linear_weights( + model_path=orig_path, + output_path=out_path, + use_external_data=False, + verbose=True, + ) + modified = onnx.load(out_path) + + return orig, modified + + +class TestDqTransposeSurgery: + def test_transpose_node_added(self, models): + _, modified = models + transpose_nodes = [n for n in modified.graph.node if n.op_type == "Transpose"] + assert len(transpose_nodes) == 1, f"Expected 1 Transpose node, got {len(transpose_nodes)}" + + def test_dq_node_preserved(self, models): + _, modified = models + dq_nodes = [n for n in modified.graph.node if n.op_type == "DequantizeLinear"] + assert len(dq_nodes) == 1, f"Expected 1 DQ node, got {len(dq_nodes)}" + + def test_weight_transposed(self, models): + orig, modified = models + + orig_w = None + for init in orig.graph.initializer: + if init.name == "weight_quantized": + orig_w = numpy_helper.to_array(init) + + mod_w = None + for init in modified.graph.initializer: + if "weight_quantized" in init.name: + mod_w = numpy_helper.to_array(init) + + assert orig_w is not None and mod_w is not None + assert mod_w.shape == (OUT_FEATURES, IN_FEATURES), ( + f"Expected transposed shape, got {mod_w.shape}" + ) + np.testing.assert_array_equal(mod_w, orig_w.T) + + def test_axis_updated(self, models): + _, modified = models + dq_node = next(n for n in modified.graph.node if n.op_type == "DequantizeLinear") + axis_val = None + for attr in dq_node.attribute: + if attr.name == "axis": + axis_val = attr.i + assert axis_val == 1, f"Expected axis=1 after transpose, got {axis_val}" + + def test_output_matches(self, models): + orig, modified = models + rng = np.random.RandomState(999) + x = rng.randn(BATCH, SEQ, IN_FEATURES).astype(np.float32) + feeds = {"input": x} + + orig_out = _run_session(orig, feeds)[0] + mod_out = _run_session(modified, feeds)[0] + + diff = np.abs(orig_out - mod_out) + + print(f"\n Original shape: {orig_out.shape}") + print(f" Original[0,:4]: {orig_out[0, 0, :4]}") + print(f" Modified[0,:4]: {mod_out[0, 0, :4]}") + print(f" Max abs diff: {diff.max():.6f}") + print(f" Mean abs diff: {diff.mean():.6f}") + + np.testing.assert_allclose(orig_out, mod_out, atol=1e-5, rtol=1e-5) diff --git a/tests/unit/onnx/test_gqa_graph_surgery.py b/tests/unit/onnx/test_gqa_graph_surgery.py new file mode 100644 index 000000000..d6d92b903 --- /dev/null +++ b/tests/unit/onnx/test_gqa_graph_surgery.py @@ -0,0 +1,697 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for GQA graph surgery (replace_attention_with_gqa).""" + +import os +import tempfile + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +import torch +from onnx import TensorProto, helper, numpy_helper + +pytest.importorskip("transformers", reason="transformers required for GQA graph surgery tests") + +from modelopt.onnx.graph_surgery.gqa_replacement import replace_attention_with_gqa + +MODEL_ID = "Qwen/Qwen2.5-0.5B" +VOCAB_SIZE = 64 +SEQ_LEN = 4 +MAX_SEQ_LEN = 128 + +_RNG = np.random.RandomState(42) + + +def _fp16(*shape): + return (_RNG.randn(*shape) * 0.02).astype(np.float16) + + +def _init(name, arr): + return numpy_helper.from_array(arr, name=name) + + +def _const_node(name, value, dtype=np.int64): + arr = np.array(value, dtype=dtype) + return helper.make_node( + "Constant", + inputs=[], + outputs=[f"{name}_output_0"], + name=name, + value=numpy_helper.from_array(arr, name=""), + ) + + +def _build_toy_model(hidden_size, num_heads, kv_heads, head_dim, inv_freq_np, num_layers=1): + """Build a toy model matching real Optimum LLaMA export patterns. + + Includes: shared rotary_emb (inv_freq x position_ids -> Cos/Sin), + per-layer rotate_half RoPE, KV cache concat, causal+padding mask. + """ + nodes, inits, vis = [], [], [] + half_dim = head_dim // 2 + + inits.append(_init("one_f16", np.array(1.0, dtype=np.float16))) + inits.append(_init("neg_large_f16", np.array(-1e4, dtype=np.float16))) + inits.append(_init("axes_0", np.array([0], dtype=np.int64))) + inits.append(_init("axes_01", np.array([0, 1], dtype=np.int64))) + inits.append(_init("axes_12", np.array([1, 2], dtype=np.int64))) + inits.append(_init("trilu_k1", np.array(1, dtype=np.int64))) + inits.append(_init("onnx::inv_freq", inv_freq_np.reshape(1, half_dim, 1).astype(np.float32))) + + graph_inputs = [ + helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["B", "S"]), + helper.make_tensor_value_info("attention_mask", TensorProto.INT64, ["B", "T"]), + helper.make_tensor_value_info("position_ids", TensorProto.INT64, ["B", "S"]), + ] + graph_outputs = [] + + graph_inputs.extend( + helper.make_tensor_value_info( + f"past_key_values.{lid}.{kv}", + TensorProto.FLOAT16, + ["B", kv_heads, "P", head_dim], + ) + for lid in range(num_layers) + for kv in ("key", "value") + ) + + inits.append(_init("model.embed_tokens.weight", _fp16(VOCAB_SIZE, hidden_size))) + nodes.append( + helper.make_node( + "Gather", + ["model.embed_tokens.weight", "input_ids"], + ["/model/embed_tokens/Gather_output_0"], + name="/model/embed_tokens/Gather", + axis=0, + ) + ) + hidden = "/model/embed_tokens/Gather_output_0" + + # -- shared rotary_emb: inv_freq x position_ids -> cos/sin (fp16) -- + re = "/model/rotary_emb" + nodes.append(_const_node(f"{re}/Constant_6", [1], np.int64)) + nodes.append( + helper.make_node( + "Unsqueeze", + ["position_ids", f"{re}/Constant_6_output_0"], + [f"{re}/Unsqueeze_1_output_0"], + name=f"{re}/Unsqueeze_1", + ) + ) + nodes.append( + helper.make_node( + "Cast", + [f"{re}/Unsqueeze_1_output_0"], + [f"{re}/Cast_1_output_0"], + name=f"{re}/Cast_1", + to=1, + ) + ) + nodes.append( + helper.make_node( + "Cast", ["onnx::inv_freq"], [f"{re}/Cast_2_output_0"], name=f"{re}/Cast_2", to=1 + ) + ) + nodes.append( + helper.make_node( + "MatMul", + [f"{re}/Cast_2_output_0", f"{re}/Cast_1_output_0"], + [f"{re}/MatMul_output_0"], + name=f"{re}/MatMul", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{re}/MatMul_output_0"], + [f"{re}/Transpose_output_0"], + name=f"{re}/Transpose", + perm=[0, 2, 1], + ) + ) + nodes.append( + helper.make_node( + "Concat", + [f"{re}/Transpose_output_0", f"{re}/Transpose_output_0"], + [f"{re}/Concat_1_output_0"], + name=f"{re}/Concat_1", + axis=-1, + ) + ) + nodes.append( + helper.make_node( + "Cos", [f"{re}/Concat_1_output_0"], [f"{re}/Cos_output_0"], name=f"{re}/Cos" + ) + ) + nodes.append( + helper.make_node( + "Sin", [f"{re}/Concat_1_output_0"], [f"{re}/Sin_output_0"], name=f"{re}/Sin" + ) + ) + nodes.append(_const_node(f"{re}/Constant_7", 1.0, np.float32)) + nodes.append( + helper.make_node( + "Mul", + [f"{re}/Cos_output_0", f"{re}/Constant_7_output_0"], + [f"{re}/Mul_1_output_0"], + name=f"{re}/Mul_1", + ) + ) + nodes.append(_const_node(f"{re}/Constant_8", 1.0, np.float32)) + nodes.append( + helper.make_node( + "Mul", + [f"{re}/Sin_output_0", f"{re}/Constant_8_output_0"], + [f"{re}/Mul_2_output_0"], + name=f"{re}/Mul_2", + ) + ) + nodes.append( + helper.make_node( + "Cast", [f"{re}/Mul_1_output_0"], [f"{re}/Cast_4_output_0"], name=f"{re}/Cast_4", to=10 + ) + ) + nodes.append( + helper.make_node( + "Cast", [f"{re}/Mul_2_output_0"], [f"{re}/Cast_5_output_0"], name=f"{re}/Cast_5", to=10 + ) + ) + cos_out = f"{re}/Cast_4_output_0" + sin_out = f"{re}/Cast_5_output_0" + + # -- shared causal + padding mask -- + nodes.append(helper.make_node("Shape", ["input_ids"], ["ids_shape"], name="/model/pos/Shape")) + nodes.append(_const_node("/model/pos/C1", 1, np.int64)) + nodes.append( + helper.make_node( + "Gather", + ["ids_shape", "/model/pos/C1_output_0"], + ["seq_len"], + name="/model/pos/seq_gather", + axis=0, + ) + ) + nodes.append( + helper.make_node("Unsqueeze", ["seq_len", "axes_0"], ["seq_1d"], name="/model/causal/unsq") + ) + nodes.append( + helper.make_node( + "Concat", ["seq_1d", "seq_1d"], ["causal_shape"], name="/model/causal/cat", axis=0 + ) + ) + nodes.append( + helper.make_node( + "ConstantOfShape", + ["causal_shape"], + ["causal_ones"], + name="/model/causal/fill", + value=numpy_helper.from_array(np.array([1.0], dtype=np.float16), name=""), + ) + ) + nodes.append( + helper.make_node( + "Trilu", ["causal_ones", "trilu_k1"], ["upper_tri"], name="/model/causal/trilu", upper=1 + ) + ) + nodes.append( + helper.make_node( + "Mul", ["upper_tri", "neg_large_f16"], ["causal_4d_raw"], name="/model/causal/mul" + ) + ) + nodes.append( + helper.make_node( + "Unsqueeze", ["causal_4d_raw", "axes_01"], ["causal_4d"], name="/model/causal/unsq4d" + ) + ) + nodes.append( + helper.make_node("Cast", ["attention_mask"], ["pad_f16"], name="/model/pad/cast", to=10) + ) + nodes.append( + helper.make_node("Unsqueeze", ["pad_f16", "axes_12"], ["pad_4d"], name="/model/pad/unsq") + ) + nodes.append(helper.make_node("Sub", ["one_f16", "pad_4d"], ["inv_pad"], name="/model/pad/inv")) + nodes.append( + helper.make_node("Mul", ["inv_pad", "neg_large_f16"], ["pad_bias"], name="/model/pad/mul") + ) + nodes.append( + helper.make_node("Add", ["causal_4d", "pad_bias"], ["attn_bias"], name="/model/bias/add") + ) + + # -- per layer -- + for lid in range(num_layers): + pre = f"/model/layers.{lid}" + ap = f"{pre}/self_attn" + q_dim = num_heads * head_dim + k_dim = kv_heads * head_dim + + ln_w = f"model.layers.{lid}.input_layernorm.weight" + ln_b = f"model.layers.{lid}.input_layernorm.bias" + inits.append(_init(ln_w, np.ones(hidden_size, dtype=np.float16))) + inits.append(_init(ln_b, np.zeros(hidden_size, dtype=np.float16))) + ln_out = f"{pre}/input_layernorm/Mul_1_output_0" + nodes.append( + helper.make_node( + "LayerNormalization", + [hidden, ln_w, ln_b], + [ln_out], + name=f"{pre}/input_layernorm/LayerNorm", + axis=-1, + epsilon=1e-5, + ) + ) + + qw = f"model.layers.{lid}.self_attn.q_proj.weight" + kw = f"model.layers.{lid}.self_attn.k_proj.weight" + vw = f"model.layers.{lid}.self_attn.v_proj.weight" + ow = f"model.layers.{lid}.self_attn.o_proj.weight" + inits += [ + _init(qw, _fp16(hidden_size, q_dim)), + _init(kw, _fp16(hidden_size, k_dim)), + _init(vw, _fp16(hidden_size, k_dim)), + _init(ow, _fp16(q_dim, hidden_size)), + ] + + for proj, dim, suf in [ + ("q_proj", q_dim, ""), + ("k_proj", k_dim, "_1"), + ("v_proj", k_dim, "_2"), + ]: + w = f"model.layers.{lid}.self_attn.{proj}.weight" + nodes.append( + helper.make_node( + "MatMul", + [ln_out, w], + [f"{ap}/{proj}/MatMul_output_0"], + name=f"{ap}/{proj}/MatMul", + ) + ) + + inits.append(_init(f"{ap}/q_shape", np.array([0, 0, num_heads, head_dim], np.int64))) + inits.append(_init(f"{ap}/kv_shape", np.array([0, 0, kv_heads, head_dim], np.int64))) + for tag, proj, shape_name in [ + ("", "q_proj", "q_shape"), + ("_1", "k_proj", "kv_shape"), + ("_2", "v_proj", "kv_shape"), + ]: + nodes.append( + helper.make_node( + "Reshape", + [f"{ap}/{proj}/MatMul_output_0", f"{ap}/{shape_name}"], + [f"{ap}/Reshape{tag}_output_0"], + name=f"{ap}/Reshape{tag}", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{ap}/Reshape{tag}_output_0"], + [f"{ap}/Transpose{tag}_output_0"], + name=f"{ap}/Transpose{tag}", + perm=[0, 2, 1, 3], + ) + ) + + qt = f"{ap}/Transpose_output_0" + kt = f"{ap}/Transpose_1_output_0" + vt = f"{ap}/Transpose_2_output_0" + + # RoPE helper: builds rotate_half pattern for a tensor + def _rope(tensor, prefix, cos=cos_out, sin=sin_out): + p = f"{ap}/{prefix}" + nodes.append(_const_node(f"{p}/c_ax", [1], np.int64)) + nodes.append( + helper.make_node( + "Unsqueeze", [cos, f"{p}/c_ax_output_0"], [f"{p}/cos4d"], name=f"{p}/cos_unsq" + ) + ) + nodes.append(_const_node(f"{p}/s_ax", [1], np.int64)) + nodes.append( + helper.make_node( + "Unsqueeze", [sin, f"{p}/s_ax_output_0"], [f"{p}/sin4d"], name=f"{p}/sin_unsq" + ) + ) + nodes.append( + helper.make_node("Mul", [tensor, f"{p}/cos4d"], [f"{p}/x_cos"], name=f"{p}/mul_cos") + ) + nodes.append(helper.make_node("Shape", [tensor], [f"{p}/sh"], name=f"{p}/shape")) + nodes.append(_const_node(f"{p}/dim_idx", 3, np.int64)) + nodes.append( + helper.make_node( + "Gather", + [f"{p}/sh", f"{p}/dim_idx_output_0"], + [f"{p}/D"], + name=f"{p}/gather_D", + axis=0, + ) + ) + nodes.append(_const_node(f"{p}/two", 2, np.int64)) + nodes.append( + helper.make_node( + "Div", [f"{p}/D", f"{p}/two_output_0"], [f"{p}/half"], name=f"{p}/div" + ) + ) + nodes.append( + helper.make_node( + "Unsqueeze", [f"{p}/half", "axes_0"], [f"{p}/half_1d"], name=f"{p}/unsq_half" + ) + ) + nodes.append(_const_node(f"{p}/zero", [0], np.int64)) + nodes.append(_const_node(f"{p}/ax", [-1], np.int64)) + nodes.append(_const_node(f"{p}/step", [1], np.int64)) + nodes.append( + helper.make_node( + "Slice", + [ + tensor, + f"{p}/zero_output_0", + f"{p}/half_1d", + f"{p}/ax_output_0", + f"{p}/step_output_0", + ], + [f"{p}/x1"], + name=f"{p}/slice1", + ) + ) + nodes.append(_const_node(f"{p}/big", [9223372036854775807], np.int64)) + nodes.append(_const_node(f"{p}/ax2", [-1], np.int64)) + nodes.append(_const_node(f"{p}/step2", [1], np.int64)) + nodes.append( + helper.make_node( + "Slice", + [ + tensor, + f"{p}/half_1d", + f"{p}/big_output_0", + f"{p}/ax2_output_0", + f"{p}/step2_output_0", + ], + [f"{p}/x2"], + name=f"{p}/slice2", + ) + ) + nodes.append(helper.make_node("Neg", [f"{p}/x2"], [f"{p}/neg_x2"], name=f"{p}/neg")) + nodes.append( + helper.make_node( + "Concat", [f"{p}/neg_x2", f"{p}/x1"], [f"{p}/rot"], name=f"{p}/concat", axis=-1 + ) + ) + nodes.append( + helper.make_node( + "Mul", [f"{p}/rot", f"{p}/sin4d"], [f"{p}/rot_sin"], name=f"{p}/mul_sin" + ) + ) + out = f"{p}/out" + nodes.append( + helper.make_node("Add", [f"{p}/x_cos", f"{p}/rot_sin"], [out], name=f"{p}/add") + ) + return out + + q_rope = _rope(qt, "rope_q") + k_rope = _rope(kt, "rope_k") + + past_k = f"past_key_values.{lid}.key" + past_v = f"past_key_values.{lid}.value" + pres_k = f"present.{lid}.key" + pres_v = f"present.{lid}.value" + nodes.append( + helper.make_node("Concat", [past_k, k_rope], [pres_k], name=f"{ap}/Concat_5", axis=2) + ) + nodes.append( + helper.make_node("Concat", [past_v, vt], [pres_v], name=f"{ap}/Concat_6", axis=2) + ) + graph_outputs.append( + helper.make_tensor_value_info( + pres_k, TensorProto.FLOAT16, ["B", kv_heads, "T", head_dim] + ) + ) + graph_outputs.append( + helper.make_tensor_value_info( + pres_v, TensorProto.FLOAT16, ["B", kv_heads, "T", head_dim] + ) + ) + + if kv_heads != num_heads: + reps = num_heads // kv_heads + inits += [ + _init(f"{ap}/rk/exp", np.array([1, reps, 1, 1], np.int64)), + _init(f"{ap}/rk/ax", np.array([2], np.int64)), + _init(f"{ap}/rk/rs", np.array([0, num_heads, -1, head_dim], np.int64)), + ] + for t, src in [("k", pres_k), ("v", pres_v)]: + nodes.append( + helper.make_node( + "Unsqueeze", + [src, f"{ap}/rk/ax"], + [f"{ap}/rk/{t}u"], + name=f"{ap}/repeat_kv/{t}_unsqueeze", + ) + ) + nodes.append( + helper.make_node( + "Expand", + [f"{ap}/rk/{t}u", f"{ap}/rk/exp"], + [f"{ap}/rk/{t}e"], + name=f"{ap}/repeat_kv/{t}_expand", + ) + ) + nodes.append( + helper.make_node( + "Reshape", + [f"{ap}/rk/{t}e", f"{ap}/rk/rs"], + [f"{ap}/rk/{t}r"], + name=f"{ap}/repeat_kv/{t}_reshape", + ) + ) + k_final, v_final = f"{ap}/rk/kr", f"{ap}/rk/vr" + else: + k_final, v_final = pres_k, pres_v + + nodes.append( + helper.make_node( + "Transpose", + [k_final], + [f"{ap}/Transpose_3_output_0"], + name=f"{ap}/Transpose_3", + perm=[0, 1, 3, 2], + ) + ) + scale_val = float(np.array(1.0 / (head_dim**0.5), dtype=np.float16)) + nodes.append(_const_node(f"{ap}/scale", scale_val, np.float16)) + nodes.append( + helper.make_node( + "Mul", + [q_rope, f"{ap}/scale_output_0"], + [f"{ap}/Mul_8_output_0"], + name=f"{ap}/Mul_8", + ) + ) + nodes.append( + helper.make_node( + "MatMul", + [f"{ap}/Mul_8_output_0", f"{ap}/Transpose_3_output_0"], + [f"{ap}/MatMul_output_0"], + name=f"{ap}/MatMul", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{ap}/MatMul_output_0", "attn_bias"], + [f"{ap}/Add_2_output_0"], + name=f"{ap}/Add_2", + ) + ) + nodes.append( + helper.make_node( + "Softmax", + [f"{ap}/Add_2_output_0"], + [f"{ap}/Softmax_output_0"], + name=f"{ap}/Softmax", + axis=-1, + ) + ) + nodes.append( + helper.make_node( + "MatMul", + [f"{ap}/Softmax_output_0", v_final], + [f"{ap}/MatMul_1_output_0"], + name=f"{ap}/MatMul_1", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{ap}/MatMul_1_output_0"], + [f"{ap}/Transpose_4_output_0"], + name=f"{ap}/Transpose_4", + perm=[0, 2, 1, 3], + ) + ) + inits.append(_init(f"{ap}/out_rs", np.array([0, 0, hidden_size], np.int64))) + nodes.append( + helper.make_node( + "Reshape", + [f"{ap}/Transpose_4_output_0", f"{ap}/out_rs"], + [f"{ap}/Reshape_7_output_0"], + name=f"{ap}/Reshape_7", + ) + ) + nodes.append( + helper.make_node( + "MatMul", + [f"{ap}/Reshape_7_output_0", ow], + [f"{ap}/o_proj/MatMul_output_0"], + name=f"{ap}/o_proj/MatMul", + ) + ) + res = f"{pre}/residual_add/output_0" + nodes.append( + helper.make_node( + "Add", [hidden, f"{ap}/o_proj/MatMul_output_0"], [res], name=f"{pre}/residual_add" + ) + ) + hidden = res + + inits.append(_init("lm_head.weight", _fp16(hidden_size, VOCAB_SIZE))) + nodes.append( + helper.make_node("MatMul", [hidden, "lm_head.weight"], ["logits"], name="/lm_head/MatMul") + ) + graph_outputs.insert( + 0, helper.make_tensor_value_info("logits", TensorProto.FLOAT16, ["B", "S", VOCAB_SIZE]) + ) + + graph = helper.make_graph( + nodes, "test_gqa", graph_inputs, graph_outputs, initializer=inits, value_info=vis + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 8 + return model + + +def _get_config(): + from transformers import AutoConfig + + cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=False) + hidden = cfg.hidden_size + heads = cfg.num_attention_heads + kv = getattr(cfg, "num_key_value_heads", heads) + hd = hidden // heads + theta = getattr(cfg, "rope_theta", 10000.0) + inv_freq = 1.0 / (theta ** (torch.arange(0, hd, 2, dtype=torch.int64).float() / hd)) + return hidden, heads, kv, hd, inv_freq.numpy() + + +def _run_session(model_proto, feeds): + """Run inference directly from an in-memory ModelProto.""" + model_bytes = model_proto.SerializeToString() + sess = ort.InferenceSession(model_bytes, providers=["CPUExecutionProvider"]) + return sess.run(None, feeds) + + +@pytest.fixture(scope="module") +def models_and_config(): + """Build original model, run GQA surgery, return both protos + config.""" + hidden, heads, kv, hd, inv_freq_np = _get_config() + orig = _build_toy_model(hidden, heads, kv, hd, inv_freq_np) + onnx.checker.check_model(orig) + + with tempfile.TemporaryDirectory() as td: + orig_path = os.path.join(td, "original.onnx") + gqa_path = os.path.join(td, "gqa.onnx") + onnx.save(orig, orig_path) + + replace_attention_with_gqa( + model_path=orig_path, + output_path=gqa_path, + hf_model_id=MODEL_ID, + max_seq_len=MAX_SEQ_LEN, + io_dtype="float16", + use_external_data=False, + ) + gqa = onnx.load(gqa_path) + + return orig, gqa, {"heads": heads, "kv": kv, "hd": hd} + + +class TestGQAGraphSurgery: + def test_gqa_node_exists(self, models_and_config): + _, gqa, _ = models_and_config + gqa_ops = [n for n in gqa.graph.node if n.op_type == "GroupQueryAttention"] + assert len(gqa_ops) == 1 + + def test_gqa_attributes(self, models_and_config): + _, gqa, cfg = models_and_config + gqa_node = next(n for n in gqa.graph.node if n.op_type == "GroupQueryAttention") + attrs = {a.name: (a.i if a.type == 2 else a.f) for a in gqa_node.attribute} + assert attrs["num_heads"] == cfg["heads"] + assert attrs["kv_num_heads"] == cfg["kv"] + assert attrs["do_rotary"] == 1 + + def test_node_count_reduced(self, models_and_config): + orig, gqa, _ = models_and_config + assert len(gqa.graph.node) < len(orig.graph.node) + + def test_rotary_emb_nodes_removed(self, models_and_config): + _, gqa, _ = models_and_config + rotary_names = [n.name for n in gqa.graph.node if "rotary_emb" in n.name] + assert len(rotary_names) == 0 + + def test_position_ids_removed(self, models_and_config): + _, gqa, _ = models_and_config + input_names = [i.name for i in gqa.graph.input] + assert "position_ids" not in input_names + + def test_logits_match(self, models_and_config): + orig, gqa, cfg = models_and_config + kv, hd = cfg["kv"], cfg["hd"] + + ids = np.arange(1, SEQ_LEN + 1, dtype=np.int64).reshape(1, SEQ_LEN) + mask = np.ones((1, SEQ_LEN), dtype=np.int64) + pos = np.arange(SEQ_LEN, dtype=np.int64).reshape(1, SEQ_LEN) + empty_kv = np.zeros((1, kv, 0, hd), dtype=np.float16) + + orig_feeds = { + "input_ids": ids, + "attention_mask": mask, + "position_ids": pos, + "past_key_values.0.key": empty_kv, + "past_key_values.0.value": empty_kv, + } + gqa_feeds = { + "input_ids": ids, + "attention_mask": mask, + "past_key_values.0.key": empty_kv, + "past_key_values.0.value": empty_kv, + } + + orig_logits = _run_session(orig, orig_feeds)[0].astype(np.float32) + gqa_logits = _run_session(gqa, gqa_feeds)[0].astype(np.float32) + + diff = np.abs(orig_logits - gqa_logits) + finite = diff[np.isfinite(diff)] + + print(f"\n Original nodes: {len(orig.graph.node)} -> GQA nodes: {len(gqa.graph.node)}") + print(f" Logits shape: {orig_logits.shape}") + print(f" Original[0,:4]: {orig_logits[0, 0, :4]}") + print(f" GQA [0,:4]: {gqa_logits[0, 0, :4]}") + if len(finite) > 0: + print(f" Max abs diff: {finite.max():.6f}") + print(f" Mean abs diff: {finite.mean():.6f}") + + assert len(finite) > 0, "All values are non-finite" + assert finite.max() < 1.0, f"Max abs diff {finite.max():.4f} exceeds tolerance"