Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/recipes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Total: **75** (model, task) tuples that pass fp16 eval on all 10 (EP, device) bu
| ahotrod/electra_large_discriminator_squad2_512 | question-answering |
| apple/mobilevit-small | image-classification |
| cardiffnlp/twitter-roberta-base-sentiment-latest | text-classification |
| dandelin/vilt-b32-finetuned-vqa | visual-question-answering |
| dbmdz/bert-large-cased-finetuned-conll03-english | token-classification |
| deepset/bert-large-uncased-whole-word-masking-squad2 | question-answering |
| deepset/roberta-base-squad2 | question-answering |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
{
"export": {
"opset_version": 17,
"batch_size": 1,
"export_params": true,
"do_constant_folding": true,
"verbose": false,
"dynamo": false,
"enable_hierarchy_tags": true,
"clean_onnx": false,
"hierarchy_tag_format": "full",
"input_tensors": [
{
"name": "input_ids",
"dtype": "int32",
"shape": [
1,
40
],
"value_range": [
0,
30522
]
},
{
"name": "attention_mask",
"dtype": "int32",
"shape": [
1,
40
],
"value_range": [
0,
2
]
},
{
"name": "token_type_ids",
"dtype": "int32",
"shape": [
1,
40
],
"value_range": [
0,
2
]
},
{
"name": "pixel_values",
"dtype": "float32",
"shape": [
1,
3,
384,
384
],
"value_range": [
0,
1
]
}
],
"output_tensors": [
{
"name": "logits"
}
]
},
"optim": {},
"quant": null,
"compile": null,
"loader": {
"task": "visual-question-answering",
"model_class": "ViltForQuestionAnswering",
"model_type": "vilt"
}
}
3 changes: 3 additions & 0 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
VisionDecoderIOConfig as _VisionDecoderIOConfig, # triggers registration
)
from .vision_encoder_decoder import VisionEncoderIOConfig as _VisionEncoderIOConfig
from .vilt import MODEL_CLASS_MAPPING as _VILT_CLASS_MAPPING
from .vilt import ViltVqaOnnxConfig as _ViltVqaOnnxConfig # triggers registration
from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration


Expand All @@ -97,6 +99,7 @@
**_SIGLIP_CLASS_MAPPING,
**_T5_CLASS_MAPPING,
**_VED_CLASS_MAPPING,
**_VILT_CLASS_MAPPING,
}

# Registry: model_type -> WinMLBuildConfig
Expand Down
242 changes: 242 additions & 0 deletions src/winml/modelkit/models/hf/vilt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""ViLT (Vision-and-Language Transformer) HuggingFace Model Configuration.

ViLT is a single-stream multi-modal transformer that processes text + image
in a unified attention stack. The ``ViltForQuestionAnswering`` head produces
classification logits over a fixed VQAv2 answer vocabulary (3129 labels for
``dandelin/vilt-b32-finetuned-vqa``).

Optimum has NO vendor ``ViltOnnxConfig`` (verified 2026-06-24: ``vilt`` is
absent from ``TasksManager._SUPPORTED_MODEL_TYPE`` for the transformers
library). This module writes the export config from scratch.

The forward takes 4 required tensors (pixel_mask is omitted — see Notes):
- ``pixel_values`` [B, 3, 384, 384] RGB image
- ``input_ids`` [B, 40] tokenized question
- ``attention_mask`` [B, 40] text padding mask
- ``token_type_ids`` [B, 40] BERT segment IDs (modality)

Output: ``logits`` [B, num_labels] over the answer vocabulary.

Notes
-----
ViLT's stock ``visual_embed`` is fundamentally NOT ONNX-traceable: it iterates
Python-level over tensor values (``for h, w in zip(x_h, x_w)``), uses
``torch.multinomial`` (random + non-exportable), and does per-row Python loops
over ``nonzero()`` results. We replace it during export with a statically-
shaped equivalent (see ``_patched_visual_embed`` + ``_ViltVisualEmbedPatcher``)
that assumes an all-ones ``pixel_mask`` — which is exactly what ``ViltProcessor``
emits in production (the processor pre-pads images to 384×384). Because the
patched path ignores ``pixel_mask``, we drop it from the exported ONNX graph.
Verified numerically equivalent: ``cos=1.000000``, same argmax,
max_abs_diff≈1.2e-5.

