Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/winml/modelkit/build/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def build_hf_model(
cache_key: str | None = None,
ep: EPNameOrAlias | None = None,
device: str | None = None,
model_type: str | None = None,
**kwargs: Any,
) -> BuildResult:
"""Build an ONNX model from a HuggingFace model architecture.
Expand Down Expand Up @@ -211,6 +212,7 @@ def _name(base: str) -> str:
model_id,
trust_remote_code,
random_init=random_init,
model_type=model_type,
)

# =========================================================================
Expand Down Expand Up @@ -315,6 +317,31 @@ def _name(base: str) -> str:
else:
logger.info("Quantizing model...")
t0 = time.monotonic()
# Some model types finalize their quant config only once the
# exported ONNX exists (calibration feeds / nodes-to-exclude derived
# from the graph). Resolve the model-type-specific quant policy from
# the quant registry, keyed on the live ``model_type``. Unregistered
# types return None → the quantizer uses its standard task-aware
# DatasetCalibrationReader.
from ..quant import get_quant_finalizer

resolved_model_type = (
getattr(getattr(pytorch_model, "config", None), "model_type", None) or model_type
)
quant_finalizer = get_quant_finalizer(resolved_model_type)
if quant_finalizer is not None:
# Generic id fallback: the policy loads a fresh reference model
# for calibration, so feed it the best-known HF id/path.
resolved_model_id = model_id or getattr(
getattr(pytorch_model, "config", None), "_name_or_path", None
)
config.quant = quant_finalizer.finalize(
config.quant, onnx_path=current_path, model_id=resolved_model_id
)
# The policy may overwrite the quant scheme (dtypes, symmetry,
# nodes-to-exclude) authoritatively, so re-persist the config
# to keep config.json consistent with what was actually applied.
config_path.write_text(json.dumps(config.to_dict(), indent=2))
quant_result = quantize_onnx(
model_path=current_path,
output_path=quantized_path,
Expand Down Expand Up @@ -443,6 +470,7 @@ def _load_model(
trust_remote_code: bool,
random_init: bool = False,
hf_config: Any | None = None,
model_type: str | None = None,
) -> Any:
"""Load PyTorch model — pretrained or random weights.

Expand Down Expand Up @@ -518,6 +546,7 @@ def _load_model(
task=task,
trust_remote_code=effective_trust,
hf_config=hf_config,
model_type=model_type,
)
return pytorch_model

Expand Down
13 changes: 13 additions & 0 deletions src/winml/modelkit/loader/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ def resolve_loader_config(
f"attribute. Cannot proceed with config generation."
)

# Explicit model_type override alongside a model_id: honor the requested
# type so downstream class / build-config / export resolution selects the
# variant (e.g. "qwen3_transformer_only") rather than the architecture's
# native type. The model_type-only path above (AutoConfig.for_model) is
# unaffected because it only runs when model_id is None.
if model_id is not None and model_type is not None and hf_config.model_type != model_type:
logger.info(
"Overriding resolved model_type '%s' -> '%s' (explicit request)",
hf_config.model_type,
model_type,
)
hf_config.model_type = model_type

