From 3796b7eabe620c9712fb65941ac5d7f441b6e79f Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 8 Jun 2026 11:17:08 -0700 Subject: [PATCH 01/13] Add qauntization for transformers for qwen0.6B --- qwen3_quantize.py | 256 ++++++++++++++++++++++++ src/winml/modelkit/onnx/__init__.py | 2 + src/winml/modelkit/onnx/qwen_surgery.py | 186 +++++++++++++++++ test_qwen 2.py | 70 +++++++ 4 files changed, 514 insertions(+) create mode 100644 qwen3_quantize.py create mode 100644 src/winml/modelkit/onnx/qwen_surgery.py create mode 100644 test_qwen 2.py diff --git a/qwen3_quantize.py b/qwen3_quantize.py new file mode 100644 index 000000000..655c65e6a --- /dev/null +++ b/qwen3_quantize.py @@ -0,0 +1,256 @@ +"""Qwen3 transformer-only quantization. + +Must be called after the composite Qwen3 model has been built (e.g. by +``test_qwen 2.py``) so that ``decoder_prefill`` / ``decoder_gen`` ONNX files +exist in the winml cache. + +Pipeline: + + 1. Apply ``make_transformer_only`` surgery to each sub-model, producing + ``*_transformer.onnx`` with ``inputs_embeds`` input and + ``output_hidden_states`` output — embeddings and lm_head are stripped + out (ignored, not quantized). + 2. Quantize those transformer-only files via winml-cli's ``quantize_onnx`` + using a calibration reader that runs ``embed_tokens`` in PyTorch on + real text samples. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Iterator + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel +from winml.modelkit.onnx import make_transformer_only +from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx + + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL_ID = "Qwen/Qwen3-0.6B" +DEFAULT_MAX_CACHE = 256 +DEFAULT_PREFILL_SEQ = 64 +DEFAULT_GEN_SEQ = 1 +DEFAULT_NUM_SAMPLES = 16 +DEFAULT_PROMPTS = [ + "Solve: 8 * 7 = ?", + "Translate to French: The weather is nice today.", + "Write a short poem about the ocean.", + "Explain gradient descent in one paragraph.", + "What is the capital of Japan?", + "List three uses of magnesium.", + "Summarize the plot of Hamlet in two sentences.", + "Give a Python one-liner to reverse a string.", +] + + +# --------------------------------------------------------------------------- +# Calibration data reader +# --------------------------------------------------------------------------- + + +class Qwen3TransformerCalibReader: + """Yields calibration feeds for the transformer-only Qwen3 ONNX. + + Runs HF ``embed_tokens`` in PyTorch to produce ``inputs_embeds`` since the + embedding layer was stripped from the ONNX graph. All other inputs + (attention_mask, position_ids, past_{i}_key/value) follow the conventions + used by winml-cli's ``WinMLQwen3Model`` runtime. + """ + + def __init__( + self, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + seq_len: int, + max_cache_len: int, + ) -> None: + self.embed = embed_tokens + self.cfg = config + self.seq_len = seq_len + self.max_cache_len = max_cache_len + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self._samples = list(self._build_samples(token_ids_list)) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _build_samples( + self, token_ids_list: list[torch.Tensor] + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + # Right-truncate / pad to seq_len so we feed the static graph shape. + ids = ids[:, : self.seq_len] + real_len = ids.shape[1] + if real_len < self.seq_len: + pad = torch.zeros( + (1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device + ) + ids = torch.cat([ids, pad], dim=1) + + with torch.no_grad(): + embeds = self.embed(ids).to(torch.float32).cpu().numpy() + + # attention_mask: ones for real prompt positions placed at the + # END of the max_cache buffer (sliding-window cache convention), + # zeros elsewhere. + attn_mask = np.zeros((1, self.max_cache_len), dtype=np.int64) + attn_mask[0, -real_len:] = 1 + + # position_ids: 0..seq_len-1 (clamped for padding). + position_ids = np.arange(self.seq_len, dtype=np.int64)[None, :] + + feed: dict[str, np.ndarray] = { + "inputs_embeds": embeds.astype(np.float32), + "attention_mask": attn_mask, + "position_ids": position_ids, + } + kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) + zeros = np.zeros(kv_shape, dtype=np.float32) + for i in range(self.num_layers): + feed[f"past_{i}_key"] = zeros + feed[f"past_{i}_value"] = zeros + yield feed + + def get_next(self) -> dict[str, np.ndarray] | None: + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + self._iter = iter(self._samples) + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def _tokenize_prompts( + tokenizer: Any, prompts: list[str], num_samples: int +) -> list[torch.Tensor]: + # Cycle through prompts up to num_samples; apply chat template like the + # runtime so calibration distribution matches inference inputs. + out: list[torch.Tensor] = [] + for i in range(num_samples): + prompt = prompts[i % len(prompts)] + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + out.append(ids) + return out + + +def quantize_built_model( + model: WinMLCompositeModel, + *, + model_id: str = DEFAULT_MODEL_ID, + max_cache_len: int = DEFAULT_MAX_CACHE, + prefill_seq: int = DEFAULT_PREFILL_SEQ, + num_samples: int = DEFAULT_NUM_SAMPLES, + weight_type: str = "uint8", + activation_type: str = "uint16", +) -> dict[str, Path]: + """Run surgery + transformer-only quantization on an already-built composite. + + Reuses the ONNX files produced by ``WinMLCompositeModel.from_pretrained`` + so this can be called after a build step without re-exporting. + + Returns: mapping of sub-model name → quantized ONNX path. + """ + sub_paths: dict[str, Path] = {} + for name, sub in model.sub_models.items(): + final_path = Path(sub._onnx_path) + # ``_model.onnx`` is the *compiled* QNN EPContext blob — surgery needs + # the uncompiled fp16 graph. ``build.hf`` emits ``{cache_key}_optimized.onnx`` + # alongside it in the same artifacts directory. + if final_path.name.endswith("_model.onnx"): + stem = final_path.name[: -len("_model.onnx")] + optimized = final_path.with_name(f"{stem}_optimized.onnx") + if optimized.exists(): + sub_paths[name] = optimized + continue + print( + f"WARNING: {optimized.name} not found next to {final_path.name}; " + "falling back to the compiled model (surgery will likely fail)." + ) + sub_paths[name] = final_path + + for name, p in sub_paths.items(): + print(f" {name}: {p}") + + print("\n=== Loading HF embed_tokens for calibration ===") + hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + hf_model.eval() + embed_tokens = hf_model.get_input_embeddings() + tokenizer = AutoTokenizer.from_pretrained(model_id) + token_ids_list = _tokenize_prompts(tokenizer, DEFAULT_PROMPTS, num_samples) + + seq_by_sub = { + "decoder_prefill": prefill_seq, + "decoder_gen": DEFAULT_GEN_SEQ, + } + + quant_paths: dict[str, Path] = {} + for sub_name, fused_path in sub_paths.items(): + if sub_name not in seq_by_sub: + print(f"\n--- Skipping unknown sub-model {sub_name!r} ---") + continue + + seq_len = seq_by_sub[sub_name] + transformer_path = fused_path.with_name(fused_path.stem + "_transformer.onnx") + quant_path = transformer_path.with_name( + transformer_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + ) + + print(f"\n=== Surgery: {sub_name} (seq_len={seq_len}) ===") + print(f" in : {fused_path}") + print(f" out: {transformer_path}") + make_transformer_only(fused_path, transformer_path) + + print(f"\n=== Quantize (transformer only): {sub_name} ===") + print(f" out: {quant_path}") + reader = Qwen3TransformerCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + cfg = WinMLQuantizationConfig( + samples=num_samples, + weight_type=weight_type, # type: ignore[arg-type] + activation_type=activation_type, # type: ignore[arg-type] + calibration_method="minmax", + calibration_data=reader, + ) + result = quantize_onnx(transformer_path, output_path=quant_path, config=cfg) + if not result.success: + print(" FAILED:") + for err in result.errors: + print(f" {err}") + raise SystemExit(1) + print( + f" ok — {result.nodes_quantized} QDQ nodes inserted in " + f"{result.total_time_seconds:.1f}s" + ) + quant_paths[sub_name] = quant_path + + print("\n=== Done ===") + return quant_paths + diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index a3bc49d51..0287a2ff7 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -19,6 +19,7 @@ from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx +from .qwen_surgery import make_transformer_only from .shape import infer_onnx_shapes, infer_shapes from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size @@ -41,6 +42,7 @@ "is_compiled_onnx", "is_quantized_onnx", "load_onnx", + "make_transformer_only", "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", diff --git a/src/winml/modelkit/onnx/qwen_surgery.py b/src/winml/modelkit/onnx/qwen_surgery.py new file mode 100644 index 000000000..cd49ee5ec --- /dev/null +++ b/src/winml/modelkit/onnx/qwen_surgery.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Ad-hoc ONNX surgery to turn a Qwen3 decoder ONNX into a transformer-only graph. + +Applied as a post-export surgery on the fused decoder ONNX produced by +``WinMLQwen3Model`` (``decoder_prefill.onnx`` / ``decoder_gen.onnx``). + +The resulting transformer-only ONNX has: + - ``input_ids`` graph input replaced by ``inputs_embeds`` (FLOAT, + ``[batch, seq, hidden_size]``) — the upstream embedding Gather is + removed. + - ``logits`` graph output replaced by ``output_hidden_states`` + (FLOAT, ``[batch, seq, hidden_size]``) — the final ``lm_head`` MatMul + is removed. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import onnx +from onnx import TensorProto, helper + +from .persistence import load_onnx, save_onnx + + +logger = logging.getLogger(__name__) + + +def _dim(d: onnx.TensorShapeProto.Dimension) -> int | str: + if d.HasField("dim_value"): + return d.dim_value + return d.dim_param or "?" + + +def make_transformer_only( + model_path: str | Path, + output_path: str | Path, + *, + input_ids_name: str = "input_ids", + logits_name: str = "logits", + inputs_embeds_name: str = "inputs_embeds", + output_hidden_states_name: str = "output_hidden_states", +) -> Path: + """Strip the embedding Gather and the lm_head MatMul from a Qwen3 ONNX. + + Args: + model_path: Path to the fused decoder ONNX (logits output, input_ids input). + output_path: Destination for the transformer-only ONNX. + input_ids_name: Name of the input_ids graph input to drop. + logits_name: Name of the logits graph output to drop. + inputs_embeds_name: Display name for the new embeddings input + (used only for logging; the actual tensor keeps its existing + internal name so downstream nodes need no rewiring). + output_hidden_states_name: Display name for the new hidden-state output. + + Returns: + The output path. + """ + model_path = Path(model_path) + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + model = load_onnx(model_path, load_weights=True, validate=False) + graph = model.graph + init_by_name = {init.name: init for init in graph.initializer} + + # -------------------- Embedding removal -------------------- + embed_idx = next( + (i for i, n in enumerate(graph.node) if input_ids_name in n.input), + None, + ) + if embed_idx is None: + msg = f"No node consumes graph input {input_ids_name!r}" + raise RuntimeError(msg) + + embed_node = graph.node[embed_idx] + embed_out_name = embed_node.output[0] + + embed_weight = None + for ipt in embed_node.input: + init = init_by_name.get(ipt) + if init is not None and len(init.dims) == 2: + embed_weight = init + break + if embed_weight is None: + msg = f"Could not find 2-D embedding weight initializer on node {embed_node.name!r}" + raise RuntimeError(msg) + hidden_size = int(embed_weight.dims[1]) + + ids_input = next(i for i in graph.input if i.name == input_ids_name) + batch_dim = _dim(ids_input.type.tensor_type.shape.dim[0]) + seq_dim = _dim(ids_input.type.tensor_type.shape.dim[1]) + + logger.info( + "Removing embedding node %r (%s) — exposing %r as new input %r [%s, %s, %d]", + embed_node.name, + embed_node.op_type, + embed_out_name, + inputs_embeds_name, + batch_dim, + seq_dim, + hidden_size, + ) + + new_embed_input = helper.make_tensor_value_info( + inputs_embeds_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[embed_idx] + graph.input.remove(ids_input) + graph.input.append(new_embed_input) + graph.initializer.remove(embed_weight) + + # Rewire any consumer of the removed embedding output to the new input. + for n in graph.node: + for i, name in enumerate(n.input): + if name == embed_out_name: + n.input[i] = inputs_embeds_name + + # -------------------- lm_head removal -------------------- + lmh_idx = next( + (i for i, n in enumerate(graph.node) if logits_name in n.output), + None, + ) + if lmh_idx is None: + msg = f"No node produces graph output {logits_name!r}" + raise RuntimeError(msg) + + lmh_node = graph.node[lmh_idx] + init_names = {init.name for init in graph.initializer} + hidden_in: str | None = None + weight_in: str | None = None + for ipt in lmh_node.input: + if ipt in init_names: + weight_in = ipt + else: + hidden_in = ipt + if hidden_in is None: + msg = f"lm_head node {lmh_node.name!r} has no non-initializer input ({list(lmh_node.input)})" + raise RuntimeError(msg) + + logger.info( + "Removing lm_head node %r (%s) — exposing %r as new output %r", + lmh_node.name, + lmh_node.op_type, + hidden_in, + output_hidden_states_name, + ) + + logits_output = next(o for o in graph.output if o.name == logits_name) + new_hidden_output = helper.make_tensor_value_info( + output_hidden_states_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[lmh_idx] + graph.output.remove(logits_output) + # Put hidden states first so it mirrors the original logits position. + graph.output.insert(0, new_hidden_output) + + # Rename the producer of ``hidden_in`` to emit the new graph output name. + for n in graph.node: + for i, name in enumerate(n.output): + if name == hidden_in: + n.output[i] = output_hidden_states_name + for i, name in enumerate(n.input): + if name == hidden_in: + n.input[i] = output_hidden_states_name + + if weight_in is not None and not any(weight_in in n.input for n in graph.node): + wi = next(init for init in graph.initializer if init.name == weight_in) + graph.initializer.remove(wi) + + save_onnx(model, output_path) + logger.info("Wrote transformer-only ONNX → %s", output_path) + return output_path + + +__all__ = ["make_transformer_only"] diff --git a/test_qwen 2.py b/test_qwen 2.py new file mode 100644 index 000000000..6a52dee72 --- /dev/null +++ b/test_qwen 2.py @@ -0,0 +1,70 @@ +"""E2E test for Qwen3 decoder-only pipeline. + +Uses sub_model_kwargs to set per-component shape_config: + - decoder_prefill: max_cache_len=256, seq_len=64 + - decoder_gen: max_cache_len=256, seq_len=1 + +Set env var ``QUANTIZE=1`` to also run the MOPS-style Step 3: +transformer-only surgery + winml quantize on both sub-models +(embeddings and lm_head are stripped and not quantized). +""" + +import os + +from transformers import AutoTokenizer + +from winml.modelkit.config import WinMLBuildConfig +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel + +model_id = "Qwen/Qwen3-0.6B" + +model = WinMLCompositeModel.from_pretrained( + model_id, + task="text-generation", + # config=WinMLBuildConfig(quant=None, compile=None), + config=WinMLBuildConfig(quant=None), + precision="fp16", + device="npu", + ep="qnn", + force_rebuild=False, + sub_model_kwargs={ + "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, + "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, + }, +) + +# Verify ONNX I/O shapes +for name, sub in model.sub_models.items(): + io = sub.io_config + shapes = dict(zip(io["input_names"], io["input_shapes"])) + print(f"\n=== {name} ===") + for k, v in shapes.items(): + print(f" {k}: {v}") + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +prompt = "8 * 7 = ?" +messages = [{"role": "user", "content": prompt}] +text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, +) +model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +generated_ids = model.generate(**model_inputs) + +output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() +content = tokenizer.decode(output_ids, skip_special_tokens=True) +print("\nAnswer:", content) + +if os.environ.get("QUANTIZE") == "1": + # Reuse the already-built decoder_prefill/decoder_gen ONNX files: + # surgery (strip embed + lm_head) + transformer-only quantize. + print("\n=== QUANTIZE=1 — running transformer-only quantization ===") + from qwen3_quantize import quantize_built_model + + quantize_built_model( + model, + model_id=model_id, + max_cache_len=256, + prefill_seq=64, + ) From 1ee316c8350d9a904e5e06a51dacb1a7186658d0 Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 16 Jun 2026 15:07:46 -0700 Subject: [PATCH 02/13] Quantize transformer-only with fused GQA + GSM8k calibration --- ...e.py => qwen3_transformer_only_quantize.py | 152 ++++---- .../modelkit/models/hf/qwen3_export_ops.py | 211 +++++++++++ .../modelkit/models/hf/qwen3_modeling.py | 237 ++++++++++++ .../models/hf/qwen_transformer_only.py | 354 ++++++++++++++++++ src/winml/modelkit/onnx/__init__.py | 2 - src/winml/modelkit/onnx/qwen_surgery.py | 186 --------- test_qwen 2.py | 70 ---- test_qwen.py | 235 ++++++++++++ 8 files changed, 1100 insertions(+), 347 deletions(-) rename qwen3_quantize.py => qwen3_transformer_only_quantize.py (54%) create mode 100644 src/winml/modelkit/models/hf/qwen3_export_ops.py create mode 100644 src/winml/modelkit/models/hf/qwen3_modeling.py create mode 100644 src/winml/modelkit/models/hf/qwen_transformer_only.py delete mode 100644 src/winml/modelkit/onnx/qwen_surgery.py delete mode 100644 test_qwen 2.py create mode 100644 test_qwen.py diff --git a/qwen3_quantize.py b/qwen3_transformer_only_quantize.py similarity index 54% rename from qwen3_quantize.py rename to qwen3_transformer_only_quantize.py index 655c65e6a..8b4efa9b7 100644 --- a/qwen3_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -1,18 +1,15 @@ -"""Qwen3 transformer-only quantization. +"""Transformer-only w8a16 quantization for Qwen3. -Must be called after the composite Qwen3 model has been built (e.g. by -``test_qwen 2.py``) so that ``decoder_prefill`` / ``decoder_gen`` ONNX files -exist in the winml cache. +Targets the transformer-only ONNX produced by +``qwen_transformer_only.install() + test_qwen.py``: -Pipeline: + - **No embedding/lm_head surgery.** The export already excludes both, + so we feed ``WinMLQuantization`` the file directly. + - **Transformer-shaped calibration feeds.** ``input_hidden_states`` (FP32), + ``past_seq_len`` / ``total_seq_len`` (INT32), ``past_keys_{i}`` / + ``past_values_{i}`` (FP16) — names + dtypes match the exported graph. - 1. Apply ``make_transformer_only`` surgery to each sub-model, producing - ``*_transformer.onnx`` with ``inputs_embeds`` input and - ``output_hidden_states`` output — embeddings and lm_head are stripped - out (ignored, not quantized). - 2. Quantize those transformer-only files via winml-cli's ``quantize_onnx`` - using a calibration reader that runs ``embed_tokens`` in PyTorch on - real text samples. +Run via ``test_qwen.py``. """ from __future__ import annotations @@ -26,7 +23,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from winml.modelkit.models.winml.composite_model import WinMLCompositeModel -from winml.modelkit.onnx import make_transformer_only from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx @@ -36,31 +32,28 @@ DEFAULT_MAX_CACHE = 256 DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 -DEFAULT_NUM_SAMPLES = 16 -DEFAULT_PROMPTS = [ - "Solve: 8 * 7 = ?", - "Translate to French: The weather is nice today.", - "Write a short poem about the ocean.", - "Explain gradient descent in one paragraph.", - "What is the capital of Japan?", - "List three uses of magnesium.", - "Summarize the plot of Hamlet in two sentences.", - "Give a Python one-liner to reverse a string.", -] - - -# --------------------------------------------------------------------------- -# Calibration data reader -# --------------------------------------------------------------------------- - - -class Qwen3TransformerCalibReader: - """Yields calibration feeds for the transformer-only Qwen3 ONNX. - - Runs HF ``embed_tokens`` in PyTorch to produce ``inputs_embeds`` since the - embedding layer was stripped from the ONNX graph. All other inputs - (attention_mask, position_ids, past_{i}_key/value) follow the conventions - used by winml-cli's ``WinMLQwen3Model`` runtime. +DEFAULT_NUM_SAMPLES = 30 +DEFAULT_CALIB_DATASET = "openai/gsm8k" +DEFAULT_CALIB_DATASET_CONFIG = "main" +DEFAULT_CALIB_SPLIT = "train" +DEFAULT_CALIB_SEED = 42 + + +def _load_gsm8k_prompts(num_samples: int) -> list[str]: + """GSM8K train split, shuffled seed=42 for reproducible calibration.""" + from datasets import load_dataset + + ds = load_dataset(DEFAULT_CALIB_DATASET, DEFAULT_CALIB_DATASET_CONFIG) + split = ds[DEFAULT_CALIB_SPLIT].shuffle(seed=DEFAULT_CALIB_SEED) + return [row["question"] for row in split.select(range(num_samples))] + + +class Qwen3TransformerOnlyCalibReader: + """Yields calibration feeds for the transformer-only ONNX. + + Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), + ``past_seq_len`` (INT32 ``[1,1]``), ``total_seq_len`` (INT32 ``[1]``), + and ``past_keys_{i}`` / ``past_values_{i}`` (FP16, full cache buffer). """ def __init__( @@ -73,7 +66,6 @@ def __init__( max_cache_len: int, ) -> None: self.embed = embed_tokens - self.cfg = config self.seq_len = seq_len self.max_cache_len = max_cache_len self.num_layers = config.num_hidden_layers @@ -85,11 +77,8 @@ def __init__( self._iter: Iterator[dict[str, np.ndarray]] | None = None self.rewind() - def _build_samples( - self, token_ids_list: list[torch.Tensor] - ) -> Iterator[dict[str, np.ndarray]]: + def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[str, np.ndarray]]: for ids in token_ids_list: - # Right-truncate / pad to seq_len so we feed the static graph shape. ids = ids[:, : self.seq_len] real_len = ids.shape[1] if real_len < self.seq_len: @@ -101,25 +90,22 @@ def _build_samples( with torch.no_grad(): embeds = self.embed(ids).to(torch.float32).cpu().numpy() - # attention_mask: ones for real prompt positions placed at the - # END of the max_cache buffer (sliding-window cache convention), - # zeros elsewhere. - attn_mask = np.zeros((1, self.max_cache_len), dtype=np.int64) - attn_mask[0, -real_len:] = 1 - - # position_ids: 0..seq_len-1 (clamped for padding). - position_ids = np.arange(self.seq_len, dtype=np.int64)[None, :] - feed: dict[str, np.ndarray] = { - "inputs_embeds": embeds.astype(np.float32), - "attention_mask": attn_mask, - "position_ids": position_ids, + "input_hidden_states": embeds.astype(np.float32), + # seqlens_k for GQA = (valid context length - 1), i.e. + # ``embeddings.shape[1] - 1``. We pad to seq_len, so the query + # has seq_len valid positions → past_seq_len = seq_len - 1. + # (Using 0 here declares only 1 valid token while feeding a + # seq_len-token query, which makes the GQA prefill kernel read + # out of bounds → native access violation.) + "past_seq_len": np.array([[self.seq_len - 1]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), } kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) - zeros = np.zeros(kv_shape, dtype=np.float32) + zeros = np.zeros(kv_shape, dtype=np.float16) for i in range(self.num_layers): - feed[f"past_{i}_key"] = zeros - feed[f"past_{i}_value"] = zeros + feed[f"past_keys_{i}"] = zeros + feed[f"past_values_{i}"] = zeros yield feed def get_next(self) -> dict[str, np.ndarray] | None: @@ -132,16 +118,7 @@ def rewind(self) -> None: self._iter = iter(self._samples) -# --------------------------------------------------------------------------- -# Pipeline -# --------------------------------------------------------------------------- - - -def _tokenize_prompts( - tokenizer: Any, prompts: list[str], num_samples: int -) -> list[torch.Tensor]: - # Cycle through prompts up to num_samples; apply chat template like the - # runtime so calibration distribution matches inference inputs. +def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: out: list[torch.Tensor] = [] for i in range(num_samples): prompt = prompts[i % len(prompts)] @@ -166,19 +143,15 @@ def quantize_built_model( weight_type: str = "uint8", activation_type: str = "uint16", ) -> dict[str, Path]: - """Run surgery + transformer-only quantization on an already-built composite. - - Reuses the ONNX files produced by ``WinMLCompositeModel.from_pretrained`` - so this can be called after a build step without re-exporting. + """Quantize the transformer-only ONNX files in-place. - Returns: mapping of sub-model name → quantized ONNX path. + Returns ``{sub_model_name: quantized_path}``. """ + # Locate the un-compiled ONNX for each sub-model (no surgery — file is + # already transformer-only). sub_paths: dict[str, Path] = {} for name, sub in model.sub_models.items(): final_path = Path(sub._onnx_path) - # ``_model.onnx`` is the *compiled* QNN EPContext blob — surgery needs - # the uncompiled fp16 graph. ``build.hf`` emits ``{cache_key}_optimized.onnx`` - # alongside it in the same artifacts directory. if final_path.name.endswith("_model.onnx"): stem = final_path.name[: -len("_model.onnx")] optimized = final_path.with_name(f"{stem}_optimized.onnx") @@ -187,7 +160,7 @@ def quantize_built_model( continue print( f"WARNING: {optimized.name} not found next to {final_path.name}; " - "falling back to the compiled model (surgery will likely fail)." + "falling back to the compiled model." ) sub_paths[name] = final_path @@ -199,7 +172,14 @@ def quantize_built_model( hf_model.eval() embed_tokens = hf_model.get_input_embeddings() tokenizer = AutoTokenizer.from_pretrained(model_id) - token_ids_list = _tokenize_prompts(tokenizer, DEFAULT_PROMPTS, num_samples) + + print( + f"=== Loading {num_samples} GSM8K calibration prompts " + f"({DEFAULT_CALIB_DATASET}/{DEFAULT_CALIB_DATASET_CONFIG}, " + f"split={DEFAULT_CALIB_SPLIT}, seed={DEFAULT_CALIB_SEED}) ===" + ) + prompts = _load_gsm8k_prompts(num_samples) + token_ids_list = _tokenize_prompts(tokenizer, prompts, num_samples) seq_by_sub = { "decoder_prefill": prefill_seq, @@ -213,19 +193,14 @@ def quantize_built_model( continue seq_len = seq_by_sub[sub_name] - transformer_path = fused_path.with_name(fused_path.stem + "_transformer.onnx") - quant_path = transformer_path.with_name( - transformer_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + quant_path = fused_path.with_name( + fused_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" ) - print(f"\n=== Surgery: {sub_name} (seq_len={seq_len}) ===") + print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") print(f" in : {fused_path}") - print(f" out: {transformer_path}") - make_transformer_only(fused_path, transformer_path) - - print(f"\n=== Quantize (transformer only): {sub_name} ===") print(f" out: {quant_path}") - reader = Qwen3TransformerCalibReader( + reader = Qwen3TransformerOnlyCalibReader( embed_tokens, hf_model.config, token_ids_list, @@ -239,7 +214,7 @@ def quantize_built_model( calibration_method="minmax", calibration_data=reader, ) - result = quantize_onnx(transformer_path, output_path=quant_path, config=cfg) + result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) if not result.success: print(" FAILED:") for err in result.errors: @@ -253,4 +228,3 @@ def quantize_built_model( print("\n=== Done ===") return quant_paths - diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py new file mode 100644 index 000000000..61d45f0ef --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Custom ONNX export ops + the entry point that reshapes HF's Qwen3 modules +for the transformer-only export. + +These reshape the standard HF Qwen3 modules so winml-cli can produce a +QNN-friendly, transformer-only graph: + +- ``LpNormalization`` replaces the eager RMSNorm Mul/Pow/ReduceMean chain. +- ``com.microsoft::GroupQueryAttention`` replaces the eager QKV MatMul + + Softmax + KV-update path (with built-in rotary). +- 1x1 ``Conv`` (NHWC<->NCHW) replaces ``nn.Linear`` for QNN-friendly + projections. + +Everything here operates only on the standard ``transformers.models.qwen3`` +module attributes. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.onnx import symbolic_helper + + +# ============================================================================= +# Custom ONNX symbolic functions +# ============================================================================= + + +class LpNormOnnxExport(torch.autograd.Function): + """RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim).""" + + @staticmethod + def symbolic(g, input, axis, p): # noqa: D401 + output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input)) + output = g.op( + "onnx::LpNormalization", + input, + axis_i=int(axis), + p_i=int(p), + ) + return output.setType(output_type) + + @staticmethod + def forward(ctx, input, axis, p): # noqa: ARG004 + return input # placeholder — real compute happens in symbolic + + +class GroupQueryAttentionOnnxExport(torch.autograd.Function): + """Fused Q/K/V + KV-cache + rotary → ``com.microsoft::GroupQueryAttention``.""" + + @staticmethod + def symbolic( + g, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ): + args = [query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache] + attention_output, present_keys, present_values = g.op( + "com.microsoft::GroupQueryAttention", + *args, + do_rotary_i=int(do_rotary), + kv_num_heads_i=int(kv_num_heads), + num_heads_i=int(num_heads), + outputs=3, + ) + + query_sizes = symbolic_helper._get_tensor_sizes(query) + attention_output.setType(query.type().with_sizes(query_sizes)) + present_keys.setType(past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key))) + present_values.setType(past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value))) + return attention_output, present_keys, present_values + + @staticmethod + def forward( + ctx, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ): # noqa: ARG004 + return query, past_key, past_value # placeholder shapes + + +# ============================================================================= +# 1x1 Conv replacement for nn.Linear +# ============================================================================= + + +class TransposeConv2d1x1Transpose(nn.Module): + """``nn.Linear`` → 1x1 ``Conv2d`` with NHWC<->NCHW permutes.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + weight: torch.nn.Parameter, + bias: torch.nn.Parameter | None = None, + ) -> None: + super().__init__() + # Linear weight is (out, in); Conv2d weight is (out, in, 1, 1). + self.weight = nn.Parameter(weight.data.view(out_channels, in_channels, 1, 1)) + self.bias = bias + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + x = torch.nn.functional.conv2d(x, self.weight) + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + if self.bias is not None: + x = x + self.bias + return x + + @classmethod + def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: + return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) + + +# ============================================================================= +# Apply export prep: bind winml Qwen3 export methods onto a loaded model +# ============================================================================= + + +def apply_transformer_only_export_prep(causal_lm: nn.Module, *, matmul_to_conv: bool = True) -> None: + """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. + + Binds the winml-owned export behaviour from :mod:`.qwen3_modeling` onto each + Qwen3 submodule (runs ``prepare_for_onnx_export`` and rebinds ``forward``). + After this call, ``causal_lm.model(inputs_embeds, past_key_values, + past_seq_len, total_seq_len)`` runs the transformer-only forward. + + Args: + causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. + matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so + QNN sees them as Conv. + """ + from .qwen3_modeling import ( + WinMLQwen3Attention, + WinMLQwen3DecoderLayer, + WinMLQwen3MLP, + WinMLQwen3Model, + WinMLQwen3RMSNorm, + ) + + def _bind(module: nn.Module, owner: type) -> None: + module.forward = owner.forward.__get__(module, type(module)) + + # Identify Qwen3 submodules by their (stock HF) class name so we don't + # depend on importing ``transformers.models.qwen3`` here. + def _is(module: nn.Module, name: str) -> bool: + return type(module).__name__ == name + + # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, + # in input/post_attention layernorms). + for mod in causal_lm.modules(): + if _is(mod, "Qwen3RMSNorm"): + WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + _bind(mod, WinMLQwen3RMSNorm) + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Attention"): + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Attention) + elif _is(mod, "Qwen3MLP"): + # MLP forward is unchanged; only the projections are swapped to Conv. + WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + + # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; + # the export forward invokes ``self.rotary_emb`` on the attention module, + # so re-attach a reference from the parent model. + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): + for layer in mod.layers: + layer.self_attn.rotary_emb = mod.rotary_emb + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3DecoderLayer"): + _bind(mod, WinMLQwen3DecoderLayer) + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model"): + WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Model) + + +__all__ = [ + "GroupQueryAttentionOnnxExport", + "LpNormOnnxExport", + "TransposeConv2d1x1Transpose", + "apply_transformer_only_export_prep", +] diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py new file mode 100644 index 000000000..05a70adfe --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -0,0 +1,237 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""winml-owned Qwen3 model definitions for the transformer-only ONNX export. + +Each class is a plain ``nn.Module`` that carries the export-time behaviour +directly (``prepare_for_onnx_export`` + ``forward``). The export entry point +binds these ``forward`` methods onto the corresponding live Qwen3 submodules, +so the stock eager model is left untouched. + +What each class emits: + +- ``WinMLQwen3RMSNorm`` -> ``onnx::LpNormalization`` body. +- ``WinMLQwen3Attention`` -> ``com.microsoft::GroupQueryAttention`` (built-in + rotary) with optional 1x1 ``Conv`` projections. +- ``WinMLQwen3MLP`` -> 1x1 ``Conv`` projections (NHWC). +- ``WinMLQwen3DecoderLayer`` / ``WinMLQwen3Model`` -> transformer-only forward + that threads the KV cache + seq-len tensors and omits embeddings / lm_head. + +``apply_transformer_only_export_prep`` (in ``qwen3_export_ops``) walks a loaded +``Qwen3ForCausalLM``, calls ``prepare_for_onnx_export`` on each submodule, and +binds the matching ``forward`` from these classes onto it. +""" + +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn + +from .qwen3_export_ops import ( + GroupQueryAttentionOnnxExport, + LpNormOnnxExport, + TransposeConv2d1x1Transpose, +) + + +class WinMLQwen3RMSNorm(nn.Module): + """RMSNorm export variant — ``onnx::LpNormalization`` body.""" + + def prepare_for_onnx_export(self) -> None: + # Pre-multiply the gain into the weight (LpNorm has unit gain). + n = self.weight.numel() + scale = torch.sqrt( + torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) + ) + if torch.any(self.weight.data != torch.ones_like(self.weight)).item(): + new_w = scale * self.weight + else: + new_w = scale + self.weight = nn.Parameter(new_w) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + out = LpNormOnnxExport.apply(hidden_states, -1, 2) + return self.weight * out + + +class WinMLQwen3MLP(nn.Module): + """MLP export variant — 1x1 Conv projections (forward unchanged).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + if not matmul_to_conv: + return + self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module(self.gate_proj) + self.up_proj = TransposeConv2d1x1Transpose.from_linear_module(self.up_proj) + self.down_proj = TransposeConv2d1x1Transpose.from_linear_module(self.down_proj) + + +class WinMLQwen3Attention(nn.Module): + """Attention export variant — fused ``GroupQueryAttention`` op.""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + if matmul_to_conv: + self.q_proj = TransposeConv2d1x1Transpose.from_linear_module(self.q_proj) + self.k_proj = TransposeConv2d1x1Transpose.from_linear_module(self.k_proj) + self.v_proj = TransposeConv2d1x1Transpose.from_linear_module(self.v_proj) + self.o_proj = TransposeConv2d1x1Transpose.from_linear_module(self.o_proj) + self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + **kwargs: Any, # noqa: ARG002 + ) -> tuple[torch.Tensor, None, tuple[torch.Tensor, torch.Tensor]]: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + input_shape = hidden_states.shape[1:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_norm(query_states.view(hidden_shape)) + key_states = self.k_norm(key_states.view(hidden_shape)) + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + query_dim = num_heads * self.head_dim + key_dim = num_kv_heads * self.head_dim + query_states = query_states.reshape(1, -1, query_dim) + key_states = key_states.reshape(1, -1, key_dim) + + if self._matmul_to_conv: + value_states = value_states.squeeze(0) + + past_keys, past_values = past_key_value + + # GroupQueryAttention requires Q/K/V/past_K/past_V to share dtype. + # The KV cache is FP16, so cast Q/K/V to the same dtype; otherwise ORT + # type inference rejects the node. + kv_dtype = past_keys.dtype + if query_states.dtype != kv_dtype: + query_states = query_states.to(kv_dtype) + key_states = key_states.to(kv_dtype) + value_states = value_states.to(kv_dtype) + + cos, sin = self.rotary_emb( + value_states, + torch.arange(self.config.max_position_embeddings).unsqueeze(0), + ) + cos = cos.squeeze(0)[:, : cos.shape[-1] // 2] + sin = sin.squeeze(0)[:, : sin.shape[-1] // 2] + if cos.dtype != kv_dtype: + cos = cos.to(kv_dtype) + sin = sin.to(kv_dtype) + + if isinstance(past_seq_len, int): + past_seq_len = torch.tensor(past_seq_len) + past_seq_len = torch.atleast_2d(past_seq_len) + + attention_output, present_keys, present_values = GroupQueryAttentionOnnxExport.apply( + query_states, + key_states, + value_states, + past_keys, + past_values, + past_seq_len, + total_seq_len, + cos, + sin, + 1, # do_rotary + num_kv_heads, + num_heads, + ) + + # Cast back to the residual-stream dtype so the downstream Conv + # (o_proj) sees its expected weight dtype. + if attention_output.dtype != hidden_states.dtype: + attention_output = attention_output.to(hidden_states.dtype) + + if self._matmul_to_conv: + attention_output = attention_output.unsqueeze(0) + + attention_output = self.o_proj(attention_output) + return attention_output, None, (present_keys, present_values) + + +class WinMLQwen3DecoderLayer(nn.Module): + """Decoder-layer export variant — threads KV cache + seq-len kwargs.""" + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + use_cache: bool = True, + **kwargs: Any, # noqa: ARG002 + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_out, _, present_kv = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + ) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if use_cache: + outputs += (present_kv,) + return outputs + + +class WinMLQwen3Model(nn.Module): + """Model export variant — transformer-only body (no embeddings / lm_head).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + + def forward( + self, + inputs_embeds: torch.Tensor, + past_key_values: list[tuple[torch.Tensor, torch.Tensor]], + past_seq_len: torch.Tensor, + total_seq_len: torch.Tensor, + use_cache: bool = True, + ) -> tuple[torch.Tensor, tuple[tuple[torch.Tensor, torch.Tensor], ...]]: + hidden_states = inputs_embeds + if self._matmul_to_conv: + hidden_states = hidden_states.unsqueeze(0) # NHWC for Conv path + + present_kvs: tuple[tuple[torch.Tensor, torch.Tensor], ...] = () + for idx, layer in enumerate(self.layers): + out = layer( + hidden_states, + past_key_value=past_key_values[idx], + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=use_cache, + ) + hidden_states = out[0] + if use_cache: + present_kvs += (out[1],) + + hidden_states = self.norm(hidden_states) + if self._matmul_to_conv: + hidden_states = hidden_states.squeeze(0) + return hidden_states, present_kvs + + +__all__ = [ + "WinMLQwen3Attention", + "WinMLQwen3DecoderLayer", + "WinMLQwen3MLP", + "WinMLQwen3Model", + "WinMLQwen3RMSNorm", +] diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py new file mode 100644 index 000000000..8e30b1fb6 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -0,0 +1,354 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Parallel ``qwen3`` build path that produces a transformer-only ONNX. + +Opt-in via ``install()`` — calling it hot-patches the WinML registries so +that the next ``WinMLAutoModel.from_pretrained("Qwen/Qwen3-*", task="text-generation")`` +exports two transformer-only ONNX files (a prefill/context graph and an +iteration/decode graph) with this I/O: + + Inputs : past_keys_{i}, past_values_{i} (FP16, ``[1, kv_heads, max_cache, head_dim]``), + input_hidden_states (FP32, ``[1, seq_len, hidden]``), + past_seq_len (INT32, ``[1, 1]``), total_seq_len (INT32, ``[1]``) + Outputs: output_hidden_states (FP32), present_keys_{i}, present_values_{i} (FP16) + Ops : ``com.microsoft::GroupQueryAttention`` (do_rotary=1), + ``onnx::LpNormalization`` (RMSNorm), 1x1 ``Conv`` projections. + +The original eager-export path in ``qwen.py`` is left intact — only the +qwen3 entries in the registries are replaced. ``install()`` is idempotent. +""" + +from __future__ import annotations + +import logging +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyInputGenerator +from transformers import AutoModelForCausalLM + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...export.config import WinMLExportConfig +from ..winml import register_specialization +from ..winml.decoder_only import WinMLDecoderOnlyModel +from ..winml.kv_cache import WinMLSlidingWindowCache +from .qwen3_export_ops import apply_transformer_only_export_prep + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Wrapper module +# ============================================================================= + + +class QwenTransformerOnlyDecoderWrapper(nn.Module): + """Wraps ``Qwen3ForCausalLM`` for transformer-only export. + + The wrapper applies the export prep (LpNorm RMSNorm, GQA op, 1x1 + Conv projections) in ``__init__`` and exposes a positional ``forward`` + whose argument order matches :class:`QwenTransformerOnlyPrefillIOConfig.inputs`. + Only ``self.model.model`` (the inner ``Qwen3Model``) is invoked at + export time — embedding lookup and ``lm_head`` stay out of the graph. + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + self.config = model.config + apply_transformer_only_export_prep(model, matmul_to_conv=True) + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: + kwargs.setdefault("torch_dtype", torch.float32) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) + model.config._attn_implementation = "eager" + wrapper = cls(model, model.config.num_hidden_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Positional inputs (order matches OnnxConfig.inputs): + + past_keys_0, past_values_0, ..., past_keys_{L-1}, past_values_{L-1}, + input_hidden_states, past_seq_len, total_seq_len + + Returns ``(output_hidden_states, present_keys_0, present_values_0, ...)``. + """ + kv_args = args[: 2 * self.num_layers] + input_hidden_states = args[2 * self.num_layers] + past_seq_len = args[2 * self.num_layers + 1] + total_seq_len = args[2 * self.num_layers + 2] + + past_key_values = [ + (kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers) + ] + + hidden_states, present_kvs = self.model.model( + inputs_embeds=input_hidden_states, + past_key_values=past_key_values, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=True, + ) + + out: list[torch.Tensor] = [hidden_states] + for k, v in present_kvs: + out.extend([k, v]) + return tuple(out) + + +# ============================================================================= +# Dummy input generators (transformer-only I/O) +# ============================================================================= + + +class _TransformerOnlyHiddenStateGenerator(DummyInputGenerator): + """Generates ``input_hidden_states`` (FP32, ``[1, seq_len, hidden]``).""" + + SUPPORTED_INPUT_NAMES = ("input_hidden_states",) + + _default_seq_len: ClassVar[int] = 1 + + def __init__( + self, + task: str, + normalized_config: Any, + batch_size: int = 1, + seq_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.hidden_size = normalized_config.hidden_size + self.seq_len = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + if input_name == "input_hidden_states": + return torch.randn(self.batch_size, self.seq_len, self.hidden_size, dtype=torch.float32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyHiddenStatePrefillGenerator(_TransformerOnlyHiddenStateGenerator): + _default_seq_len = 64 + + +class _TransformerOnlySeqLenGenerator(DummyInputGenerator): + """Generates ``past_seq_len`` (INT32 ``[1,1]``) and ``total_seq_len`` (INT32 ``[1]``).""" + + SUPPORTED_INPUT_NAMES = ("past_seq_len", "total_seq_len") + + def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: # noqa: ARG002 + self.max_cache_len = normalized_config.max_cache_len + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + if input_name == "past_seq_len": + return torch.zeros((1, 1), dtype=torch.int32) + if input_name == "total_seq_len": + return torch.tensor([self.max_cache_len], dtype=torch.int32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyKvCacheGenerator(DummyInputGenerator): + """Generates ``past_keys_{i}`` / ``past_values_{i}`` (FP16).""" + + SUPPORTED_INPUT_NAMES = () # built dynamically in __init__ + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.num_layers: int = normalized_config.num_layers + self.num_heads: int = normalized_config.num_attention_heads # KV heads (NormalizedConfig maps it) + self.head_dim: int = normalized_config.head_dim + self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len + self.SUPPORTED_INPUT_NAMES = tuple( + name for i in range(self.num_layers) for name in (f"past_keys_{i}", f"past_values_{i}") + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + shape = (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim) + return torch.zeros(shape, dtype=torch.float16) + + +# ============================================================================= +# OnnxConfigs — transformer-only I/O layout +# ============================================================================= + + +_QWEN_TRANSFORMER_ONLY_NORMALIZED = NormalizedConfig.with_args( + hidden_size="hidden_size", + num_layers="num_hidden_layers", + num_attention_heads="num_key_value_heads", # KV heads (GQA) + head_dim="head_dim", + max_cache_len="max_position_embeddings", + vocab_size="vocab_size", + allow_new=True, +) + + +def _transformer_only_inputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: + """Input ordering: past KV pairs, then hidden states, then seq lens.""" + result: dict[str, dict[int, str]] = {} + for i in range(num_layers): + result[f"past_keys_{i}"] = {2: kv_seq_axis} + result[f"past_values_{i}"] = {2: kv_seq_axis} + result["input_hidden_states"] = {1: "seq_len"} + result["past_seq_len"] = {} + result["total_seq_len"] = {} + return result + + +def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = {"output_hidden_states": {1: "seq_len"}} + for i in range(num_layers): + result[f"present_keys_{i}"] = {2: kv_seq_axis} + result[f"present_values_{i}"] = {2: kv_seq_axis} + return result + + +class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): + """Prefill (seq=64) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStatePrefillGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_outputs(self._normalized_config.num_layers) + + +class QwenTransformerOnlyGenIOConfig(OnnxConfig): + """Generation (seq=1) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStateGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_outputs(self._normalized_config.num_layers) + + +# ============================================================================= +# Build config — TorchScript exporter required for the custom autograd ops +# ============================================================================= + + +QWEN_TRANSFORMER_ONLY_CONFIG = WinMLBuildConfig( + export=WinMLExportConfig(dynamo=False, opset_version=18), + # Pure graph (no post-export RMSNorm fusion / matmul-add fusion). + optim=None, +) + + +# ============================================================================= +# Composite inference wrapper (placeholder so the build pipeline finds a +# composite class — generation isn't yet wired for the transformer-only +# I/O signature). +# ============================================================================= + + +class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): + """Composite handle for the transformer-only Qwen3 build (export only). + + ``generate()`` is **not** functional with this build path — the inference + feeds and KV update logic still target the eager I/O signature. Use the + eager :class:`WinMLQwen3Model` for generation; use this class to produce + the transformer-only ONNX for downstream quantization. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "decoder_prefill": "feature-extraction", + "decoder_gen": "text2text-generation", + } + + @classmethod + def get_cache_class(cls) -> type: + return WinMLSlidingWindowCache + + +# ============================================================================= +# install() — hot-patch the registries +# ============================================================================= + + +_INSTALLED = False + + +def install() -> None: + """Replace the qwen3 entries in WinML registries with the transformer-only variants. + + Idempotent. After this call, building any qwen3 model via + :class:`~winml.modelkit.models.winml.composite_model.WinMLCompositeModel` + or :class:`~winml.modelkit.models.auto.WinMLAutoModel` produces + transformer-only ONNX files. + """ + global _INSTALLED + if _INSTALLED: + return + + # 1) Per-model build config + wrapper-class lookup live on the parent + # ``models.hf`` package as module-level dicts; mutating them is the + # documented hook for adding/overriding a model_type. + from .. import hf as _hf_pkg # noqa: PLC0415 + + _hf_pkg.MODEL_BUILD_CONFIGS["qwen3"] = QWEN_TRANSFORMER_ONLY_CONFIG + _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "feature-extraction")] = QwenTransformerOnlyDecoderWrapper + _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "text2text-generation")] = QwenTransformerOnlyDecoderWrapper + + # 2) Optimum OnnxConfig (overwrites existing registration). + register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers")(QwenTransformerOnlyPrefillIOConfig) + register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers")(QwenTransformerOnlyGenIOConfig) + + # 3) Inference specialization (still GenericTask — wrapper returns raw KV). + register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") + register_specialization("qwen3", "text2text-generation", "WinMLModelForGenericTask") + + # 4) Composite registry — swap to the transformer-only handle. + from ..winml.composite_model import COMPOSITE_MODEL_REGISTRY + + COMPOSITE_MODEL_REGISTRY[("qwen3", "text-generation")] = WinMLQwen3TransformerOnlyModel + + _INSTALLED = True + logger.info("qwen_transformer_only: transformer-only export path installed for qwen3.") + + +__all__ = [ + "QWEN_TRANSFORMER_ONLY_CONFIG", + "QwenTransformerOnlyDecoderWrapper", + "QwenTransformerOnlyGenIOConfig", + "QwenTransformerOnlyPrefillIOConfig", + "WinMLQwen3TransformerOnlyModel", + "install", +] diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index 0287a2ff7..a3bc49d51 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -19,7 +19,6 @@ from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx -from .qwen_surgery import make_transformer_only from .shape import infer_onnx_shapes, infer_shapes from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size @@ -42,7 +41,6 @@ "is_compiled_onnx", "is_quantized_onnx", "load_onnx", - "make_transformer_only", "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", diff --git a/src/winml/modelkit/onnx/qwen_surgery.py b/src/winml/modelkit/onnx/qwen_surgery.py deleted file mode 100644 index cd49ee5ec..000000000 --- a/src/winml/modelkit/onnx/qwen_surgery.py +++ /dev/null @@ -1,186 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Ad-hoc ONNX surgery to turn a Qwen3 decoder ONNX into a transformer-only graph. - -Applied as a post-export surgery on the fused decoder ONNX produced by -``WinMLQwen3Model`` (``decoder_prefill.onnx`` / ``decoder_gen.onnx``). - -The resulting transformer-only ONNX has: - - ``input_ids`` graph input replaced by ``inputs_embeds`` (FLOAT, - ``[batch, seq, hidden_size]``) — the upstream embedding Gather is - removed. - - ``logits`` graph output replaced by ``output_hidden_states`` - (FLOAT, ``[batch, seq, hidden_size]``) — the final ``lm_head`` MatMul - is removed. -""" - -from __future__ import annotations - -import logging -from pathlib import Path - -import onnx -from onnx import TensorProto, helper - -from .persistence import load_onnx, save_onnx - - -logger = logging.getLogger(__name__) - - -def _dim(d: onnx.TensorShapeProto.Dimension) -> int | str: - if d.HasField("dim_value"): - return d.dim_value - return d.dim_param or "?" - - -def make_transformer_only( - model_path: str | Path, - output_path: str | Path, - *, - input_ids_name: str = "input_ids", - logits_name: str = "logits", - inputs_embeds_name: str = "inputs_embeds", - output_hidden_states_name: str = "output_hidden_states", -) -> Path: - """Strip the embedding Gather and the lm_head MatMul from a Qwen3 ONNX. - - Args: - model_path: Path to the fused decoder ONNX (logits output, input_ids input). - output_path: Destination for the transformer-only ONNX. - input_ids_name: Name of the input_ids graph input to drop. - logits_name: Name of the logits graph output to drop. - inputs_embeds_name: Display name for the new embeddings input - (used only for logging; the actual tensor keeps its existing - internal name so downstream nodes need no rewiring). - output_hidden_states_name: Display name for the new hidden-state output. - - Returns: - The output path. - """ - model_path = Path(model_path) - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - model = load_onnx(model_path, load_weights=True, validate=False) - graph = model.graph - init_by_name = {init.name: init for init in graph.initializer} - - # -------------------- Embedding removal -------------------- - embed_idx = next( - (i for i, n in enumerate(graph.node) if input_ids_name in n.input), - None, - ) - if embed_idx is None: - msg = f"No node consumes graph input {input_ids_name!r}" - raise RuntimeError(msg) - - embed_node = graph.node[embed_idx] - embed_out_name = embed_node.output[0] - - embed_weight = None - for ipt in embed_node.input: - init = init_by_name.get(ipt) - if init is not None and len(init.dims) == 2: - embed_weight = init - break - if embed_weight is None: - msg = f"Could not find 2-D embedding weight initializer on node {embed_node.name!r}" - raise RuntimeError(msg) - hidden_size = int(embed_weight.dims[1]) - - ids_input = next(i for i in graph.input if i.name == input_ids_name) - batch_dim = _dim(ids_input.type.tensor_type.shape.dim[0]) - seq_dim = _dim(ids_input.type.tensor_type.shape.dim[1]) - - logger.info( - "Removing embedding node %r (%s) — exposing %r as new input %r [%s, %s, %d]", - embed_node.name, - embed_node.op_type, - embed_out_name, - inputs_embeds_name, - batch_dim, - seq_dim, - hidden_size, - ) - - new_embed_input = helper.make_tensor_value_info( - inputs_embeds_name, - TensorProto.FLOAT, - [batch_dim, seq_dim, hidden_size], - ) - - del graph.node[embed_idx] - graph.input.remove(ids_input) - graph.input.append(new_embed_input) - graph.initializer.remove(embed_weight) - - # Rewire any consumer of the removed embedding output to the new input. - for n in graph.node: - for i, name in enumerate(n.input): - if name == embed_out_name: - n.input[i] = inputs_embeds_name - - # -------------------- lm_head removal -------------------- - lmh_idx = next( - (i for i, n in enumerate(graph.node) if logits_name in n.output), - None, - ) - if lmh_idx is None: - msg = f"No node produces graph output {logits_name!r}" - raise RuntimeError(msg) - - lmh_node = graph.node[lmh_idx] - init_names = {init.name for init in graph.initializer} - hidden_in: str | None = None - weight_in: str | None = None - for ipt in lmh_node.input: - if ipt in init_names: - weight_in = ipt - else: - hidden_in = ipt - if hidden_in is None: - msg = f"lm_head node {lmh_node.name!r} has no non-initializer input ({list(lmh_node.input)})" - raise RuntimeError(msg) - - logger.info( - "Removing lm_head node %r (%s) — exposing %r as new output %r", - lmh_node.name, - lmh_node.op_type, - hidden_in, - output_hidden_states_name, - ) - - logits_output = next(o for o in graph.output if o.name == logits_name) - new_hidden_output = helper.make_tensor_value_info( - output_hidden_states_name, - TensorProto.FLOAT, - [batch_dim, seq_dim, hidden_size], - ) - - del graph.node[lmh_idx] - graph.output.remove(logits_output) - # Put hidden states first so it mirrors the original logits position. - graph.output.insert(0, new_hidden_output) - - # Rename the producer of ``hidden_in`` to emit the new graph output name. - for n in graph.node: - for i, name in enumerate(n.output): - if name == hidden_in: - n.output[i] = output_hidden_states_name - for i, name in enumerate(n.input): - if name == hidden_in: - n.input[i] = output_hidden_states_name - - if weight_in is not None and not any(weight_in in n.input for n in graph.node): - wi = next(init for init in graph.initializer if init.name == weight_in) - graph.initializer.remove(wi) - - save_onnx(model, output_path) - logger.info("Wrote transformer-only ONNX → %s", output_path) - return output_path - - -__all__ = ["make_transformer_only"] diff --git a/test_qwen 2.py b/test_qwen 2.py deleted file mode 100644 index 6a52dee72..000000000 --- a/test_qwen 2.py +++ /dev/null @@ -1,70 +0,0 @@ -"""E2E test for Qwen3 decoder-only pipeline. - -Uses sub_model_kwargs to set per-component shape_config: - - decoder_prefill: max_cache_len=256, seq_len=64 - - decoder_gen: max_cache_len=256, seq_len=1 - -Set env var ``QUANTIZE=1`` to also run the MOPS-style Step 3: -transformer-only surgery + winml quantize on both sub-models -(embeddings and lm_head are stripped and not quantized). -""" - -import os - -from transformers import AutoTokenizer - -from winml.modelkit.config import WinMLBuildConfig -from winml.modelkit.models.winml.composite_model import WinMLCompositeModel - -model_id = "Qwen/Qwen3-0.6B" - -model = WinMLCompositeModel.from_pretrained( - model_id, - task="text-generation", - # config=WinMLBuildConfig(quant=None, compile=None), - config=WinMLBuildConfig(quant=None), - precision="fp16", - device="npu", - ep="qnn", - force_rebuild=False, - sub_model_kwargs={ - "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, - "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, - }, -) - -# Verify ONNX I/O shapes -for name, sub in model.sub_models.items(): - io = sub.io_config - shapes = dict(zip(io["input_names"], io["input_shapes"])) - print(f"\n=== {name} ===") - for k, v in shapes.items(): - print(f" {k}: {v}") - -tokenizer = AutoTokenizer.from_pretrained(model_id) - -prompt = "8 * 7 = ?" -messages = [{"role": "user", "content": prompt}] -text = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, -) -model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - -generated_ids = model.generate(**model_inputs) - -output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() -content = tokenizer.decode(output_ids, skip_special_tokens=True) -print("\nAnswer:", content) - -if os.environ.get("QUANTIZE") == "1": - # Reuse the already-built decoder_prefill/decoder_gen ONNX files: - # surgery (strip embed + lm_head) + transformer-only quantize. - print("\n=== QUANTIZE=1 — running transformer-only quantization ===") - from qwen3_quantize import quantize_built_model - - quantize_built_model( - model, - model_id=model_id, - max_cache_len=256, - prefill_seq=64, - ) diff --git a/test_qwen.py b/test_qwen.py new file mode 100644 index 000000000..f958c2932 --- /dev/null +++ b/test_qwen.py @@ -0,0 +1,235 @@ +"""E2E test for the transformer-only Qwen3 export path. + +Produces two transformer-only ONNX files whose I/O matches +``qwen3_gqa_fp16_ctx.onnx`` / ``qwen3_gqa_fp16_iter.onnx``: + + decoder_prefill: input_hidden_states [1, 64, 1024] → output_hidden_states + KV + decoder_gen : input_hidden_states [1, 1, 1024] → output_hidden_states + KV + +with FP16 past/present KV named ``past_keys_{i}`` / ``past_values_{i}``, +``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv +projections. + +Important: ``install()`` MUST be called before importing the composite model +machinery so the registry hot-patches take effect. + +Generation (``model.generate(...)``) is NOT supported by this build path — +the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager +I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end +generation. + +Run:: + + python test_qwen_transformer_only.py + +This builds each transformer sub-model and then runs the w8a16 +quantization on the exported transformer ONNX files (no surgery needed — +files are already transformer-only). +""" + +import os +import sys +import pathlib +import subprocess + +# Put the in-repo `src/` ahead of site-packages so `import winml` always +# resolves to the editable source tree — no manual copy-to-venv needed. +_repo_root = pathlib.Path(__file__).resolve().parent +sys.path.insert(0, str(_repo_root / "src")) +sys.path.insert(0, str(_repo_root)) + +model_id = "Qwen/Qwen3-0.6B" +MAX_CACHE = 256 + +# component name -> (HF task, seq_len, artifact prefix). Order matters +# (prefill first). The prefix is how the built npu_ctx file is named so the +# parent can verify success by artifact appearance (the build segfaults on +# native QNN/ORT teardown AFTER writing the file, so exit codes are unreliable). +SUB_MODELS = { + "decoder_prefill": ("feature-extraction", 64, "feat_"), + "decoder_gen": ("text2text-generation", 1, "txt2txt_"), +} + +ARTIFACTS_DIR = ( + pathlib.Path.home() / ".cache" / "winml" / "artifacts" / model_id.replace("/", "_") +) + + +def _latest_ctx_mtime(prefix: str) -> float: + """Newest mtime of a ``{prefix}*_optimized_npu_ctx.onnx`` artifact, or 0.""" + files = list(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) + return max((f.stat().st_mtime for f in files), default=0.0) + + +def _build_one(task: str, seq_len: int) -> None: + """Build a SINGLE transformer sub-model in this (fresh) process. + + Invoked as a subprocess by ``main()`` so each sub-model exports in a + clean interpreter — building both in one process leaves PyTorch/ORT + state from the first build that corrupts/kills the second. + """ + from winml.modelkit.models.hf.qwen_transformer_only import install as install_qwen_transformer_only + + install_qwen_transformer_only() + + from winml.modelkit.config import WinMLBuildConfig + from winml.modelkit.models.auto import WinMLAutoModel + + WinMLAutoModel.from_pretrained( + model_id, + task=task, + config=WinMLBuildConfig(quant=None, compile=None), + precision="fp16", + device="npu", + ep="qnn", + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, + ) + # The QNN/ORT teardown segfaults (0xC0000005) on interpreter shutdown + # AFTER the artifact is fully written. Skip the buggy cleanup with a hard + # exit so the parent sees a clean exit code 0. + print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +def _find_optimized(prefix: str) -> pathlib.Path: + """Locate the cached transformer-only ``{prefix}*_optimized.onnx`` file.""" + cands = [ + p for p in ARTIFACTS_DIR.glob(f"{prefix}*_optimized.onnx") + if not p.name.endswith("_optimized_npu_ctx.onnx") + ] + if not cands: + raise FileNotFoundError( + f"No {prefix}*_optimized.onnx in {ARTIFACTS_DIR} — build the sub-model first." + ) + return max(cands, key=lambda p: p.stat().st_mtime) + + +class _SubShim: + """Minimal stand-in exposing the ``_onnx_path`` quant needs.""" + + def __init__(self, onnx_path: pathlib.Path): + self._onnx_path = str(onnx_path) + + +class _ModelShim: + """Minimal stand-in exposing ``sub_models`` for ``quantize_built_model``.""" + + def __init__(self, sub_models: dict): + self.sub_models = sub_models + + +def _run_quant() -> None: + """Quantize the cached transformer ONNX files (no composite/QNN load). + + Runs as its own subprocess so any ORT teardown crash can't poison the + parent. Builds a shim ``model`` whose ``sub_models[name]._onnx_path`` + point straight at the cached ``*_optimized.onnx`` files. + """ + # Dump a native C-stack if the calibration InferenceSession segfaults + # (otherwise the crash is silent — no Python traceback). + import faulthandler + faulthandler.enable() + + from qwen3_transformer_only_quantize import quantize_built_model + + sub_models = { + name: _SubShim(_find_optimized(prefix)) + for name, (_task, _seq, prefix) in SUB_MODELS.items() + } + model = _ModelShim(sub_models) + print("=== Running transformer w8a16 quantization ===", flush=True) + for name, sub in sub_models.items(): + print(f" {name}: {sub._onnx_path}", flush=True) + + try: + quantize_built_model( + model, + model_id=model_id, + max_cache_len=MAX_CACHE, + prefill_seq=64, + ) + except BaseException: + import traceback + print("QUANT FAILED with exception:", flush=True) + traceback.print_exc() + sys.stdout.flush() + sys.stderr.flush() + raise + print("QUANT COMPLETE", flush=True) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +def main() -> None: + # 1) Build each sub-model in its OWN subprocess (fresh state each time). + # Judge success by whether a FRESH npu_ctx artifact appeared, NOT by the + # subprocess exit code: the native QNN/ORT layer segfaults (0xC0000005) + # on teardown AFTER the artifact is fully written to disk. + import time as _time + + for name, (task, seq_len, prefix) in SUB_MODELS.items(): + print(f"\n########## BUILD {name} (task={task}, seq_len={seq_len}) ##########", flush=True) + before = _latest_ctx_mtime(prefix) + start = _time.time() + rc = subprocess.run( + [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), + "--build-sub", task, str(seq_len)], + cwd=str(_repo_root), + ).returncode + + after = _latest_ctx_mtime(prefix) + if after > before and after >= start - 1: + status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" + print(f"########## {name} {status}: fresh {prefix}*_optimized_npu_ctx.onnx ##########", flush=True) + else: + raise SystemExit( + f"Sub-model build failed for {name} (exit {rc}) — " + f"no fresh {prefix}*_optimized_npu_ctx.onnx in {ARTIFACTS_DIR}" + ) + + # 2) Report the built transformer-only ONNX files (no composite/QNN load — + # that creates QNN EP sessions that segfault the parent on teardown). + for name, (_task, _seq, prefix) in SUB_MODELS.items(): + print(f"\n=== {name} ===") + print(f" optimized : {_find_optimized(prefix).name}") + ctx = sorted(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) + if ctx: + print(f" npu_ctx : {ctx[-1].name}") + + # 3) Quantization — run in its OWN subprocess for the same teardown-crash + # isolation. Judge by whether quant files appeared. + print("\n########## QUANTIZE ##########", flush=True) + before = max( + (p.stat().st_mtime for p in ARTIFACTS_DIR.glob("*quant.onnx")), + default=0.0, + ) + qstart = _time.time() + rc = subprocess.run( + [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], + cwd=str(_repo_root), + ).returncode + after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) + after = max((p.stat().st_mtime for p in after_files), default=0.0) + if after > before and after >= qstart - 1: + status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" + print(f"########## QUANTIZE {status} ##########", flush=True) + for p in sorted(after_files, key=lambda x: x.stat().st_mtime)[-len(SUB_MODELS):]: + print(f" {p.name}", flush=True) + else: + raise SystemExit( + f"Quantization failed (exit {rc}) — no fresh *quant.onnx in {ARTIFACTS_DIR}" + ) + + +if __name__ == "__main__": + if len(sys.argv) >= 4 and sys.argv[1] == "--build-sub": + _build_one(sys.argv[2], int(sys.argv[3])) + elif len(sys.argv) >= 2 and sys.argv[1] == "--quant": + _run_quant() + else: + main() + From 78815fd97d7458edc745185d65c8aefdb7b82d67 Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 22 Jun 2026 10:44:41 -0700 Subject: [PATCH 03/13] Fix Qwen3 w8a16 quant: symmetric int8 weights + exclude GQA from QDQ --- qwen3_transformer_only_quantize.py | 33 ++++++++++++++++++++++++++- src/winml/modelkit/quant/config.py | 9 ++++++++ src/winml/modelkit/quant/quantizer.py | 19 ++++++++++++--- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 8b4efa9b7..3ae895ae2 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -133,6 +133,23 @@ def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> l return out +def _gqa_node_names(onnx_path: Path) -> list[str]: + """Return the names of every GroupQueryAttention node in ``onnx_path``. + + These nodes are excluded from quantization so ORT leaves both their + inputs and output in float (``... -> Cast -> GQA -> Cast``), matching + the reference graph which keeps attention entirely out of QDQ. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + return [ + n.name + for n in model.graph.node + if n.op_type == "GroupQueryAttention" and n.name + ] + + def quantize_built_model( model: WinMLCompositeModel, *, @@ -140,7 +157,7 @@ def quantize_built_model( max_cache_len: int = DEFAULT_MAX_CACHE, prefill_seq: int = DEFAULT_PREFILL_SEQ, num_samples: int = DEFAULT_NUM_SAMPLES, - weight_type: str = "uint8", + weight_type: str = "int8", activation_type: str = "uint16", ) -> dict[str, Path]: """Quantize the transformer-only ONNX files in-place. @@ -200,6 +217,11 @@ def quantize_built_model( print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") print(f" in : {fused_path}") print(f" out: {quant_path}") + gqa_nodes = _gqa_node_names(fused_path) + print( + f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " + "quantization (inputs + output stay float, Cast -> GQA -> Cast)" + ) reader = Qwen3TransformerOnlyCalibReader( embed_tokens, hf_model.config, @@ -213,6 +235,15 @@ def quantize_built_model( activation_type=activation_type, # type: ignore[arg-type] calibration_method="minmax", calibration_data=reader, + # w8a16: symmetric int8 weights (zp=0) + asymmetric uint16 + # activations, matching the reference quantization. + weight_symmetric=True, + activation_symmetric=False, + # ORT treats GroupQueryAttention as quantizable and wraps both its + # inputs and output in QDQ. The reference keeps attention entirely + # in float (Cast -> GQA -> Cast), so exclude the GQA nodes from + # quantization so no QDQ is inserted around them. + nodes_to_exclude=gqa_nodes, ) result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) if not result.success: diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index b9709cc0e..6132be599 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -68,6 +68,11 @@ class WinMLQuantizationConfig: # Quantization options per_channel: bool = False symmetric: bool = False + # Optional per-target symmetry overrides. When None, fall back to + # ``symmetric``. Lets w8a16 use symmetric weights (int8, zp=0) together + # with asymmetric activations (uint16). + weight_symmetric: bool | None = None + activation_symmetric: bool | None = None # Output settings save_calibration: bool = False @@ -98,6 +103,8 @@ def to_dict(self) -> dict: "activation_type": self.activation_type, "per_channel": self.per_channel, "symmetric": self.symmetric, + "weight_symmetric": self.weight_symmetric, + "activation_symmetric": self.activation_symmetric, "save_calibration": self.save_calibration, "distribution": self.distribution, "seed": self.seed, @@ -139,6 +146,8 @@ def from_dict(cls, data: dict) -> WinMLQuantizationConfig: activation_type=data.get("activation_type", "uint8"), per_channel=data.get("per_channel", False), symmetric=data.get("symmetric", False), + weight_symmetric=data.get("weight_symmetric"), + activation_symmetric=data.get("activation_symmetric"), save_calibration=data.get("save_calibration", False), distribution=data.get("distribution", "uniform"), seed=data.get("seed"), diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index c562599de..e5fd30df3 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -132,10 +132,23 @@ def quantize_onnx( activation_type = activation_type_map[config.activation_type] calibrate_method = calibration_method_map[config.calibration_method] - # Build extra options + # Build extra options. Weight/activation symmetry can be controlled + # independently (e.g. w8a16 = symmetric int8 weights + asymmetric + # uint16 activations); fall back to the single ``symmetric`` flag when + # the per-target override is unset. + weight_symmetric = ( + config.weight_symmetric + if config.weight_symmetric is not None + else config.symmetric + ) + activation_symmetric = ( + config.activation_symmetric + if config.activation_symmetric is not None + else config.symmetric + ) extra_options = { - "ActivationSymmetric": config.symmetric, - "WeightSymmetric": config.symmetric, + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, } # Step 1: Generate QDQ config From 95d45d9ad9a9baab2576e2b88d7c3999a60ca3f4 Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 22 Jun 2026 14:51:23 -0700 Subject: [PATCH 04/13] refactor(qwen): register transformer-only path as a declarative model_type variant --- qwen3_transformer_only_quantize.py | 7 +- src/winml/modelkit/build/hf.py | 4 + src/winml/modelkit/loader/config.py | 13 +++ src/winml/modelkit/loader/hf.py | 13 +++ src/winml/modelkit/models/auto.py | 16 +++- src/winml/modelkit/models/hf/__init__.py | 10 ++ .../models/hf/qwen_transformer_only.py | 93 +++++++++---------- test_qwen.py | 8 +- 8 files changed, 105 insertions(+), 59 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 3ae895ae2..0b90c8bd0 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -1,7 +1,7 @@ """Transformer-only w8a16 quantization for Qwen3. -Targets the transformer-only ONNX produced by -``qwen_transformer_only.install() + test_qwen.py``: +Targets the transformer-only ONNX produced by the +``qwen3_transformer_only`` build variant (see ``test_qwen.py``): - **No embedding/lm_head surgery.** The export already excludes both, so we feed ``WinMLQuantization`` the file directly. @@ -24,6 +24,7 @@ from winml.modelkit.models.winml.composite_model import WinMLCompositeModel from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx +from winml.modelkit.quant.config import CalibrationDataReader logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def _load_gsm8k_prompts(num_samples: int) -> list[str]: return [row["question"] for row in split.select(range(num_samples))] -class Qwen3TransformerOnlyCalibReader: +class Qwen3TransformerOnlyCalibReader(CalibrationDataReader): """Yields calibration feeds for the transformer-only ONNX. Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 26356a6eb..dc2661afa 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -91,6 +91,7 @@ def build_hf_model( cache_key: str | None = None, ep: EPNameOrAlias | None = None, device: str | None = None, + model_type: str | None = None, **kwargs: Any, ) -> BuildResult: """Build an ONNX model from a HuggingFace model architecture. @@ -208,6 +209,7 @@ def _name(base: str) -> str: model_id, trust_remote_code, random_init=random_init, + model_type=model_type, ) # ========================================================================= @@ -436,6 +438,7 @@ def _load_model( trust_remote_code: bool, random_init: bool = False, hf_config: Any | None = None, + model_type: str | None = None, ) -> Any: """Load PyTorch model — pretrained or random weights. @@ -511,6 +514,7 @@ def _load_model( task=task, trust_remote_code=effective_trust, hf_config=hf_config, + model_type=model_type, ) return pytorch_model diff --git a/src/winml/modelkit/loader/config.py b/src/winml/modelkit/loader/config.py index cb6cb9af1..b533c1636 100644 --- a/src/winml/modelkit/loader/config.py +++ b/src/winml/modelkit/loader/config.py @@ -218,6 +218,19 @@ def resolve_loader_config( f"attribute. Cannot proceed with config generation." ) + # Explicit model_type override alongside a model_id: honor the requested + # type so downstream class / build-config / export resolution selects the + # variant (e.g. "qwen3_transformer_only") rather than the architecture's + # native type. The model_type-only path above (AutoConfig.for_model) is + # unaffected because it only runs when model_id is None. + if model_id is not None and model_type is not None and hf_config.model_type != model_type: + logger.info( + "Overriding resolved model_type '%s' -> '%s' (explicit request)", + hf_config.model_type, + model_type, + ) + hf_config.model_type = model_type + # 2. Infer task (depends on: model_type param or hf_config.architectures) if task is None and model_type is not None: supported = get_supported_tasks(model_type, library_name=library_name) diff --git a/src/winml/modelkit/loader/hf.py b/src/winml/modelkit/loader/hf.py index 5a90b5828..7c40c5fee 100644 --- a/src/winml/modelkit/loader/hf.py +++ b/src/winml/modelkit/loader/hf.py @@ -150,6 +150,7 @@ def load_hf_model( user_script: str | None = None, trust_remote_code: bool = False, hf_config: PretrainedConfig | None = None, + model_type: str | None = None, ) -> tuple[nn.Module, PretrainedConfig, str]: """Load, detect task, and prepare HuggingFace model. @@ -224,6 +225,18 @@ def load_hf_model( trust_remote_code=trust_remote_code, ) + # Explicit model_type override: select a registered build variant (e.g. + # "qwen3_transformer_only") rather than the architecture's native type. + # Mutates the freshly-loaded config only; gated on an explicit request so + # normal loading is unaffected. + if model_type is not None and getattr(hf_config, "model_type", None) != model_type: + logger.info( + "Overriding model_type '%s' -> '%s' (explicit request)", + getattr(hf_config, "model_type", None), + model_type, + ) + hf_config.model_type = model_type + # [2] Task & Model Class Resolution if user_script is not None: resolved_class = _load_class_from_script(user_script, model_class) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 78f944b36..4767b97db 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -247,6 +247,7 @@ def from_pretrained( trust_remote_code: bool = False, shape_config: dict | None = None, no_compile: bool = False, + model_type: str | None = None, **kwargs: Any, ) -> WinMLPreTrainedModel: """Load appropriate WinML model based on task detection. @@ -278,6 +279,10 @@ def from_pretrained( shape_config: Shape overrides passed to generate_build_config(). Valid keys -- text: sequence_length; vision: height, width; audio: feature_size, nb_max_frames, audio_sequence_length. + model_type: Explicit model_type override. When provided alongside a + HF model_id, selects a registered build variant (e.g. + ``"qwen3_transformer_only"``) instead of the architecture's + native model_type. Leave ``None`` for normal auto-detection. **kwargs: Additional arguments Returns: @@ -334,6 +339,11 @@ def from_pretrained( else: _model_type = None + # Explicit override wins so a variant composite (e.g. + # "qwen3_transformer_only") can be selected over the native type. + if model_type is not None: + _model_type = model_type + if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY: from .winml.composite_model import WinMLCompositeModel @@ -368,6 +378,7 @@ def from_pretrained( trust_remote_code=trust_remote_code, ep=kwargs.get("ep"), no_compile=no_compile, + model_type=model_type, ) resolved_task = build_config.loader.task @@ -402,7 +413,9 @@ def from_pretrained( from transformers import AutoConfig hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=effective_trust) - model_type = getattr(hf_config, "model_type", "unknown") + # Honor an explicit model_type override; otherwise probe from the config. + if model_type is None: + model_type = getattr(hf_config, "model_type", "unknown") logger.debug("Model type: %s, task: %s", model_type, resolved_task) # ===================================================================== @@ -431,6 +444,7 @@ def from_pretrained( cache_key=cache_key, ep=resolved_ep, device=device, + model_type=model_type, ) onnx_path = result.final_onnx_path diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index c6f4c9520..0d2e538a3 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -56,6 +56,14 @@ from .qwen import QWEN_CONFIG from .qwen import QwenGenIOConfig as _QwenGenIOConfig from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig +from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING +from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG +from .qwen_transformer_only import ( + QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration +) +from .qwen_transformer_only import ( + QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, # triggers registration +) from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING @@ -92,6 +100,7 @@ **_MARIAN_CLASS_MAPPING, **_MU2_CLASS_MAPPING, **_QWEN_CLASS_MAPPING, + **_QWEN_TO_CLASS_MAPPING, **_SAM2_CLASS_MAPPING, **_SEGFORMER_CLASS_MAPPING, **_SIGLIP_CLASS_MAPPING, @@ -115,6 +124,7 @@ "roberta": ROBERTA_FAMILY_CONFIG, "mu2": MU2_CONFIG, "qwen3": QWEN_CONFIG, + "qwen3-transformer-only": QWEN_TRANSFORMER_ONLY_CONFIG, "siglip": SIGLIP_CONFIG, "siglip-text-model": SIGLIP_CONFIG, "siglip-vision-model": SIGLIP_CONFIG, diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 8e30b1fb6..614267df4 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -2,12 +2,17 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Parallel ``qwen3`` build path that produces a transformer-only ONNX. +"""Transformer-only ``qwen3`` build variant, registered as a distinct model_type. -Opt-in via ``install()`` — calling it hot-patches the WinML registries so -that the next ``WinMLAutoModel.from_pretrained("Qwen/Qwen3-*", task="text-generation")`` -exports two transformer-only ONNX files (a prefill/context graph and an -iteration/decode graph) with this I/O: +This module registers a self-contained build path under the model_type +``"qwen3_transformer_only"`` (distinct from the stock ``"qwen3"`` path in +``qwen.py``). Selecting it is explicit — pass ``model_type="qwen3_transformer_only"`` +to ``WinMLAutoModel.from_pretrained(...)`` (or the underlying +``generate_hf_build_config(...)``). Both paths coexist; neither overrides the +other, and there is no import-ordering requirement. + +The variant exports two transformer-only ONNX files (a prefill/context graph +and an iteration/decode graph) with this I/O: Inputs : past_keys_{i}, past_values_{i} (FP16, ``[1, kv_heads, max_cache, head_dim]``), input_hidden_states (FP32, ``[1, seq_len, hidden]``), @@ -16,8 +21,9 @@ Ops : ``com.microsoft::GroupQueryAttention`` (do_rotary=1), ``onnx::LpNormalization`` (RMSNorm), 1x1 ``Conv`` projections. -The original eager-export path in ``qwen.py`` is left intact — only the -qwen3 entries in the registries are replaced. ``install()`` is idempotent. +Registration happens at import time via decorators and module-level mappings, +mirroring ``qwen.py``. The aggregating ``models.hf`` package imports this +module so the entries land in ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS``. """ from __future__ import annotations @@ -36,6 +42,7 @@ from ...export import register_onnx_overwrite from ...export.config import WinMLExportConfig from ..winml import register_specialization +from ..winml.composite_model import register_composite_model from ..winml.decoder_only import WinMLDecoderOnlyModel from ..winml.kv_cache import WinMLSlidingWindowCache from .qwen3_export_ops import apply_transformer_only_export_prep @@ -43,6 +50,13 @@ logger = logging.getLogger(__name__) +# Distinct model_type for this variant. The underscore form is what the +# exporter sees on ``model.config.model_type`` and what Optimum's TasksManager +# and ``register_specialization`` are keyed on; the hyphenated form is used for +# the ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS`` lookups (those callers +# normalize ``_`` -> ``-``). +TRANSFORMER_ONLY_MODEL_TYPE = "qwen3_transformer_only" + # ============================================================================= # Wrapper module @@ -65,6 +79,10 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.num_layers = num_layers self.config = model.config apply_transformer_only_export_prep(model, matmul_to_conv=True) + # Tag the config so the exporter resolves this variant's OnnxConfig + # (registered under ``TRANSFORMER_ONLY_MODEL_TYPE``) rather than the + # stock qwen3 one. Mirrors the CLIP/zoedepth sub-model precedent. + self.config.model_type = TRANSFORMER_ONLY_MODEL_TYPE @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: @@ -222,6 +240,9 @@ def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") return result +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", library_name="transformers" +) class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): """Prefill (seq=64) — transformer-only I/O.""" @@ -241,6 +262,9 @@ def outputs(self) -> dict[str, dict[int, str]]: return _transformer_only_outputs(self._normalized_config.num_layers) +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", library_name="transformers" +) class QwenTransformerOnlyGenIOConfig(OnnxConfig): """Generation (seq=1) — transformer-only I/O.""" @@ -279,6 +303,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ============================================================================= +@register_composite_model(TRANSFORMER_ONLY_MODEL_TYPE, "text-generation") class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): """Composite handle for the transformer-only Qwen3 build (export only). @@ -299,56 +324,28 @@ def get_cache_class(cls) -> type: # ============================================================================= -# install() — hot-patch the registries +# Declarative registration (import-time) # ============================================================================= +# Wrapper-class lookup keyed by (model_type, task). Keys use the hyphenated +# model_type form because ``_get_custom_model_class`` normalizes ``_`` -> ``-`` +# before lookup. Merged into the aggregate mapping by ``models.hf.__init__``. +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("qwen3-transformer-only", "feature-extraction"): QwenTransformerOnlyDecoderWrapper, + ("qwen3-transformer-only", "text2text-generation"): QwenTransformerOnlyDecoderWrapper, +} -_INSTALLED = False - - -def install() -> None: - """Replace the qwen3 entries in WinML registries with the transformer-only variants. - - Idempotent. After this call, building any qwen3 model via - :class:`~winml.modelkit.models.winml.composite_model.WinMLCompositeModel` - or :class:`~winml.modelkit.models.auto.WinMLAutoModel` produces - transformer-only ONNX files. - """ - global _INSTALLED - if _INSTALLED: - return - - # 1) Per-model build config + wrapper-class lookup live on the parent - # ``models.hf`` package as module-level dicts; mutating them is the - # documented hook for adding/overriding a model_type. - from .. import hf as _hf_pkg # noqa: PLC0415 - - _hf_pkg.MODEL_BUILD_CONFIGS["qwen3"] = QWEN_TRANSFORMER_ONLY_CONFIG - _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "feature-extraction")] = QwenTransformerOnlyDecoderWrapper - _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "text2text-generation")] = QwenTransformerOnlyDecoderWrapper - - # 2) Optimum OnnxConfig (overwrites existing registration). - register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers")(QwenTransformerOnlyPrefillIOConfig) - register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers")(QwenTransformerOnlyGenIOConfig) - - # 3) Inference specialization (still GenericTask — wrapper returns raw KV). - register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") - register_specialization("qwen3", "text2text-generation", "WinMLModelForGenericTask") - - # 4) Composite registry — swap to the transformer-only handle. - from ..winml.composite_model import COMPOSITE_MODEL_REGISTRY - - COMPOSITE_MODEL_REGISTRY[("qwen3", "text-generation")] = WinMLQwen3TransformerOnlyModel - - _INSTALLED = True - logger.info("qwen_transformer_only: transformer-only export path installed for qwen3.") +# Inference specialization (GenericTask — the wrapper returns raw hidden states / KV). +register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask") +register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask") __all__ = [ + "MODEL_CLASS_MAPPING", "QWEN_TRANSFORMER_ONLY_CONFIG", + "TRANSFORMER_ONLY_MODEL_TYPE", "QwenTransformerOnlyDecoderWrapper", "QwenTransformerOnlyGenIOConfig", "QwenTransformerOnlyPrefillIOConfig", "WinMLQwen3TransformerOnlyModel", - "install", ] diff --git a/test_qwen.py b/test_qwen.py index f958c2932..da23f4481 100644 --- a/test_qwen.py +++ b/test_qwen.py @@ -10,9 +10,6 @@ ``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv projections. -Important: ``install()`` MUST be called before importing the composite model -machinery so the registry hot-patches take effect. - Generation (``model.generate(...)``) is NOT supported by this build path — the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end @@ -68,16 +65,13 @@ def _build_one(task: str, seq_len: int) -> None: clean interpreter — building both in one process leaves PyTorch/ORT state from the first build that corrupts/kills the second. """ - from winml.modelkit.models.hf.qwen_transformer_only import install as install_qwen_transformer_only - - install_qwen_transformer_only() - from winml.modelkit.config import WinMLBuildConfig from winml.modelkit.models.auto import WinMLAutoModel WinMLAutoModel.from_pretrained( model_id, task=task, + model_type="qwen3_transformer_only", config=WinMLBuildConfig(quant=None, compile=None), precision="fp16", device="npu", From 9cecb03913d9d19a10e1d2c27934414bcdd5ee3f Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 23 Jun 2026 11:52:52 -0700 Subject: [PATCH 05/13] fix(qwen): calibrate transformer-only decode model on real trajectory --- qwen3_transformer_only_quantize.py | 170 +++++++++++++++++++++++++++-- 1 file changed, 163 insertions(+), 7 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 0b90c8bd0..81bcb780f 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -34,6 +34,7 @@ DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 DEFAULT_NUM_SAMPLES = 30 +DEFAULT_DECODE_STEPS = 16 DEFAULT_CALIB_DATASET = "openai/gsm8k" DEFAULT_CALIB_DATASET_CONFIG = "main" DEFAULT_CALIB_SPLIT = "train" @@ -119,6 +120,140 @@ def rewind(self) -> None: self._iter = iter(self._samples) +def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: + """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. + + Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` + (``.key_cache`` / ``.value_cache``), and the newer per-layer + ``DynamicCache`` (``.layers[i].keys`` / ``.values``). + """ + if hasattr(past, "key_cache") and hasattr(past, "value_cache"): + return past.key_cache[i], past.value_cache[i] + if hasattr(past, "layers"): + layer = past.layers[i] + return layer.keys, layer.values + return past[i][0], past[i][1] + + +class Qwen3DecodeTrajectoryCalibReader(CalibrationDataReader): + """Calibrate the iter (seq_len=1) model on REAL decode-step states. + + The naive reader feeds one (repeated) token with a zeroed KV cache and + ``past_seq_len=0`` — a state the model never sees during generation. With + MinMax calibration this collapses the observed activation ranges far below + the real decode distribution, so the resulting w8a16 model degenerates + (e.g. ``Paris -> Parisammedammed...``). + + Instead, drive the HF FP reference model through a real prefill + decode + trajectory and capture, at each decode step, the exact feed the iter ONNX + would receive: the embedding of the *actually generated* token, the real + accumulated KV cache (copied into the fixed ``[1, kv_heads, max_cache, + head_dim]`` FP16 buffer), and the growing ``past_seq_len``. Token + selection uses the HF model's true logits, so the trajectory matches + greedy generation. The QDQ scheme is unchanged — only the calibration + statistics become representative. + """ + + def __init__( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + max_cache_len: int, + decode_steps: int = 16, + ) -> None: + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = max_cache_len + self._samples = list( + self._build_samples( + hf_model, + embed_tokens, + token_ids_list, + prefill_seq=prefill_seq, + decode_steps=decode_steps, + ) + ) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _kv_buffers(self, past: Any, cur_len: int) -> dict[str, np.ndarray]: + """Copy the ``cur_len`` valid KV positions into fixed FP16 buffers.""" + feed: dict[str, np.ndarray] = {} + for i in range(self.num_layers): + k, v = _layer_kv(past, i) + kbuf = np.zeros( + (1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16 + ) + vbuf = np.zeros_like(kbuf) + kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + feed[f"past_keys_{i}"] = kbuf + feed[f"past_values_{i}"] = vbuf + return feed + + def _build_samples( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + decode_steps: int, + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + ids = ids[:, :prefill_seq] # real prompt prefix (no pad-token KV) + cur_len = ids.shape[1] + + # FP prefill once to seed a realistic KV cache + first token. + with torch.no_grad(): + out = hf_model(input_ids=ids, use_cache=True) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + + for _ in range(decode_steps): + if cur_len >= self.max_cache_len: + break + # The feed the iter model sees for THIS token: embedding of the + # token to process, the KV of the `cur_len` preceding tokens, + # and seqlens_k = (cur_len + 1) - 1 = cur_len. + with torch.no_grad(): + emb = embed_tokens(torch.tensor([[tok]])).to(torch.float32).cpu().numpy() + feed: dict[str, np.ndarray] = { + "input_hidden_states": emb.astype(np.float32), + "past_seq_len": np.array([[cur_len]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), + } + feed.update(self._kv_buffers(past, cur_len)) + yield feed + + # Advance the reference model one real decode step. + with torch.no_grad(): + out = hf_model( + input_ids=torch.tensor([[tok]]), + past_key_values=past, + use_cache=True, + ) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + cur_len += 1 + + def get_next(self) -> dict[str, np.ndarray] | None: + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + self._iter = iter(self._samples) + + def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: out: list[torch.Tensor] = [] for i in range(num_samples): @@ -160,6 +295,7 @@ def quantize_built_model( num_samples: int = DEFAULT_NUM_SAMPLES, weight_type: str = "int8", activation_type: str = "uint16", + decode_steps: int = DEFAULT_DECODE_STEPS, ) -> dict[str, Path]: """Quantize the transformer-only ONNX files in-place. @@ -223,13 +359,33 @@ def quantize_built_model( f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " "quantization (inputs + output stay float, Cast -> GQA -> Cast)" ) - reader = Qwen3TransformerOnlyCalibReader( - embed_tokens, - hf_model.config, - token_ids_list, - seq_len=seq_len, - max_cache_len=max_cache_len, - ) + if sub_name == "decoder_gen": + # The iter model only sees mid-generation states. Calibrate it on a + # real prefill+decode trajectory (true tokens, accumulated KV, + # growing past_seq_len) instead of one token + zeroed KV, which + # would under-range the MinMax activation scales and collapse + # generation. + print( + f" calibrating on decode trajectory ({decode_steps} steps/prompt, " + f"prefill_seq={prefill_seq})" + ) + reader: CalibrationDataReader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed_tokens, + hf_model.config, + token_ids_list, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + else: + reader = Qwen3TransformerOnlyCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) cfg = WinMLQuantizationConfig( samples=num_samples, weight_type=weight_type, # type: ignore[arg-type] From 08f05d7c399dafe2f60dfccf3cbd3348355ab721 Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 23 Jun 2026 13:18:25 -0700 Subject: [PATCH 06/13] Fixed small bugs --- qwen3_transformer_only_quantize.py | 18 +++- .../modelkit/models/hf/qwen3_export_ops.py | 81 +++----------- .../modelkit/models/hf/qwen3_modeling.py | 101 ++++++++++++++++-- .../models/hf/qwen_transformer_only.py | 2 +- test_qwen.py | 8 +- 5 files changed, 132 insertions(+), 78 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 81bcb780f..559620973 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +import gc from pathlib import Path from typing import Any, Iterator @@ -40,6 +41,16 @@ DEFAULT_CALIB_SPLIT = "train" DEFAULT_CALIB_SEED = 42 +# Map an ONNX quantization dtype to the bit-width suffix used in artifact +# filenames (e.g. int8 -> "8", uint16 -> "16"), instead of brittle string +# slicing of the dtype name. +_DTYPE_BITS = { + "int8": "8", + "uint8": "8", + "int16": "16", + "uint16": "16", +} + def _load_gsm8k_prompts(num_samples: int) -> list[str]: """GSM8K train split, shuffled seed=42 for reproducible calibration.""" @@ -348,7 +359,8 @@ def quantize_built_model( seq_len = seq_by_sub[sub_name] quant_path = fused_path.with_name( - fused_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + fused_path.stem + + f"_w{_DTYPE_BITS[weight_type]}a{_DTYPE_BITS[activation_type]}.quant.onnx" ) print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") @@ -414,5 +426,9 @@ def quantize_built_model( ) quant_paths[sub_name] = quant_path + # Free the FP reference model now that calibration is done. + del hf_model, embed_tokens + gc.collect() + print("\n=== Done ===") return quant_paths diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py index 61d45f0ef..5fd3edb68 100644 --- a/src/winml/modelkit/models/hf/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -46,7 +46,12 @@ def symbolic(g, input, axis, p): # noqa: D401 @staticmethod def forward(ctx, input, axis, p): # noqa: ARG004 - return input # placeholder — real compute happens in symbolic + # Shape-only tracing placeholder. The real op is emitted by + # ``symbolic`` during ONNX export; ``forward`` exists solely so the + # TorchScript exporter (and Optimum's pre-export dry run) can trace + # output shapes. It returns ``input`` unchanged on purpose and is NOT a + # correct eager RMSNorm — do not call this module for real inference. + return input class GroupQueryAttentionOnnxExport(torch.autograd.Function): @@ -100,6 +105,12 @@ def forward( kv_num_heads, num_heads, ): # noqa: ARG004 + # Shape-only tracing placeholder. The real op is emitted by + # ``symbolic`` during ONNX export; ``forward`` exists solely so the + # TorchScript exporter (and Optimum's pre-export dry run) can trace + # output shapes. It returns the inputs as stand-in present-KV on + # purpose and is NOT correct attention — do not call this module for + # real inference. return query, past_key, past_value # placeholder shapes @@ -136,76 +147,8 @@ def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) -# ============================================================================= -# Apply export prep: bind winml Qwen3 export methods onto a loaded model -# ============================================================================= - - -def apply_transformer_only_export_prep(causal_lm: nn.Module, *, matmul_to_conv: bool = True) -> None: - """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. - - Binds the winml-owned export behaviour from :mod:`.qwen3_modeling` onto each - Qwen3 submodule (runs ``prepare_for_onnx_export`` and rebinds ``forward``). - After this call, ``causal_lm.model(inputs_embeds, past_key_values, - past_seq_len, total_seq_len)`` runs the transformer-only forward. - - Args: - causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. - matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so - QNN sees them as Conv. - """ - from .qwen3_modeling import ( - WinMLQwen3Attention, - WinMLQwen3DecoderLayer, - WinMLQwen3MLP, - WinMLQwen3Model, - WinMLQwen3RMSNorm, - ) - - def _bind(module: nn.Module, owner: type) -> None: - module.forward = owner.forward.__get__(module, type(module)) - - # Identify Qwen3 submodules by their (stock HF) class name so we don't - # depend on importing ``transformers.models.qwen3`` here. - def _is(module: nn.Module, name: str) -> bool: - return type(module).__name__ == name - - # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, - # in input/post_attention layernorms). - for mod in causal_lm.modules(): - if _is(mod, "Qwen3RMSNorm"): - WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) - _bind(mod, WinMLQwen3RMSNorm) - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Attention"): - WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - _bind(mod, WinMLQwen3Attention) - elif _is(mod, "Qwen3MLP"): - # MLP forward is unchanged; only the projections are swapped to Conv. - WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - - # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; - # the export forward invokes ``self.rotary_emb`` on the attention module, - # so re-attach a reference from the parent model. - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): - for layer in mod.layers: - layer.self_attn.rotary_emb = mod.rotary_emb - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3DecoderLayer"): - _bind(mod, WinMLQwen3DecoderLayer) - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Model"): - WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - _bind(mod, WinMLQwen3Model) - - __all__ = [ "GroupQueryAttentionOnnxExport", "LpNormOnnxExport", "TransposeConv2d1x1Transpose", - "apply_transformer_only_export_prep", ] diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py index 05a70adfe..d3c538df5 100644 --- a/src/winml/modelkit/models/hf/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -18,7 +18,7 @@ - ``WinMLQwen3DecoderLayer`` / ``WinMLQwen3Model`` -> transformer-only forward that threads the KV cache + seq-len tensors and omits embeddings / lm_head. -``apply_transformer_only_export_prep`` (in ``qwen3_export_ops``) walks a loaded +``apply_transformer_only_export_prep`` (defined below) walks a loaded ``Qwen3ForCausalLM``, calls ``prepare_for_onnx_export`` on each submodule, and binds the matching ``forward`` from these classes onto it. """ @@ -42,15 +42,14 @@ class WinMLQwen3RMSNorm(nn.Module): def prepare_for_onnx_export(self) -> None: # Pre-multiply the gain into the weight (LpNorm has unit gain). + # ``scale`` is shape ``[1]`` and broadcasts over ``self.weight`` + # (shape ``[hidden_size]``), so the result keeps the per-channel + # shape even when the original weights are all ones. n = self.weight.numel() scale = torch.sqrt( torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) ) - if torch.any(self.weight.data != torch.ones_like(self.weight)).item(): - new_w = scale * self.weight - else: - new_w = scale - self.weight = nn.Parameter(new_w) + self.weight = nn.Parameter(scale * self.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: out = LpNormOnnxExport.apply(hidden_states, -1, 2) @@ -228,10 +227,100 @@ def forward( return hidden_states, present_kvs +# ============================================================================= +# Apply export prep: bind winml Qwen3 export methods onto a loaded model +# ============================================================================= + + +def apply_transformer_only_export_prep( + causal_lm: nn.Module, *, matmul_to_conv: bool = True +) -> None: + """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. + + Binds the winml-owned export behaviour (the ``WinMLQwen3*`` classes in this + module) onto each Qwen3 submodule (runs ``prepare_for_onnx_export`` and + rebinds ``forward``). After this call, ``causal_lm.model(inputs_embeds, + past_key_values, past_seq_len, total_seq_len)`` runs the transformer-only + forward. + + Args: + causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. + matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so + QNN sees them as Conv. + + Raises: + RuntimeError: If any expected Qwen3 submodule class is not found, + meaning the loaded model does not match the expected topology + (e.g. the stock HF class names changed). + """ + + def _bind(module: nn.Module, owner: type) -> None: + module.forward = owner.forward.__get__(module, type(module)) + + # Identify Qwen3 submodules by their (stock HF) class name so we don't + # depend on importing ``transformers.models.qwen3`` here. + def _is(module: nn.Module, name: str) -> bool: + return type(module).__name__ == name + + patched = { + "Qwen3RMSNorm": 0, + "Qwen3Attention": 0, + "Qwen3MLP": 0, + "Qwen3DecoderLayer": 0, + "Qwen3Model": 0, + } + + # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, + # in input/post_attention layernorms). + for mod in causal_lm.modules(): + if _is(mod, "Qwen3RMSNorm"): + WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + _bind(mod, WinMLQwen3RMSNorm) + patched["Qwen3RMSNorm"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Attention"): + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Attention) + patched["Qwen3Attention"] += 1 + elif _is(mod, "Qwen3MLP"): + # MLP forward is unchanged; only the projections are swapped to Conv. + WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + patched["Qwen3MLP"] += 1 + + # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; + # the export forward invokes ``self.rotary_emb`` on the attention module, + # so re-attach a reference from the parent model. + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): + for layer in mod.layers: + layer.self_attn.rotary_emb = mod.rotary_emb + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3DecoderLayer"): + _bind(mod, WinMLQwen3DecoderLayer) + patched["Qwen3DecoderLayer"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model"): + WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Model) + patched["Qwen3Model"] += 1 + + missing = [name for name, count in patched.items() if count == 0] + if missing: + raise RuntimeError( + "transformer-only export prep found no " + f"{missing} submodule(s) to patch; the loaded model does not match " + "the expected Qwen3 topology (stock HF class names may have changed)." + ) + + __all__ = [ "WinMLQwen3Attention", "WinMLQwen3DecoderLayer", "WinMLQwen3MLP", "WinMLQwen3Model", "WinMLQwen3RMSNorm", + "apply_transformer_only_export_prep", ] diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 614267df4..6ac9d0852 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -45,7 +45,7 @@ from ..winml.composite_model import register_composite_model from ..winml.decoder_only import WinMLDecoderOnlyModel from ..winml.kv_cache import WinMLSlidingWindowCache -from .qwen3_export_ops import apply_transformer_only_export_prep +from .qwen3_modeling import apply_transformer_only_export_prep logger = logging.getLogger(__name__) diff --git a/test_qwen.py b/test_qwen.py index da23f4481..14cf4656d 100644 --- a/test_qwen.py +++ b/test_qwen.py @@ -17,7 +17,7 @@ Run:: - python test_qwen_transformer_only.py + python test_qwen.py This builds each transformer sub-model and then runs the w8a16 quantization on the exported transformer ONNX files (no surgery needed — @@ -85,6 +85,8 @@ def _build_one(task: str, seq_len: int) -> None: print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) sys.stdout.flush() sys.stderr.flush() + # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT + # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. os._exit(0) @@ -155,6 +157,8 @@ def _run_quant() -> None: print("QUANT COMPLETE", flush=True) sys.stdout.flush() sys.stderr.flush() + # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT + # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. os._exit(0) @@ -173,6 +177,7 @@ def main() -> None: [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--build-sub", task, str(seq_len)], cwd=str(_repo_root), + timeout=1800, ).returncode after = _latest_ctx_mtime(prefix) @@ -205,6 +210,7 @@ def main() -> None: rc = subprocess.run( [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], cwd=str(_repo_root), + timeout=1800, ).returncode after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) after = max((p.stat().st_mtime for p in after_files), default=0.0) From 818cfe47fbf30381c67cc7a7fefb99dd04edc509 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 11:53:25 +0800 Subject: [PATCH 07/13] refactor(qwen): config-driven transformer-only quant + pytest Replace the standalone root-level quant driver and __main__/subprocess test runner with the regular build pipeline and pytest. - Move calibration logic into src/.../hf/qwen_transformer_only_quant.py; the decode wrapper exposes winml_finalize_quant_config, invoked generically from build/hf.py just before quantize_onnx. The build now quantizes via precision=w8a16 + config.quant instead of a separate script. - The hook reads seq_len / max_cache / GQA node names from the exported ONNX and selects the prefill vs decode-trajectory calibration reader, keeping the verified-good scheme (int8-symmetric weights, uint16 activations, minmax, GQA excluded from QDQ). - Delete root qwen3_transformer_only_quantize.py and test_qwen.py. - Add tests/unit/models/qwen_transformer_only (fast, offline) and tests/e2e/models/test_qwen3_transformer_only_quant.py (build+quant+decode-parity, QNN-gated NPU). --- src/winml/modelkit/build/hf.py | 8 + .../models/hf/qwen_transformer_only.py | 92 ++++- .../models/hf/qwen_transformer_only_quant.py | 387 +++++++++--------- test_qwen.py | 235 ----------- .../test_qwen3_transformer_only_quant.py | 248 +++++++++++ .../models/qwen_transformer_only/__init__.py | 4 + .../test_quant_calibration.py | 234 +++++++++++ 7 files changed, 753 insertions(+), 455 deletions(-) rename qwen3_transformer_only_quantize.py => src/winml/modelkit/models/hf/qwen_transformer_only_quant.py (60%) delete mode 100644 test_qwen.py create mode 100644 tests/e2e/models/test_qwen3_transformer_only_quant.py create mode 100644 tests/unit/models/qwen_transformer_only/__init__.py create mode 100644 tests/unit/models/qwen_transformer_only/test_quant_calibration.py diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index dc2661afa..4dcf09b5b 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -310,6 +310,14 @@ def _name(base: str) -> str: else: logger.info("Quantizing model...") t0 = time.monotonic() + # Some model wrappers can only finalize their quant config once the + # exported ONNX exists (e.g. calibration feeds / nodes-to-exclude + # derived from the graph). Give the wrapper a chance to populate + # those runtime-only fields here. + if pytorch_model is not None and hasattr(pytorch_model, "winml_finalize_quant_config"): + config.quant = pytorch_model.winml_finalize_quant_config( + config.quant, onnx_path=current_path, model_id=model_id + ) quant_result = quantize_onnx( model_path=current_path, output_path=quantized_path, diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 6ac9d0852..fda69495f 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -85,7 +85,10 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.config.model_type = TRANSFORMER_ONLY_MODEL_TYPE @classmethod - def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: + def from_pretrained( + cls, model_name_or_path: str, **kwargs: Any + ) -> QwenTransformerOnlyDecoderWrapper: + """Load the HF model and wrap it for transformer-only export.""" kwargs.setdefault("torch_dtype", torch.float32) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) model.config._attn_implementation = "eager" @@ -94,24 +97,23 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransfor return wrapper def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Flatten the dummy-input dict into positional export args.""" return tuple(inputs.values()) def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: - """Positional inputs (order matches OnnxConfig.inputs): + """Run the decoder stack on positional inputs (order matches OnnxConfig.inputs). - past_keys_0, past_values_0, ..., past_keys_{L-1}, past_values_{L-1}, - input_hidden_states, past_seq_len, total_seq_len - - Returns ``(output_hidden_states, present_keys_0, present_values_0, ...)``. + Positional inputs are ``past_keys_0, past_values_0, ..., + past_keys_{L-1}, past_values_{L-1}, input_hidden_states, past_seq_len, + total_seq_len``. Returns ``(output_hidden_states, present_keys_0, + present_values_0, ...)``. """ kv_args = args[: 2 * self.num_layers] input_hidden_states = args[2 * self.num_layers] past_seq_len = args[2 * self.num_layers + 1] total_seq_len = args[2 * self.num_layers + 2] - past_key_values = [ - (kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers) - ] + past_key_values = [(kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers)] hidden_states, present_kvs = self.model.model( inputs_embeds=input_hidden_states, @@ -126,6 +128,27 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: out.extend([k, v]) return tuple(out) + def winml_finalize_quant_config( + self, quant: Any, *, onnx_path: Any, model_id: str | None = None + ) -> Any: + """Build-pipeline hook: attach the calibration reader + GQA exclusions. + + Called by ``build_hf_model`` just before ``quantize_onnx`` (see + ``build/hf.py``). The exported transformer-only graph determines the + calibration feeds (shapes, KV buffers) and which GroupQueryAttention + nodes stay in float, so the live :class:`WinMLQuantizationConfig` can + only be finalized here — not at config-construction time. + """ + from .qwen_transformer_only_quant import ( + DEFAULT_MODEL_ID, + finalize_transformer_only_quant_config, + ) + + resolved_id = model_id or getattr(self.config, "_name_or_path", None) or DEFAULT_MODEL_ID + return finalize_transformer_only_quant_config( + quant, onnx_path=onnx_path, model_id=resolved_id + ) + # ============================================================================= # Dummy input generators (transformer-only I/O) @@ -151,7 +174,13 @@ def __init__( self.hidden_size = normalized_config.hidden_size self.seq_len = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: if input_name == "input_hidden_states": return torch.randn(self.batch_size, self.seq_len, self.hidden_size, dtype=torch.float32) raise ValueError(f"Unknown input: {input_name}") @@ -166,10 +195,16 @@ class _TransformerOnlySeqLenGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ("past_seq_len", "total_seq_len") - def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: # noqa: ARG002 + def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: self.max_cache_len = normalized_config.max_cache_len - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: if input_name == "past_seq_len": return torch.zeros((1, 1), dtype=torch.int32) if input_name == "total_seq_len": @@ -192,14 +227,22 @@ def __init__( ) -> None: self.batch_size = batch_size self.num_layers: int = normalized_config.num_layers - self.num_heads: int = normalized_config.num_attention_heads # KV heads (NormalizedConfig maps it) + self.num_heads: int = ( + normalized_config.num_attention_heads + ) # KV heads (NormalizedConfig maps it) self.head_dim: int = normalized_config.head_dim self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len self.SUPPORTED_INPUT_NAMES = tuple( name for i in range(self.num_layers) for name in (f"past_keys_{i}", f"past_values_{i}") ) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: shape = (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim) return torch.zeros(shape, dtype=torch.float16) @@ -220,7 +263,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) -def _transformer_only_inputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: +def _transformer_only_inputs( + num_layers: int, kv_seq_axis: str = "max_seq_len" +) -> dict[str, dict[int, str]]: """Input ordering: past KV pairs, then hidden states, then seq lens.""" result: dict[str, dict[int, str]] = {} for i in range(num_layers): @@ -232,7 +277,9 @@ def _transformer_only_inputs(num_layers: int, kv_seq_axis: str = "max_seq_len") return result -def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: +def _transformer_only_outputs( + num_layers: int, kv_seq_axis: str = "max_seq_len" +) -> dict[str, dict[int, str]]: result: dict[str, dict[int, str]] = {"output_hidden_states": {1: "seq_len"}} for i in range(num_layers): result[f"present_keys_{i}"] = {2: kv_seq_axis} @@ -255,10 +302,12 @@ class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): @property def inputs(self) -> dict[str, dict[int, str]]: + """ONNX input axes (past KV pairs, hidden states, seq lengths).""" return _transformer_only_inputs(self._normalized_config.num_layers) @property def outputs(self) -> dict[str, dict[int, str]]: + """ONNX output axes (hidden states then present KV pairs).""" return _transformer_only_outputs(self._normalized_config.num_layers) @@ -277,10 +326,12 @@ class QwenTransformerOnlyGenIOConfig(OnnxConfig): @property def inputs(self) -> dict[str, dict[int, str]]: + """ONNX input axes (past KV pairs, hidden states, seq lengths).""" return _transformer_only_inputs(self._normalized_config.num_layers) @property def outputs(self) -> dict[str, dict[int, str]]: + """ONNX output axes (hidden states then present KV pairs).""" return _transformer_only_outputs(self._normalized_config.num_layers) @@ -320,6 +371,7 @@ class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): @classmethod def get_cache_class(cls) -> type: + """Return the KV-cache class used during generation.""" return WinMLSlidingWindowCache @@ -336,8 +388,12 @@ def get_cache_class(cls) -> type: } # Inference specialization (GenericTask — the wrapper returns raw hidden states / KV). -register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask") -register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask") +register_specialization( + TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask" +) +register_specialization( + TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask" +) __all__ = [ diff --git a/qwen3_transformer_only_quantize.py b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py similarity index 60% rename from qwen3_transformer_only_quantize.py rename to src/winml/modelkit/models/hf/qwen_transformer_only_quant.py index 559620973..f01de2f71 100644 --- a/qwen3_transformer_only_quantize.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py @@ -1,37 +1,58 @@ -"""Transformer-only w8a16 quantization for Qwen3. - -Targets the transformer-only ONNX produced by the -``qwen3_transformer_only`` build variant (see ``test_qwen.py``): - - - **No embedding/lm_head surgery.** The export already excludes both, - so we feed ``WinMLQuantization`` the file directly. - - **Transformer-shaped calibration feeds.** ``input_hidden_states`` (FP32), - ``past_seq_len`` / ``total_seq_len`` (INT32), ``past_keys_{i}`` / - ``past_values_{i}`` (FP16) — names + dtypes match the exported graph. - -Run via ``test_qwen.py``. +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Config-driven w8a16 calibration for the transformer-only Qwen3 build. + +The transformer-only export (:mod:`qwen_transformer_only`) emits a graph whose +only quantization-relevant runtime inputs (the calibration feeds and the +``GroupQueryAttention`` node names to keep in float) can't be known until the +ONNX exists. Rather than a standalone post-build script that reaches into +``composite.sub_models[...]._onnx_path``, this module plugs into the normal +build pipeline: :meth:`QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config` +calls :func:`finalize_transformer_only_quant_config` just before +``quantize_onnx`` runs (see ``build/hf.py``), populating the live +:class:`WinMLQuantizationConfig` with the right +:class:`~winml.modelkit.quant.config.CalibrationDataReader` and +``nodes_to_exclude``. + +The two readers match the exported graph exactly: + + - ``input_hidden_states`` (FP32), ``past_seq_len`` / ``total_seq_len`` + (INT32), ``past_keys_{i}`` / ``past_values_{i}`` (FP16, full cache buffer). + - The prefill reader (``seq_len > 1``) embeds real prompt prefixes. + - The decode reader (``seq_len == 1``) drives a fresh FP reference model + through a real prefill + decode trajectory so MinMax sees representative + mid-generation activation ranges (a single repeated token + zeroed KV + collapses the ranges and degenerates generation). + +The export wrapper surgically replaces its own ``self.model`` (RMSNorm -> +LpNorm-identity, attention -> GQA placeholder, Linear -> 1x1 Conv), so it can't +run real inference; calibration loads a *fresh* ``AutoModelForCausalLM``. """ from __future__ import annotations -import logging import gc +import logging from pathlib import Path -from typing import Any, Iterator +from typing import TYPE_CHECKING, Any import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from winml.modelkit.models.winml.composite_model import WinMLCompositeModel -from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx -from winml.modelkit.quant.config import CalibrationDataReader +from ...quant.config import CalibrationDataReader, WinMLQuantizationConfig + + +if TYPE_CHECKING: + from collections.abc import Iterator logger = logging.getLogger(__name__) DEFAULT_MODEL_ID = "Qwen/Qwen3-0.6B" -DEFAULT_MAX_CACHE = 256 DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 DEFAULT_NUM_SAMPLES = 30 @@ -41,16 +62,6 @@ DEFAULT_CALIB_SPLIT = "train" DEFAULT_CALIB_SEED = 42 -# Map an ONNX quantization dtype to the bit-width suffix used in artifact -# filenames (e.g. int8 -> "8", uint16 -> "16"), instead of brittle string -# slicing of the dtype name. -_DTYPE_BITS = { - "int8": "8", - "uint8": "8", - "int16": "16", - "uint16": "16", -} - def _load_gsm8k_prompts(num_samples: int) -> list[str]: """GSM8K train split, shuffled seed=42 for reproducible calibration.""" @@ -61,8 +72,79 @@ def _load_gsm8k_prompts(num_samples: int) -> list[str]: return [row["question"] for row in split.select(range(num_samples))] +def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for i in range(num_samples): + prompt = prompts[i % len(prompts)] + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + out.append(ids) + return out + + +def _gqa_node_names(onnx_path: Path) -> list[str]: + """Return the names of every GroupQueryAttention node in ``onnx_path``. + + These nodes are excluded from quantization so ORT leaves both their + inputs and output in float (``... -> Cast -> GQA -> Cast``), matching + the reference graph which keeps attention entirely out of QDQ. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + return [n.name for n in model.graph.node if n.op_type == "GroupQueryAttention" and n.name] + + +def _graph_shapes(onnx_path: Path) -> tuple[int, int]: + """Read ``(seq_len, max_cache_len)`` from the exported graph's static inputs. + + ``seq_len`` is the query length (``input_hidden_states`` dim 1) and + ``max_cache_len`` is the KV buffer length (``past_keys_0`` dim 2). The + transformer-only export keeps both axes static, so these fully determine + whether the sub-model is prefill (``seq_len > 1``) or decode (``seq_len == 1``) + and the size of the fixed KV buffers the calibration feeds must match. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + seq_len: int | None = None + max_cache_len: int | None = None + for inp in model.graph.input: + dims = inp.type.tensor_type.shape.dim + if inp.name == "input_hidden_states" and len(dims) >= 2: + seq_len = dims[1].dim_value + elif inp.name == "past_keys_0" and len(dims) >= 3: + max_cache_len = dims[2].dim_value + if seq_len is None or max_cache_len is None: + raise ValueError( + f"Could not read seq_len/max_cache_len from {onnx_path.name}; " + f"found seq_len={seq_len}, max_cache_len={max_cache_len}" + ) + return seq_len, max_cache_len + + +def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: + """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. + + Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` + (``.key_cache`` / ``.value_cache``), and the newer per-layer + ``DynamicCache`` (``.layers[i].keys`` / ``.values``). + """ + if hasattr(past, "key_cache") and hasattr(past, "value_cache"): + return past.key_cache[i], past.value_cache[i] + if hasattr(past, "layers"): + layer = past.layers[i] + return layer.keys, layer.values + return past[i][0], past[i][1] + + class Qwen3TransformerOnlyCalibReader(CalibrationDataReader): - """Yields calibration feeds for the transformer-only ONNX. + """Prefill calibration feeds for the transformer-only ONNX. Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), ``past_seq_len`` (INT32 ``[1,1]``), ``total_seq_len`` (INT32 ``[1]``), @@ -95,9 +177,7 @@ def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[st ids = ids[:, : self.seq_len] real_len = ids.shape[1] if real_len < self.seq_len: - pad = torch.zeros( - (1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device - ) + pad = torch.zeros((1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device) ids = torch.cat([ids, pad], dim=1) with torch.no_grad(): @@ -107,10 +187,10 @@ def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[st "input_hidden_states": embeds.astype(np.float32), # seqlens_k for GQA = (valid context length - 1), i.e. # ``embeddings.shape[1] - 1``. We pad to seq_len, so the query - # has seq_len valid positions → past_seq_len = seq_len - 1. + # has seq_len valid positions -> past_seq_len = seq_len - 1. # (Using 0 here declares only 1 valid token while feeding a # seq_len-token query, which makes the GQA prefill kernel read - # out of bounds → native access violation.) + # out of bounds -> native access violation.) "past_seq_len": np.array([[self.seq_len - 1]], dtype=np.int32), "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), } @@ -122,30 +202,17 @@ def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[st yield feed def get_next(self) -> dict[str, np.ndarray] | None: + """Return the next calibration feed, or None when exhausted.""" try: return next(self._iter) if self._iter is not None else None except StopIteration: return None def rewind(self) -> None: + """Reset the iterator so calibration can run another pass.""" self._iter = iter(self._samples) -def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: - """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. - - Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` - (``.key_cache`` / ``.value_cache``), and the newer per-layer - ``DynamicCache`` (``.layers[i].keys`` / ``.values``). - """ - if hasattr(past, "key_cache") and hasattr(past, "value_cache"): - return past.key_cache[i], past.value_cache[i] - if hasattr(past, "layers"): - layer = past.layers[i] - return layer.keys, layer.values - return past[i][0], past[i][1] - - class Qwen3DecodeTrajectoryCalibReader(CalibrationDataReader): """Calibrate the iter (seq_len=1) model on REAL decode-step states. @@ -174,7 +241,7 @@ def __init__( *, prefill_seq: int, max_cache_len: int, - decode_steps: int = 16, + decode_steps: int = DEFAULT_DECODE_STEPS, ) -> None: self.num_layers = config.num_hidden_layers self.num_kv_heads = config.num_key_value_heads @@ -199,9 +266,7 @@ def _kv_buffers(self, past: Any, cur_len: int) -> dict[str, np.ndarray]: feed: dict[str, np.ndarray] = {} for i in range(self.num_layers): k, v = _layer_kv(past, i) - kbuf = np.zeros( - (1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16 - ) + kbuf = np.zeros((1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16) vbuf = np.zeros_like(kbuf) kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() @@ -256,179 +321,97 @@ def _build_samples( cur_len += 1 def get_next(self) -> dict[str, np.ndarray] | None: + """Return the next calibration feed, or None when exhausted.""" try: return next(self._iter) if self._iter is not None else None except StopIteration: return None def rewind(self) -> None: + """Reset the iterator so calibration can run another pass.""" self._iter = iter(self._samples) -def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: - out: list[torch.Tensor] = [] - for i in range(num_samples): - prompt = prompts[i % len(prompts)] - text = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, - ) - ids = tokenizer([text], return_tensors="pt").input_ids - out.append(ids) - return out - - -def _gqa_node_names(onnx_path: Path) -> list[str]: - """Return the names of every GroupQueryAttention node in ``onnx_path``. - - These nodes are excluded from quantization so ORT leaves both their - inputs and output in float (``... -> Cast -> GQA -> Cast``), matching - the reference graph which keeps attention entirely out of QDQ. - """ - import onnx - - model = onnx.load(str(onnx_path), load_external_data=False) - return [ - n.name - for n in model.graph.node - if n.op_type == "GroupQueryAttention" and n.name - ] - - -def quantize_built_model( - model: WinMLCompositeModel, +def finalize_transformer_only_quant_config( + quant: WinMLQuantizationConfig, *, + onnx_path: Path, model_id: str = DEFAULT_MODEL_ID, - max_cache_len: int = DEFAULT_MAX_CACHE, prefill_seq: int = DEFAULT_PREFILL_SEQ, - num_samples: int = DEFAULT_NUM_SAMPLES, - weight_type: str = "int8", - activation_type: str = "uint16", decode_steps: int = DEFAULT_DECODE_STEPS, -) -> dict[str, Path]: - """Quantize the transformer-only ONNX files in-place. - - Returns ``{sub_model_name: quantized_path}``. +) -> WinMLQuantizationConfig: + """Populate ``quant`` with the transformer-only w8a16 scheme + runtime fields. + + The build pipeline's device/precision policy only enables quantization and + picks generic dtypes; the transformer-only scheme is fixed and reference- + matched, so this hook is authoritative: + + - **int8-symmetric weights** (zp=0) + **uint16 asymmetric activations**, + - **MinMax** calibration, + - GroupQueryAttention nodes excluded from QDQ (read from the graph), + - the matching :class:`CalibrationDataReader` (prefill vs. decode-trajectory, + chosen by the graph's ``seq_len``). + + Reads static shapes + GQA nodes from ``onnx_path`` and loads a fresh FP + reference model for calibration (the export wrapper's own weights are + surgically replaced and can't run real inference). """ - # Locate the un-compiled ONNX for each sub-model (no surgery — file is - # already transformer-only). - sub_paths: dict[str, Path] = {} - for name, sub in model.sub_models.items(): - final_path = Path(sub._onnx_path) - if final_path.name.endswith("_model.onnx"): - stem = final_path.name[: -len("_model.onnx")] - optimized = final_path.with_name(f"{stem}_optimized.onnx") - if optimized.exists(): - sub_paths[name] = optimized - continue - print( - f"WARNING: {optimized.name} not found next to {final_path.name}; " - "falling back to the compiled model." - ) - sub_paths[name] = final_path - - for name, p in sub_paths.items(): - print(f" {name}: {p}") + onnx_path = Path(onnx_path) + seq_len, max_cache_len = _graph_shapes(onnx_path) + gqa_nodes = _gqa_node_names(onnx_path) + + # Fixed, reference-matched w8a16 scheme (authoritative over policy dtypes). + quant.weight_type = "int8" + quant.activation_type = "uint16" + quant.weight_symmetric = True + quant.activation_symmetric = False + quant.calibration_method = "minmax" + num_samples = quant.samples or DEFAULT_NUM_SAMPLES + + logger.info( + "Finalizing transformer-only quant config for %s " + "(seq_len=%d, max_cache_len=%d, %d GQA nodes excluded, %d samples)", + onnx_path.name, + seq_len, + max_cache_len, + len(gqa_nodes), + num_samples, + ) - print("\n=== Loading HF embed_tokens for calibration ===") - hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32) hf_model.eval() embed_tokens = hf_model.get_input_embeddings() tokenizer = AutoTokenizer.from_pretrained(model_id) - - print( - f"=== Loading {num_samples} GSM8K calibration prompts " - f"({DEFAULT_CALIB_DATASET}/{DEFAULT_CALIB_DATASET_CONFIG}, " - f"split={DEFAULT_CALIB_SPLIT}, seed={DEFAULT_CALIB_SEED}) ===" - ) prompts = _load_gsm8k_prompts(num_samples) token_ids_list = _tokenize_prompts(tokenizer, prompts, num_samples) - seq_by_sub = { - "decoder_prefill": prefill_seq, - "decoder_gen": DEFAULT_GEN_SEQ, - } - - quant_paths: dict[str, Path] = {} - for sub_name, fused_path in sub_paths.items(): - if sub_name not in seq_by_sub: - print(f"\n--- Skipping unknown sub-model {sub_name!r} ---") - continue - - seq_len = seq_by_sub[sub_name] - quant_path = fused_path.with_name( - fused_path.stem - + f"_w{_DTYPE_BITS[weight_type]}a{_DTYPE_BITS[activation_type]}.quant.onnx" + reader: CalibrationDataReader + if seq_len == DEFAULT_GEN_SEQ: + # Decode sub-model: calibrate on a real prefill+decode trajectory. + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed_tokens, + hf_model.config, + token_ids_list, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, ) - - print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") - print(f" in : {fused_path}") - print(f" out: {quant_path}") - gqa_nodes = _gqa_node_names(fused_path) - print( - f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " - "quantization (inputs + output stay float, Cast -> GQA -> Cast)" + else: + reader = Qwen3TransformerOnlyCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, ) - if sub_name == "decoder_gen": - # The iter model only sees mid-generation states. Calibrate it on a - # real prefill+decode trajectory (true tokens, accumulated KV, - # growing past_seq_len) instead of one token + zeroed KV, which - # would under-range the MinMax activation scales and collapse - # generation. - print( - f" calibrating on decode trajectory ({decode_steps} steps/prompt, " - f"prefill_seq={prefill_seq})" - ) - reader: CalibrationDataReader = Qwen3DecodeTrajectoryCalibReader( - hf_model, - embed_tokens, - hf_model.config, - token_ids_list, - prefill_seq=prefill_seq, - max_cache_len=max_cache_len, - decode_steps=decode_steps, - ) - else: - reader = Qwen3TransformerOnlyCalibReader( - embed_tokens, - hf_model.config, - token_ids_list, - seq_len=seq_len, - max_cache_len=max_cache_len, - ) - cfg = WinMLQuantizationConfig( - samples=num_samples, - weight_type=weight_type, # type: ignore[arg-type] - activation_type=activation_type, # type: ignore[arg-type] - calibration_method="minmax", - calibration_data=reader, - # w8a16: symmetric int8 weights (zp=0) + asymmetric uint16 - # activations, matching the reference quantization. - weight_symmetric=True, - activation_symmetric=False, - # ORT treats GroupQueryAttention as quantizable and wraps both its - # inputs and output in QDQ. The reference keeps attention entirely - # in float (Cast -> GQA -> Cast), so exclude the GQA nodes from - # quantization so no QDQ is inserted around them. - nodes_to_exclude=gqa_nodes, - ) - result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) - if not result.success: - print(" FAILED:") - for err in result.errors: - print(f" {err}") - raise SystemExit(1) - print( - f" ok — {result.nodes_quantized} QDQ nodes inserted in " - f"{result.total_time_seconds:.1f}s" - ) - quant_paths[sub_name] = quant_path - # Free the FP reference model now that calibration is done. + quant.calibration_data = reader + quant.nodes_to_exclude = gqa_nodes + + # Readers materialize all samples eagerly, so the FP reference is no longer + # needed once they're built. del hf_model, embed_tokens gc.collect() - print("\n=== Done ===") - return quant_paths + return quant diff --git a/test_qwen.py b/test_qwen.py deleted file mode 100644 index 14cf4656d..000000000 --- a/test_qwen.py +++ /dev/null @@ -1,235 +0,0 @@ -"""E2E test for the transformer-only Qwen3 export path. - -Produces two transformer-only ONNX files whose I/O matches -``qwen3_gqa_fp16_ctx.onnx`` / ``qwen3_gqa_fp16_iter.onnx``: - - decoder_prefill: input_hidden_states [1, 64, 1024] → output_hidden_states + KV - decoder_gen : input_hidden_states [1, 1, 1024] → output_hidden_states + KV - -with FP16 past/present KV named ``past_keys_{i}`` / ``past_values_{i}``, -``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv -projections. - -Generation (``model.generate(...)``) is NOT supported by this build path — -the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager -I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end -generation. - -Run:: - - python test_qwen.py - -This builds each transformer sub-model and then runs the w8a16 -quantization on the exported transformer ONNX files (no surgery needed — -files are already transformer-only). -""" - -import os -import sys -import pathlib -import subprocess - -# Put the in-repo `src/` ahead of site-packages so `import winml` always -# resolves to the editable source tree — no manual copy-to-venv needed. -_repo_root = pathlib.Path(__file__).resolve().parent -sys.path.insert(0, str(_repo_root / "src")) -sys.path.insert(0, str(_repo_root)) - -model_id = "Qwen/Qwen3-0.6B" -MAX_CACHE = 256 - -# component name -> (HF task, seq_len, artifact prefix). Order matters -# (prefill first). The prefix is how the built npu_ctx file is named so the -# parent can verify success by artifact appearance (the build segfaults on -# native QNN/ORT teardown AFTER writing the file, so exit codes are unreliable). -SUB_MODELS = { - "decoder_prefill": ("feature-extraction", 64, "feat_"), - "decoder_gen": ("text2text-generation", 1, "txt2txt_"), -} - -ARTIFACTS_DIR = ( - pathlib.Path.home() / ".cache" / "winml" / "artifacts" / model_id.replace("/", "_") -) - - -def _latest_ctx_mtime(prefix: str) -> float: - """Newest mtime of a ``{prefix}*_optimized_npu_ctx.onnx`` artifact, or 0.""" - files = list(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) - return max((f.stat().st_mtime for f in files), default=0.0) - - -def _build_one(task: str, seq_len: int) -> None: - """Build a SINGLE transformer sub-model in this (fresh) process. - - Invoked as a subprocess by ``main()`` so each sub-model exports in a - clean interpreter — building both in one process leaves PyTorch/ORT - state from the first build that corrupts/kills the second. - """ - from winml.modelkit.config import WinMLBuildConfig - from winml.modelkit.models.auto import WinMLAutoModel - - WinMLAutoModel.from_pretrained( - model_id, - task=task, - model_type="qwen3_transformer_only", - config=WinMLBuildConfig(quant=None, compile=None), - precision="fp16", - device="npu", - ep="qnn", - force_rebuild=True, - shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, - ) - # The QNN/ORT teardown segfaults (0xC0000005) on interpreter shutdown - # AFTER the artifact is fully written. Skip the buggy cleanup with a hard - # exit so the parent sees a clean exit code 0. - print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) - sys.stdout.flush() - sys.stderr.flush() - # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT - # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. - os._exit(0) - - -def _find_optimized(prefix: str) -> pathlib.Path: - """Locate the cached transformer-only ``{prefix}*_optimized.onnx`` file.""" - cands = [ - p for p in ARTIFACTS_DIR.glob(f"{prefix}*_optimized.onnx") - if not p.name.endswith("_optimized_npu_ctx.onnx") - ] - if not cands: - raise FileNotFoundError( - f"No {prefix}*_optimized.onnx in {ARTIFACTS_DIR} — build the sub-model first." - ) - return max(cands, key=lambda p: p.stat().st_mtime) - - -class _SubShim: - """Minimal stand-in exposing the ``_onnx_path`` quant needs.""" - - def __init__(self, onnx_path: pathlib.Path): - self._onnx_path = str(onnx_path) - - -class _ModelShim: - """Minimal stand-in exposing ``sub_models`` for ``quantize_built_model``.""" - - def __init__(self, sub_models: dict): - self.sub_models = sub_models - - -def _run_quant() -> None: - """Quantize the cached transformer ONNX files (no composite/QNN load). - - Runs as its own subprocess so any ORT teardown crash can't poison the - parent. Builds a shim ``model`` whose ``sub_models[name]._onnx_path`` - point straight at the cached ``*_optimized.onnx`` files. - """ - # Dump a native C-stack if the calibration InferenceSession segfaults - # (otherwise the crash is silent — no Python traceback). - import faulthandler - faulthandler.enable() - - from qwen3_transformer_only_quantize import quantize_built_model - - sub_models = { - name: _SubShim(_find_optimized(prefix)) - for name, (_task, _seq, prefix) in SUB_MODELS.items() - } - model = _ModelShim(sub_models) - print("=== Running transformer w8a16 quantization ===", flush=True) - for name, sub in sub_models.items(): - print(f" {name}: {sub._onnx_path}", flush=True) - - try: - quantize_built_model( - model, - model_id=model_id, - max_cache_len=MAX_CACHE, - prefill_seq=64, - ) - except BaseException: - import traceback - print("QUANT FAILED with exception:", flush=True) - traceback.print_exc() - sys.stdout.flush() - sys.stderr.flush() - raise - print("QUANT COMPLETE", flush=True) - sys.stdout.flush() - sys.stderr.flush() - # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT - # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. - os._exit(0) - - -def main() -> None: - # 1) Build each sub-model in its OWN subprocess (fresh state each time). - # Judge success by whether a FRESH npu_ctx artifact appeared, NOT by the - # subprocess exit code: the native QNN/ORT layer segfaults (0xC0000005) - # on teardown AFTER the artifact is fully written to disk. - import time as _time - - for name, (task, seq_len, prefix) in SUB_MODELS.items(): - print(f"\n########## BUILD {name} (task={task}, seq_len={seq_len}) ##########", flush=True) - before = _latest_ctx_mtime(prefix) - start = _time.time() - rc = subprocess.run( - [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), - "--build-sub", task, str(seq_len)], - cwd=str(_repo_root), - timeout=1800, - ).returncode - - after = _latest_ctx_mtime(prefix) - if after > before and after >= start - 1: - status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" - print(f"########## {name} {status}: fresh {prefix}*_optimized_npu_ctx.onnx ##########", flush=True) - else: - raise SystemExit( - f"Sub-model build failed for {name} (exit {rc}) — " - f"no fresh {prefix}*_optimized_npu_ctx.onnx in {ARTIFACTS_DIR}" - ) - - # 2) Report the built transformer-only ONNX files (no composite/QNN load — - # that creates QNN EP sessions that segfault the parent on teardown). - for name, (_task, _seq, prefix) in SUB_MODELS.items(): - print(f"\n=== {name} ===") - print(f" optimized : {_find_optimized(prefix).name}") - ctx = sorted(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) - if ctx: - print(f" npu_ctx : {ctx[-1].name}") - - # 3) Quantization — run in its OWN subprocess for the same teardown-crash - # isolation. Judge by whether quant files appeared. - print("\n########## QUANTIZE ##########", flush=True) - before = max( - (p.stat().st_mtime for p in ARTIFACTS_DIR.glob("*quant.onnx")), - default=0.0, - ) - qstart = _time.time() - rc = subprocess.run( - [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], - cwd=str(_repo_root), - timeout=1800, - ).returncode - after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) - after = max((p.stat().st_mtime for p in after_files), default=0.0) - if after > before and after >= qstart - 1: - status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" - print(f"########## QUANTIZE {status} ##########", flush=True) - for p in sorted(after_files, key=lambda x: x.stat().st_mtime)[-len(SUB_MODELS):]: - print(f" {p.name}", flush=True) - else: - raise SystemExit( - f"Quantization failed (exit {rc}) — no fresh *quant.onnx in {ARTIFACTS_DIR}" - ) - - -if __name__ == "__main__": - if len(sys.argv) >= 4 and sys.argv[1] == "--build-sub": - _build_one(sys.argv[2], int(sys.argv[3])) - elif len(sys.argv) >= 2 and sys.argv[1] == "--quant": - _run_quant() - else: - main() - diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py new file mode 100644 index 000000000..7c6499e51 --- /dev/null +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""End-to-end coverage for the transformer-only Qwen3 w8a16 build. + +Replaces the former root-level ``test_qwen.py`` / ``qwen3_transformer_only_quantize.py`` +scripts. Quantization is now driven entirely through the standard build +pipeline (``WinMLAutoModel.from_pretrained(..., precision="w8a16")``): the +device/precision policy enables the quantize stage, and +``QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config`` finalizes the +reference-matched scheme (int8-symmetric weights, uint16 activations, +GroupQueryAttention excluded from QDQ) plus the decode-trajectory calibration +reader. + +These tests download Qwen3-0.6B from HuggingFace and run a full CPU export + +quantize, so they are gated behind ``slow`` + ``network`` and excluded from the +default lane. The QNN/NPU build is additionally gated on a real NPU. + +All expectations are generated in-code (FP reference greedy decode), never +hardcoded from a prior model run. +""" + +from __future__ import annotations + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from winml.modelkit.config import WinMLBuildConfig +from winml.modelkit.models.auto import WinMLAutoModel +from winml.modelkit.quant import WinMLQuantizationConfig + + +pytestmark = [pytest.mark.e2e, pytest.mark.slow, pytest.mark.network] + +MODEL_ID = "Qwen/Qwen3-0.6B" +MAX_CACHE = 256 +PARITY_TOKENS = 8 +DECODE_STEPS = 12 +# Keep CPU calibration cheap: the decode reader emits ``samples * 16`` feeds. +CALIB_SAMPLES = 4 + + +def _qnn_available() -> bool: + """True when ONNX Runtime exposes the QNN execution provider (real NPU).""" + return "QNNExecutionProvider" in ort.get_available_providers() + + +def _decoder_onnx_path(model) -> str: + """Locate the quantized decode ONNX behind the composite handle.""" + sub = model.sub_models["decoder_gen"] + return str(sub._onnx_path) + + +def _qdq_counts(onnx_path: str) -> dict[str, int]: + graph = onnx.load(onnx_path, load_external_data=False).graph + counts: dict[str, int] = {} + for node in graph.node: + counts[node.op_type] = counts.get(node.op_type, 0) + 1 + return counts + + +def _gqa_tensor_set(graph) -> set[str]: + tensors: set[str] = set() + for node in graph.node: + if node.op_type == "GroupQueryAttention": + tensors.update(node.input) + tensors.update(node.output) + return tensors + + +@pytest.fixture(scope="module") +def decode_quant_model(tmp_path_factory): + """Build + quantize the decode (seq_len=1) sub-model once on CPU.""" + cache_dir = tmp_path_factory.mktemp("qwen3_w8a16") + return WinMLAutoModel.from_pretrained( + MODEL_ID, + task="text2text-generation", + model_type="qwen3_transformer_only", + config=WinMLBuildConfig(quant=WinMLQuantizationConfig(samples=CALIB_SAMPLES)), + precision="w8a16", + device="cpu", + ep="cpu", + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": 1}, + cache_dir=str(cache_dir), + ) + + +@pytest.mark.timeout(2400) +def test_decode_model_is_quantized_with_gqa_excluded(decode_quant_model): + onnx_path = _decoder_onnx_path(decode_quant_model) + counts = _qdq_counts(onnx_path) + + # QDQ nodes were inserted via the config-driven pipeline. + assert counts.get("QuantizeLinear", 0) > 0 + assert counts.get("DequantizeLinear", 0) > 0 + # GroupQueryAttention survives in float (not quantized away). + assert counts.get("GroupQueryAttention", 0) > 0 + + # GQA exclusion contract: no QuantizeLinear/DequantizeLinear touches a GQA + # input or output tensor (attention stays Cast -> GQA -> Cast). + graph = onnx.load(onnx_path, load_external_data=False).graph + gqa_tensors = _gqa_tensor_set(graph) + touching = [ + node.name + for node in graph.node + if node.op_type in ("QuantizeLinear", "DequantizeLinear") + and (set(node.input) & gqa_tensors or set(node.output) & gqa_tensors) + ] + assert touching == [] + + +def _carry_kv(kv: dict[str, np.ndarray], out: dict[str, np.ndarray], num_layers: int) -> None: + for i in range(num_layers): + kv[f"past_keys_{i}"] = out[f"present_keys_{i}"] + kv[f"past_values_{i}"] = out[f"present_values_{i}"] + + +def _seed_kv_from_fp(past, num_layers, num_kv_heads, head_dim, cur_len): + """Copy an HF FP prefill cache into the decode model's fixed FP16 buffers.""" + kv: dict[str, np.ndarray] = {} + for i in range(num_layers): + layer = past[i] if not hasattr(past, "layers") else None + if layer is not None: + k, v = past[i][0], past[i][1] + else: # newer per-layer DynamicCache + k, v = past.layers[i].keys, past.layers[i].values + kbuf = np.zeros((1, num_kv_heads, MAX_CACHE, head_dim), np.float16) + vbuf = np.zeros_like(kbuf) + kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + kv[f"past_keys_{i}"] = kbuf + kv[f"past_values_{i}"] = vbuf + return kv + + +@pytest.mark.timeout(2400) +def test_decode_parity_against_fp_reference(decode_quant_model): + """The w8a16 decode model must track the FP reference token-for-token. + + This is the regression guard against the historical "decode collapse": + a degenerate calibration (single repeated token + zeroed KV) made the + quantized decode model diverge into garbage after ~1 token. With the + decode-trajectory reader the quantized greedy trajectory must match the + FP reference for the first ``PARITY_TOKENS`` tokens. + """ + onnx_path = _decoder_onnx_path(decode_quant_model) + session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + want = {i.name for i in session.get_inputs()} + + hf = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32) + hf.eval() + cfg = hf.config + embed = hf.get_input_embeddings() + lm_head = hf.lm_head + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + num_layers = cfg.num_hidden_layers + num_kv_heads = cfg.num_key_value_heads + head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads) + + text = tokenizer.apply_chat_template( + [{"role": "user", "content": "What is the capital of France?"}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + cur_len = ids.shape[1] + assert cur_len < MAX_CACHE + + # --- FP reference greedy decode (generates the expected tokens) --- + with torch.no_grad(): + out = hf(input_ids=ids, use_cache=True) + fp_past = out.past_key_values + first_tok = int(out.logits[:, -1, :].argmax(-1)) + fp_tokens: list[int] = [] + tok, past = first_tok, fp_past + for _ in range(DECODE_STEPS): + with torch.no_grad(): + out = hf(input_ids=torch.tensor([[tok]]), past_key_values=past, use_cache=True) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + fp_tokens.append(tok) + + # --- Quantized decode model greedy decode (own KV, FP embed + lm_head) --- + with torch.no_grad(): + seed = hf(input_ids=ids, use_cache=True) + kv = _seed_kv_from_fp(seed.past_key_values, num_layers, num_kv_heads, head_dim, cur_len) + quant_tokens: list[int] = [] + tok, past_len = first_tok, cur_len + for _ in range(DECODE_STEPS): + with torch.no_grad(): + emb = embed(torch.tensor([[tok]])).to(torch.float32).cpu().numpy() + feeds = { + "input_hidden_states": emb.astype(np.float32), + "past_seq_len": np.array([[past_len]], np.int32), + "total_seq_len": np.array([MAX_CACHE], np.int32), + **kv, + } + feeds = {k: v for k, v in feeds.items() if k in want} + names = [o.name for o in session.get_outputs()] + outs = dict(zip(names, session.run(None, feeds), strict=False)) + _carry_kv(kv, outs, num_layers) + hidden = torch.tensor(outs["output_hidden_states"][:, 0, :]) + with torch.no_grad(): + tok = int(lm_head(hidden).numpy()[0].argmax()) + quant_tokens.append(tok) + past_len += 1 + + assert quant_tokens[:PARITY_TOKENS] == fp_tokens[:PARITY_TOKENS], ( + f"w8a16 decode diverged from FP reference:\n" + f" fp : {fp_tokens[:PARITY_TOKENS]}\n" + f" quant: {quant_tokens[:PARITY_TOKENS]}" + ) + + +@pytest.mark.npu +@pytest.mark.qnn +@pytest.mark.timeout(2400) +@pytest.mark.skipif(not _qnn_available(), reason="requires QNN execution provider (NPU)") +@pytest.mark.parametrize( + ("task", "seq_len"), + [("feature-extraction", 64), ("text2text-generation", 1)], +) +def test_npu_build_quantizes(task, seq_len, tmp_path): + """On real NPU hardware, the w8a16 pipeline produces a quantized graph.""" + model = WinMLAutoModel.from_pretrained( + MODEL_ID, + task=task, + model_type="qwen3_transformer_only", + precision="w8a16", + device="npu", + ep="qnn", + no_compile=True, + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, + cache_dir=str(tmp_path), + ) + sub_name = "decoder_prefill" if seq_len == 64 else "decoder_gen" + onnx_path = str(model.sub_models[sub_name]._onnx_path) + counts = _qdq_counts(onnx_path) + assert counts.get("QuantizeLinear", 0) > 0 + assert counts.get("GroupQueryAttention", 0) > 0 diff --git a/tests/unit/models/qwen_transformer_only/__init__.py b/tests/unit/models/qwen_transformer_only/__init__.py new file mode 100644 index 000000000..862c45ce3 --- /dev/null +++ b/tests/unit/models/qwen_transformer_only/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py new file mode 100644 index 000000000..f1b160433 --- /dev/null +++ b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py @@ -0,0 +1,234 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the transformer-only Qwen3 quant calibration readers. + +These are fast, offline tests (no model download, no ONNX Runtime): they +exercise the graph-shape introspection, GroupQueryAttention node discovery, +and the exact feed contract (names / dtypes / shapes) the two calibration +readers must satisfy. All expectations are derived in-code from the inputs, +never hardcoded from a model run. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import onnx +import torch +from onnx import TensorProto, helper + +from winml.modelkit.models.hf.qwen_transformer_only_quant import ( + Qwen3DecodeTrajectoryCalibReader, + Qwen3TransformerOnlyCalibReader, + _gqa_node_names, + _graph_shapes, +) + + +NUM_LAYERS = 2 +NUM_KV_HEADS = 2 +HEAD_DIM = 4 +HIDDEN = NUM_KV_HEADS * HEAD_DIM +VOCAB = 16 + + +def _fake_config() -> SimpleNamespace: + return SimpleNamespace( + num_hidden_layers=NUM_LAYERS, + num_key_value_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + hidden_size=HIDDEN, + num_attention_heads=NUM_KV_HEADS, + ) + + +def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: + """Write a minimal graph carrying the inputs the readers introspect.""" + inputs = [ + helper.make_tensor_value_info( + "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + ), + helper.make_tensor_value_info( + "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + ), + ] + out = helper.make_tensor_value_info( + "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + ) + gqa = helper.make_node( + "GroupQueryAttention", + ["input_hidden_states"], + ["attn_out"], + name="gqa_layer_0", + domain="com.microsoft", + ) + identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + onnx.save(model, str(path)) + + +def test_graph_shapes_and_gqa_nodes(tmp_path): + p = tmp_path / "tiny.onnx" + _build_tiny_onnx(p, seq_len=1, max_cache_len=16) + + assert _graph_shapes(p) == (1, 16) + assert _gqa_node_names(p) == ["gqa_layer_0"] + + +def test_graph_shapes_prefill(tmp_path): + p = tmp_path / "tiny_prefill.onnx" + _build_tiny_onnx(p, seq_len=64, max_cache_len=256) + + assert _graph_shapes(p) == (64, 256) + + +def _drain(reader) -> list[dict[str, np.ndarray]]: + feeds = [] + while (feed := reader.get_next()) is not None: + feeds.append(feed) + return feeds + + +def test_prefill_reader_feed_contract(): + seq_len, max_cache_len = 4, 16 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + token_ids = [torch.tensor([[1, 2, 3, 4, 5]])] + + reader = Qwen3TransformerOnlyCalibReader( + embed, + _fake_config(), + token_ids, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + feeds = _drain(reader) + + assert len(feeds) == len(token_ids) + feed = feeds[0] + + # input_hidden_states: FP32, truncated to seq_len. + assert feed["input_hidden_states"].dtype == np.float32 + assert feed["input_hidden_states"].shape == (1, seq_len, HIDDEN) + + # seqlens_k contract: past_seq_len = seq_len - 1 (INT32 [1,1]). + assert feed["past_seq_len"].dtype == np.int32 + np.testing.assert_array_equal(feed["past_seq_len"], [[seq_len - 1]]) + + # total_seq_len: full cache (INT32 [1]). + assert feed["total_seq_len"].dtype == np.int32 + np.testing.assert_array_equal(feed["total_seq_len"], [max_cache_len]) + + # KV buffers: FP16, full cache shape, present for every layer. + for i in range(NUM_LAYERS): + for prefix in ("past_keys_", "past_values_"): + kv = feed[f"{prefix}{i}"] + assert kv.dtype == np.float16 + assert kv.shape == (1, NUM_KV_HEADS, max_cache_len, HEAD_DIM) + + # rewind() replays the same samples. + reader.rewind() + assert len(_drain(reader)) == len(token_ids) + + +def test_prefill_reader_pads_short_prompts(): + seq_len = 6 # longer than the 3-token prompt -> must pad + embed = torch.nn.Embedding(VOCAB, HIDDEN) + token_ids = [torch.tensor([[1, 2, 3]])] + + reader = Qwen3TransformerOnlyCalibReader( + embed, _fake_config(), token_ids, seq_len=seq_len, max_cache_len=16 + ) + feed = _drain(reader)[0] + assert feed["input_hidden_states"].shape == (1, seq_len, HIDDEN) + + +class _StubCausalLM: + """Minimal HF-like model: grows a tuple-of-tuples KV cache by 1 each call. + + Always predicts ``next_token`` so the trajectory is deterministic. + """ + + def __init__(self, next_token: int) -> None: + self.next_token = next_token + + def _cache(self, length: int): + return tuple( + ( + torch.randn(1, NUM_KV_HEADS, length, HEAD_DIM), + torch.randn(1, NUM_KV_HEADS, length, HEAD_DIM), + ) + for _ in range(NUM_LAYERS) + ) + + def __call__(self, input_ids=None, past_key_values=None, use_cache=True): + if past_key_values is None: + length = input_ids.shape[1] + query_len = length + else: + length = past_key_values[0][0].shape[2] + input_ids.shape[1] + query_len = input_ids.shape[1] + logits = torch.full((1, query_len, VOCAB), -10.0) + logits[..., self.next_token] = 10.0 + return SimpleNamespace(past_key_values=self._cache(length), logits=logits) + + +def test_decode_trajectory_reader_grows_past_seq_len(): + prefill_seq, decode_steps, max_cache_len = 2, 3, 16 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + hf_model = _StubCausalLM(next_token=5) + token_ids = [torch.tensor([[1, 2, 3, 4]])] # truncated to prefill_seq=2 + + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed, + _fake_config(), + token_ids, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + feeds = _drain(reader) + + assert len(feeds) == len(token_ids) * decode_steps + + # past_seq_len must grow monotonically from prefill_seq (real decode), not + # stay pinned at 0 like the degenerate single-token reader. + seq_lens = [int(f["past_seq_len"][0, 0]) for f in feeds] + assert seq_lens == [prefill_seq, prefill_seq + 1, prefill_seq + 2] + + for f in feeds: + # One token per decode step. + assert f["input_hidden_states"].shape == (1, 1, HIDDEN) + assert f["input_hidden_states"].dtype == np.float32 + cur_len = int(f["past_seq_len"][0, 0]) + for i in range(NUM_LAYERS): + kv = f[f"past_keys_{i}"] + assert kv.dtype == np.float16 + assert kv.shape == (1, NUM_KV_HEADS, max_cache_len, HEAD_DIM) + # Positions beyond the valid context stay zero-padded. + assert np.all(kv[:, :, cur_len:, :] == 0) + + +def test_decode_trajectory_reader_respects_max_cache(): + prefill_seq, decode_steps, max_cache_len = 4, 10, 6 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + hf_model = _StubCausalLM(next_token=2) + token_ids = [torch.tensor([[1, 2, 3, 4, 5, 6]])] + + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed, + _fake_config(), + token_ids, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + feeds = _drain(reader) + # Trajectory must stop once the cache is full (cur_len reaches max_cache_len). + assert len(feeds) == max_cache_len - prefill_seq + assert max(int(f["past_seq_len"][0, 0]) for f in feeds) == max_cache_len - 1 From a7f518e6ea61138a128eed8473a3fc89783caa8c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 13:42:32 +0800 Subject: [PATCH 08/13] fix(qwen): clean lint + persist finalized quant config + guard dynamic shapes - Add missing docstrings / return-type annotations and drop dead noqa directives across qwen3_export_ops.py, qwen3_modeling.py and the transformer-only registration so 'ruff check src/ tests/' (CI lint) passes. - build/hf.py: re-persist config.json after winml_finalize_quant_config runs, so the saved metadata reflects the actually-applied w8a16 scheme (int8/uint16/symmetry + GQA nodes_to_exclude) rather than the pre-finalize policy dtypes. - qwen_transformer_only_quant._graph_shapes: treat a non-positive dim_value (symbolic/dynamic axis) as a hard error instead of silently returning a zero-length shape. --- src/winml/modelkit/build/hf.py | 4 ++ src/winml/modelkit/models/hf/__init__.py | 3 +- .../modelkit/models/hf/qwen3_export_ops.py | 62 +++++++++++++------ .../modelkit/models/hf/qwen3_modeling.py | 20 +++--- .../models/hf/qwen_transformer_only_quant.py | 7 ++- 5 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 2faa3eec2..08698d125 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -325,6 +325,10 @@ def _name(base: str) -> str: config.quant = pytorch_model.winml_finalize_quant_config( config.quant, onnx_path=current_path, model_id=model_id ) + # The hook may overwrite the quant scheme (dtypes, symmetry, + # nodes-to-exclude) authoritatively, so re-persist the config + # to keep config.json consistent with what was actually applied. + config_path.write_text(json.dumps(config.to_dict(), indent=2)) quant_result = quantize_onnx( model_path=current_path, output_path=quantized_path, diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index 0d2e538a3..458bc8e34 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -62,7 +62,8 @@ QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration ) from .qwen_transformer_only import ( - QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, # triggers registration + # triggers registration + QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, ) from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py index 5fd3edb68..e1eba87c3 100644 --- a/src/winml/modelkit/models/hf/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -2,8 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Custom ONNX export ops + the entry point that reshapes HF's Qwen3 modules -for the transformer-only export. +"""Custom ONNX export ops that reshape HF's Qwen3 modules for export. These reshape the standard HF Qwen3 modules so winml-cli can produce a QNN-friendly, transformer-only graph: @@ -20,6 +19,8 @@ from __future__ import annotations +from typing import Any + import torch import torch.nn as nn from torch.onnx import symbolic_helper @@ -34,7 +35,8 @@ class LpNormOnnxExport(torch.autograd.Function): """RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim).""" @staticmethod - def symbolic(g, input, axis, p): # noqa: D401 + def symbolic(g, input, axis, p) -> Any: + """Emit the ONNX ``LpNormalization`` node during export.""" output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input)) output = g.op( "onnx::LpNormalization", @@ -45,12 +47,14 @@ def symbolic(g, input, axis, p): # noqa: D401 return output.setType(output_type) @staticmethod - def forward(ctx, input, axis, p): # noqa: ARG004 - # Shape-only tracing placeholder. The real op is emitted by - # ``symbolic`` during ONNX export; ``forward`` exists solely so the - # TorchScript exporter (and Optimum's pre-export dry run) can trace - # output shapes. It returns ``input`` unchanged on purpose and is NOT a - # correct eager RMSNorm — do not call this module for real inference. + def forward(ctx, input, axis, p) -> Any: + """Shape-only tracing placeholder; returns ``input`` unchanged. + + The real op is emitted by ``symbolic`` during ONNX export; ``forward`` + exists solely so the TorchScript exporter (and Optimum's pre-export dry + run) can trace output shapes. It is NOT a correct eager RMSNorm — do + not call this module for real inference. + """ return input @@ -72,8 +76,19 @@ def symbolic( do_rotary, kv_num_heads, num_heads, - ): - args = [query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache] + ) -> Any: + """Emit the fused ``com.microsoft::GroupQueryAttention`` node.""" + args = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + ] attention_output, present_keys, present_values = g.op( "com.microsoft::GroupQueryAttention", *args, @@ -85,8 +100,12 @@ def symbolic( query_sizes = symbolic_helper._get_tensor_sizes(query) attention_output.setType(query.type().with_sizes(query_sizes)) - present_keys.setType(past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key))) - present_values.setType(past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value))) + present_keys.setType( + past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key)) + ) + present_values.setType( + past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value)) + ) return attention_output, present_keys, present_values @staticmethod @@ -104,13 +123,14 @@ def forward( do_rotary, kv_num_heads, num_heads, - ): # noqa: ARG004 - # Shape-only tracing placeholder. The real op is emitted by - # ``symbolic`` during ONNX export; ``forward`` exists solely so the - # TorchScript exporter (and Optimum's pre-export dry run) can trace - # output shapes. It returns the inputs as stand-in present-KV on - # purpose and is NOT correct attention — do not call this module for - # real inference. + ) -> Any: + """Shape-only tracing placeholder; returns stand-in (output, KV). + + The real op is emitted by ``symbolic`` during ONNX export; ``forward`` + exists solely so the TorchScript exporter (and Optimum's pre-export dry + run) can trace output shapes. It is NOT correct attention — do not call + this module for real inference. + """ return query, past_key, past_value # placeholder shapes @@ -135,6 +155,7 @@ def __init__( self.bias = bias def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the 1x1 conv with NHWC<->NCHW permutes (+ optional bias).""" x = x.permute(0, 3, 1, 2) # NHWC -> NCHW x = torch.nn.functional.conv2d(x, self.weight) x = x.permute(0, 2, 3, 1) # NCHW -> NHWC @@ -144,6 +165,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @classmethod def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: + """Build a 1x1-conv replacement from an existing ``nn.Linear``.""" return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py index d3c538df5..f5207d797 100644 --- a/src/winml/modelkit/models/hf/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -41,17 +41,17 @@ class WinMLQwen3RMSNorm(nn.Module): """RMSNorm export variant — ``onnx::LpNormalization`` body.""" def prepare_for_onnx_export(self) -> None: + """Fold the RMSNorm gain into the weight (LpNorm has unit gain).""" # Pre-multiply the gain into the weight (LpNorm has unit gain). # ``scale`` is shape ``[1]`` and broadcasts over ``self.weight`` # (shape ``[hidden_size]``), so the result keeps the per-channel # shape even when the original weights are all ones. n = self.weight.numel() - scale = torch.sqrt( - torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) - ) + scale = torch.sqrt(torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype)) self.weight = nn.Parameter(scale * self.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply the LpNormalization-based RMSNorm body.""" out = LpNormOnnxExport.apply(hidden_states, -1, 2) return self.weight * out @@ -60,6 +60,7 @@ class WinMLQwen3MLP(nn.Module): """MLP export variant — 1x1 Conv projections (forward unchanged).""" def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + """Optionally swap the MLP's linear projections for 1x1 convs.""" if not matmul_to_conv: return self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module(self.gate_proj) @@ -71,12 +72,13 @@ class WinMLQwen3Attention(nn.Module): """Attention export variant — fused ``GroupQueryAttention`` op.""" def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + """Optionally swap the Q/K/V/O projections for 1x1 convs.""" if matmul_to_conv: self.q_proj = TransposeConv2d1x1Transpose.from_linear_module(self.q_proj) self.k_proj = TransposeConv2d1x1Transpose.from_linear_module(self.k_proj) self.v_proj = TransposeConv2d1x1Transpose.from_linear_module(self.v_proj) self.o_proj = TransposeConv2d1x1Transpose.from_linear_module(self.o_proj) - self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + self._matmul_to_conv = matmul_to_conv def forward( self, @@ -84,8 +86,9 @@ def forward( past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, past_seq_len: torch.Tensor | None = None, total_seq_len: torch.Tensor | None = None, - **kwargs: Any, # noqa: ARG002 + **kwargs: Any, ) -> tuple[torch.Tensor, None, tuple[torch.Tensor, torch.Tensor]]: + """Run fused GQA attention and return (output, None, present_kv).""" query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -167,8 +170,9 @@ def forward( past_seq_len: torch.Tensor | None = None, total_seq_len: torch.Tensor | None = None, use_cache: bool = True, - **kwargs: Any, # noqa: ARG002 + **kwargs: Any, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Run the decoder layer (attention + MLP) with residual adds.""" residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_out, _, present_kv = self.self_attn( @@ -194,7 +198,8 @@ class WinMLQwen3Model(nn.Module): """Model export variant — transformer-only body (no embeddings / lm_head).""" def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: - self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + """Record whether projections use the 1x1-conv (NHWC) path.""" + self._matmul_to_conv = matmul_to_conv def forward( self, @@ -204,6 +209,7 @@ def forward( total_seq_len: torch.Tensor, use_cache: bool = True, ) -> tuple[torch.Tensor, tuple[tuple[torch.Tensor, torch.Tensor], ...]]: + """Run the transformer-only body, returning hidden states + KV.""" hidden_states = inputs_embeds if self._matmul_to_conv: hidden_states = hidden_states.unsqueeze(0) # NHWC for Conv path diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py index f01de2f71..b52dfd85e 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py @@ -120,9 +120,12 @@ def _graph_shapes(onnx_path: Path) -> tuple[int, int]: seq_len = dims[1].dim_value elif inp.name == "past_keys_0" and len(dims) >= 3: max_cache_len = dims[2].dim_value - if seq_len is None or max_cache_len is None: + # A symbolic/dynamic axis yields dim_value == 0 (not None), so treat any + # non-positive value as "not a usable static shape" and fail loudly rather + # than silently building zero-length calibration feeds. + if not seq_len or not max_cache_len: raise ValueError( - f"Could not read seq_len/max_cache_len from {onnx_path.name}; " + f"Could not read static seq_len/max_cache_len from {onnx_path.name}; " f"found seq_len={seq_len}, max_cache_len={max_cache_len}" ) return seq_len, max_cache_len From caada38336eb671fa3809a4cd19d983ca8f50d3b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 15:18:36 +0800 Subject: [PATCH 09/13] fix(qwen): address review comments (LpNorm eager norm, CodeQL lint, e2e helper) - LpNormOnnxExport.forward now computes the real L2 normalization instead of a silent identity; export-invariant (node comes from symbolic) and correct in eager. - GroupQueryAttentionOnnxExport.forward keeps the non-raising placeholder, with a docstring explaining why raising is impossible (HTP hierarchy capture runs an eager forward outside trace/export). - Remove unused module-level logger in qwen_transformer_only.py (CodeQL). - Use a single onnx import form in test_quant_calibration.py (CodeQL). - Fix e2e _decoder_onnx_path helper to handle the single-model WinMLModelForGenericTask (.onnx_path) build, not just composite .sub_models. --- .../modelkit/models/hf/qwen3_export_ops.py | 35 ++++++++++++------- .../models/hf/qwen_transformer_only.py | 3 -- .../test_qwen3_transformer_only_quant.py | 14 ++++++-- .../test_quant_calibration.py | 21 ++++++----- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py index e1eba87c3..aed592fa7 100644 --- a/src/winml/modelkit/models/hf/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -48,14 +48,15 @@ def symbolic(g, input, axis, p) -> Any: @staticmethod def forward(ctx, input, axis, p) -> Any: - """Shape-only tracing placeholder; returns ``input`` unchanged. + """Real ``LpNormalization`` (``input / ||input||_p`` along ``axis``). - The real op is emitted by ``symbolic`` during ONNX export; ``forward`` - exists solely so the TorchScript exporter (and Optimum's pre-export dry - run) can trace output shapes. It is NOT a correct eager RMSNorm — do - not call this module for real inference. + The exported node comes from ``symbolic``; this eager body computes the + same value so any eager execution (unit tests, calibration debug runs, + the exporter's own shape-tracing pass) gets correctly normalized output + instead of a silent identity. It matches the ONNX op faithfully (no + RMSNorm epsilon), since that is exactly what ``symbolic`` emits. """ - return input + return input / torch.linalg.vector_norm(input, ord=p, dim=axis, keepdim=True) class GroupQueryAttentionOnnxExport(torch.autograd.Function): @@ -124,14 +125,22 @@ def forward( kv_num_heads, num_heads, ) -> Any: - """Shape-only tracing placeholder; returns stand-in (output, KV). - - The real op is emitted by ``symbolic`` during ONNX export; ``forward`` - exists solely so the TorchScript exporter (and Optimum's pre-export dry - run) can trace output shapes. It is NOT correct attention — do not call - this module for real inference. + """Shape-only tracing placeholder; returns a stand-in ``(output, KV)``. + + The real op is emitted by ``symbolic`` during ONNX export; this body + only needs to return tensors of the right shape/dtype. It deliberately + does NOT raise on eager execution, even though that yields a stale + (never-advanced) KV cache: the HTP export pipeline runs a real eager + ``forward`` pass to capture the module hierarchy (see + ``export/htp/hierarchy.py::trace_model_execution``), and that pass is + indistinguishable from misuse — ``torch.jit.is_tracing()`` and + ``torch.onnx.is_in_onnx_export()`` are both False there — so raising + would break the actual build. There is also no cheap faithful eager + equivalent (correct attention would grow the sequence axis that the + static-shape export freezes). This module is export-only by design and + is never run for real inference; calibration loads a fresh real model. """ - return query, past_key, past_value # placeholder shapes + return query, past_key, past_value # placeholder shapes (export-only) # ============================================================================= diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index fda69495f..28e394a4a 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -28,7 +28,6 @@ from __future__ import annotations -import logging from typing import Any, ClassVar import torch @@ -48,8 +47,6 @@ from .qwen3_modeling import apply_transformer_only_export_prep -logger = logging.getLogger(__name__) - # Distinct model_type for this variant. The underscore form is what the # exporter sees on ``model.config.model_type`` and what Optimum's TasksManager # and ``register_specialization`` are keyed on; the hyphenated form is used for diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py index 7c6499e51..cf5b34132 100644 --- a/tests/e2e/models/test_qwen3_transformer_only_quant.py +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -51,9 +51,17 @@ def _qnn_available() -> bool: def _decoder_onnx_path(model) -> str: - """Locate the quantized decode ONNX behind the composite handle.""" - sub = model.sub_models["decoder_gen"] - return str(sub._onnx_path) + """Locate the quantized decode ONNX behind the model handle. + + The decode-only build (``seq_len=1``) returns a single + ``WinMLModelForGenericTask`` whose ``onnx_path`` is the quantized graph; a + full composite build instead exposes it under ``sub_models["decoder_gen"]``. + Handle both so the test does not depend on which wrapper the build picks. + """ + sub_models = getattr(model, "sub_models", None) + if sub_models and "decoder_gen" in sub_models: + return str(sub_models["decoder_gen"].onnx_path) + return str(model.onnx_path) def _qdq_counts(onnx_path: str) -> dict[str, int]: diff --git a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py index f1b160433..75933962e 100644 --- a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py +++ b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py @@ -18,7 +18,6 @@ import numpy as np import onnx import torch -from onnx import TensorProto, helper from winml.modelkit.models.hf.qwen_transformer_only_quant import ( Qwen3DecodeTrajectoryCalibReader, @@ -48,26 +47,26 @@ def _fake_config() -> SimpleNamespace: def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: """Write a minimal graph carrying the inputs the readers introspect.""" inputs = [ - helper.make_tensor_value_info( - "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + onnx.helper.make_tensor_value_info( + "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] ), - helper.make_tensor_value_info( - "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + onnx.helper.make_tensor_value_info( + "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] ), ] - out = helper.make_tensor_value_info( - "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + out = onnx.helper.make_tensor_value_info( + "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] ) - gqa = helper.make_node( + gqa = onnx.helper.make_node( "GroupQueryAttention", ["input_hidden_states"], ["attn_out"], name="gqa_layer_0", domain="com.microsoft", ) - identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) - graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)]) onnx.save(model, str(path)) From c97373c3a663365e6f7f2316d2d14c9be0e4f9de Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 16:14:40 +0800 Subject: [PATCH 10/13] fix(build): resolve quant-finalize hook on model class + update model_type-override test - build_hf_model: look up winml_finalize_quant_config on type(pytorch_model) instead of the instance, and call it with explicit self. Fixes the mypy 'Tensor not callable' error (getattr yields Any) and stops the hook firing on raw HF models / MagicMock test doubles (whose attributes are instance-synthesized), which was serializing a MagicMock into config.json. - test_resolve_loader_config: replace the obsolete 'never mutated' test with one asserting the intended explicit-model_type override (needed for variants like qwen3_transformer_only). --- src/winml/modelkit/build/hf.py | 13 ++++++++++--- tests/unit/loader/test_resolve_loader_config.py | 17 +++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 08698d125..60c27a02b 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -321,9 +321,16 @@ def _name(base: str) -> str: # exported ONNX exists (e.g. calibration feeds / nodes-to-exclude # derived from the graph). Give the wrapper a chance to populate # those runtime-only fields here. - if pytorch_model is not None and hasattr(pytorch_model, "winml_finalize_quant_config"): - config.quant = pytorch_model.winml_finalize_quant_config( - config.quant, onnx_path=current_path, model_id=model_id + # Resolve the optional hook on the model's *class* (not the + # instance): a genuine wrapper defines it at class scope, whereas a + # raw HF model — or a test double whose attributes are synthesized + # per-instance — does not, so this avoids firing spuriously. + finalize_quant_config = getattr( + type(pytorch_model), "winml_finalize_quant_config", None + ) + if callable(finalize_quant_config): + config.quant = finalize_quant_config( + pytorch_model, config.quant, onnx_path=current_path, model_id=model_id ) # The hook may overwrite the quant scheme (dtypes, symmetry, # nodes-to-exclude) authoritatively, so re-persist the config diff --git a/tests/unit/loader/test_resolve_loader_config.py b/tests/unit/loader/test_resolve_loader_config.py index ea26e6cff..491af63ce 100644 --- a/tests/unit/loader/test_resolve_loader_config.py +++ b/tests/unit/loader/test_resolve_loader_config.py @@ -142,8 +142,13 @@ def test_model_type_only_creates_default_config(self) -> None: mock_create.assert_called_once_with("bert") assert loader_config.task == "feature-extraction" - def test_hf_config_never_mutated(self) -> None: - """hf_config is never mutated — model_type param does not override it.""" + def test_explicit_model_type_overrides_hf_config(self) -> None: + """An explicit model_type (with a model_id) overrides the resolved type. + + Needed so a variant model_type such as ``qwen3_transformer_only`` selects + the variant rather than the architecture's native type. The override only + applies when a model_id is present and the requested type differs. + """ mock_config = MagicMock() mock_config.model_type = "original_type" mock_class = MagicMock(spec=[]) @@ -164,10 +169,10 @@ def test_hf_config_never_mutated(self) -> None: "some-model", model_type="gpt2", task="text-generation" ) - # hf_config retains its original model_type — never mutated - assert hf_config.model_type == "original_type" - # loader_config.model_type reflects the REAL hf_config, not the param - assert loader_config.model_type == "original_type" + # The explicit model_type wins over the architecture's native type. + assert hf_config.model_type == "gpt2" + # loader_config.model_type reflects the overridden type. + assert loader_config.model_type == "gpt2" def test_auto_detect_task_from_model_type(self) -> None: """model_type without task auto-detects first supported task.""" From 752e6c990b17ca8e8b6b3b8b72ae84443312b2d9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 18:20:00 +0800 Subject: [PATCH 11/13] refactor(quant): move qwen3 calibration logic into quant registry Relocate the model-specific transformer-only calibration/quant logic out of models/hf (an export-only package) into a new quant/calibration/ subpackage, dispatched via a model_type-keyed registry that mirrors COMPOSITE_MODEL_REGISTRY. - Add quant/calibration/{base,registry}.py: QuantConfigFinalizer protocol + register_quant_finalizer / get_quant_finalizer (lazy, torch-free import). - git mv qwen_transformer_only_quant.py -> quant/calibration/qwen3_transformer_only.py and register Qwen3TransformerOnlyQuantFinalizer for 'qwen3_transformer_only'. - build/hf.py: replace the winml_finalize_quant_config wrapper hook with explicit registry dispatch keyed on config.model_type; unregistered types fall back to the default DatasetCalibrationReader. Preserve the model_id/_name_or_path fallback (now model-agnostic in the build layer). - Remove the hook from the export wrapper (back to export-only). - Relocate unit tests to tests/unit/quant/calibration/ and add test_registry.py. w8a16 scheme unchanged; CPU e2e (quantized-graph + GQA-exclusion + FP-parity) and 86 build/quant unit tests pass. --- src/winml/modelkit/build/hf.py | 34 +++++---- .../models/hf/qwen_transformer_only.py | 21 ------ src/winml/modelkit/quant/__init__.py | 4 + .../modelkit/quant/calibration/__init__.py | 23 ++++++ src/winml/modelkit/quant/calibration/base.py | 42 +++++++++++ .../calibration/qwen3_transformer_only.py} | 36 +++++++-- .../modelkit/quant/calibration/registry.py | 73 +++++++++++++++++++ .../test_qwen3_transformer_only_quant.py | 11 +-- .../calibration}/__init__.py | 0 .../calibration/test_qwen3_calibration.py} | 2 +- tests/unit/quant/calibration/test_registry.py | 38 ++++++++++ 11 files changed, 238 insertions(+), 46 deletions(-) create mode 100644 src/winml/modelkit/quant/calibration/__init__.py create mode 100644 src/winml/modelkit/quant/calibration/base.py rename src/winml/modelkit/{models/hf/qwen_transformer_only_quant.py => quant/calibration/qwen3_transformer_only.py} (92%) create mode 100644 src/winml/modelkit/quant/calibration/registry.py rename tests/unit/{models/qwen_transformer_only => quant/calibration}/__init__.py (100%) rename tests/unit/{models/qwen_transformer_only/test_quant_calibration.py => quant/calibration/test_qwen3_calibration.py} (99%) create mode 100644 tests/unit/quant/calibration/test_registry.py diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 60c27a02b..ef8e794ee 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -317,22 +317,28 @@ def _name(base: str) -> str: else: logger.info("Quantizing model...") t0 = time.monotonic() - # Some model wrappers can only finalize their quant config once the - # exported ONNX exists (e.g. calibration feeds / nodes-to-exclude - # derived from the graph). Give the wrapper a chance to populate - # those runtime-only fields here. - # Resolve the optional hook on the model's *class* (not the - # instance): a genuine wrapper defines it at class scope, whereas a - # raw HF model — or a test double whose attributes are synthesized - # per-instance — does not, so this avoids firing spuriously. - finalize_quant_config = getattr( - type(pytorch_model), "winml_finalize_quant_config", None + # Some model types finalize their quant config only once the + # exported ONNX exists (calibration feeds / nodes-to-exclude derived + # from the graph). Resolve the model-type-specific quant policy from + # the quant registry, keyed on the live ``model_type``. Unregistered + # types return None → the quantizer uses its standard task-aware + # DatasetCalibrationReader. + from ..quant import get_quant_finalizer + + resolved_model_type = ( + getattr(getattr(pytorch_model, "config", None), "model_type", None) or model_type ) - if callable(finalize_quant_config): - config.quant = finalize_quant_config( - pytorch_model, config.quant, onnx_path=current_path, model_id=model_id + quant_finalizer = get_quant_finalizer(resolved_model_type) + if quant_finalizer is not None: + # Generic id fallback: the policy loads a fresh reference model + # for calibration, so feed it the best-known HF id/path. + resolved_model_id = model_id or getattr( + getattr(pytorch_model, "config", None), "_name_or_path", None ) - # The hook may overwrite the quant scheme (dtypes, symmetry, + config.quant = quant_finalizer.finalize( + config.quant, onnx_path=current_path, model_id=resolved_model_id + ) + # The policy may overwrite the quant scheme (dtypes, symmetry, # nodes-to-exclude) authoritatively, so re-persist the config # to keep config.json consistent with what was actually applied. config_path.write_text(json.dumps(config.to_dict(), indent=2)) diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 28e394a4a..bff3cc5c7 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -125,27 +125,6 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: out.extend([k, v]) return tuple(out) - def winml_finalize_quant_config( - self, quant: Any, *, onnx_path: Any, model_id: str | None = None - ) -> Any: - """Build-pipeline hook: attach the calibration reader + GQA exclusions. - - Called by ``build_hf_model`` just before ``quantize_onnx`` (see - ``build/hf.py``). The exported transformer-only graph determines the - calibration feeds (shapes, KV buffers) and which GroupQueryAttention - nodes stay in float, so the live :class:`WinMLQuantizationConfig` can - only be finalized here — not at config-construction time. - """ - from .qwen_transformer_only_quant import ( - DEFAULT_MODEL_ID, - finalize_transformer_only_quant_config, - ) - - resolved_id = model_id or getattr(self.config, "_name_or_path", None) or DEFAULT_MODEL_ID - return finalize_transformer_only_quant_config( - quant, onnx_path=onnx_path, model_id=resolved_id - ) - # ============================================================================= # Dummy input generators (transformer-only I/O) diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index bc8e6ee06..b7bc8bf38 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -24,12 +24,16 @@ __all__ = [ "QuantizeResult", "WinMLQuantizationConfig", + "get_quant_finalizer", "quantize_onnx", + "register_quant_finalizer", ] _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), + "get_quant_finalizer": (".calibration", "get_quant_finalizer"), + "register_quant_finalizer": (".calibration", "register_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/__init__.py b/src/winml/modelkit/quant/calibration/__init__.py new file mode 100644 index 000000000..88b1434c5 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/__init__.py @@ -0,0 +1,23 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Model-type-specific quantization policies (calibration readers + schemes). + +This subpackage stays import-light on purpose: it exposes only the registry +API. The individual finalizer modules (which pull in torch/transformers) are +imported lazily by :func:`get_quant_finalizer` when their ``model_type`` is +quantized. +""" + +from __future__ import annotations + +from .base import QuantConfigFinalizer +from .registry import get_quant_finalizer, register_quant_finalizer + + +__all__ = [ + "QuantConfigFinalizer", + "get_quant_finalizer", + "register_quant_finalizer", +] diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py new file mode 100644 index 000000000..895c48b63 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/base.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Base protocol for model-type-specific quantization policies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from pathlib import Path + + from ..config import WinMLQuantizationConfig + + +@runtime_checkable +class QuantConfigFinalizer(Protocol): + """Model-type-specific quant policy. + + Given the freshly exported ONNX, a finalizer populates the live + :class:`WinMLQuantizationConfig` with the fields that can only be known + once the graph exists — the calibration data reader, ``nodes_to_exclude``, + and (where the scheme is fixed and reference-matched) the dtype/symmetry + settings. + + Finalizers are registered per ``model_type`` (see + :func:`.registry.register_quant_finalizer`). Model types without a + registered policy fall back to the quantizer's default + ``DatasetCalibrationReader``. + """ + + def finalize( + self, + quant: WinMLQuantizationConfig, + *, + onnx_path: Path, + model_id: str | None = None, + ) -> WinMLQuantizationConfig: + """Return ``quant`` populated with the graph-derived quant settings.""" + ... diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py similarity index 92% rename from src/winml/modelkit/models/hf/qwen_transformer_only_quant.py rename to src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index b52dfd85e..5abb7e4ce 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -5,12 +5,13 @@ """Config-driven w8a16 calibration for the transformer-only Qwen3 build. -The transformer-only export (:mod:`qwen_transformer_only`) emits a graph whose -only quantization-relevant runtime inputs (the calibration feeds and the +The transformer-only export (``models.hf.qwen_transformer_only``) emits a graph +whose only quantization-relevant runtime inputs (the calibration feeds and the ``GroupQueryAttention`` node names to keep in float) can't be known until the ONNX exists. Rather than a standalone post-build script that reaches into -``composite.sub_models[...]._onnx_path``, this module plugs into the normal -build pipeline: :meth:`QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config` +``composite.sub_models[...]._onnx_path``, this module registers a quant policy +keyed on ``model_type`` (:class:`Qwen3TransformerOnlyQuantFinalizer`). The build +pipeline resolves it via :func:`~winml.modelkit.quant.get_quant_finalizer` and calls :func:`finalize_transformer_only_quant_config` just before ``quantize_onnx`` runs (see ``build/hf.py``), populating the live :class:`WinMLQuantizationConfig` with the right @@ -43,7 +44,8 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from ...quant.config import CalibrationDataReader, WinMLQuantizationConfig +from ..config import CalibrationDataReader, WinMLQuantizationConfig +from .registry import register_quant_finalizer if TYPE_CHECKING: @@ -418,3 +420,27 @@ def finalize_transformer_only_quant_config( gc.collect() return quant + + +@register_quant_finalizer("qwen3_transformer_only") +class Qwen3TransformerOnlyQuantFinalizer: + """Registered quant policy for the ``qwen3_transformer_only`` model_type. + + Adapts :func:`finalize_transformer_only_quant_config` to the + :class:`~winml.modelkit.quant.calibration.base.QuantConfigFinalizer` + protocol so the build pipeline resolves the model-specific w8a16 scheme + + calibration reader through the quant registry (keyed on ``model_type``) + rather than a hardcoded hook on the export wrapper. + """ + + def finalize( + self, + quant: WinMLQuantizationConfig, + *, + onnx_path: Path, + model_id: str | None = None, + ) -> WinMLQuantizationConfig: + """Populate ``quant`` with the transformer-only w8a16 scheme + reader.""" + return finalize_transformer_only_quant_config( + quant, onnx_path=onnx_path, model_id=model_id or DEFAULT_MODEL_ID + ) diff --git a/src/winml/modelkit/quant/calibration/registry.py b/src/winml/modelkit/quant/calibration/registry.py new file mode 100644 index 000000000..47698da63 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/registry.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Registry mapping ``model_type`` to its quantization policy. + +Mirrors the project's other ``model_type``-keyed registries (e.g. +``COMPOSITE_MODEL_REGISTRY``): a finalizer registers itself with +``@register_quant_finalizer(model_type)`` and the build pipeline resolves it +with :func:`get_quant_finalizer`. + +The registry is intentionally lazy. Importing :mod:`winml.modelkit.quant` +must stay free of heavy deps (torch/transformers); the per-model finalizer +modules — which do pull those in — are only imported the first time their +``model_type`` is actually quantized. +""" + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from .base import QuantConfigFinalizer + + +# Populated by the ``@register_quant_finalizer`` decorator at import time. +_QUANT_FINALIZER_REGISTRY: dict[str, type[QuantConfigFinalizer]] = {} + +# ``model_type`` -> submodule that defines (and self-registers) its finalizer. +# Looked up lazily so the heavy module loads only when needed. Keys must match +# the live ``model_type`` string verbatim (no ``_`` -> ``-`` normalization), +# since lookup is keyed on the exported model's ``config.model_type``. +_KNOWN_FINALIZER_MODULES: dict[str, str] = { + "qwen3_transformer_only": ".qwen3_transformer_only", +} + + +def register_quant_finalizer(model_type: str): + """Class decorator registering a :class:`QuantConfigFinalizer` for ``model_type``.""" + + def decorator(cls: type) -> type: + if not hasattr(cls, "finalize"): + raise TypeError( + f"{cls.__name__} cannot register as a quant finalizer for " + f"{model_type!r}: it must define a ``finalize`` method." + ) + if model_type in _QUANT_FINALIZER_REGISTRY: + raise ValueError( + f"Quant finalizer already registered for {model_type!r}: " + f"{_QUANT_FINALIZER_REGISTRY[model_type].__name__}. " + f"Cannot register {cls.__name__}." + ) + _QUANT_FINALIZER_REGISTRY[model_type] = cls + return cls + + return decorator + + +def get_quant_finalizer(model_type: str | None) -> QuantConfigFinalizer | None: + """Return a finalizer instance for ``model_type``, or ``None`` if unregistered. + + ``None`` means "no model-specific policy" — the quantizer then uses its + standard task-aware ``DatasetCalibrationReader``. + """ + if not model_type: + return None + if model_type not in _QUANT_FINALIZER_REGISTRY and model_type in _KNOWN_FINALIZER_MODULES: + # Triggers the module's ``@register_quant_finalizer`` side effect. + importlib.import_module(_KNOWN_FINALIZER_MODULES[model_type], __package__) + cls = _QUANT_FINALIZER_REGISTRY.get(model_type) + return cls() if cls is not None else None diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py index cf5b34132..831a640e8 100644 --- a/tests/e2e/models/test_qwen3_transformer_only_quant.py +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -7,11 +7,12 @@ Replaces the former root-level ``test_qwen.py`` / ``qwen3_transformer_only_quantize.py`` scripts. Quantization is now driven entirely through the standard build pipeline (``WinMLAutoModel.from_pretrained(..., precision="w8a16")``): the -device/precision policy enables the quantize stage, and -``QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config`` finalizes the -reference-matched scheme (int8-symmetric weights, uint16 activations, -GroupQueryAttention excluded from QDQ) plus the decode-trajectory calibration -reader. +device/precision policy enables the quantize stage, and the +``qwen3_transformer_only`` quant policy registered in +``winml.modelkit.quant.calibration`` (resolved via ``get_quant_finalizer``) +finalizes the reference-matched scheme (int8-symmetric weights, uint16 +activations, GroupQueryAttention excluded from QDQ) plus the decode-trajectory +calibration reader. These tests download Qwen3-0.6B from HuggingFace and run a full CPU export + quantize, so they are gated behind ``slow`` + ``network`` and excluded from the diff --git a/tests/unit/models/qwen_transformer_only/__init__.py b/tests/unit/quant/calibration/__init__.py similarity index 100% rename from tests/unit/models/qwen_transformer_only/__init__.py rename to tests/unit/quant/calibration/__init__.py diff --git a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py similarity index 99% rename from tests/unit/models/qwen_transformer_only/test_quant_calibration.py rename to tests/unit/quant/calibration/test_qwen3_calibration.py index 75933962e..5c8bd9d69 100644 --- a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -19,7 +19,7 @@ import onnx import torch -from winml.modelkit.models.hf.qwen_transformer_only_quant import ( +from winml.modelkit.quant.calibration.qwen3_transformer_only import ( Qwen3DecodeTrajectoryCalibReader, Qwen3TransformerOnlyCalibReader, _gqa_node_names, diff --git a/tests/unit/quant/calibration/test_registry.py b/tests/unit/quant/calibration/test_registry.py new file mode 100644 index 000000000..b60f74b9b --- /dev/null +++ b/tests/unit/quant/calibration/test_registry.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the quant finalizer registry. + +Fast, offline: no model download, no ONNX Runtime. Verifies that the +``model_type`` -> quant policy dispatch (lazy import + decorator registration) +resolves the registered Qwen3 finalizer and falls back to ``None`` (the +quantizer's default DatasetCalibrationReader path) for everything else. +""" + +from __future__ import annotations + +from winml.modelkit.quant import get_quant_finalizer +from winml.modelkit.quant.calibration import QuantConfigFinalizer + + +def test_registered_model_type_resolves_finalizer(): + """The qwen3_transformer_only policy is found via lazy registry import.""" + finalizer = get_quant_finalizer("qwen3_transformer_only") + assert finalizer is not None + assert isinstance(finalizer, QuantConfigFinalizer) + assert hasattr(finalizer, "finalize") + # Registry returns the concrete policy class, not the generic protocol. + assert type(finalizer).__name__ == "Qwen3TransformerOnlyQuantFinalizer" + + +def test_unregistered_model_type_returns_none(): + """Unknown / native model types have no policy -> default reader path.""" + assert get_quant_finalizer("resnet") is None + assert get_quant_finalizer("qwen3") is None + + +def test_none_model_type_returns_none(): + """A missing model_type must not raise and must not dispatch a policy.""" + assert get_quant_finalizer(None) is None + assert get_quant_finalizer("") is None From e9dbe2a2770df108c425078c83068a3dd2803998 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 19:09:17 +0800 Subject: [PATCH 12/13] fix(quant): satisfy mypy + CodeQL on calibration registry - annotate register_quant_finalizer return type (mypy no-untyped-def) - add TYPE_CHECKING re-imports so static analyzers see lazy __all__ exports (CodeQL py/undefined-export) - drop bare ... from finalizer Protocol; docstring is the body (CodeQL ineffectual-statement) --- src/winml/modelkit/quant/__init__.py | 11 ++++++++++- src/winml/modelkit/quant/calibration/base.py | 1 - src/winml/modelkit/quant/calibration/registry.py | 4 +++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index b7bc8bf38..e43a69068 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -16,7 +16,7 @@ result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100)) """ -from typing import Any +from typing import TYPE_CHECKING, Any from .config import QuantizeResult, WinMLQuantizationConfig @@ -30,6 +30,15 @@ ] +# Names below are loaded lazily via ``__getattr__`` to avoid pulling in +# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports +# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports +# without triggering the heavy imports at runtime. +if TYPE_CHECKING: + from .calibration import get_quant_finalizer, register_quant_finalizer + from .quantizer import quantize_onnx + + _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), "get_quant_finalizer": (".calibration", "get_quant_finalizer"), diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py index 895c48b63..d62ba4322 100644 --- a/src/winml/modelkit/quant/calibration/base.py +++ b/src/winml/modelkit/quant/calibration/base.py @@ -39,4 +39,3 @@ def finalize( model_id: str | None = None, ) -> WinMLQuantizationConfig: """Return ``quant`` populated with the graph-derived quant settings.""" - ... diff --git a/src/winml/modelkit/quant/calibration/registry.py b/src/winml/modelkit/quant/calibration/registry.py index 47698da63..78b321ae4 100644 --- a/src/winml/modelkit/quant/calibration/registry.py +++ b/src/winml/modelkit/quant/calibration/registry.py @@ -22,6 +22,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from .base import QuantConfigFinalizer @@ -37,7 +39,7 @@ } -def register_quant_finalizer(model_type: str): +def register_quant_finalizer(model_type: str) -> Callable[[type], type]: """Class decorator registering a :class:`QuantConfigFinalizer` for ``model_type``.""" def decorator(cls: type) -> type: From 52745638d9b4e887fe10d4d36a7433813592a3ef Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 22:12:32 +0800 Subject: [PATCH 13/13] Thread model_type + quant finalizer through CLI HF build; move qwen3 transformer-only into subpackage The CLI-only _build_hf_pipeline did not pass loader.model_type to _load_model, so a config requesting qwen3_transformer_only was silently loaded as native qwen3 and crashed at export (embedding got HalfTensor). It also skipped the model-type quant finalizer, producing the default uint8/uint16 minmax scheme instead of the registered int8-sym / GQA-excluded policy. Both gaps existed only in the CLI path; the library build_hf_model already handled them. Mirror that logic so winml build produces the verified w8a16 graph (985 Q / 1294 DQ / 28 GQA / 0 QDQ-touching-GQA) end-to-end. Also move qwen3_export_ops, qwen3_modeling and qwen_transformer_only into a models/hf/qwen3/ subpackage and add regression tests for both fixes. --- src/winml/modelkit/commands/build.py | 32 ++++- src/winml/modelkit/models/hf/__init__.py | 8 +- .../modelkit/models/hf/qwen3/__init__.py | 6 + .../models/hf/{ => qwen3}/qwen3_export_ops.py | 0 .../models/hf/{ => qwen3}/qwen3_modeling.py | 0 .../hf/{ => qwen3}/qwen_transformer_only.py | 14 +- .../calibration/qwen3_transformer_only.py | 2 +- tests/unit/commands/test_build.py | 121 ++++++++++++++++++ 8 files changed, 170 insertions(+), 13 deletions(-) create mode 100644 src/winml/modelkit/models/hf/qwen3/__init__.py rename src/winml/modelkit/models/hf/{ => qwen3}/qwen3_export_ops.py (100%) rename src/winml/modelkit/models/hf/{ => qwen3}/qwen3_modeling.py (100%) rename src/winml/modelkit/models/hf/{ => qwen3}/qwen_transformer_only.py (97%) diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index c3ffc660d..0d10ebf67 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -1339,7 +1339,11 @@ def _name(base: str) -> str: # Load + export (blocking) pytorch_model = _load_model( - config, model_id, trust_remote_code=False, hf_config=preloaded_hf_config + config, + model_id, + trust_remote_code=False, + hf_config=preloaded_hf_config, + model_type=config.loader.model_type, ) t0 = time.monotonic() # config.export is None only for the ONNX build path; this is the HF path. @@ -1384,6 +1388,32 @@ def _name(base: str) -> str: config_path.write_text(json.dumps(config.to_dict(), indent=2)) # ── Quantize stage ─────────────────────────────────────────── + # Some model types finalize their quant config only once the exported ONNX + # exists (calibration feeds / nodes-to-exclude derived from the graph). + # Resolve the model-type-specific quant policy from the quant registry, + # keyed on the live ``model_type`` — mirrors build.hf.build_hf_model so the + # CLI and library pipelines apply the same scheme. Unregistered types return + # None → the quantizer uses its standard task-aware DatasetCalibrationReader. + if config.quant is not None: + from ..quant import get_quant_finalizer + + resolved_model_type = ( + getattr(getattr(pytorch_model, "config", None), "model_type", None) + or config.loader.model_type + ) + quant_finalizer = get_quant_finalizer(resolved_model_type) + if quant_finalizer is not None: + resolved_model_id = model_id or getattr( + getattr(pytorch_model, "config", None), "_name_or_path", None + ) + config.quant = quant_finalizer.finalize( + config.quant, onnx_path=current_path, model_id=resolved_model_id + ) + # The policy may overwrite the quant scheme (dtypes, symmetry, + # nodes-to-exclude) authoritatively, so re-persist the config to keep + # config.json consistent with what was actually applied. + config_path.write_text(json.dumps(config.to_dict(), indent=2)) + current_path = _run_quantize_stage( config=config, current_path=current_path, diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index 458bc8e34..5c854bb60 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -56,12 +56,12 @@ from .qwen import QWEN_CONFIG from .qwen import QwenGenIOConfig as _QwenGenIOConfig from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig -from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING -from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG -from .qwen_transformer_only import ( +from .qwen3.qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING +from .qwen3.qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG +from .qwen3.qwen_transformer_only import ( QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration ) -from .qwen_transformer_only import ( +from .qwen3.qwen_transformer_only import ( # triggers registration QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, ) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py new file mode 100644 index 000000000..332fb9234 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -0,0 +1,6 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Qwen3 transformer-only export support (modeling, export ops, IO configs).""" diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py similarity index 100% rename from src/winml/modelkit/models/hf/qwen3_export_ops.py rename to src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py similarity index 100% rename from src/winml/modelkit/models/hf/qwen3_modeling.py rename to src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py similarity index 97% rename from src/winml/modelkit/models/hf/qwen_transformer_only.py rename to src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py index bff3cc5c7..cc4985de0 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py @@ -37,13 +37,13 @@ from optimum.utils.input_generators import DummyInputGenerator from transformers import AutoModelForCausalLM -from ...config import WinMLBuildConfig -from ...export import register_onnx_overwrite -from ...export.config import WinMLExportConfig -from ..winml import register_specialization -from ..winml.composite_model import register_composite_model -from ..winml.decoder_only import WinMLDecoderOnlyModel -from ..winml.kv_cache import WinMLSlidingWindowCache +from ....config import WinMLBuildConfig +from ....export import register_onnx_overwrite +from ....export.config import WinMLExportConfig +from ...winml import register_specialization +from ...winml.composite_model import register_composite_model +from ...winml.decoder_only import WinMLDecoderOnlyModel +from ...winml.kv_cache import WinMLSlidingWindowCache from .qwen3_modeling import apply_transformer_only_export_prep diff --git a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index 5abb7e4ce..a4dc0c61c 100644 --- a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -5,7 +5,7 @@ """Config-driven w8a16 calibration for the transformer-only Qwen3 build. -The transformer-only export (``models.hf.qwen_transformer_only``) emits a graph +The transformer-only export (``models.hf.qwen3.qwen_transformer_only``) emits a graph whose only quantization-relevant runtime inputs (the calibration feeds and the ``GroupQueryAttention`` node names to keep in float) can't be known until the ONNX exists. Rather than a standalone post-build script that reaches into diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index 00f54fc23..400d3cccd 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -1711,3 +1711,124 @@ def test_returns_compiled_path_when_file_exists( # current_path should be updated to compiled_path assert result == compiled_path + + +class TestBuildHfPipelineModelType: + """Regression: the CLI HF pipeline must thread loader.model_type into _load_model. + + Without this, a config requesting a derived model_type (e.g. + ``qwen3_transformer_only``) is silently loaded as its native type, so the + wrong model class is exported. See _build_hf_pipeline. + """ + + @patch("winml.modelkit.utils.console.StageLive") + @patch("winml.modelkit.export.export_onnx") + @patch("winml.modelkit.build.hf._load_model") + def test_load_model_receives_config_model_type( + self, + mock_load_model: MagicMock, + mock_export_onnx: MagicMock, + mock_stage_live: MagicMock, + tmp_path: Path, + ) -> None: + from winml.modelkit.commands.build import _build_hf_pipeline + + mock_stage_live.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_stage_live.return_value.__exit__ = MagicMock(return_value=False) + + # Stop the pipeline right after export so we only exercise the load call. + sentinel = RuntimeError("stop-after-export") + mock_export_onnx.side_effect = sentinel + + config = MagicMock() + config.loader.model_type = "qwen3_transformer_only" + config.loader.task = "feature-extraction" + config.export = MagicMock() + + with pytest.raises(RuntimeError, match="stop-after-export"): + _build_hf_pipeline( + config=config, + model_id="Qwen/Qwen3-0.6B", + output_dir=tmp_path / "out", + rebuild=True, + cache_key=None, + ep=None, + device="cpu", + extra_kwargs={}, + preloaded_hf_config=None, + ) + + mock_load_model.assert_called_once() + assert mock_load_model.call_args.kwargs["model_type"] == "qwen3_transformer_only" + + @patch("winml.modelkit.commands.build._run_compile_stage") + @patch("winml.modelkit.commands.build._run_quantize_stage") + @patch("winml.modelkit.quant.get_quant_finalizer") + @patch("winml.modelkit.commands.build._run_optimize_stage") + @patch("winml.modelkit.commands.build._show_io") + @patch("winml.modelkit.utils.console.StageLive") + @patch("winml.modelkit.export.export_onnx") + @patch("winml.modelkit.build.hf._load_model") + def test_quant_finalizer_applied_for_registered_model_type( + self, + mock_load_model: MagicMock, + mock_export_onnx: MagicMock, + mock_stage_live: MagicMock, + mock_show_io: MagicMock, + mock_optimize: MagicMock, + mock_get_finalizer: MagicMock, + mock_quantize: MagicMock, + mock_compile: MagicMock, + tmp_path: Path, + ) -> None: + """The CLI HF pipeline must apply the registered quant finalizer. + + Mirrors build.hf.build_hf_model: without this the CLI quantizes with the + default task-aware scheme instead of the model-type-specific policy. + """ + from winml.modelkit.commands.build import _build_hf_pipeline + + mock_stage_live.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_stage_live.return_value.__exit__ = MagicMock(return_value=False) + + pytorch_model = MagicMock() + pytorch_model.config.model_type = "qwen3_transformer_only" + mock_load_model.return_value = pytorch_model + + optimized = tmp_path / "optimized.onnx" + mock_optimize.return_value = (optimized, None) + + finalized_quant = MagicMock(name="finalized_quant_config") + finalizer = MagicMock() + finalizer.finalize.return_value = finalized_quant + mock_get_finalizer.return_value = finalizer + + # Stop right after the quantize stage so we don't exercise compile. + mock_quantize.side_effect = RuntimeError("stop-after-quantize") + + config = MagicMock() + config.loader.model_type = "qwen3_transformer_only" + config.loader.task = "text2text-generation" + config.loader.model_class = None + config.export = MagicMock() + config.quant = MagicMock(name="initial_quant_config") + config.to_dict.return_value = {} + + with pytest.raises(RuntimeError, match="stop-after-quantize"): + _build_hf_pipeline( + config=config, + model_id="Qwen/Qwen3-0.6B", + output_dir=tmp_path / "out", + rebuild=True, + cache_key=None, + ep=None, + device="cpu", + extra_kwargs={}, + preloaded_hf_config=None, + ) + + mock_get_finalizer.assert_called_once_with("qwen3_transformer_only") + finalizer.finalize.assert_called_once() + assert finalizer.finalize.call_args.kwargs["model_id"] == "Qwen/Qwen3-0.6B" + # config.quant must be replaced with the finalized scheme before quantize. + assert config.quant is finalized_quant