This is an Effort-L1 contribution per the `adding-model-support` skill:
new OnnxConfig from scratch + custom model patcher.
"""

from __future__ import annotations

import types

from optimum.exporters.onnx import OnnxConfig
from optimum.exporters.onnx.model_patcher import ModelPatcher
from optimum.utils import NormalizedTextConfig
from optimum.utils.input_generators import DummyVisionInputGenerator
from transformers import ViltForQuestionAnswering

from ...export import MaxLengthTextInputGenerator, register_onnx_overwrite


# =============================================================================
# Export-time patch for ``ViltEmbeddings.visual_embed``
# =============================================================================
# Upstream ``visual_embed`` is **not ONNX-traceable** as written:
# * ``for h, w in zip(x_h, x_w)`` iterates Python-level over tensor values
# * ``nonzero()`` + ``unique()`` + per-row Python list-comprehension subset
# selection over a dynamic ``valid_idx``
# * ``torch.multinomial`` random sampling (non-deterministic, not exportable)
# The eager path silently "works" only when ``pixel_mask`` is all-ones and the
# batch is 1, because the Python loop runs once with a concrete value. Under
# legacy ``torch.onnx.export`` tracing the shape resolves to 0 and PyTorch's
# ``F.interpolate`` aborts with ``input (H: 12, W: 12) output (H: 0, W: 0)``.
#
# For the production ``visual-question-answering`` inference path the
# ``ViltProcessor`` ALWAYS pads to 384×384 and emits an all-ones ``pixel_mask``,
# so the per-sample subset selection is a no-op. We replace ``visual_embed``
# during export with a simplified, statically-shaped implementation that:
# * uses ``x.shape[2], x.shape[3]`` (static) for position-embed interpolation
# * skips ``multinomial`` / nonzero / Python-level batch loops
# * returns an all-ones token mask of length ``H*W + 1`` (patches + CLS)
#
# Verified numerically equivalent on ``dandelin/vilt-b32-finetuned-vqa`` with
# fixed seed: ``cos=1.000000``, same ``argmax`` class, ``max_abs_diff≈1.2e-5``
# (within fp tolerance from interpolation operation ordering).


def _patched_visual_embed(self, pixel_values, pixel_mask, max_image_length=200):
"""Static-shape, ONNX-traceable replacement for ``ViltEmbeddings.visual_embed``."""
import torch
from torch import nn

x = self.patch_embeddings(pixel_values)
batch_size, num_channels, height, width = x.shape

patch_dim = self.config.image_size // self.config.patch_size
spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(
1, num_channels, patch_dim, patch_dim
)
pos_embed = nn.functional.interpolate(
spatial_pos, size=(height, width), mode="bilinear", align_corners=True
)
pos_embed = pos_embed.flatten(2).transpose(1, 2).expand(batch_size, -1, -1)

x = x.flatten(2).transpose(1, 2)

cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
pos_cls = self.position_embeddings[:, 0:1, :].expand(batch_size, -1, -1)
pos_embed = torch.cat((pos_cls, pos_embed), dim=1)
x = x + pos_embed
x = self.dropout(x)

num_tokens = height * width + 1 # patches + CLS
x_mask = torch.ones(batch_size, num_tokens, dtype=torch.long, device=x.device)
return x, x_mask, None


class _ViltVisualEmbedPatcher(ModelPatcher):
"""Swaps ``ViltEmbeddings.visual_embed`` for the duration of ONNX export."""

def __enter__(self):
super().__enter__()
emb = self._model.vilt.embeddings if hasattr(self._model, "vilt") else self._model.embeddings
self._emb_ref = emb
self._orig_visual_embed = emb.visual_embed
emb.visual_embed = types.MethodType(_patched_visual_embed, emb)
return self

def __exit__(self, exc_type, exc_value, traceback):
self._emb_ref.visual_embed = self._orig_visual_embed
super().__exit__(exc_type, exc_value, traceback)


# =============================================================================
# Optimum ONNX Export Config Registration
# =============================================================================
@register_onnx_overwrite("vilt", "visual-question-answering", library_name="transformers")
class ViltVqaOnnxConfig(OnnxConfig):
"""ONNX export config for ``ViltForQuestionAnswering``.

