Skip to content

Add ViLT (dandelin/vilt-b32-finetuned-vqa) visual-question-answering support#951

Draft
ssss141414 wants to merge 1 commit into
mainfrom
shzhen/add-vilt-vqa
Draft

Add ViLT (dandelin/vilt-b32-finetuned-vqa) visual-question-answering support#951
ssss141414 wants to merge 1 commit into
mainfrom
shzhen/add-vilt-vqa

Conversation

@ssss141414

Copy link
Copy Markdown
Contributor

Summary

Adds first-class support for ViLT under the visual-question-answering task, validated on dandelin/vilt-b32-finetuned-vqa.

ViLT has no vendor optimum coverage, and its stock ViltEmbeddings.visual_embed is fundamentally not ONNX-traceable (Python iteration over tensor shapes, torch.multinomial, per-row nonzero() loops). Eager works because the loops resolve concretely; tracing fails. This PR therefore ships:

  1. A from-scratch ViltVqaOnnxConfig(OnnxConfig) registered via @register_onnx_overwrite("vilt", "visual-question-answering").
  2. A _ViltVisualEmbedPatcher(ModelPatcher) that swaps visual_embed for a static-shape replacement using nn.functional.interpolate(spatial_pos, size=(H, W), mode='bilinear', align_corners=True) and a synthesized all-ones token mask.
  3. Pinned static H/W on pixel_values; pixel_mask is intentionally omitted from the export signature since the patched path doesn't read it (leaving it in would create a dead graph input).
  4. Recipe + README row + model class mapping wired into models/hf/__init__.py.

Files changed

File Kind
src/winml/modelkit/models/hf/vilt.py NEW (190 LOC)
src/winml/modelkit/models/hf/__init__.py +3 (wiring)
examples/recipes/dandelin_vilt-b32-finetuned-vqa/visual-question-answering_config.json NEW
examples/recipes/README.md +1 row

Validation (dandelin/vilt-b32-finetuned-vqa @ CPU fp32)

Gate Result
L0 build ✅ Build complete in 62.9s (Export 29.8s, Optimize 32.2s); 449.2 MB optimized ONNX
L1 perf ✅ mean=67.49 ms, p50=65.83 ms, p90=76.52 ms, throughput=14.82 samples/sec, std=5.92 ms (20 iters, warmup 3)
L2 numerics (PT vs ORT) ✅ cos=1.000000, max_abs_diff=4.2e-5, top class match (3129-way head)
Patched-vs-original PT parity ✅ cos=1.000000, max_abs_diff=1.2e-5, same argmax
L3 dataset eval ⏭ skipped (no default VQA dataset wired)

Notes for reviewers

  • Inputs declared: input_ids, attention_mask, token_type_ids (dynamic batch_size/sequence_length), pixel_values (only batch_size dynamic, H/W=384 static).
  • Output: logits (3129-way), batch_size dynamic.
  • Opset 17, fp32, CPU/auto resolution.
  • Recipe value_range for mask-of-ones inputs must be [1, 2] not [0, 1] because randint high is exclusive — relevant if anyone re-derives this recipe.

…support

Adds OnnxConfig + ModelPatcher for ViLT visual-question-answering since vendor optimum coverage is absent and stock ViltEmbeddings.visual_embed is not ONNX-traceable (Python iteration over tensor shapes, torch.multinomial, per-row nonzero loops). Patcher swaps in a static-shape replacement using nn.functional.interpolate for spatial position embeddings and a synthesized all-ones token mask. H/W axes are pinned static; pixel_mask is intentionally dropped since the patched path does not reference it.

Validated on dandelin/vilt-b32-finetuned-vqa @ CPU fp32:
- L0 build: 62.9s, 449.2 MB optimized ONNX
- L1 perf: p50=65.83ms, throughput=14.82 samples/sec (20 iters, warmup 3)
- L2 numerics: cos=1.000000, max_abs_diff=4.2e-5, top-class match (3129-way head)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant