diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 78ef8576f..ef8e794ee 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -91,6 +91,7 @@ def build_hf_model( cache_key: str | None = None, ep: EPNameOrAlias | None = None, device: str | None = None, + model_type: str | None = None, **kwargs: Any, ) -> BuildResult: """Build an ONNX model from a HuggingFace model architecture. @@ -211,6 +212,7 @@ def _name(base: str) -> str: model_id, trust_remote_code, random_init=random_init, + model_type=model_type, ) # ========================================================================= @@ -315,6 +317,31 @@ def _name(base: str) -> str: else: logger.info("Quantizing model...") t0 = time.monotonic() + # Some model types finalize their quant config only once the + # exported ONNX exists (calibration feeds / nodes-to-exclude derived + # from the graph). Resolve the model-type-specific quant policy from + # the quant registry, keyed on the live ``model_type``. Unregistered + # types return None → the quantizer uses its standard task-aware + # DatasetCalibrationReader. + from ..quant import get_quant_finalizer + + resolved_model_type = ( + getattr(getattr(pytorch_model, "config", None), "model_type", None) or model_type + ) + quant_finalizer = get_quant_finalizer(resolved_model_type) + if quant_finalizer is not None: + # Generic id fallback: the policy loads a fresh reference model + # for calibration, so feed it the best-known HF id/path. + resolved_model_id = model_id or getattr( + getattr(pytorch_model, "config", None), "_name_or_path", None + ) + config.quant = quant_finalizer.finalize( + config.quant, onnx_path=current_path, model_id=resolved_model_id + ) + # The policy may overwrite the quant scheme (dtypes, symmetry, + # nodes-to-exclude) authoritatively, so re-persist the config + # to keep config.json consistent with what was actually applied. + config_path.write_text(json.dumps(config.to_dict(), indent=2)) quant_result = quantize_onnx( model_path=current_path, output_path=quantized_path, @@ -443,6 +470,7 @@ def _load_model( trust_remote_code: bool, random_init: bool = False, hf_config: Any | None = None, + model_type: str | None = None, ) -> Any: """Load PyTorch model — pretrained or random weights. @@ -518,6 +546,7 @@ def _load_model( task=task, trust_remote_code=effective_trust, hf_config=hf_config, + model_type=model_type, ) return pytorch_model diff --git a/src/winml/modelkit/loader/config.py b/src/winml/modelkit/loader/config.py index ff93a1c34..b97a13825 100644 --- a/src/winml/modelkit/loader/config.py +++ b/src/winml/modelkit/loader/config.py @@ -220,6 +220,19 @@ def resolve_loader_config( f"attribute. Cannot proceed with config generation." ) + # Explicit model_type override alongside a model_id: honor the requested + # type so downstream class / build-config / export resolution selects the + # variant (e.g. "qwen3_transformer_only") rather than the architecture's + # native type. The model_type-only path above (AutoConfig.for_model) is + # unaffected because it only runs when model_id is None. + if model_id is not None and model_type is not None and hf_config.model_type != model_type: + logger.info( + "Overriding resolved model_type '%s' -> '%s' (explicit request)", + hf_config.model_type, + model_type, + ) + hf_config.model_type = model_type + # 2-3. Unified resolution. Task detection — including the no-architectures # --model-type fallback (first supported task) — now lives in resolve_task. resolution = resolve_task(hf_config, task=task, model_class=model_class) diff --git a/src/winml/modelkit/loader/hf.py b/src/winml/modelkit/loader/hf.py index 12e325855..619109368 100644 --- a/src/winml/modelkit/loader/hf.py +++ b/src/winml/modelkit/loader/hf.py @@ -144,6 +144,7 @@ def load_hf_model( user_script: str | None = None, trust_remote_code: bool = False, hf_config: PretrainedConfig | None = None, + model_type: str | None = None, ) -> tuple[nn.Module, PretrainedConfig, str]: """Load, detect task, and prepare HuggingFace model. @@ -218,6 +219,18 @@ def load_hf_model( trust_remote_code=trust_remote_code, ) + # Explicit model_type override: select a registered build variant (e.g. + # "qwen3_transformer_only") rather than the architecture's native type. + # Mutates the freshly-loaded config only; gated on an explicit request so + # normal loading is unaffected. + if model_type is not None and getattr(hf_config, "model_type", None) != model_type: + logger.info( + "Overriding model_type '%s' -> '%s' (explicit request)", + getattr(hf_config, "model_type", None), + model_type, + ) + hf_config.model_type = model_type + # [2] Task & Model Class Resolution from .resolution import resolve_task diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index d184091c5..25620d3d5 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -266,6 +266,7 @@ def from_pretrained( trust_remote_code: bool = False, shape_config: dict | None = None, no_compile: bool = False, + model_type: str | None = None, allow_unsupported_nodes: bool = False, **kwargs: Any, ) -> WinMLPreTrainedModel: @@ -300,6 +301,10 @@ def from_pretrained( shape_config: Shape overrides passed to generate_build_config(). Valid keys -- text: sequence_length; vision: height, width; audio: feature_size, nb_max_frames, audio_sequence_length. + model_type: Explicit model_type override. When provided alongside a + HF model_id, selects a registered build variant (e.g. + ``"qwen3_transformer_only"``) instead of the architecture's + native model_type. Leave ``None`` for normal auto-detection. allow_unsupported_nodes: If True, warn instead of raising when the analyzer reports unsupported nodes that persist; the build proceeds and the EP may fall back to another device for them. @@ -361,6 +366,11 @@ def from_pretrained( else: _model_type = None + # Explicit override wins so a variant composite (e.g. + # "qwen3_transformer_only") can be selected over the native type. + if model_type is not None: + _model_type = model_type + if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY: from .winml.composite_model import WinMLCompositeModel @@ -398,6 +408,7 @@ def from_pretrained( trust_remote_code=trust_remote_code, ep=kwargs.get("ep"), no_compile=no_compile, + model_type=model_type, ) resolved_task = build_config.loader.task @@ -432,7 +443,9 @@ def from_pretrained( from transformers import AutoConfig hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=effective_trust) - model_type = getattr(hf_config, "model_type", "unknown") + # Honor an explicit model_type override; otherwise probe from the config. + if model_type is None: + model_type = getattr(hf_config, "model_type", "unknown") logger.debug("Model type: %s, task: %s", model_type, resolved_task) # ===================================================================== @@ -470,6 +483,7 @@ def from_pretrained( cache_key=cache_key, ep=resolved_ep, device=device, + model_type=model_type, allow_unsupported_nodes=allow_unsupported_nodes, **build_control_kwargs, ) diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index c6f4c9520..458bc8e34 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -56,6 +56,15 @@ from .qwen import QWEN_CONFIG from .qwen import QwenGenIOConfig as _QwenGenIOConfig from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig +from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING +from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG +from .qwen_transformer_only import ( + QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration +) +from .qwen_transformer_only import ( + # triggers registration + QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, +) from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING @@ -92,6 +101,7 @@ **_MARIAN_CLASS_MAPPING, **_MU2_CLASS_MAPPING, **_QWEN_CLASS_MAPPING, + **_QWEN_TO_CLASS_MAPPING, **_SAM2_CLASS_MAPPING, **_SEGFORMER_CLASS_MAPPING, **_SIGLIP_CLASS_MAPPING, @@ -115,6 +125,7 @@ "roberta": ROBERTA_FAMILY_CONFIG, "mu2": MU2_CONFIG, "qwen3": QWEN_CONFIG, + "qwen3-transformer-only": QWEN_TRANSFORMER_ONLY_CONFIG, "siglip": SIGLIP_CONFIG, "siglip-text-model": SIGLIP_CONFIG, "siglip-vision-model": SIGLIP_CONFIG, diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py new file mode 100644 index 000000000..aed592fa7 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -0,0 +1,185 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Custom ONNX export ops that reshape HF's Qwen3 modules for export. + +These reshape the standard HF Qwen3 modules so winml-cli can produce a +QNN-friendly, transformer-only graph: + +- ``LpNormalization`` replaces the eager RMSNorm Mul/Pow/ReduceMean chain. +- ``com.microsoft::GroupQueryAttention`` replaces the eager QKV MatMul + + Softmax + KV-update path (with built-in rotary). +- 1x1 ``Conv`` (NHWC<->NCHW) replaces ``nn.Linear`` for QNN-friendly + projections. + +Everything here operates only on the standard ``transformers.models.qwen3`` +module attributes. +""" + +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn +from torch.onnx import symbolic_helper + + +# ============================================================================= +# Custom ONNX symbolic functions +# ============================================================================= + + +class LpNormOnnxExport(torch.autograd.Function): + """RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim).""" + + @staticmethod + def symbolic(g, input, axis, p) -> Any: + """Emit the ONNX ``LpNormalization`` node during export.""" + output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input)) + output = g.op( + "onnx::LpNormalization", + input, + axis_i=int(axis), + p_i=int(p), + ) + return output.setType(output_type) + + @staticmethod + def forward(ctx, input, axis, p) -> Any: + """Real ``LpNormalization`` (``input / ||input||_p`` along ``axis``). + + The exported node comes from ``symbolic``; this eager body computes the + same value so any eager execution (unit tests, calibration debug runs, + the exporter's own shape-tracing pass) gets correctly normalized output + instead of a silent identity. It matches the ONNX op faithfully (no + RMSNorm epsilon), since that is exactly what ``symbolic`` emits. + """ + return input / torch.linalg.vector_norm(input, ord=p, dim=axis, keepdim=True) + + +class GroupQueryAttentionOnnxExport(torch.autograd.Function): + """Fused Q/K/V + KV-cache + rotary → ``com.microsoft::GroupQueryAttention``.""" + + @staticmethod + def symbolic( + g, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ) -> Any: + """Emit the fused ``com.microsoft::GroupQueryAttention`` node.""" + args = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + ] + attention_output, present_keys, present_values = g.op( + "com.microsoft::GroupQueryAttention", + *args, + do_rotary_i=int(do_rotary), + kv_num_heads_i=int(kv_num_heads), + num_heads_i=int(num_heads), + outputs=3, + ) + + query_sizes = symbolic_helper._get_tensor_sizes(query) + attention_output.setType(query.type().with_sizes(query_sizes)) + present_keys.setType( + past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key)) + ) + present_values.setType( + past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value)) + ) + return attention_output, present_keys, present_values + + @staticmethod + def forward( + ctx, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ) -> Any: + """Shape-only tracing placeholder; returns a stand-in ``(output, KV)``. + + The real op is emitted by ``symbolic`` during ONNX export; this body + only needs to return tensors of the right shape/dtype. It deliberately + does NOT raise on eager execution, even though that yields a stale + (never-advanced) KV cache: the HTP export pipeline runs a real eager + ``forward`` pass to capture the module hierarchy (see + ``export/htp/hierarchy.py::trace_model_execution``), and that pass is + indistinguishable from misuse — ``torch.jit.is_tracing()`` and + ``torch.onnx.is_in_onnx_export()`` are both False there — so raising + would break the actual build. There is also no cheap faithful eager + equivalent (correct attention would grow the sequence axis that the + static-shape export freezes). This module is export-only by design and + is never run for real inference; calibration loads a fresh real model. + """ + return query, past_key, past_value # placeholder shapes (export-only) + + +# ============================================================================= +# 1x1 Conv replacement for nn.Linear +# ============================================================================= + + +class TransposeConv2d1x1Transpose(nn.Module): + """``nn.Linear`` → 1x1 ``Conv2d`` with NHWC<->NCHW permutes.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + weight: torch.nn.Parameter, + bias: torch.nn.Parameter | None = None, + ) -> None: + super().__init__() + # Linear weight is (out, in); Conv2d weight is (out, in, 1, 1). + self.weight = nn.Parameter(weight.data.view(out_channels, in_channels, 1, 1)) + self.bias = bias + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the 1x1 conv with NHWC<->NCHW permutes (+ optional bias).""" + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + x = torch.nn.functional.conv2d(x, self.weight) + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + if self.bias is not None: + x = x + self.bias + return x + + @classmethod + def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: + """Build a 1x1-conv replacement from an existing ``nn.Linear``.""" + return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) + + +__all__ = [ + "GroupQueryAttentionOnnxExport", + "LpNormOnnxExport", + "TransposeConv2d1x1Transpose", +] diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py new file mode 100644 index 000000000..f5207d797 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -0,0 +1,332 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""winml-owned Qwen3 model definitions for the transformer-only ONNX export. + +Each class is a plain ``nn.Module`` that carries the export-time behaviour +directly (``prepare_for_onnx_export`` + ``forward``). The export entry point +binds these ``forward`` methods onto the corresponding live Qwen3 submodules, +so the stock eager model is left untouched. + +What each class emits: + +- ``WinMLQwen3RMSNorm`` -> ``onnx::LpNormalization`` body. +- ``WinMLQwen3Attention`` -> ``com.microsoft::GroupQueryAttention`` (built-in + rotary) with optional 1x1 ``Conv`` projections. +- ``WinMLQwen3MLP`` -> 1x1 ``Conv`` projections (NHWC). +- ``WinMLQwen3DecoderLayer`` / ``WinMLQwen3Model`` -> transformer-only forward + that threads the KV cache + seq-len tensors and omits embeddings / lm_head. + +``apply_transformer_only_export_prep`` (defined below) walks a loaded +``Qwen3ForCausalLM``, calls ``prepare_for_onnx_export`` on each submodule, and +binds the matching ``forward`` from these classes onto it. +""" + +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn + +from .qwen3_export_ops import ( + GroupQueryAttentionOnnxExport, + LpNormOnnxExport, + TransposeConv2d1x1Transpose, +) + + +class WinMLQwen3RMSNorm(nn.Module): + """RMSNorm export variant — ``onnx::LpNormalization`` body.""" + + def prepare_for_onnx_export(self) -> None: + """Fold the RMSNorm gain into the weight (LpNorm has unit gain).""" + # Pre-multiply the gain into the weight (LpNorm has unit gain). + # ``scale`` is shape ``[1]`` and broadcasts over ``self.weight`` + # (shape ``[hidden_size]``), so the result keeps the per-channel + # shape even when the original weights are all ones. + n = self.weight.numel() + scale = torch.sqrt(torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype)) + self.weight = nn.Parameter(scale * self.weight) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply the LpNormalization-based RMSNorm body.""" + out = LpNormOnnxExport.apply(hidden_states, -1, 2) + return self.weight * out + + +class WinMLQwen3MLP(nn.Module): + """MLP export variant — 1x1 Conv projections (forward unchanged).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + """Optionally swap the MLP's linear projections for 1x1 convs.""" + if not matmul_to_conv: + return + self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module(self.gate_proj) + self.up_proj = TransposeConv2d1x1Transpose.from_linear_module(self.up_proj) + self.down_proj = TransposeConv2d1x1Transpose.from_linear_module(self.down_proj) + + +class WinMLQwen3Attention(nn.Module): + """Attention export variant — fused ``GroupQueryAttention`` op.""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + """Optionally swap the Q/K/V/O projections for 1x1 convs.""" + if matmul_to_conv: + self.q_proj = TransposeConv2d1x1Transpose.from_linear_module(self.q_proj) + self.k_proj = TransposeConv2d1x1Transpose.from_linear_module(self.k_proj) + self.v_proj = TransposeConv2d1x1Transpose.from_linear_module(self.v_proj) + self.o_proj = TransposeConv2d1x1Transpose.from_linear_module(self.o_proj) + self._matmul_to_conv = matmul_to_conv + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + **kwargs: Any, + ) -> tuple[torch.Tensor, None, tuple[torch.Tensor, torch.Tensor]]: + """Run fused GQA attention and return (output, None, present_kv).""" + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + input_shape = hidden_states.shape[1:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_norm(query_states.view(hidden_shape)) + key_states = self.k_norm(key_states.view(hidden_shape)) + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + query_dim = num_heads * self.head_dim + key_dim = num_kv_heads * self.head_dim + query_states = query_states.reshape(1, -1, query_dim) + key_states = key_states.reshape(1, -1, key_dim) + + if self._matmul_to_conv: + value_states = value_states.squeeze(0) + + past_keys, past_values = past_key_value + + # GroupQueryAttention requires Q/K/V/past_K/past_V to share dtype. + # The KV cache is FP16, so cast Q/K/V to the same dtype; otherwise ORT + # type inference rejects the node. + kv_dtype = past_keys.dtype + if query_states.dtype != kv_dtype: + query_states = query_states.to(kv_dtype) + key_states = key_states.to(kv_dtype) + value_states = value_states.to(kv_dtype) + + cos, sin = self.rotary_emb( + value_states, + torch.arange(self.config.max_position_embeddings).unsqueeze(0), + ) + cos = cos.squeeze(0)[:, : cos.shape[-1] // 2] + sin = sin.squeeze(0)[:, : sin.shape[-1] // 2] + if cos.dtype != kv_dtype: + cos = cos.to(kv_dtype) + sin = sin.to(kv_dtype) + + if isinstance(past_seq_len, int): + past_seq_len = torch.tensor(past_seq_len) + past_seq_len = torch.atleast_2d(past_seq_len) + + attention_output, present_keys, present_values = GroupQueryAttentionOnnxExport.apply( + query_states, + key_states, + value_states, + past_keys, + past_values, + past_seq_len, + total_seq_len, + cos, + sin, + 1, # do_rotary + num_kv_heads, + num_heads, + ) + + # Cast back to the residual-stream dtype so the downstream Conv + # (o_proj) sees its expected weight dtype. + if attention_output.dtype != hidden_states.dtype: + attention_output = attention_output.to(hidden_states.dtype) + + if self._matmul_to_conv: + attention_output = attention_output.unsqueeze(0) + + attention_output = self.o_proj(attention_output) + return attention_output, None, (present_keys, present_values) + + +class WinMLQwen3DecoderLayer(nn.Module): + """Decoder-layer export variant — threads KV cache + seq-len kwargs.""" + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + use_cache: bool = True, + **kwargs: Any, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Run the decoder layer (attention + MLP) with residual adds.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_out, _, present_kv = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + ) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if use_cache: + outputs += (present_kv,) + return outputs + + +class WinMLQwen3Model(nn.Module): + """Model export variant — transformer-only body (no embeddings / lm_head).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + """Record whether projections use the 1x1-conv (NHWC) path.""" + self._matmul_to_conv = matmul_to_conv + + def forward( + self, + inputs_embeds: torch.Tensor, + past_key_values: list[tuple[torch.Tensor, torch.Tensor]], + past_seq_len: torch.Tensor, + total_seq_len: torch.Tensor, + use_cache: bool = True, + ) -> tuple[torch.Tensor, tuple[tuple[torch.Tensor, torch.Tensor], ...]]: + """Run the transformer-only body, returning hidden states + KV.""" + hidden_states = inputs_embeds + if self._matmul_to_conv: + hidden_states = hidden_states.unsqueeze(0) # NHWC for Conv path + + present_kvs: tuple[tuple[torch.Tensor, torch.Tensor], ...] = () + for idx, layer in enumerate(self.layers): + out = layer( + hidden_states, + past_key_value=past_key_values[idx], + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=use_cache, + ) + hidden_states = out[0] + if use_cache: + present_kvs += (out[1],) + + hidden_states = self.norm(hidden_states) + if self._matmul_to_conv: + hidden_states = hidden_states.squeeze(0) + return hidden_states, present_kvs + + +# ============================================================================= +# Apply export prep: bind winml Qwen3 export methods onto a loaded model +# ============================================================================= + + +def apply_transformer_only_export_prep( + causal_lm: nn.Module, *, matmul_to_conv: bool = True +) -> None: + """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. + + Binds the winml-owned export behaviour (the ``WinMLQwen3*`` classes in this + module) onto each Qwen3 submodule (runs ``prepare_for_onnx_export`` and + rebinds ``forward``). After this call, ``causal_lm.model(inputs_embeds, + past_key_values, past_seq_len, total_seq_len)`` runs the transformer-only + forward. + + Args: + causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. + matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so + QNN sees them as Conv. + + Raises: + RuntimeError: If any expected Qwen3 submodule class is not found, + meaning the loaded model does not match the expected topology + (e.g. the stock HF class names changed). + """ + + def _bind(module: nn.Module, owner: type) -> None: + module.forward = owner.forward.__get__(module, type(module)) + + # Identify Qwen3 submodules by their (stock HF) class name so we don't + # depend on importing ``transformers.models.qwen3`` here. + def _is(module: nn.Module, name: str) -> bool: + return type(module).__name__ == name + + patched = { + "Qwen3RMSNorm": 0, + "Qwen3Attention": 0, + "Qwen3MLP": 0, + "Qwen3DecoderLayer": 0, + "Qwen3Model": 0, + } + + # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, + # in input/post_attention layernorms). + for mod in causal_lm.modules(): + if _is(mod, "Qwen3RMSNorm"): + WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + _bind(mod, WinMLQwen3RMSNorm) + patched["Qwen3RMSNorm"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Attention"): + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Attention) + patched["Qwen3Attention"] += 1 + elif _is(mod, "Qwen3MLP"): + # MLP forward is unchanged; only the projections are swapped to Conv. + WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + patched["Qwen3MLP"] += 1 + + # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; + # the export forward invokes ``self.rotary_emb`` on the attention module, + # so re-attach a reference from the parent model. + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): + for layer in mod.layers: + layer.self_attn.rotary_emb = mod.rotary_emb + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3DecoderLayer"): + _bind(mod, WinMLQwen3DecoderLayer) + patched["Qwen3DecoderLayer"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model"): + WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Model) + patched["Qwen3Model"] += 1 + + missing = [name for name, count in patched.items() if count == 0] + if missing: + raise RuntimeError( + "transformer-only export prep found no " + f"{missing} submodule(s) to patch; the loaded model does not match " + "the expected Qwen3 topology (stock HF class names may have changed)." + ) + + +__all__ = [ + "WinMLQwen3Attention", + "WinMLQwen3DecoderLayer", + "WinMLQwen3MLP", + "WinMLQwen3Model", + "WinMLQwen3RMSNorm", + "apply_transformer_only_export_prep", +] diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py new file mode 100644 index 000000000..bff3cc5c7 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -0,0 +1,383 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Transformer-only ``qwen3`` build variant, registered as a distinct model_type. + +This module registers a self-contained build path under the model_type +``"qwen3_transformer_only"`` (distinct from the stock ``"qwen3"`` path in +``qwen.py``). Selecting it is explicit — pass ``model_type="qwen3_transformer_only"`` +to ``WinMLAutoModel.from_pretrained(...)`` (or the underlying +``generate_hf_build_config(...)``). Both paths coexist; neither overrides the +other, and there is no import-ordering requirement. + +The variant exports two transformer-only ONNX files (a prefill/context graph +and an iteration/decode graph) with this I/O: + + Inputs : past_keys_{i}, past_values_{i} (FP16, ``[1, kv_heads, max_cache, head_dim]``), + input_hidden_states (FP32, ``[1, seq_len, hidden]``), + past_seq_len (INT32, ``[1, 1]``), total_seq_len (INT32, ``[1]``) + Outputs: output_hidden_states (FP32), present_keys_{i}, present_values_{i} (FP16) + Ops : ``com.microsoft::GroupQueryAttention`` (do_rotary=1), + ``onnx::LpNormalization`` (RMSNorm), 1x1 ``Conv`` projections. + +Registration happens at import time via decorators and module-level mappings, +mirroring ``qwen.py``. The aggregating ``models.hf`` package imports this +module so the entries land in ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS``. +""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyInputGenerator +from transformers import AutoModelForCausalLM + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...export.config import WinMLExportConfig +from ..winml import register_specialization +from ..winml.composite_model import register_composite_model +from ..winml.decoder_only import WinMLDecoderOnlyModel +from ..winml.kv_cache import WinMLSlidingWindowCache +from .qwen3_modeling import apply_transformer_only_export_prep + + +# Distinct model_type for this variant. The underscore form is what the +# exporter sees on ``model.config.model_type`` and what Optimum's TasksManager +# and ``register_specialization`` are keyed on; the hyphenated form is used for +# the ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS`` lookups (those callers +# normalize ``_`` -> ``-``). +TRANSFORMER_ONLY_MODEL_TYPE = "qwen3_transformer_only" + + +# ============================================================================= +# Wrapper module +# ============================================================================= + + +class QwenTransformerOnlyDecoderWrapper(nn.Module): + """Wraps ``Qwen3ForCausalLM`` for transformer-only export. + + The wrapper applies the export prep (LpNorm RMSNorm, GQA op, 1x1 + Conv projections) in ``__init__`` and exposes a positional ``forward`` + whose argument order matches :class:`QwenTransformerOnlyPrefillIOConfig.inputs`. + Only ``self.model.model`` (the inner ``Qwen3Model``) is invoked at + export time — embedding lookup and ``lm_head`` stay out of the graph. + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + self.config = model.config + apply_transformer_only_export_prep(model, matmul_to_conv=True) + # Tag the config so the exporter resolves this variant's OnnxConfig + # (registered under ``TRANSFORMER_ONLY_MODEL_TYPE``) rather than the + # stock qwen3 one. Mirrors the CLIP/zoedepth sub-model precedent. + self.config.model_type = TRANSFORMER_ONLY_MODEL_TYPE + + @classmethod + def from_pretrained( + cls, model_name_or_path: str, **kwargs: Any + ) -> QwenTransformerOnlyDecoderWrapper: + """Load the HF model and wrap it for transformer-only export.""" + kwargs.setdefault("torch_dtype", torch.float32) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) + model.config._attn_implementation = "eager" + wrapper = cls(model, model.config.num_hidden_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Flatten the dummy-input dict into positional export args.""" + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Run the decoder stack on positional inputs (order matches OnnxConfig.inputs). + + Positional inputs are ``past_keys_0, past_values_0, ..., + past_keys_{L-1}, past_values_{L-1}, input_hidden_states, past_seq_len, + total_seq_len``. Returns ``(output_hidden_states, present_keys_0, + present_values_0, ...)``. + """ + kv_args = args[: 2 * self.num_layers] + input_hidden_states = args[2 * self.num_layers] + past_seq_len = args[2 * self.num_layers + 1] + total_seq_len = args[2 * self.num_layers + 2] + + past_key_values = [(kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers)] + + hidden_states, present_kvs = self.model.model( + inputs_embeds=input_hidden_states, + past_key_values=past_key_values, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=True, + ) + + out: list[torch.Tensor] = [hidden_states] + for k, v in present_kvs: + out.extend([k, v]) + return tuple(out) + + +# ============================================================================= +# Dummy input generators (transformer-only I/O) +# ============================================================================= + + +class _TransformerOnlyHiddenStateGenerator(DummyInputGenerator): + """Generates ``input_hidden_states`` (FP32, ``[1, seq_len, hidden]``).""" + + SUPPORTED_INPUT_NAMES = ("input_hidden_states",) + + _default_seq_len: ClassVar[int] = 1 + + def __init__( + self, + task: str, + normalized_config: Any, + batch_size: int = 1, + seq_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.hidden_size = normalized_config.hidden_size + self.seq_len = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + if input_name == "input_hidden_states": + return torch.randn(self.batch_size, self.seq_len, self.hidden_size, dtype=torch.float32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyHiddenStatePrefillGenerator(_TransformerOnlyHiddenStateGenerator): + _default_seq_len = 64 + + +class _TransformerOnlySeqLenGenerator(DummyInputGenerator): + """Generates ``past_seq_len`` (INT32 ``[1,1]``) and ``total_seq_len`` (INT32 ``[1]``).""" + + SUPPORTED_INPUT_NAMES = ("past_seq_len", "total_seq_len") + + def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: + self.max_cache_len = normalized_config.max_cache_len + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + if input_name == "past_seq_len": + return torch.zeros((1, 1), dtype=torch.int32) + if input_name == "total_seq_len": + return torch.tensor([self.max_cache_len], dtype=torch.int32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyKvCacheGenerator(DummyInputGenerator): + """Generates ``past_keys_{i}`` / ``past_values_{i}`` (FP16).""" + + SUPPORTED_INPUT_NAMES = () # built dynamically in __init__ + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.num_layers: int = normalized_config.num_layers + self.num_heads: int = ( + normalized_config.num_attention_heads + ) # KV heads (NormalizedConfig maps it) + self.head_dim: int = normalized_config.head_dim + self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len + self.SUPPORTED_INPUT_NAMES = tuple( + name for i in range(self.num_layers) for name in (f"past_keys_{i}", f"past_values_{i}") + ) + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: + shape = (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim) + return torch.zeros(shape, dtype=torch.float16) + + +# ============================================================================= +# OnnxConfigs — transformer-only I/O layout +# ============================================================================= + + +_QWEN_TRANSFORMER_ONLY_NORMALIZED = NormalizedConfig.with_args( + hidden_size="hidden_size", + num_layers="num_hidden_layers", + num_attention_heads="num_key_value_heads", # KV heads (GQA) + head_dim="head_dim", + max_cache_len="max_position_embeddings", + vocab_size="vocab_size", + allow_new=True, +) + + +def _transformer_only_inputs( + num_layers: int, kv_seq_axis: str = "max_seq_len" +) -> dict[str, dict[int, str]]: + """Input ordering: past KV pairs, then hidden states, then seq lens.""" + result: dict[str, dict[int, str]] = {} + for i in range(num_layers): + result[f"past_keys_{i}"] = {2: kv_seq_axis} + result[f"past_values_{i}"] = {2: kv_seq_axis} + result["input_hidden_states"] = {1: "seq_len"} + result["past_seq_len"] = {} + result["total_seq_len"] = {} + return result + + +def _transformer_only_outputs( + num_layers: int, kv_seq_axis: str = "max_seq_len" +) -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = {"output_hidden_states": {1: "seq_len"}} + for i in range(num_layers): + result[f"present_keys_{i}"] = {2: kv_seq_axis} + result[f"present_values_{i}"] = {2: kv_seq_axis} + return result + + +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", library_name="transformers" +) +class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): + """Prefill (seq=64) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStatePrefillGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + """ONNX input axes (past KV pairs, hidden states, seq lengths).""" + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + """ONNX output axes (hidden states then present KV pairs).""" + return _transformer_only_outputs(self._normalized_config.num_layers) + + +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", library_name="transformers" +) +class QwenTransformerOnlyGenIOConfig(OnnxConfig): + """Generation (seq=1) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStateGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + """ONNX input axes (past KV pairs, hidden states, seq lengths).""" + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + """ONNX output axes (hidden states then present KV pairs).""" + return _transformer_only_outputs(self._normalized_config.num_layers) + + +# ============================================================================= +# Build config — TorchScript exporter required for the custom autograd ops +# ============================================================================= + + +QWEN_TRANSFORMER_ONLY_CONFIG = WinMLBuildConfig( + export=WinMLExportConfig(dynamo=False, opset_version=18), + # Pure graph (no post-export RMSNorm fusion / matmul-add fusion). + optim=None, +) + + +# ============================================================================= +# Composite inference wrapper (placeholder so the build pipeline finds a +# composite class — generation isn't yet wired for the transformer-only +# I/O signature). +# ============================================================================= + + +@register_composite_model(TRANSFORMER_ONLY_MODEL_TYPE, "text-generation") +class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): + """Composite handle for the transformer-only Qwen3 build (export only). + + ``generate()`` is **not** functional with this build path — the inference + feeds and KV update logic still target the eager I/O signature. Use the + eager :class:`WinMLQwen3Model` for generation; use this class to produce + the transformer-only ONNX for downstream quantization. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "decoder_prefill": "feature-extraction", + "decoder_gen": "text2text-generation", + } + + @classmethod + def get_cache_class(cls) -> type: + """Return the KV-cache class used during generation.""" + return WinMLSlidingWindowCache + + +# ============================================================================= +# Declarative registration (import-time) +# ============================================================================= + +# Wrapper-class lookup keyed by (model_type, task). Keys use the hyphenated +# model_type form because ``_get_custom_model_class`` normalizes ``_`` -> ``-`` +# before lookup. Merged into the aggregate mapping by ``models.hf.__init__``. +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("qwen3-transformer-only", "feature-extraction"): QwenTransformerOnlyDecoderWrapper, + ("qwen3-transformer-only", "text2text-generation"): QwenTransformerOnlyDecoderWrapper, +} + +# Inference specialization (GenericTask — the wrapper returns raw hidden states / KV). +register_specialization( + TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask" +) +register_specialization( + TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask" +) + + +__all__ = [ + "MODEL_CLASS_MAPPING", + "QWEN_TRANSFORMER_ONLY_CONFIG", + "TRANSFORMER_ONLY_MODEL_TYPE", + "QwenTransformerOnlyDecoderWrapper", + "QwenTransformerOnlyGenIOConfig", + "QwenTransformerOnlyPrefillIOConfig", + "WinMLQwen3TransformerOnlyModel", +] diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index bc8e6ee06..e43a69068 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -16,7 +16,7 @@ result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100)) """ -from typing import Any +from typing import TYPE_CHECKING, Any from .config import QuantizeResult, WinMLQuantizationConfig @@ -24,12 +24,25 @@ __all__ = [ "QuantizeResult", "WinMLQuantizationConfig", + "get_quant_finalizer", "quantize_onnx", + "register_quant_finalizer", ] +# Names below are loaded lazily via ``__getattr__`` to avoid pulling in +# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports +# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports +# without triggering the heavy imports at runtime. +if TYPE_CHECKING: + from .calibration import get_quant_finalizer, register_quant_finalizer + from .quantizer import quantize_onnx + + _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), + "get_quant_finalizer": (".calibration", "get_quant_finalizer"), + "register_quant_finalizer": (".calibration", "register_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/__init__.py b/src/winml/modelkit/quant/calibration/__init__.py new file mode 100644 index 000000000..88b1434c5 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/__init__.py @@ -0,0 +1,23 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Model-type-specific quantization policies (calibration readers + schemes). + +This subpackage stays import-light on purpose: it exposes only the registry +API. The individual finalizer modules (which pull in torch/transformers) are +imported lazily by :func:`get_quant_finalizer` when their ``model_type`` is +quantized. +""" + +from __future__ import annotations + +from .base import QuantConfigFinalizer +from .registry import get_quant_finalizer, register_quant_finalizer + + +__all__ = [ + "QuantConfigFinalizer", + "get_quant_finalizer", + "register_quant_finalizer", +] diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py new file mode 100644 index 000000000..d62ba4322 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/base.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Base protocol for model-type-specific quantization policies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from pathlib import Path + + from ..config import WinMLQuantizationConfig + + +@runtime_checkable +class QuantConfigFinalizer(Protocol): + """Model-type-specific quant policy. + + Given the freshly exported ONNX, a finalizer populates the live + :class:`WinMLQuantizationConfig` with the fields that can only be known + once the graph exists — the calibration data reader, ``nodes_to_exclude``, + and (where the scheme is fixed and reference-matched) the dtype/symmetry + settings. + + Finalizers are registered per ``model_type`` (see + :func:`.registry.register_quant_finalizer`). Model types without a + registered policy fall back to the quantizer's default + ``DatasetCalibrationReader``. + """ + + def finalize( + self, + quant: WinMLQuantizationConfig, + *, + onnx_path: Path, + model_id: str | None = None, + ) -> WinMLQuantizationConfig: + """Return ``quant`` populated with the graph-derived quant settings.""" diff --git a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py new file mode 100644 index 000000000..5abb7e4ce --- /dev/null +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -0,0 +1,446 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Config-driven w8a16 calibration for the transformer-only Qwen3 build. + +The transformer-only export (``models.hf.qwen_transformer_only``) emits a graph +whose only quantization-relevant runtime inputs (the calibration feeds and the +``GroupQueryAttention`` node names to keep in float) can't be known until the +ONNX exists. Rather than a standalone post-build script that reaches into +``composite.sub_models[...]._onnx_path``, this module registers a quant policy +keyed on ``model_type`` (:class:`Qwen3TransformerOnlyQuantFinalizer`). The build +pipeline resolves it via :func:`~winml.modelkit.quant.get_quant_finalizer` and +calls :func:`finalize_transformer_only_quant_config` just before +``quantize_onnx`` runs (see ``build/hf.py``), populating the live +:class:`WinMLQuantizationConfig` with the right +:class:`~winml.modelkit.quant.config.CalibrationDataReader` and +``nodes_to_exclude``. + +The two readers match the exported graph exactly: + + - ``input_hidden_states`` (FP32), ``past_seq_len`` / ``total_seq_len`` + (INT32), ``past_keys_{i}`` / ``past_values_{i}`` (FP16, full cache buffer). + - The prefill reader (``seq_len > 1``) embeds real prompt prefixes. + - The decode reader (``seq_len == 1``) drives a fresh FP reference model + through a real prefill + decode trajectory so MinMax sees representative + mid-generation activation ranges (a single repeated token + zeroed KV + collapses the ranges and degenerates generation). + +The export wrapper surgically replaces its own ``self.model`` (RMSNorm -> +LpNorm-identity, attention -> GQA placeholder, Linear -> 1x1 Conv), so it can't +run real inference; calibration loads a *fresh* ``AutoModelForCausalLM``. +""" + +from __future__ import annotations + +import gc +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ..config import CalibrationDataReader, WinMLQuantizationConfig +from .registry import register_quant_finalizer + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL_ID = "Qwen/Qwen3-0.6B" +DEFAULT_PREFILL_SEQ = 64 +DEFAULT_GEN_SEQ = 1 +DEFAULT_NUM_SAMPLES = 30 +DEFAULT_DECODE_STEPS = 16 +DEFAULT_CALIB_DATASET = "openai/gsm8k" +DEFAULT_CALIB_DATASET_CONFIG = "main" +DEFAULT_CALIB_SPLIT = "train" +DEFAULT_CALIB_SEED = 42 + + +def _load_gsm8k_prompts(num_samples: int) -> list[str]: + """GSM8K train split, shuffled seed=42 for reproducible calibration.""" + from datasets import load_dataset + + ds = load_dataset(DEFAULT_CALIB_DATASET, DEFAULT_CALIB_DATASET_CONFIG) + split = ds[DEFAULT_CALIB_SPLIT].shuffle(seed=DEFAULT_CALIB_SEED) + return [row["question"] for row in split.select(range(num_samples))] + + +def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for i in range(num_samples): + prompt = prompts[i % len(prompts)] + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + out.append(ids) + return out + + +def _gqa_node_names(onnx_path: Path) -> list[str]: + """Return the names of every GroupQueryAttention node in ``onnx_path``. + + These nodes are excluded from quantization so ORT leaves both their + inputs and output in float (``... -> Cast -> GQA -> Cast``), matching + the reference graph which keeps attention entirely out of QDQ. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + return [n.name for n in model.graph.node if n.op_type == "GroupQueryAttention" and n.name] + + +def _graph_shapes(onnx_path: Path) -> tuple[int, int]: + """Read ``(seq_len, max_cache_len)`` from the exported graph's static inputs. + + ``seq_len`` is the query length (``input_hidden_states`` dim 1) and + ``max_cache_len`` is the KV buffer length (``past_keys_0`` dim 2). The + transformer-only export keeps both axes static, so these fully determine + whether the sub-model is prefill (``seq_len > 1``) or decode (``seq_len == 1``) + and the size of the fixed KV buffers the calibration feeds must match. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + seq_len: int | None = None + max_cache_len: int | None = None + for inp in model.graph.input: + dims = inp.type.tensor_type.shape.dim + if inp.name == "input_hidden_states" and len(dims) >= 2: + seq_len = dims[1].dim_value + elif inp.name == "past_keys_0" and len(dims) >= 3: + max_cache_len = dims[2].dim_value + # A symbolic/dynamic axis yields dim_value == 0 (not None), so treat any + # non-positive value as "not a usable static shape" and fail loudly rather + # than silently building zero-length calibration feeds. + if not seq_len or not max_cache_len: + raise ValueError( + f"Could not read static seq_len/max_cache_len from {onnx_path.name}; " + f"found seq_len={seq_len}, max_cache_len={max_cache_len}" + ) + return seq_len, max_cache_len + + +def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: + """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. + + Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` + (``.key_cache`` / ``.value_cache``), and the newer per-layer + ``DynamicCache`` (``.layers[i].keys`` / ``.values``). + """ + if hasattr(past, "key_cache") and hasattr(past, "value_cache"): + return past.key_cache[i], past.value_cache[i] + if hasattr(past, "layers"): + layer = past.layers[i] + return layer.keys, layer.values + return past[i][0], past[i][1] + + +class Qwen3TransformerOnlyCalibReader(CalibrationDataReader): + """Prefill calibration feeds for the transformer-only ONNX. + + Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), + ``past_seq_len`` (INT32 ``[1,1]``), ``total_seq_len`` (INT32 ``[1]``), + and ``past_keys_{i}`` / ``past_values_{i}`` (FP16, full cache buffer). + """ + + def __init__( + self, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + seq_len: int, + max_cache_len: int, + ) -> None: + self.embed = embed_tokens + self.seq_len = seq_len + self.max_cache_len = max_cache_len + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self._samples = list(self._build_samples(token_ids_list)) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + ids = ids[:, : self.seq_len] + real_len = ids.shape[1] + if real_len < self.seq_len: + pad = torch.zeros((1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device) + ids = torch.cat([ids, pad], dim=1) + + with torch.no_grad(): + embeds = self.embed(ids).to(torch.float32).cpu().numpy() + + feed: dict[str, np.ndarray] = { + "input_hidden_states": embeds.astype(np.float32), + # seqlens_k for GQA = (valid context length - 1), i.e. + # ``embeddings.shape[1] - 1``. We pad to seq_len, so the query + # has seq_len valid positions -> past_seq_len = seq_len - 1. + # (Using 0 here declares only 1 valid token while feeding a + # seq_len-token query, which makes the GQA prefill kernel read + # out of bounds -> native access violation.) + "past_seq_len": np.array([[self.seq_len - 1]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), + } + kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) + zeros = np.zeros(kv_shape, dtype=np.float16) + for i in range(self.num_layers): + feed[f"past_keys_{i}"] = zeros + feed[f"past_values_{i}"] = zeros + yield feed + + def get_next(self) -> dict[str, np.ndarray] | None: + """Return the next calibration feed, or None when exhausted.""" + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + """Reset the iterator so calibration can run another pass.""" + self._iter = iter(self._samples) + + +class Qwen3DecodeTrajectoryCalibReader(CalibrationDataReader): + """Calibrate the iter (seq_len=1) model on REAL decode-step states. + + The naive reader feeds one (repeated) token with a zeroed KV cache and + ``past_seq_len=0`` — a state the model never sees during generation. With + MinMax calibration this collapses the observed activation ranges far below + the real decode distribution, so the resulting w8a16 model degenerates + (e.g. ``Paris -> Parisammedammed...``). + + Instead, drive the HF FP reference model through a real prefill + decode + trajectory and capture, at each decode step, the exact feed the iter ONNX + would receive: the embedding of the *actually generated* token, the real + accumulated KV cache (copied into the fixed ``[1, kv_heads, max_cache, + head_dim]`` FP16 buffer), and the growing ``past_seq_len``. Token + selection uses the HF model's true logits, so the trajectory matches + greedy generation. The QDQ scheme is unchanged — only the calibration + statistics become representative. + """ + + def __init__( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + max_cache_len: int, + decode_steps: int = DEFAULT_DECODE_STEPS, + ) -> None: + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = max_cache_len + self._samples = list( + self._build_samples( + hf_model, + embed_tokens, + token_ids_list, + prefill_seq=prefill_seq, + decode_steps=decode_steps, + ) + ) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _kv_buffers(self, past: Any, cur_len: int) -> dict[str, np.ndarray]: + """Copy the ``cur_len`` valid KV positions into fixed FP16 buffers.""" + feed: dict[str, np.ndarray] = {} + for i in range(self.num_layers): + k, v = _layer_kv(past, i) + kbuf = np.zeros((1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16) + vbuf = np.zeros_like(kbuf) + kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + feed[f"past_keys_{i}"] = kbuf + feed[f"past_values_{i}"] = vbuf + return feed + + def _build_samples( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + decode_steps: int, + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + ids = ids[:, :prefill_seq] # real prompt prefix (no pad-token KV) + cur_len = ids.shape[1] + + # FP prefill once to seed a realistic KV cache + first token. + with torch.no_grad(): + out = hf_model(input_ids=ids, use_cache=True) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + + for _ in range(decode_steps): + if cur_len >= self.max_cache_len: + break + # The feed the iter model sees for THIS token: embedding of the + # token to process, the KV of the `cur_len` preceding tokens, + # and seqlens_k = (cur_len + 1) - 1 = cur_len. + with torch.no_grad(): + emb = embed_tokens(torch.tensor([[tok]])).to(torch.float32).cpu().numpy() + feed: dict[str, np.ndarray] = { + "input_hidden_states": emb.astype(np.float32), + "past_seq_len": np.array([[cur_len]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), + } + feed.update(self._kv_buffers(past, cur_len)) + yield feed + + # Advance the reference model one real decode step. + with torch.no_grad(): + out = hf_model( + input_ids=torch.tensor([[tok]]), + past_key_values=past, + use_cache=True, + ) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + cur_len += 1 + + def get_next(self) -> dict[str, np.ndarray] | None: + """Return the next calibration feed, or None when exhausted.""" + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + """Reset the iterator so calibration can run another pass.""" + self._iter = iter(self._samples) + + +def finalize_transformer_only_quant_config( + quant: WinMLQuantizationConfig, + *, + onnx_path: Path, + model_id: str = DEFAULT_MODEL_ID, + prefill_seq: int = DEFAULT_PREFILL_SEQ, + decode_steps: int = DEFAULT_DECODE_STEPS, +) -> WinMLQuantizationConfig: + """Populate ``quant`` with the transformer-only w8a16 scheme + runtime fields. + + The build pipeline's device/precision policy only enables quantization and + picks generic dtypes; the transformer-only scheme is fixed and reference- + matched, so this hook is authoritative: + + - **int8-symmetric weights** (zp=0) + **uint16 asymmetric activations**, + - **MinMax** calibration, + - GroupQueryAttention nodes excluded from QDQ (read from the graph), + - the matching :class:`CalibrationDataReader` (prefill vs. decode-trajectory, + chosen by the graph's ``seq_len``). + + Reads static shapes + GQA nodes from ``onnx_path`` and loads a fresh FP + reference model for calibration (the export wrapper's own weights are + surgically replaced and can't run real inference). + """ + onnx_path = Path(onnx_path) + seq_len, max_cache_len = _graph_shapes(onnx_path) + gqa_nodes = _gqa_node_names(onnx_path) + + # Fixed, reference-matched w8a16 scheme (authoritative over policy dtypes). + quant.weight_type = "int8" + quant.activation_type = "uint16" + quant.weight_symmetric = True + quant.activation_symmetric = False + quant.calibration_method = "minmax" + num_samples = quant.samples or DEFAULT_NUM_SAMPLES + + logger.info( + "Finalizing transformer-only quant config for %s " + "(seq_len=%d, max_cache_len=%d, %d GQA nodes excluded, %d samples)", + onnx_path.name, + seq_len, + max_cache_len, + len(gqa_nodes), + num_samples, + ) + + hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32) + hf_model.eval() + embed_tokens = hf_model.get_input_embeddings() + tokenizer = AutoTokenizer.from_pretrained(model_id) + prompts = _load_gsm8k_prompts(num_samples) + token_ids_list = _tokenize_prompts(tokenizer, prompts, num_samples) + + reader: CalibrationDataReader + if seq_len == DEFAULT_GEN_SEQ: + # Decode sub-model: calibrate on a real prefill+decode trajectory. + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed_tokens, + hf_model.config, + token_ids_list, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + else: + reader = Qwen3TransformerOnlyCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + + quant.calibration_data = reader + quant.nodes_to_exclude = gqa_nodes + + # Readers materialize all samples eagerly, so the FP reference is no longer + # needed once they're built. + del hf_model, embed_tokens + gc.collect() + + return quant + + +@register_quant_finalizer("qwen3_transformer_only") +class Qwen3TransformerOnlyQuantFinalizer: + """Registered quant policy for the ``qwen3_transformer_only`` model_type. + + Adapts :func:`finalize_transformer_only_quant_config` to the + :class:`~winml.modelkit.quant.calibration.base.QuantConfigFinalizer` + protocol so the build pipeline resolves the model-specific w8a16 scheme + + calibration reader through the quant registry (keyed on ``model_type``) + rather than a hardcoded hook on the export wrapper. + """ + + def finalize( + self, + quant: WinMLQuantizationConfig, + *, + onnx_path: Path, + model_id: str | None = None, + ) -> WinMLQuantizationConfig: + """Populate ``quant`` with the transformer-only w8a16 scheme + reader.""" + return finalize_transformer_only_quant_config( + quant, onnx_path=onnx_path, model_id=model_id or DEFAULT_MODEL_ID + ) diff --git a/src/winml/modelkit/quant/calibration/registry.py b/src/winml/modelkit/quant/calibration/registry.py new file mode 100644 index 000000000..78b321ae4 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/registry.py @@ -0,0 +1,75 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Registry mapping ``model_type`` to its quantization policy. + +Mirrors the project's other ``model_type``-keyed registries (e.g. +``COMPOSITE_MODEL_REGISTRY``): a finalizer registers itself with +``@register_quant_finalizer(model_type)`` and the build pipeline resolves it +with :func:`get_quant_finalizer`. + +The registry is intentionally lazy. Importing :mod:`winml.modelkit.quant` +must stay free of heavy deps (torch/transformers); the per-model finalizer +modules — which do pull those in — are only imported the first time their +``model_type`` is actually quantized. +""" + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from .base import QuantConfigFinalizer + + +# Populated by the ``@register_quant_finalizer`` decorator at import time. +_QUANT_FINALIZER_REGISTRY: dict[str, type[QuantConfigFinalizer]] = {} + +# ``model_type`` -> submodule that defines (and self-registers) its finalizer. +# Looked up lazily so the heavy module loads only when needed. Keys must match +# the live ``model_type`` string verbatim (no ``_`` -> ``-`` normalization), +# since lookup is keyed on the exported model's ``config.model_type``. +_KNOWN_FINALIZER_MODULES: dict[str, str] = { + "qwen3_transformer_only": ".qwen3_transformer_only", +} + + +def register_quant_finalizer(model_type: str) -> Callable[[type], type]: + """Class decorator registering a :class:`QuantConfigFinalizer` for ``model_type``.""" + + def decorator(cls: type) -> type: + if not hasattr(cls, "finalize"): + raise TypeError( + f"{cls.__name__} cannot register as a quant finalizer for " + f"{model_type!r}: it must define a ``finalize`` method." + ) + if model_type in _QUANT_FINALIZER_REGISTRY: + raise ValueError( + f"Quant finalizer already registered for {model_type!r}: " + f"{_QUANT_FINALIZER_REGISTRY[model_type].__name__}. " + f"Cannot register {cls.__name__}." + ) + _QUANT_FINALIZER_REGISTRY[model_type] = cls + return cls + + return decorator + + +def get_quant_finalizer(model_type: str | None) -> QuantConfigFinalizer | None: + """Return a finalizer instance for ``model_type``, or ``None`` if unregistered. + + ``None`` means "no model-specific policy" — the quantizer then uses its + standard task-aware ``DatasetCalibrationReader``. + """ + if not model_type: + return None + if model_type not in _QUANT_FINALIZER_REGISTRY and model_type in _KNOWN_FINALIZER_MODULES: + # Triggers the module's ``@register_quant_finalizer`` side effect. + importlib.import_module(_KNOWN_FINALIZER_MODULES[model_type], __package__) + cls = _QUANT_FINALIZER_REGISTRY.get(model_type) + return cls() if cls is not None else None diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index b9709cc0e..6132be599 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -68,6 +68,11 @@ class WinMLQuantizationConfig: # Quantization options per_channel: bool = False symmetric: bool = False + # Optional per-target symmetry overrides. When None, fall back to + # ``symmetric``. Lets w8a16 use symmetric weights (int8, zp=0) together + # with asymmetric activations (uint16). + weight_symmetric: bool | None = None + activation_symmetric: bool | None = None # Output settings save_calibration: bool = False @@ -98,6 +103,8 @@ def to_dict(self) -> dict: "activation_type": self.activation_type, "per_channel": self.per_channel, "symmetric": self.symmetric, + "weight_symmetric": self.weight_symmetric, + "activation_symmetric": self.activation_symmetric, "save_calibration": self.save_calibration, "distribution": self.distribution, "seed": self.seed, @@ -139,6 +146,8 @@ def from_dict(cls, data: dict) -> WinMLQuantizationConfig: activation_type=data.get("activation_type", "uint8"), per_channel=data.get("per_channel", False), symmetric=data.get("symmetric", False), + weight_symmetric=data.get("weight_symmetric"), + activation_symmetric=data.get("activation_symmetric"), save_calibration=data.get("save_calibration", False), distribution=data.get("distribution", "uniform"), seed=data.get("seed"), diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index c562599de..e5fd30df3 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -132,10 +132,23 @@ def quantize_onnx( activation_type = activation_type_map[config.activation_type] calibrate_method = calibration_method_map[config.calibration_method] - # Build extra options + # Build extra options. Weight/activation symmetry can be controlled + # independently (e.g. w8a16 = symmetric int8 weights + asymmetric + # uint16 activations); fall back to the single ``symmetric`` flag when + # the per-target override is unset. + weight_symmetric = ( + config.weight_symmetric + if config.weight_symmetric is not None + else config.symmetric + ) + activation_symmetric = ( + config.activation_symmetric + if config.activation_symmetric is not None + else config.symmetric + ) extra_options = { - "ActivationSymmetric": config.symmetric, - "WeightSymmetric": config.symmetric, + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, } # Step 1: Generate QDQ config diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py new file mode 100644 index 000000000..831a640e8 --- /dev/null +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -0,0 +1,257 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""End-to-end coverage for the transformer-only Qwen3 w8a16 build. + +Replaces the former root-level ``test_qwen.py`` / ``qwen3_transformer_only_quantize.py`` +scripts. Quantization is now driven entirely through the standard build +pipeline (``WinMLAutoModel.from_pretrained(..., precision="w8a16")``): the +device/precision policy enables the quantize stage, and the +``qwen3_transformer_only`` quant policy registered in +``winml.modelkit.quant.calibration`` (resolved via ``get_quant_finalizer``) +finalizes the reference-matched scheme (int8-symmetric weights, uint16 +activations, GroupQueryAttention excluded from QDQ) plus the decode-trajectory +calibration reader. + +These tests download Qwen3-0.6B from HuggingFace and run a full CPU export + +quantize, so they are gated behind ``slow`` + ``network`` and excluded from the +default lane. The QNN/NPU build is additionally gated on a real NPU. + +All expectations are generated in-code (FP reference greedy decode), never +hardcoded from a prior model run. +""" + +from __future__ import annotations + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from winml.modelkit.config import WinMLBuildConfig +from winml.modelkit.models.auto import WinMLAutoModel +from winml.modelkit.quant import WinMLQuantizationConfig + + +pytestmark = [pytest.mark.e2e, pytest.mark.slow, pytest.mark.network] + +MODEL_ID = "Qwen/Qwen3-0.6B" +MAX_CACHE = 256 +PARITY_TOKENS = 8 +DECODE_STEPS = 12 +# Keep CPU calibration cheap: the decode reader emits ``samples * 16`` feeds. +CALIB_SAMPLES = 4 + + +def _qnn_available() -> bool: + """True when ONNX Runtime exposes the QNN execution provider (real NPU).""" + return "QNNExecutionProvider" in ort.get_available_providers() + + +def _decoder_onnx_path(model) -> str: + """Locate the quantized decode ONNX behind the model handle. + + The decode-only build (``seq_len=1``) returns a single + ``WinMLModelForGenericTask`` whose ``onnx_path`` is the quantized graph; a + full composite build instead exposes it under ``sub_models["decoder_gen"]``. + Handle both so the test does not depend on which wrapper the build picks. + """ + sub_models = getattr(model, "sub_models", None) + if sub_models and "decoder_gen" in sub_models: + return str(sub_models["decoder_gen"].onnx_path) + return str(model.onnx_path) + + +def _qdq_counts(onnx_path: str) -> dict[str, int]: + graph = onnx.load(onnx_path, load_external_data=False).graph + counts: dict[str, int] = {} + for node in graph.node: + counts[node.op_type] = counts.get(node.op_type, 0) + 1 + return counts + + +def _gqa_tensor_set(graph) -> set[str]: + tensors: set[str] = set() + for node in graph.node: + if node.op_type == "GroupQueryAttention": + tensors.update(node.input) + tensors.update(node.output) + return tensors + + +@pytest.fixture(scope="module") +def decode_quant_model(tmp_path_factory): + """Build + quantize the decode (seq_len=1) sub-model once on CPU.""" + cache_dir = tmp_path_factory.mktemp("qwen3_w8a16") + return WinMLAutoModel.from_pretrained( + MODEL_ID, + task="text2text-generation", + model_type="qwen3_transformer_only", + config=WinMLBuildConfig(quant=WinMLQuantizationConfig(samples=CALIB_SAMPLES)), + precision="w8a16", + device="cpu", + ep="cpu", + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": 1}, + cache_dir=str(cache_dir), + ) + + +@pytest.mark.timeout(2400) +def test_decode_model_is_quantized_with_gqa_excluded(decode_quant_model): + onnx_path = _decoder_onnx_path(decode_quant_model) + counts = _qdq_counts(onnx_path) + + # QDQ nodes were inserted via the config-driven pipeline. + assert counts.get("QuantizeLinear", 0) > 0 + assert counts.get("DequantizeLinear", 0) > 0 + # GroupQueryAttention survives in float (not quantized away). + assert counts.get("GroupQueryAttention", 0) > 0 + + # GQA exclusion contract: no QuantizeLinear/DequantizeLinear touches a GQA + # input or output tensor (attention stays Cast -> GQA -> Cast). + graph = onnx.load(onnx_path, load_external_data=False).graph + gqa_tensors = _gqa_tensor_set(graph) + touching = [ + node.name + for node in graph.node + if node.op_type in ("QuantizeLinear", "DequantizeLinear") + and (set(node.input) & gqa_tensors or set(node.output) & gqa_tensors) + ] + assert touching == [] + + +def _carry_kv(kv: dict[str, np.ndarray], out: dict[str, np.ndarray], num_layers: int) -> None: + for i in range(num_layers): + kv[f"past_keys_{i}"] = out[f"present_keys_{i}"] + kv[f"past_values_{i}"] = out[f"present_values_{i}"] + + +def _seed_kv_from_fp(past, num_layers, num_kv_heads, head_dim, cur_len): + """Copy an HF FP prefill cache into the decode model's fixed FP16 buffers.""" + kv: dict[str, np.ndarray] = {} + for i in range(num_layers): + layer = past[i] if not hasattr(past, "layers") else None + if layer is not None: + k, v = past[i][0], past[i][1] + else: # newer per-layer DynamicCache + k, v = past.layers[i].keys, past.layers[i].values + kbuf = np.zeros((1, num_kv_heads, MAX_CACHE, head_dim), np.float16) + vbuf = np.zeros_like(kbuf) + kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + kv[f"past_keys_{i}"] = kbuf + kv[f"past_values_{i}"] = vbuf + return kv + + +@pytest.mark.timeout(2400) +def test_decode_parity_against_fp_reference(decode_quant_model): + """The w8a16 decode model must track the FP reference token-for-token. + + This is the regression guard against the historical "decode collapse": + a degenerate calibration (single repeated token + zeroed KV) made the + quantized decode model diverge into garbage after ~1 token. With the + decode-trajectory reader the quantized greedy trajectory must match the + FP reference for the first ``PARITY_TOKENS`` tokens. + """ + onnx_path = _decoder_onnx_path(decode_quant_model) + session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + want = {i.name for i in session.get_inputs()} + + hf = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32) + hf.eval() + cfg = hf.config + embed = hf.get_input_embeddings() + lm_head = hf.lm_head + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + num_layers = cfg.num_hidden_layers + num_kv_heads = cfg.num_key_value_heads + head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads) + + text = tokenizer.apply_chat_template( + [{"role": "user", "content": "What is the capital of France?"}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + cur_len = ids.shape[1] + assert cur_len < MAX_CACHE + + # --- FP reference greedy decode (generates the expected tokens) --- + with torch.no_grad(): + out = hf(input_ids=ids, use_cache=True) + fp_past = out.past_key_values + first_tok = int(out.logits[:, -1, :].argmax(-1)) + fp_tokens: list[int] = [] + tok, past = first_tok, fp_past + for _ in range(DECODE_STEPS): + with torch.no_grad(): + out = hf(input_ids=torch.tensor([[tok]]), past_key_values=past, use_cache=True) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + fp_tokens.append(tok) + + # --- Quantized decode model greedy decode (own KV, FP embed + lm_head) --- + with torch.no_grad(): + seed = hf(input_ids=ids, use_cache=True) + kv = _seed_kv_from_fp(seed.past_key_values, num_layers, num_kv_heads, head_dim, cur_len) + quant_tokens: list[int] = [] + tok, past_len = first_tok, cur_len + for _ in range(DECODE_STEPS): + with torch.no_grad(): + emb = embed(torch.tensor([[tok]])).to(torch.float32).cpu().numpy() + feeds = { + "input_hidden_states": emb.astype(np.float32), + "past_seq_len": np.array([[past_len]], np.int32), + "total_seq_len": np.array([MAX_CACHE], np.int32), + **kv, + } + feeds = {k: v for k, v in feeds.items() if k in want} + names = [o.name for o in session.get_outputs()] + outs = dict(zip(names, session.run(None, feeds), strict=False)) + _carry_kv(kv, outs, num_layers) + hidden = torch.tensor(outs["output_hidden_states"][:, 0, :]) + with torch.no_grad(): + tok = int(lm_head(hidden).numpy()[0].argmax()) + quant_tokens.append(tok) + past_len += 1 + + assert quant_tokens[:PARITY_TOKENS] == fp_tokens[:PARITY_TOKENS], ( + f"w8a16 decode diverged from FP reference:\n" + f" fp : {fp_tokens[:PARITY_TOKENS]}\n" + f" quant: {quant_tokens[:PARITY_TOKENS]}" + ) + + +@pytest.mark.npu +@pytest.mark.qnn +@pytest.mark.timeout(2400) +@pytest.mark.skipif(not _qnn_available(), reason="requires QNN execution provider (NPU)") +@pytest.mark.parametrize( + ("task", "seq_len"), + [("feature-extraction", 64), ("text2text-generation", 1)], +) +def test_npu_build_quantizes(task, seq_len, tmp_path): + """On real NPU hardware, the w8a16 pipeline produces a quantized graph.""" + model = WinMLAutoModel.from_pretrained( + MODEL_ID, + task=task, + model_type="qwen3_transformer_only", + precision="w8a16", + device="npu", + ep="qnn", + no_compile=True, + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, + cache_dir=str(tmp_path), + ) + sub_name = "decoder_prefill" if seq_len == 64 else "decoder_gen" + onnx_path = str(model.sub_models[sub_name]._onnx_path) + counts = _qdq_counts(onnx_path) + assert counts.get("QuantizeLinear", 0) > 0 + assert counts.get("GroupQueryAttention", 0) > 0 diff --git a/tests/unit/loader/test_resolve_loader_config.py b/tests/unit/loader/test_resolve_loader_config.py index ea26e6cff..491af63ce 100644 --- a/tests/unit/loader/test_resolve_loader_config.py +++ b/tests/unit/loader/test_resolve_loader_config.py @@ -142,8 +142,13 @@ def test_model_type_only_creates_default_config(self) -> None: mock_create.assert_called_once_with("bert") assert loader_config.task == "feature-extraction" - def test_hf_config_never_mutated(self) -> None: - """hf_config is never mutated — model_type param does not override it.""" + def test_explicit_model_type_overrides_hf_config(self) -> None: + """An explicit model_type (with a model_id) overrides the resolved type. + + Needed so a variant model_type such as ``qwen3_transformer_only`` selects + the variant rather than the architecture's native type. The override only + applies when a model_id is present and the requested type differs. + """ mock_config = MagicMock() mock_config.model_type = "original_type" mock_class = MagicMock(spec=[]) @@ -164,10 +169,10 @@ def test_hf_config_never_mutated(self) -> None: "some-model", model_type="gpt2", task="text-generation" ) - # hf_config retains its original model_type — never mutated - assert hf_config.model_type == "original_type" - # loader_config.model_type reflects the REAL hf_config, not the param - assert loader_config.model_type == "original_type" + # The explicit model_type wins over the architecture's native type. + assert hf_config.model_type == "gpt2" + # loader_config.model_type reflects the overridden type. + assert loader_config.model_type == "gpt2" def test_auto_detect_task_from_model_type(self) -> None: """model_type without task auto-detects first supported task.""" diff --git a/tests/unit/quant/calibration/__init__.py b/tests/unit/quant/calibration/__init__.py new file mode 100644 index 000000000..862c45ce3 --- /dev/null +++ b/tests/unit/quant/calibration/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py new file mode 100644 index 000000000..5c8bd9d69 --- /dev/null +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -0,0 +1,233 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the transformer-only Qwen3 quant calibration readers. + +These are fast, offline tests (no model download, no ONNX Runtime): they +exercise the graph-shape introspection, GroupQueryAttention node discovery, +and the exact feed contract (names / dtypes / shapes) the two calibration +readers must satisfy. All expectations are derived in-code from the inputs, +never hardcoded from a model run. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import onnx +import torch + +from winml.modelkit.quant.calibration.qwen3_transformer_only import ( + Qwen3DecodeTrajectoryCalibReader, + Qwen3TransformerOnlyCalibReader, + _gqa_node_names, + _graph_shapes, +) + + +NUM_LAYERS = 2 +NUM_KV_HEADS = 2 +HEAD_DIM = 4 +HIDDEN = NUM_KV_HEADS * HEAD_DIM +VOCAB = 16 + + +def _fake_config() -> SimpleNamespace: + return SimpleNamespace( + num_hidden_layers=NUM_LAYERS, + num_key_value_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + hidden_size=HIDDEN, + num_attention_heads=NUM_KV_HEADS, + ) + + +def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: + """Write a minimal graph carrying the inputs the readers introspect.""" + inputs = [ + onnx.helper.make_tensor_value_info( + "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + ), + onnx.helper.make_tensor_value_info( + "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + ), + ] + out = onnx.helper.make_tensor_value_info( + "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + ) + gqa = onnx.helper.make_node( + "GroupQueryAttention", + ["input_hidden_states"], + ["attn_out"], + name="gqa_layer_0", + domain="com.microsoft", + ) + identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)]) + onnx.save(model, str(path)) + + +def test_graph_shapes_and_gqa_nodes(tmp_path): + p = tmp_path / "tiny.onnx" + _build_tiny_onnx(p, seq_len=1, max_cache_len=16) + + assert _graph_shapes(p) == (1, 16) + assert _gqa_node_names(p) == ["gqa_layer_0"] + + +def test_graph_shapes_prefill(tmp_path): + p = tmp_path / "tiny_prefill.onnx" + _build_tiny_onnx(p, seq_len=64, max_cache_len=256) + + assert _graph_shapes(p) == (64, 256) + + +def _drain(reader) -> list[dict[str, np.ndarray]]: + feeds = [] + while (feed := reader.get_next()) is not None: + feeds.append(feed) + return feeds + + +def test_prefill_reader_feed_contract(): + seq_len, max_cache_len = 4, 16 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + token_ids = [torch.tensor([[1, 2, 3, 4, 5]])] + + reader = Qwen3TransformerOnlyCalibReader( + embed, + _fake_config(), + token_ids, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + feeds = _drain(reader) + + assert len(feeds) == len(token_ids) + feed = feeds[0] + + # input_hidden_states: FP32, truncated to seq_len. + assert feed["input_hidden_states"].dtype == np.float32 + assert feed["input_hidden_states"].shape == (1, seq_len, HIDDEN) + + # seqlens_k contract: past_seq_len = seq_len - 1 (INT32 [1,1]). + assert feed["past_seq_len"].dtype == np.int32 + np.testing.assert_array_equal(feed["past_seq_len"], [[seq_len - 1]]) + + # total_seq_len: full cache (INT32 [1]). + assert feed["total_seq_len"].dtype == np.int32 + np.testing.assert_array_equal(feed["total_seq_len"], [max_cache_len]) + + # KV buffers: FP16, full cache shape, present for every layer. + for i in range(NUM_LAYERS): + for prefix in ("past_keys_", "past_values_"): + kv = feed[f"{prefix}{i}"] + assert kv.dtype == np.float16 + assert kv.shape == (1, NUM_KV_HEADS, max_cache_len, HEAD_DIM) + + # rewind() replays the same samples. + reader.rewind() + assert len(_drain(reader)) == len(token_ids) + + +def test_prefill_reader_pads_short_prompts(): + seq_len = 6 # longer than the 3-token prompt -> must pad + embed = torch.nn.Embedding(VOCAB, HIDDEN) + token_ids = [torch.tensor([[1, 2, 3]])] + + reader = Qwen3TransformerOnlyCalibReader( + embed, _fake_config(), token_ids, seq_len=seq_len, max_cache_len=16 + ) + feed = _drain(reader)[0] + assert feed["input_hidden_states"].shape == (1, seq_len, HIDDEN) + + +class _StubCausalLM: + """Minimal HF-like model: grows a tuple-of-tuples KV cache by 1 each call. + + Always predicts ``next_token`` so the trajectory is deterministic. + """ + + def __init__(self, next_token: int) -> None: + self.next_token = next_token + + def _cache(self, length: int): + return tuple( + ( + torch.randn(1, NUM_KV_HEADS, length, HEAD_DIM), + torch.randn(1, NUM_KV_HEADS, length, HEAD_DIM), + ) + for _ in range(NUM_LAYERS) + ) + + def __call__(self, input_ids=None, past_key_values=None, use_cache=True): + if past_key_values is None: + length = input_ids.shape[1] + query_len = length + else: + length = past_key_values[0][0].shape[2] + input_ids.shape[1] + query_len = input_ids.shape[1] + logits = torch.full((1, query_len, VOCAB), -10.0) + logits[..., self.next_token] = 10.0 + return SimpleNamespace(past_key_values=self._cache(length), logits=logits) + + +def test_decode_trajectory_reader_grows_past_seq_len(): + prefill_seq, decode_steps, max_cache_len = 2, 3, 16 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + hf_model = _StubCausalLM(next_token=5) + token_ids = [torch.tensor([[1, 2, 3, 4]])] # truncated to prefill_seq=2 + + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed, + _fake_config(), + token_ids, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + feeds = _drain(reader) + + assert len(feeds) == len(token_ids) * decode_steps + + # past_seq_len must grow monotonically from prefill_seq (real decode), not + # stay pinned at 0 like the degenerate single-token reader. + seq_lens = [int(f["past_seq_len"][0, 0]) for f in feeds] + assert seq_lens == [prefill_seq, prefill_seq + 1, prefill_seq + 2] + + for f in feeds: + # One token per decode step. + assert f["input_hidden_states"].shape == (1, 1, HIDDEN) + assert f["input_hidden_states"].dtype == np.float32 + cur_len = int(f["past_seq_len"][0, 0]) + for i in range(NUM_LAYERS): + kv = f[f"past_keys_{i}"] + assert kv.dtype == np.float16 + assert kv.shape == (1, NUM_KV_HEADS, max_cache_len, HEAD_DIM) + # Positions beyond the valid context stay zero-padded. + assert np.all(kv[:, :, cur_len:, :] == 0) + + +def test_decode_trajectory_reader_respects_max_cache(): + prefill_seq, decode_steps, max_cache_len = 4, 10, 6 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + hf_model = _StubCausalLM(next_token=2) + token_ids = [torch.tensor([[1, 2, 3, 4, 5, 6]])] + + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed, + _fake_config(), + token_ids, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + feeds = _drain(reader) + # Trajectory must stop once the cache is full (cur_len reaches max_cache_len). + assert len(feeds) == max_cache_len - prefill_seq + assert max(int(f["past_seq_len"][0, 0]) for f in feeds) == max_cache_len - 1 diff --git a/tests/unit/quant/calibration/test_registry.py b/tests/unit/quant/calibration/test_registry.py new file mode 100644 index 000000000..b60f74b9b --- /dev/null +++ b/tests/unit/quant/calibration/test_registry.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the quant finalizer registry. + +Fast, offline: no model download, no ONNX Runtime. Verifies that the +``model_type`` -> quant policy dispatch (lazy import + decorator registration) +resolves the registered Qwen3 finalizer and falls back to ``None`` (the +quantizer's default DatasetCalibrationReader path) for everything else. +""" + +from __future__ import annotations + +from winml.modelkit.quant import get_quant_finalizer +from winml.modelkit.quant.calibration import QuantConfigFinalizer + + +def test_registered_model_type_resolves_finalizer(): + """The qwen3_transformer_only policy is found via lazy registry import.""" + finalizer = get_quant_finalizer("qwen3_transformer_only") + assert finalizer is not None + assert isinstance(finalizer, QuantConfigFinalizer) + assert hasattr(finalizer, "finalize") + # Registry returns the concrete policy class, not the generic protocol. + assert type(finalizer).__name__ == "Qwen3TransformerOnlyQuantFinalizer" + + +def test_unregistered_model_type_returns_none(): + """Unknown / native model types have no policy -> default reader path.""" + assert get_quant_finalizer("resnet") is None + assert get_quant_finalizer("qwen3") is None + + +def test_none_model_type_returns_none(): + """A missing model_type must not raise and must not dispatch a policy.""" + assert get_quant_finalizer(None) is None + assert get_quant_finalizer("") is None