Declares 4 multi-modal inputs (text triple + pixel_values) and the single
classification output. ``pixel_mask`` is deliberately omitted — see
``inputs`` property docstring and the module-level ``Notes`` section for
the full rationale.

Inputs:
- ``input_ids``: [B, 40] int64
- ``attention_mask``: [B, 40] int64
- ``token_type_ids``: [B, 40] int64
- ``pixel_values``: [B, 3, 384, 384] float32

Outputs:
- ``logits``: [B, num_labels=3129] float32

Notes:
- ``num_labels`` (3129 for VQAv2) is a config-time fact, not declared
dynamic in the symbolic axes — it's a static dim of ``logits``.
- ``sequence_length`` resolves to ``max_position_embeddings`` (40 for
ViLT-B/32) via ``NORMALIZED_CONFIG_CLASS``; the
``MaxLengthTextInputGenerator`` reads this for dummy tokens.
- Chained ``DummyVisionInputGenerator`` + ``MaxLengthTextInputGenerator``
produce ``pixel_values`` + ``input_ids``/``attention_mask``/
``token_type_ids``. The patched ``visual_embed`` (see module-level
``_ViltVisualEmbedPatcher``) synthesizes an all-ones token mask
internally, so no ``pixel_mask`` input is required.
"""

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
sequence_length="max_position_embeddings",
num_channels="num_channels",
image_size="image_size",
patch_size="patch_size",
allow_new=True,
)

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyVisionInputGenerator,
MaxLengthTextInputGenerator,
)

DEFAULT_ONNX_OPSET = 17

@property
def inputs(self) -> dict[str, dict[int, str]]:
"""Declare 4 model inputs (insertion order matches forward).

``pixel_values`` H,W is kept STATIC — ViLT interpolates position
embeddings from the actual H,W, and exposing those as dynamic symbols
trips the ONNX ``Resize`` shape-inference (``input (H:12 W:12) output
(H:0 W:0)``). Pinning H,W matches all known production usage (always
384×384 input via ``ViltProcessor``).

Note: ViLT's ``forward`` also takes a ``pixel_mask`` parameter, but
this contribution exports without it. The ``ViltProcessor`` always
emits an all-ones mask (the image is padded to 384×384 before the
model sees it), and our export-time ``ModelPatcher`` replaces the
original ``visual_embed`` with a statically-shaped version that
synthesizes an all-ones token mask internally. Including ``pixel_mask``
as an ONNX input would dead-code-eliminate (since the patched path
doesn't reference it) and confuse runtime callers.
"""
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"token_type_ids": {0: "batch_size", 1: "sequence_length"},
"pixel_values": {0: "batch_size"},
}

@property
def outputs(self) -> dict[str, dict[int, str]]:
"""Single classification output over fixed answer vocabulary."""
return {
"logits": {0: "batch_size"},
}

def generate_dummy_inputs(self, framework: str = "pt", **kwargs): # type: ignore[override]
"""Generate the 4 declared inputs via the chained vendor generators.

``pixel_mask`` is intentionally NOT generated — see ``inputs`` docstring.
Our model patcher's replacement ``visual_embed`` synthesizes an
all-ones token mask internally, so the model can be called with the
4 declared inputs.
"""
dummy = super().generate_dummy_inputs(framework=framework, **kwargs)
# Drop any pixel_mask the generators may have produced — the patched
# visual_embed ignores it (and including it would error at sess.run
# since it isn't in the exported ONNX graph).
dummy.pop("pixel_mask", None)
return dummy

def patch_model_for_export(self, model, model_kwargs=None): # type: ignore[override]
"""Install the ``visual_embed`` patcher for the export context."""
return _ViltVisualEmbedPatcher(self, model, model_kwargs=model_kwargs)


# =============================================================================
# HuggingFace Model Class Mapping
# =============================================================================
# ``visual-question-answering`` has no default AutoModel routing for ViLT;
# bind the (model_type, task) tuple directly to the head-bearing HF class.
MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = {
("vilt", "visual-question-answering"): ViltForQuestionAnswering,
}


__all__ = [
"ViltVqaOnnxConfig",
"MODEL_CLASS_MAPPING",
]
Loading