# 2-3. Unified resolution. Task detection — including the no-architectures
# --model-type fallback (first supported task) — now lives in resolve_task.
resolution = resolve_task(hf_config, task=task, model_class=model_class)
Expand Down
13 changes: 13 additions & 0 deletions src/winml/modelkit/loader/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def load_hf_model(
user_script: str | None = None,
trust_remote_code: bool = False,
hf_config: PretrainedConfig | None = None,
model_type: str | None = None,
) -> tuple[nn.Module, PretrainedConfig, str]:
"""Load, detect task, and prepare HuggingFace model.

Expand Down Expand Up @@ -218,6 +219,18 @@ def load_hf_model(
trust_remote_code=trust_remote_code,
)

# Explicit model_type override: select a registered build variant (e.g.
# "qwen3_transformer_only") rather than the architecture's native type.
# Mutates the freshly-loaded config only; gated on an explicit request so
# normal loading is unaffected.
if model_type is not None and getattr(hf_config, "model_type", None) != model_type:
logger.info(
"Overriding model_type '%s' -> '%s' (explicit request)",
getattr(hf_config, "model_type", None),
model_type,
)
hf_config.model_type = model_type

# [2] Task & Model Class Resolution
from .resolution import resolve_task

Expand Down
16 changes: 15 additions & 1 deletion src/winml/modelkit/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def from_pretrained(
trust_remote_code: bool = False,
shape_config: dict | None = None,
no_compile: bool = False,
model_type: str | None = None,
allow_unsupported_nodes: bool = False,
**kwargs: Any,
) -> WinMLPreTrainedModel:
Expand Down Expand Up @@ -300,6 +301,10 @@ def from_pretrained(
shape_config: Shape overrides passed to generate_build_config().
Valid keys -- text: sequence_length; vision: height, width;
audio: feature_size, nb_max_frames, audio_sequence_length.
model_type: Explicit model_type override. When provided alongside a
HF model_id, selects a registered build variant (e.g.
``"qwen3_transformer_only"``) instead of the architecture's
native model_type. Leave ``None`` for normal auto-detection.
allow_unsupported_nodes: If True, warn instead of raising when the
analyzer reports unsupported nodes that persist; the build
proceeds and the EP may fall back to another device for them.
Expand Down Expand Up @@ -361,6 +366,11 @@ def from_pretrained(
else:
_model_type = None

# Explicit override wins so a variant composite (e.g.
# "qwen3_transformer_only") can be selected over the native type.
if model_type is not None:
_model_type = model_type

if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY:
from .winml.composite_model import WinMLCompositeModel

Expand Down Expand Up @@ -398,6 +408,7 @@ def from_pretrained(
trust_remote_code=trust_remote_code,
ep=kwargs.get("ep"),
no_compile=no_compile,
model_type=model_type,
)

resolved_task = build_config.loader.task
Expand Down Expand Up @@ -432,7 +443,9 @@ def from_pretrained(
from transformers import AutoConfig

hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=effective_trust)
model_type = getattr(hf_config, "model_type", "unknown")
# Honor an explicit model_type override; otherwise probe from the config.
if model_type is None:
model_type = getattr(hf_config, "model_type", "unknown")
logger.debug("Model type: %s, task: %s", model_type, resolved_task)

# =====================================================================
Expand Down Expand Up @@ -470,6 +483,7 @@ def from_pretrained(
cache_key=cache_key,
ep=resolved_ep,
device=device,
model_type=model_type,
allow_unsupported_nodes=allow_unsupported_nodes,
**build_control_kwargs,
)
Expand Down
11 changes: 11 additions & 0 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
from .qwen import QWEN_CONFIG
from .qwen import QwenGenIOConfig as _QwenGenIOConfig
from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig
from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING
from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG
from .qwen_transformer_only import (
QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration
)
from .qwen_transformer_only import (
# triggers registration
QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig,
)
from .roberta import ROBERTA_FAMILY_CONFIG
from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration
from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING
Expand Down Expand Up @@ -92,6 +101,7 @@
**_MARIAN_CLASS_MAPPING,
**_MU2_CLASS_MAPPING,
**_QWEN_CLASS_MAPPING,
**_QWEN_TO_CLASS_MAPPING,
**_SAM2_CLASS_MAPPING,
**_SEGFORMER_CLASS_MAPPING,
**_SIGLIP_CLASS_MAPPING,
Expand All @@ -115,6 +125,7 @@
"roberta": ROBERTA_FAMILY_CONFIG,
"mu2": MU2_CONFIG,
"qwen3": QWEN_CONFIG,
"qwen3-transformer-only": QWEN_TRANSFORMER_ONLY_CONFIG,
"siglip": SIGLIP_CONFIG,
"siglip-text-model": SIGLIP_CONFIG,
"siglip-vision-model": SIGLIP_CONFIG,
Expand Down
185 changes: 185 additions & 0 deletions src/winml/modelkit/models/hf/qwen3_export_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Custom ONNX export ops that reshape HF's Qwen3 modules for export.

These reshape the standard HF Qwen3 modules so winml-cli can produce a
QNN-friendly, transformer-only graph:

- ``LpNormalization`` replaces the eager RMSNorm Mul/Pow/ReduceMean chain.
- ``com.microsoft::GroupQueryAttention`` replaces the eager QKV MatMul +
Softmax + KV-update path (with built-in rotary).
- 1x1 ``Conv`` (NHWC<->NCHW) replaces ``nn.Linear`` for QNN-friendly
projections.

Everything here operates only on the standard ``transformers.models.qwen3``
module attributes.
"""

from __future__ import annotations

from typing import Any

import torch
import torch.nn as nn
from torch.onnx import symbolic_helper


# =============================================================================
# Custom ONNX symbolic functions
# =============================================================================


class LpNormOnnxExport(torch.autograd.Function):
"""RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim)."""

@staticmethod
def symbolic(g, input, axis, p) -> Any:
"""Emit the ONNX ``LpNormalization`` node during export."""
output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input))
output = g.op(
"onnx::LpNormalization",
input,
axis_i=int(axis),
p_i=int(p),
)
return output.setType(output_type)

@staticmethod
def forward(ctx, input, axis, p) -> Any:
"""Real ``LpNormalization`` (``input / ||input||_p`` along ``axis``).

The exported node comes from ``symbolic``; this eager body computes the
same value so any eager execution (unit tests, calibration debug runs,
the exporter's own shape-tracing pass) gets correctly normalized output
instead of a silent identity. It matches the ONNX op faithfully (no
RMSNorm epsilon), since that is exactly what ``symbolic`` emits.
"""
return input / torch.linalg.vector_norm(input, ord=p, dim=axis, keepdim=True)


class GroupQueryAttentionOnnxExport(torch.autograd.Function):
"""Fused Q/K/V + KV-cache + rotary → ``com.microsoft::GroupQueryAttention``."""

@staticmethod
def symbolic(
g,
query,
key,
value,
past_key,
past_value,
seqlens_k,
total_sequence_length,
cos_cache,
sin_cache,
do_rotary,
kv_num_heads,
num_heads,
) -> Any:
"""Emit the fused ``com.microsoft::GroupQueryAttention`` node."""
args = [
query,
key,
value,
past_key,
past_value,
seqlens_k,
total_sequence_length,
cos_cache,
sin_cache,
]
attention_output, present_keys, present_values = g.op(
"com.microsoft::GroupQueryAttention",
*args,
do_rotary_i=int(do_rotary),
kv_num_heads_i=int(kv_num_heads),
num_heads_i=int(num_heads),
outputs=3,
)

query_sizes = symbolic_helper._get_tensor_sizes(query)
attention_output.setType(query.type().with_sizes(query_sizes))
present_keys.setType(
past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key))
)
present_values.setType(
past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value))
)
return attention_output, present_keys, present_values

@staticmethod
def forward(
ctx,
query,
key,
value,
past_key,
past_value,
seqlens_k,
total_sequence_length,
cos_cache,
sin_cache,
do_rotary,
kv_num_heads,
num_heads,
) -> Any:
"""Shape-only tracing placeholder; returns a stand-in ``(output, KV)``.

The real op is emitted by ``symbolic`` during ONNX export; this body
only needs to return tensors of the right shape/dtype. It deliberately
does NOT raise on eager execution, even though that yields a stale
(never-advanced) KV cache: the HTP export pipeline runs a real eager
``forward`` pass to capture the module hierarchy (see
``export/htp/hierarchy.py::trace_model_execution``), and that pass is
indistinguishable from misuse — ``torch.jit.is_tracing()`` and
``torch.onnx.is_in_onnx_export()`` are both False there — so raising
would break the actual build. There is also no cheap faithful eager
equivalent (correct attention would grow the sequence axis that the
static-shape export freezes). This module is export-only by design and
is never run for real inference; calibration loads a fresh real model.
"""
return query, past_key, past_value # placeholder shapes (export-only)


# =============================================================================
# 1x1 Conv replacement for nn.Linear
# =============================================================================


class TransposeConv2d1x1Transpose(nn.Module):
"""``nn.Linear`` → 1x1 ``Conv2d`` with NHWC<->NCHW permutes."""

def __init__(
self,
in_channels: int,
out_channels: int,
weight: torch.nn.Parameter,
bias: torch.nn.Parameter | None = None,
) -> None:
super().__init__()
# Linear weight is (out, in); Conv2d weight is (out, in, 1, 1).
self.weight = nn.Parameter(weight.data.view(out_channels, in_channels, 1, 1))
self.bias = bias

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the 1x1 conv with NHWC<->NCHW permutes (+ optional bias)."""
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
x = torch.nn.functional.conv2d(x, self.weight)
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
if self.bias is not None:
x = x + self.bias
return x

@classmethod
def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose:
"""Build a 1x1-conv replacement from an existing ``nn.Linear``."""
return cls(linear.in_features, linear.out_features, linear.weight, linear.bias)


__all__ = [
"GroupQueryAttentionOnnxExport",
"LpNormOnnxExport",
"TransposeConv2d1x1Transpose",
]
Loading
Loading