From 991f06c8eacf39bd63de85d12ac9b5d086f55255 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Wed, 24 Jun 2026 15:49:45 +0800 Subject: [PATCH 1/3] # type: ignore[misc] # optimum base is untyped --- src/winml/modelkit/__init__.py | 3 ++- src/winml/modelkit/cli.py | 11 +++++++++-- src/winml/modelkit/models/__init__.py | 4 ++-- src/winml/modelkit/models/hf/bart.py | 6 +++--- src/winml/modelkit/models/hf/bert.py | 2 +- src/winml/modelkit/models/hf/blip.py | 4 ++-- src/winml/modelkit/models/hf/clip.py | 4 ++-- src/winml/modelkit/models/hf/convnext.py | 2 +- src/winml/modelkit/models/hf/depth_pro.py | 4 ++-- src/winml/modelkit/models/hf/marian.py | 6 +++--- src/winml/modelkit/models/hf/mu2.py | 4 ++-- src/winml/modelkit/models/hf/qwen.py | 4 ++-- src/winml/modelkit/models/hf/roberta.py | 8 ++++---- src/winml/modelkit/models/hf/sam.py | 20 ++++++++++---------- src/winml/modelkit/models/hf/segformer.py | 4 ++-- src/winml/modelkit/models/hf/siglip.py | 4 ++-- src/winml/modelkit/models/hf/t5.py | 4 ++-- src/winml/modelkit/models/hf/zoedepth.py | 2 +- 18 files changed, 52 insertions(+), 44 deletions(-) diff --git a/src/winml/modelkit/__init__.py b/src/winml/modelkit/__init__.py index 3e3142d71..463a60b7f 100644 --- a/src/winml/modelkit/__init__.py +++ b/src/winml/modelkit/__init__.py @@ -31,6 +31,7 @@ import logging import sys from importlib.metadata import PackageNotFoundError, version +from typing import Any # Force utf-8 stdout/stderr so emoji and Unicode output (rich console, logs, @@ -98,7 +99,7 @@ def _preload_bundled_onnxruntime_dll() -> None: } -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy-load heavy exports on first access (PEP 562). This avoids importing torch/transformers/optimum (~30s) when only diff --git a/src/winml/modelkit/cli.py b/src/winml/modelkit/cli.py index 2e4745950..d910242c7 100644 --- a/src/winml/modelkit/cli.py +++ b/src/winml/modelkit/cli.py @@ -23,9 +23,14 @@ import logging from importlib import import_module from pathlib import Path +from typing import TYPE_CHECKING import click + +if TYPE_CHECKING: + from rich.console import Console + from . import __version__ from .telemetry import ActionGroup from .telemetry import telemetry as _telemetry_mod @@ -78,7 +83,7 @@ def _gradient_color(t: float) -> tuple[int, int, int]: return _GRADIENT[-1][1] -def _print_banner(version: str, *, _console: object | None = None) -> None: +def _print_banner(version: str, *, _console: Console | None = None) -> None: """Print the WinML CLI gradient banner to stderr using Rich.""" from rich.console import Console # lazy import - keeps startup fast from rich.text import Text @@ -205,7 +210,9 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None discovered = attr return discovered - def resolve_command(self, ctx: click.Context, args: list[str]): + def resolve_command( + self, ctx: click.Context, args: list[str] + ) -> tuple[str | None, click.Command | None, list[str]]: """Seed ``self.commands`` so Click can emit a did-you-mean hint on typos.""" # Click's NoSuchCommand exception uses self.commands to find suggestions. for name in self.list_commands(ctx): diff --git a/src/winml/modelkit/models/__init__.py b/src/winml/modelkit/models/__init__.py index 2fa39bda4..c5ce1464c 100644 --- a/src/winml/modelkit/models/__init__.py +++ b/src/winml/modelkit/models/__init__.py @@ -22,7 +22,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from .hf import MODEL_BUILD_CONFIGS @@ -60,7 +60,7 @@ } -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy load modules that would cause circular imports.""" if name in _LAZY_IMPORTS: module_path, attr_name = _LAZY_IMPORTS[name] diff --git a/src/winml/modelkit/models/hf/bart.py b/src/winml/modelkit/models/hf/bart.py index bb204ac4f..278b1223b 100644 --- a/src/winml/modelkit/models/hf/bart.py +++ b/src/winml/modelkit/models/hf/bart.py @@ -358,7 +358,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("bart", "feature-extraction", library_name="transformers") -class BartEncoderIOConfig(OnnxConfig): +class BartEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for BART encoder (feature-extraction task). Inputs: input_ids, attention_mask @@ -385,7 +385,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 } -class _BartDecoderNormalizedConfig(NormalizedConfig): +class _BartDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped """NormalizedConfig for BART decoder-side export. Maps NormalizedConfig attributes to BartConfig's decoder-side attrs. @@ -404,7 +404,7 @@ def head_dim(self) -> int: @register_onnx_overwrite("bart", "text2text-generation", library_name="transformers") -class BartDecoderIOConfig(OnnxConfig): +class BartDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for BART decoder with sliding-window KV cache. Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, diff --git a/src/winml/modelkit/models/hf/bert.py b/src/winml/modelkit/models/hf/bert.py index d537c8010..c0df6ee7c 100644 --- a/src/winml/modelkit/models/hf/bert.py +++ b/src/winml/modelkit/models/hf/bert.py @@ -44,7 +44,7 @@ @register_onnx_overwrite("bert", *COMMON_TEXT_TASKS, library_name="transformers") -class BertIOConfig(BertOnnxConfig): +class BertIOConfig(BertOnnxConfig): # type: ignore[misc] # optimum base is untyped """BERT OnnxConfig using max_position_embeddings as sequence_length. Inputs: diff --git a/src/winml/modelkit/models/hf/blip.py b/src/winml/modelkit/models/hf/blip.py index 4aa2ee727..5063f9561 100644 --- a/src/winml/modelkit/models/hf/blip.py +++ b/src/winml/modelkit/models/hf/blip.py @@ -85,7 +85,7 @@ @register_onnx_overwrite("blip", "image-to-text", library_name="transformers") @register_onnx_overwrite("blip", "image-text-to-text", library_name="transformers") -class BlipCaptioningIOConfig(OnnxConfig): +class BlipCaptioningIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """Monolithic ONNX config for BLIP captioning — single-graph fallback. Traces ``BlipForConditionalGeneration.forward`` with pixel_values + @@ -148,7 +148,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: @register_onnx_overwrite("blip", "feature-extraction", library_name="transformers") -class BlipVisionEncoderIOConfig(OnnxConfig): +class BlipVisionEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for the BLIP vision encoder. ``image-feature-extraction`` is a synonym that Optimum's TasksManager diff --git a/src/winml/modelkit/models/hf/clip.py b/src/winml/modelkit/models/hf/clip.py index 045fcb160..28ed62b41 100644 --- a/src/winml/modelkit/models/hf/clip.py +++ b/src/winml/modelkit/models/hf/clip.py @@ -73,7 +73,7 @@ # Optimum ONNX Export Config Registrations # ============================================================================= @register_onnx_overwrite("clip_text_model", "feature-extraction", library_name="transformers") -class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig): +class CLIPTextModelIOConfig(CLIPTextWithProjectionOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for CLIPTextModelWithProjection from transformers. Model: openai/clip-vit-base-patch32 (text encoder only) @@ -108,7 +108,7 @@ def inputs(self) -> dict[str, dict[int, str]]: @register_onnx_overwrite("clip_vision_model", "feature-extraction", library_name="transformers") -class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig): +class CLIPVisionModelIOConfig(CLIPVisionModelOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for CLIPVisionModelWithProjection from transformers. Model: openai/clip-vit-base-patch32 (vision encoder only) diff --git a/src/winml/modelkit/models/hf/convnext.py b/src/winml/modelkit/models/hf/convnext.py index f1d453c74..5a43b0a30 100644 --- a/src/winml/modelkit/models/hf/convnext.py +++ b/src/winml/modelkit/models/hf/convnext.py @@ -102,7 +102,7 @@ def _build_patching_specs() -> list[PatchingSpec]: "image-classification", library_name="transformers", ) -class ConvNextIOConfig(ConvNextOnnxConfig): +class ConvNextIOConfig(ConvNextOnnxConfig): # type: ignore[misc] # optimum base is untyped """ConvNextOnnxConfig override that adds a LayerNorm fusion patch. Inherits all I/O specs from Optimum's ``ConvNextOnnxConfig``. The only diff --git a/src/winml/modelkit/models/hf/depth_pro.py b/src/winml/modelkit/models/hf/depth_pro.py index a5d53a770..3fe865c65 100644 --- a/src/winml/modelkit/models/hf/depth_pro.py +++ b/src/winml/modelkit/models/hf/depth_pro.py @@ -30,7 +30,7 @@ from ...export import register_onnx_overwrite -class _DepthProNormalizedConfig(NormalizedConfig): +class _DepthProNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped """Normalized config for DepthPro with computed image_size. image_size is derived from patch_size / min(scaled_images_ratios), @@ -47,7 +47,7 @@ def image_size(self) -> int: @register_onnx_overwrite("depth_pro", "depth-estimation", library_name="transformers") -class DepthProIOConfig(OnnxConfig): +class DepthProIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for DepthPro depth estimation. Model: apple/DepthPro-hf diff --git a/src/winml/modelkit/models/hf/marian.py b/src/winml/modelkit/models/hf/marian.py index 6251ff4ce..8540972a6 100644 --- a/src/winml/modelkit/models/hf/marian.py +++ b/src/winml/modelkit/models/hf/marian.py @@ -398,7 +398,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("marian", "feature-extraction", library_name="transformers") -class MarianEncoderIOConfig(OnnxConfig): +class MarianEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Marian encoder (feature-extraction task). Inputs: input_ids, attention_mask @@ -425,7 +425,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 } -class _MarianDecoderNormalizedConfig(NormalizedConfig): +class _MarianDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum base is untyped """NormalizedConfig for Marian decoder-side export. Maps NormalizedConfig attributes to MarianConfig's decoder-side attrs. @@ -444,7 +444,7 @@ def head_dim(self) -> int: @register_onnx_overwrite("marian", "text2text-generation", library_name="transformers") -class MarianDecoderIOConfig(OnnxConfig): +class MarianDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Marian decoder with sliding-window KV cache. Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, diff --git a/src/winml/modelkit/models/hf/mu2.py b/src/winml/modelkit/models/hf/mu2.py index 3efcabc5d..2a8dd11cc 100644 --- a/src/winml/modelkit/models/hf/mu2.py +++ b/src/winml/modelkit/models/hf/mu2.py @@ -194,7 +194,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("mu2", "feature-extraction", library_name="transformers") -class Mu2EncoderIOConfig(OnnxConfig): +class Mu2EncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Mu2 encoder (feature-extraction task).""" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -218,7 +218,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_onnx_overwrite("mu2", "text2text-generation", library_name="transformers") -class Mu2DecoderIOConfig(OnnxConfig): +class Mu2DecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Mu2 decoder with static KV cache.""" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( diff --git a/src/winml/modelkit/models/hf/qwen.py b/src/winml/modelkit/models/hf/qwen.py index 6f88a078d..0b8c9b45b 100644 --- a/src/winml/modelkit/models/hf/qwen.py +++ b/src/winml/modelkit/models/hf/qwen.py @@ -262,7 +262,7 @@ def _qwen_io_outputs(num_layers: int) -> dict[str, dict[int, str]]: @register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers") -class QwenPrefillIOConfig(OnnxConfig): +class QwenPrefillIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Qwen3 prefill (feature-extraction task). Inputs: input_ids [1, 64], attention_mask [1, 256], position_ids [1, 64], @@ -283,7 +283,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers") -class QwenGenIOConfig(OnnxConfig): +class QwenGenIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Qwen3 generation (text2text-generation task). Inputs: input_ids [1, 1], attention_mask [1, 256], position_ids [1, 1], diff --git a/src/winml/modelkit/models/hf/roberta.py b/src/winml/modelkit/models/hf/roberta.py index db4fb9103..94b31bf7e 100644 --- a/src/winml/modelkit/models/hf/roberta.py +++ b/src/winml/modelkit/models/hf/roberta.py @@ -122,7 +122,7 @@ def __init__(self, config, task, **kwargs): @register_onnx_overwrite("roberta", *COMMON_TEXT_TASKS, library_name="transformers") -class RobertaIOConfig(_RobertaPositionOffsetMixin, RobertaOnnxConfig): +class RobertaIOConfig(_RobertaPositionOffsetMixin, RobertaOnnxConfig): # type: ignore[misc] # optimum base is untyped """Roberta OnnxConfig with position-offset-adjusted sequence_length. Inputs (same as DistilBERT — no token_type_ids): @@ -137,17 +137,17 @@ class RobertaIOConfig(_RobertaPositionOffsetMixin, RobertaOnnxConfig): @register_onnx_overwrite("xlm-roberta", *COMMON_TEXT_TASKS, library_name="transformers") -class XLMRobertaIOConfig(_RobertaPositionOffsetMixin, XLMRobertaOnnxConfig): +class XLMRobertaIOConfig(_RobertaPositionOffsetMixin, XLMRobertaOnnxConfig): # type: ignore[misc] # optimum base is untyped """XLM-Roberta OnnxConfig with position-offset-adjusted sequence_length.""" @register_onnx_overwrite("camembert", *COMMON_TEXT_TASKS, library_name="transformers") -class CamemBERTIOConfig(_RobertaPositionOffsetMixin, CamembertOnnxConfig): +class CamemBERTIOConfig(_RobertaPositionOffsetMixin, CamembertOnnxConfig): # type: ignore[misc] # optimum base is untyped """CamemBERT OnnxConfig with position-offset-adjusted sequence_length.""" @register_onnx_overwrite("mpnet", *COMMON_TEXT_TASKS, library_name="transformers") -class MPNetIOConfig(_RobertaPositionOffsetMixin, MPNetOnnxConfig): +class MPNetIOConfig(_RobertaPositionOffsetMixin, MPNetOnnxConfig): # type: ignore[misc] # optimum base is untyped """MPNet OnnxConfig with position-offset-adjusted sequence_length. MPNet, like Roberta-family models, sets: diff --git a/src/winml/modelkit/models/hf/sam.py b/src/winml/modelkit/models/hf/sam.py index 421bd954d..0d5976118 100644 --- a/src/winml/modelkit/models/hf/sam.py +++ b/src/winml/modelkit/models/hf/sam.py @@ -593,7 +593,7 @@ def _patched_sam2_prompt_encoder_forward( } -class Sam2ModelPatcher(ModelPatcher): +class Sam2ModelPatcher(ModelPatcher): # type: ignore[misc] # optimum base is untyped """Custom ModelPatcher that applies SAM2 QNN-compatible patches during export. Patches Sam2MultiScaleBlock and Sam2PromptEncoder forward methods on all @@ -636,7 +636,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # ============================================================================= # Custom Dummy Input Generators for SAM2 # ============================================================================= -class Sam2PointsInputGenerator(DummyInputGenerator): +class Sam2PointsInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Points input generator for SAM2 decoder. Generates: @@ -684,7 +684,7 @@ def generate( ) -class Sam2EmbeddingsInputGenerator(DummyInputGenerator): +class Sam2EmbeddingsInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Embeddings input generator for SAM2 mask generation decoder. Generates raw (pre-projection) encoder outputs: @@ -728,7 +728,7 @@ def generate( return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) -class Sam2MaskInputGenerator(DummyInputGenerator): +class Sam2MaskInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Mask input generator for SAM2 decoder refinement. Generates: @@ -767,7 +767,7 @@ def generate( # ============================================================================= # Normalized Config with Default Image Size # ============================================================================= -class Sam2NormalizedVisionConfig(NormalizedVisionConfig): +class Sam2NormalizedVisionConfig(NormalizedVisionConfig): # type: ignore[misc] # optimum base is untyped """NormalizedVisionConfig with default image_size for SAM2. SAM2 uses 1024x1024 input images by default. @@ -798,7 +798,7 @@ def __getattr__(self, attr_name: str): @register_onnx_overwrite("sam2", "feature-extraction", library_name="transformers") @register_onnx_overwrite("sam2_video", "feature-extraction", library_name="transformers") @register_onnx_overwrite("sam2_vision_model", "feature-extraction", library_name="transformers") -class Sam2ImageEncoderIOConfig(OnnxConfig): +class Sam2ImageEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAM2 image encoder (vision_encoder component). Task: image-feature-extraction (encoder-only export) @@ -839,7 +839,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ----------------------------------------------------------------------------- @register_onnx_overwrite("sam2", "image-segmentation", library_name="transformers") @register_onnx_overwrite("sam2_video", "image-segmentation", library_name="transformers") -class Sam2IOConfig(OnnxConfig): +class Sam2IOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAM2 full model (encoder + decoder monolith). Task: image-segmentation (full model export) @@ -885,7 +885,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ----------------------------------------------------------------------------- @register_onnx_overwrite("sam2", "mask-generation", library_name="transformers") @register_onnx_overwrite("sam2_video", "mask-generation", library_name="transformers") -class Sam2MaskGenerationIOConfig(OnnxConfig): +class Sam2MaskGenerationIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAM2MaskGeneration (decoder with raw FPN inputs). Model: facebook/sam2-hiera-small (decoder wrapper) @@ -941,7 +941,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ============================================================================= # SAM v1 Custom Dummy Input Generators # ============================================================================= -class SamEmbeddingsInputGenerator(DummyInputGenerator): +class SamEmbeddingsInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Embeddings input generator for SAM v1 mask generation decoder. Generates: @@ -982,7 +982,7 @@ def generate( # Mask generation export (SAMMaskGeneration wrapper) - SAM v1 # ----------------------------------------------------------------------------- @register_onnx_overwrite("sam", "mask-generation", library_name="transformers") -class SamMaskGenerationIOConfig(OnnxConfig): +class SamMaskGenerationIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SAMMaskGeneration (SAM v1 decoder). Model: facebook/sam-vit-huge, facebook/sam-vit-large, facebook/sam-vit-base diff --git a/src/winml/modelkit/models/hf/segformer.py b/src/winml/modelkit/models/hf/segformer.py index 2e70b44da..0748545b0 100644 --- a/src/winml/modelkit/models/hf/segformer.py +++ b/src/winml/modelkit/models/hf/segformer.py @@ -32,7 +32,7 @@ } -class _SegformerVisionInputGenerator(DummyVisionInputGenerator): +class _SegformerVisionInputGenerator(DummyVisionInputGenerator): # type: ignore[misc] # optimum base is untyped """Vision input generator that uses preprocessor resolution over config.image_size. Optimum's DummyVisionInputGenerator prioritizes normalized_config.image_size @@ -74,7 +74,7 @@ def __init__( @register_onnx_overwrite("segformer", "image-segmentation", library_name="transformers") -class SegformerIOConfig(OnnxConfig): +class SegformerIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for Segformer semantic segmentation. Model: nvidia/segformer-b0-finetuned-ade-512-512 diff --git a/src/winml/modelkit/models/hf/siglip.py b/src/winml/modelkit/models/hf/siglip.py index c7f240d5a..6c55f9b72 100644 --- a/src/winml/modelkit/models/hf/siglip.py +++ b/src/winml/modelkit/models/hf/siglip.py @@ -64,7 +64,7 @@ # Optimum ONNX Export Config Registrations # ============================================================================= @register_onnx_overwrite("siglip_text_model", "feature-extraction", library_name="transformers") -class SiglipTextModelIOConfig(SiglipTextOnnxConfig): +class SiglipTextModelIOConfig(SiglipTextOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SiglipTextModel (text encoder only). Uses ``max_position_embeddings`` (64 for SigLIP) as the fixed sequence @@ -83,7 +83,7 @@ class SiglipTextModelIOConfig(SiglipTextOnnxConfig): @register_onnx_overwrite("siglip_vision_model", "feature-extraction", library_name="transformers") -class SiglipVisionModelIOConfig(SiglipVisionModelOnnxConfig): +class SiglipVisionModelIOConfig(SiglipVisionModelOnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for SiglipVisionModel (vision encoder only). Uses Optimum defaults; no overrides needed. diff --git a/src/winml/modelkit/models/hf/t5.py b/src/winml/modelkit/models/hf/t5.py index 686f43562..48a906f43 100644 --- a/src/winml/modelkit/models/hf/t5.py +++ b/src/winml/modelkit/models/hf/t5.py @@ -203,7 +203,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: @register_onnx_overwrite("t5", "feature-extraction", library_name="transformers") -class T5EncoderIOConfig(OnnxConfig): +class T5EncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for T5 encoder (feature-extraction task). Inputs: input_ids, attention_mask @@ -231,7 +231,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # noqa: D102 @register_onnx_overwrite("t5", "text2text-generation", library_name="transformers") -class T5DecoderIOConfig(OnnxConfig): +class T5DecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for T5 decoder with sliding-window KV cache. Inputs: decoder_input_ids, encoder_hidden_states, attention_mask, diff --git a/src/winml/modelkit/models/hf/zoedepth.py b/src/winml/modelkit/models/hf/zoedepth.py index e648b8548..83dda0e05 100644 --- a/src/winml/modelkit/models/hf/zoedepth.py +++ b/src/winml/modelkit/models/hf/zoedepth.py @@ -29,7 +29,7 @@ @register_onnx_overwrite("zoedepth", "depth-estimation", library_name="transformers") -class ZoeDepthIOConfig(OnnxConfig): +class ZoeDepthIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """ONNX config for ZoeDepth depth estimation. Model: Intel/zoedepth-nyu-kitti From 9f8f73c9cdf7f8e348b0c49601bdcca73745bf89 Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Wed, 24 Jun 2026 15:57:26 +0800 Subject: [PATCH 2/3] forward --- .../modelkit/models/winml/depth_estimation.py | 9 +++++++-- .../models/winml/feature_extraction.py | 2 +- .../models/winml/image_classification.py | 8 +++++--- .../models/winml/image_segmentation.py | 20 ++++++++++--------- .../modelkit/models/winml/object_detection.py | 10 +++++----- .../models/winml/question_answering.py | 8 +++++--- .../models/winml/sequence_classification.py | 8 +++++--- .../winml/zero_shot_image_classification.py | 2 +- 8 files changed, 40 insertions(+), 27 deletions(-) diff --git a/src/winml/modelkit/models/winml/depth_estimation.py b/src/winml/modelkit/models/winml/depth_estimation.py index dda6b5c3e..f57844390 100644 --- a/src/winml/modelkit/models/winml/depth_estimation.py +++ b/src/winml/modelkit/models/winml/depth_estimation.py @@ -12,13 +12,16 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import DepthEstimatorOutput from .base import WinMLPreTrainedModel +if TYPE_CHECKING: + import torch + logger = logging.getLogger(__name__) @@ -48,4 +51,6 @@ def forward(self, **kwargs: Any) -> DepthEstimatorOutput: # Fall back to first output for non-standard output names. predicted_depth = next(iter(outputs.values())) - return DepthEstimatorOutput(predicted_depth=predicted_depth) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return DepthEstimatorOutput(predicted_depth=cast("torch.FloatTensor", predicted_depth)) diff --git a/src/winml/modelkit/models/winml/feature_extraction.py b/src/winml/modelkit/models/winml/feature_extraction.py index 3b4ffca66..df3444159 100644 --- a/src/winml/modelkit/models/winml/feature_extraction.py +++ b/src/winml/modelkit/models/winml/feature_extraction.py @@ -15,7 +15,7 @@ from collections import OrderedDict from typing import Any -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .base import WinMLPreTrainedModel diff --git a/src/winml/modelkit/models/winml/image_classification.py b/src/winml/modelkit/models/winml/image_classification.py index 500cee1b3..f4e749dde 100644 --- a/src/winml/modelkit/models/winml/image_classification.py +++ b/src/winml/modelkit/models/winml/image_classification.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import ImageClassifierOutput @@ -32,7 +32,7 @@ class WinMLModelForImageClassification(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, **kwargs: Any, @@ -53,7 +53,9 @@ def forward( # Get logits (by name or first output) logits = outputs.get("logits", next(iter(outputs.values()))) - return ImageClassifierOutput(logits=logits) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return ImageClassifierOutput(logits=cast("torch.FloatTensor", logits)) @property def num_labels(self) -> int: diff --git a/src/winml/modelkit/models/winml/image_segmentation.py b/src/winml/modelkit/models/winml/image_segmentation.py index 8fe572c63..13d9d180c 100644 --- a/src/winml/modelkit/models/winml/image_segmentation.py +++ b/src/winml/modelkit/models/winml/image_segmentation.py @@ -19,11 +19,11 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from transformers.modeling_outputs import SemanticSegmenterOutput -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .base import WinMLPreTrainedModel @@ -48,10 +48,10 @@ class ImageSegmentationOutput(ModelOutput): outputs.pred_boxes — [B, num_queries, 4] """ - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - pred_boxes: torch.FloatTensor | None = None - pred_masks: torch.FloatTensor | None = None + loss: torch.Tensor | None = None + logits: torch.Tensor | None = None + pred_boxes: torch.Tensor | None = None + pred_masks: torch.Tensor | None = None class WinMLModelForImageSegmentation(WinMLPreTrainedModel): @@ -65,7 +65,7 @@ class WinMLModelForImageSegmentation(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, pixel_mask: torch.Tensor | np.ndarray | None = None, @@ -131,7 +131,7 @@ class WinMLModelForSemanticSegmentation(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, **kwargs: Any, @@ -152,7 +152,9 @@ def forward( # Get logits (by name or first output) logits = outputs.get("logits", next(iter(outputs.values()))) - return SemanticSegmenterOutput(logits=logits) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return SemanticSegmenterOutput(logits=cast("torch.FloatTensor", logits)) @property def num_labels(self) -> int: diff --git a/src/winml/modelkit/models/winml/object_detection.py b/src/winml/modelkit/models/winml/object_detection.py index 58e4186b8..d008d3256 100644 --- a/src/winml/modelkit/models/winml/object_detection.py +++ b/src/winml/modelkit/models/winml/object_detection.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any import torch -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .base import WinMLPreTrainedModel @@ -38,9 +38,9 @@ class ObjectDetectionOutput(ModelOutput): outputs.pred_boxes — [B, num_queries, 4] """ - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - pred_boxes: torch.FloatTensor | None = None + loss: torch.Tensor | None = None + logits: torch.Tensor | None = None + pred_boxes: torch.Tensor | None = None class WinMLModelForObjectDetection(WinMLPreTrainedModel): @@ -51,7 +51,7 @@ class WinMLModelForObjectDetection(WinMLPreTrainedModel): so that image_processor.post_process_object_detection() works. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, pixel_values: torch.Tensor | np.ndarray, pixel_mask: torch.Tensor | np.ndarray | None = None, diff --git a/src/winml/modelkit/models/winml/question_answering.py b/src/winml/modelkit/models/winml/question_answering.py index 4dc30ab22..f9ded4fde 100644 --- a/src/winml/modelkit/models/winml/question_answering.py +++ b/src/winml/modelkit/models/winml/question_answering.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import QuestionAnsweringModelOutput @@ -74,7 +74,9 @@ def forward( formatted = self._format_inputs(**inputs) outputs = self._run_inference(formatted) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns real float Tensors. return QuestionAnsweringModelOutput( - start_logits=outputs.get("start_logits"), - end_logits=outputs.get("end_logits"), + start_logits=cast("torch.FloatTensor | None", outputs.get("start_logits")), + end_logits=cast("torch.FloatTensor | None", outputs.get("end_logits")), ) diff --git a/src/winml/modelkit/models/winml/sequence_classification.py b/src/winml/modelkit/models/winml/sequence_classification.py index df84948ac..602590972 100644 --- a/src/winml/modelkit/models/winml/sequence_classification.py +++ b/src/winml/modelkit/models/winml/sequence_classification.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from transformers.modeling_outputs import SequenceClassifierOutput @@ -37,7 +37,7 @@ class WinMLModelForSequenceClassification(WinMLPreTrainedModel): Pipeline execution is done by WinMLAutoModel factory. """ - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, input_ids: torch.Tensor | np.ndarray, attention_mask: torch.Tensor | np.ndarray | None = None, @@ -70,7 +70,9 @@ def forward( # Get logits (by name or first output) logits = outputs.get("logits", next(iter(outputs.values()))) - return SequenceClassifierOutput(logits=logits) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. + return SequenceClassifierOutput(logits=cast("torch.FloatTensor", logits)) @property def num_labels(self) -> int: diff --git a/src/winml/modelkit/models/winml/zero_shot_image_classification.py b/src/winml/modelkit/models/winml/zero_shot_image_classification.py index 94132ede8..9ab089c45 100644 --- a/src/winml/modelkit/models/winml/zero_shot_image_classification.py +++ b/src/winml/modelkit/models/winml/zero_shot_image_classification.py @@ -16,7 +16,7 @@ import numpy as np import torch -from transformers.utils import ModelOutput +from transformers.utils.generic import ModelOutput from .composite_model import WinMLCompositeModel, register_composite_model From a8958f73b399780645572e00f4f8e1767f765e8e Mon Sep 17 00:00:00 2001 From: Hualiang Xie Date: Wed, 24 Jun 2026 16:26:57 +0800 Subject: [PATCH 3/3] so many.. --- src/winml/modelkit/models/hf/bart.py | 30 ++++++++---- .../modelkit/models/hf/decoder_wrapper.py | 2 +- .../modelkit/models/hf/depth_anything.py | 4 +- src/winml/modelkit/models/hf/marian.py | 28 +++++++---- .../models/hf/vision_encoder_decoder.py | 4 +- src/winml/modelkit/models/winml/base.py | 5 +- .../modelkit/models/winml/composite_model.py | 12 +++-- .../modelkit/models/winml/decoder_only.py | 40 ++++++++++------ .../modelkit/models/winml/encoder_decoder.py | 48 ++++++++++++------- .../winml/zero_shot_image_classification.py | 9 +++- 10 files changed, 119 insertions(+), 63 deletions(-) diff --git a/src/winml/modelkit/models/hf/bart.py b/src/winml/modelkit/models/hf/bart.py index 278b1223b..cb44d2d74 100644 --- a/src/winml/modelkit/models/hf/bart.py +++ b/src/winml/modelkit/models/hf/bart.py @@ -71,7 +71,7 @@ from __future__ import annotations import logging -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -93,6 +93,9 @@ from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +if TYPE_CHECKING: + from transformers import GenerationConfig, PretrainedConfig + logger = logging.getLogger(__name__) @@ -140,7 +143,7 @@ def _patched_bart_learned_forward( - self, + self: Any, # monkey-patched onto BartLearnedPositionalEmbedding (HF internal) input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, @@ -229,10 +232,14 @@ def forward( attention_mask: torch.Tensor, ) -> torch.Tensor: """Return encoder last hidden state.""" - return self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - ).last_hidden_state + # self.encoder is a torch submodule (untyped __call__ -> Any). + return cast( + "torch.Tensor", + self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state, + ) class BartDecoderWrapper(nn.Module): @@ -262,8 +269,10 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: super().__init__() self.model = model self.num_layers = num_layers - # Expose config for OnnxConfig / NormalizedConfig access - self.config = model.config + # Expose config for OnnxConfig / NormalizedConfig access. + # model is typed nn.Module, so torch's __getattr__ types .config as + # Tensor | Module; it is really the model's PretrainedConfig. + self.config = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> BartDecoderWrapper: @@ -400,7 +409,8 @@ class _BartDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # o @property def head_dim(self) -> int: - return self.hidden_size // self.num_attention_heads + # hidden_size / num_attention_heads come from the untyped NormalizedConfig base. + return cast("int", self.hidden_size // self.num_attention_heads) @register_onnx_overwrite("bart", "text2text-generation", library_name="transformers") @@ -517,7 +527,7 @@ def get_cache_class(cls) -> type: return WinMLStaticCache # static cache (index_put_ → ScatterND) @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/hf/decoder_wrapper.py b/src/winml/modelkit/models/hf/decoder_wrapper.py index 5e7b69dfd..ab6932df6 100644 --- a/src/winml/modelkit/models/hf/decoder_wrapper.py +++ b/src/winml/modelkit/models/hf/decoder_wrapper.py @@ -156,7 +156,7 @@ def _invoke_hf( """Call the HF decoder with ``past_key_values=``. Returns logits.""" -class WinMLStaticCacheDecoderIOConfig(OnnxConfig): +class WinMLStaticCacheDecoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped """Semantic-name contract used by ``WinMLDecoderWrapper._make_cache``. Subclasses declare their own ``inputs`` / ``outputs`` bodies (each diff --git a/src/winml/modelkit/models/hf/depth_anything.py b/src/winml/modelkit/models/hf/depth_anything.py index 3e30a34e6..456b93eee 100644 --- a/src/winml/modelkit/models/hf/depth_anything.py +++ b/src/winml/modelkit/models/hf/depth_anything.py @@ -23,7 +23,7 @@ from ...export import register_onnx_overwrite -class _DepthAnythingVisionInputGenerator(DummyVisionInputGenerator): +class _DepthAnythingVisionInputGenerator(DummyVisionInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped """Vision input generator that lets explicit height/width override config.image_size. Optimum's DummyVisionInputGenerator prioritizes normalized_config.image_size @@ -62,7 +62,7 @@ def __init__( @register_onnx_overwrite("depth_anything", "depth-estimation", library_name="transformers") -class DepthAnythingIOConfig(OnnxConfig): +class DepthAnythingIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped """ONNX config for Depth Anything depth estimation. Model: depth-anything/Depth-Anything-V2-Small-hf diff --git a/src/winml/modelkit/models/hf/marian.py b/src/winml/modelkit/models/hf/marian.py index 8540972a6..f3b2abb13 100644 --- a/src/winml/modelkit/models/hf/marian.py +++ b/src/winml/modelkit/models/hf/marian.py @@ -85,7 +85,7 @@ from __future__ import annotations import logging -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch import torch.nn as nn @@ -107,6 +107,9 @@ from ..winml.kv_cache import PastKeyValueInputGenerator, WinMLStaticCache +if TYPE_CHECKING: + from transformers import GenerationConfig, PretrainedConfig + logger = logging.getLogger(__name__) @@ -177,7 +180,7 @@ def _patched_marian_sinusoidal_forward( - self, + self: Any, # monkey-patched onto MarianSinusoidalPositionalEmbedding (HF internal) input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None, @@ -262,10 +265,14 @@ def forward( attention_mask: torch.Tensor, ) -> torch.Tensor: """Return encoder last hidden state.""" - return self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - ).last_hidden_state + # self.encoder is a torch submodule (untyped __call__ -> Any). + return cast( + "torch.Tensor", + self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state, + ) class MarianDecoderWrapper(nn.Module): @@ -301,7 +308,9 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.model = model self.num_layers = num_layers # Expose config for OnnxConfig / NormalizedConfig access - self.config = model.config + # model is typed nn.Module, so torch's __getattr__ types .config as + # Tensor | Module; it is really the model's PretrainedConfig. + self.config = cast("PretrainedConfig", model.config) @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> MarianDecoderWrapper: @@ -440,7 +449,8 @@ class _MarianDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # @property def head_dim(self) -> int: - return self.hidden_size // self.num_attention_heads + # hidden_size / num_attention_heads come from the untyped NormalizedConfig base. + return cast("int", self.hidden_size // self.num_attention_heads) @register_onnx_overwrite("marian", "text2text-generation", library_name="transformers") @@ -554,7 +564,7 @@ def get_cache_class(cls) -> type: return WinMLStaticCache # static cache (index_put_ → ScatterND) @property - def generation_config(self): # noqa: D102 + def generation_config(self) -> GenerationConfig: # noqa: D102 if not hasattr(self, "_generation_config"): from transformers import GenerationConfig diff --git a/src/winml/modelkit/models/hf/vision_encoder_decoder.py b/src/winml/modelkit/models/hf/vision_encoder_decoder.py index a576b0d51..14e9a7c74 100644 --- a/src/winml/modelkit/models/hf/vision_encoder_decoder.py +++ b/src/winml/modelkit/models/hf/vision_encoder_decoder.py @@ -104,7 +104,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: @register_onnx_overwrite( "vision-encoder-decoder", "feature-extraction", library_name="transformers" ) -class VisionEncoderIOConfig(OnnxConfig): +class VisionEncoderIOConfig(OnnxConfig): # type: ignore[misc] # optimum/transformers base is untyped """ONNX config for the vision encoder.""" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -243,7 +243,7 @@ def _build_ved_patching_specs() -> list[PatchingSpec]: # ============================================================================= -class _VedDecoderNormalizedConfig(NormalizedConfig): +class _VedDecoderNormalizedConfig(NormalizedConfig): # type: ignore[misc] # optimum/transformers base is untyped """VED decoder NormalizedConfig. Per-architecture field paths (``decoder.d_model`` vs ``decoder.n_embd`` diff --git a/src/winml/modelkit/models/winml/base.py b/src/winml/modelkit/models/winml/base.py index e8656d646..ea9172032 100644 --- a/src/winml/modelkit/models/winml/base.py +++ b/src/winml/modelkit/models/winml/base.py @@ -21,7 +21,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import numpy as np import torch @@ -246,7 +246,8 @@ def task(self) -> str | None: if build_config is not None: loader = getattr(build_config, "loader", None) if loader: - return loader.task + # loader comes from getattr (Any); task is a str | None field. + return cast("str | None", loader.task) return None @property diff --git a/src/winml/modelkit/models/winml/composite_model.py b/src/winml/modelkit/models/winml/composite_model.py index 8c839859a..1d89b6c66 100644 --- a/src/winml/modelkit/models/winml/composite_model.py +++ b/src/winml/modelkit/models/winml/composite_model.py @@ -41,7 +41,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch @@ -49,7 +49,7 @@ if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Callable, Mapping from pathlib import Path from transformers import PretrainedConfig @@ -66,7 +66,7 @@ COMPOSITE_MODEL_REGISTRY: dict[tuple[str, str], type[WinMLCompositeModel]] = {} -def register_composite_model(model_type: str, task: str): +def register_composite_model(model_type: str, task: str) -> Callable[[type], type]: """Class decorator that registers a composite model for `winml config`.""" def decorator(cls: type) -> type: @@ -109,7 +109,7 @@ class WinMLCompositeModel(PreTrainedModel): def __init__( self, sub_models: dict[str, Any], - config: PretrainedConfig, + config: PretrainedConfig | None, device: str = "cpu", ) -> None: self.sub_models = sub_models @@ -243,7 +243,9 @@ def from_onnx( # Resolve concrete class from registry model_type = getattr(hf_config, "model_type", None) if hf_config else None if not cls._SUB_MODEL_CONFIG: - resolved_cls = COMPOSITE_MODEL_REGISTRY.get((model_type, task)) + # model_type/task may be None; the str-keyed registry simply misses + # (returns None, handled below). dict.get tolerates any hashable key. + resolved_cls = COMPOSITE_MODEL_REGISTRY.get(cast("tuple[str, str]", (model_type, task))) if resolved_cls is None: raise ValueError( f"No composite model for ({model_type!r}, {task!r}). " diff --git a/src/winml/modelkit/models/winml/decoder_only.py b/src/winml/modelkit/models/winml/decoder_only.py index 3bfa77700..d47e78703 100644 --- a/src/winml/modelkit/models/winml/decoder_only.py +++ b/src/winml/modelkit/models/winml/decoder_only.py @@ -58,7 +58,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from optimum.utils.input_generators import DummyInputGenerator @@ -71,6 +71,8 @@ if TYPE_CHECKING: from transformers import Cache, PretrainedConfig + from .kv_cache import WinMLCache + logger = logging.getLogger(__name__) @@ -79,7 +81,7 @@ # ========================================================================= -class DecoderOnlyInputGenerator(DummyInputGenerator): +class DecoderOnlyInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped """Generates base inputs for decoder-only models with static KV cache. Produces ``input_ids``, ``attention_mask``, ``position_ids``, and @@ -118,7 +120,9 @@ def __init__( self.batch_size = batch_size self.vocab_size = normalized_config.vocab_size self.max_cache_len = max_cache_len or normalized_config.max_cache_len - self.seq_len: int = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + self.seq_len: int = seq_len or cast( + "int", getattr(normalized_config, "seq_len", self._default_seq_len) + ) def generate( self, @@ -129,11 +133,15 @@ def generate( ) -> torch.Tensor: """Generate a dummy tensor for the given input name.""" if input_name == "input_ids": - return self.random_int_tensor( - (self.batch_size, self.seq_len), - max_value=self.vocab_size, - framework=framework, - dtype=int_dtype, + # optimum's DummyInputGenerator is untyped, so random_int_tensor returns Any. + return cast( + "torch.Tensor", + self.random_int_tensor( + (self.batch_size, self.seq_len), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ), ) if input_name == "attention_mask": mask = torch.zeros(self.batch_size, self.max_cache_len, dtype=torch.int64) @@ -225,7 +233,7 @@ def __init__( # ----- Cache + GenerationMixin interface ----- @classmethod - def get_cache_class(cls) -> type: + def get_cache_class(cls) -> type[WinMLCache]: """Return the WinMLCache subclass. Subclasses must override.""" raise NotImplementedError @@ -250,6 +258,8 @@ def _resolve_cache(self, past_key_values: Any) -> Any: if isinstance(past_key_values, WinMLCache): return past_key_values + if self.config is None: + raise ValueError("Decoder-only generation requires an HF config to build the KV cache.") kv_shape = [1, self._num_kv_heads, self._max_cache_len, self._head_dim] cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) cache.reset() @@ -258,7 +268,7 @@ def _resolve_cache(self, past_key_values: Any) -> Any: def can_generate(self) -> bool: # noqa: D102 return True - def prepare_inputs_for_generation( + def prepare_inputs_for_generation( # type: ignore[override] # GenerationMixin's base signature differs; static-cache flow self, input_ids: torch.LongTensor, past_key_values: Cache | None = None, @@ -269,7 +279,7 @@ def prepare_inputs_for_generation( from .kv_cache import WinMLCache if isinstance(past_key_values, WinMLCache) and past_key_values.get_seq_length() > 0: - input_ids = input_ids[:, -1:] + input_ids = cast("torch.LongTensor", input_ids[:, -1:]) else: past_key_values = None return { @@ -280,7 +290,7 @@ def prepare_inputs_for_generation( # ----- Forward ----- - def forward( + def forward( # type: ignore[override] # HF-pipeline base uses generic **kwargs; task-specific signature self, *, input_ids: torch.Tensor, @@ -311,8 +321,10 @@ def forward( else: logits = self._run_gen(input_ids, cache) + # transformers' Output fields are annotated FloatTensor (legacy, over-narrow); + # the ONNX session returns a real float Tensor. return CausalLMOutputWithPast( - logits=logits, + logits=cast("torch.FloatTensor", logits), past_key_values=cache, ) @@ -395,4 +407,4 @@ def _run_gen(self, input_ids: torch.Tensor, cache: Any) -> torch.Tensor: outputs = self._gen_model(**feeds) cache.update_all_layers(outputs) - return outputs["logits"] + return cast("torch.Tensor", outputs["logits"]) diff --git a/src/winml/modelkit/models/winml/encoder_decoder.py b/src/winml/modelkit/models/winml/encoder_decoder.py index ae1b669fe..8430b2477 100644 --- a/src/winml/modelkit/models/winml/encoder_decoder.py +++ b/src/winml/modelkit/models/winml/encoder_decoder.py @@ -57,7 +57,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch from optimum.utils.input_generators import DummyInputGenerator @@ -72,6 +72,8 @@ from optimum.utils import NormalizedConfig from transformers import Cache, PretrainedConfig + from .kv_cache import WinMLCache + logger = logging.getLogger(__name__) @@ -80,7 +82,7 @@ # ============================================================================= -class EncoderDecoderInputGenerator(DummyInputGenerator): +class EncoderDecoderInputGenerator(DummyInputGenerator): # type: ignore[misc] # optimum/transformers base is untyped """Generates decoder base inputs for encoder-decoder models. Produces ``decoder_input_ids``, ``encoder_hidden_states``, @@ -108,7 +110,9 @@ def __init__( ) -> None: self.batch_size = batch_size self.d_model = normalized_config.hidden_size - self.enc_seq = sequence_length or getattr(normalized_config, "sequence_length", 16) + self.enc_seq: int = sequence_length or cast( + "int", getattr(normalized_config, "sequence_length", 16) + ) self.max_cache_len = max_cache_len or normalized_config.max_cache_len self.vocab_size = normalized_config.vocab_size @@ -120,18 +124,25 @@ def generate( float_dtype: str = "fp32", ) -> torch.Tensor: """Generate a dummy tensor for the given input name.""" + # optimum's DummyInputGenerator is untyped, so random_*_tensor returns Any. if input_name == "decoder_input_ids": - return self.random_int_tensor( - (self.batch_size, 1), - max_value=self.vocab_size, - framework=framework, - dtype=int_dtype, + return cast( + "torch.Tensor", + self.random_int_tensor( + (self.batch_size, 1), + max_value=self.vocab_size, + framework=framework, + dtype=int_dtype, + ), ) if input_name == "encoder_hidden_states": - return self.random_float_tensor( - (self.batch_size, self.enc_seq, self.d_model), - framework=framework, - dtype=float_dtype, + return cast( + "torch.Tensor", + self.random_float_tensor( + (self.batch_size, self.enc_seq, self.d_model), + framework=framework, + dtype=float_dtype, + ), ) if input_name == "attention_mask": return torch.ones(self.batch_size, self.enc_seq, dtype=torch.int64) @@ -226,7 +237,8 @@ def __init__(self, encoder: Any, expected: dict[str, list[int]]) -> None: def forward(self, **kwargs: Any) -> BaseModelOutput: feeds = pad_inputs(kwargs, self._expected) - return self._encoder(**feeds) + # self._encoder is a torch Module (untyped __call__ -> Any). + return cast("BaseModelOutput", self._encoder(**feeds)) def get_encoder(self) -> torch.nn.Module: """Return encoder for GenerationMixin (already wrapped with padding).""" @@ -235,7 +247,7 @@ def get_encoder(self) -> torch.nn.Module: def can_generate(self) -> bool: # noqa: D102 return True - def prepare_inputs_for_generation( + def prepare_inputs_for_generation( # type: ignore[override] # GenerationMixin's base signature differs; static-cache flow self, input_ids: torch.LongTensor, past_key_values: Cache | None = None, @@ -260,7 +272,7 @@ def prepare_inputs_for_generation( # ----- Cache management ----- @classmethod - def get_cache_class(cls) -> type: + def get_cache_class(cls) -> type[WinMLCache]: """Return the WinMLCache subclass. Subclasses must override.""" raise NotImplementedError @@ -282,6 +294,8 @@ def _resolve_cache(self, past_key_values: Any) -> Any: return past_key_values # (3) Create fresh cache and reset + if self.config is None: + raise ValueError("Encoder-decoder generation requires an HF config to build the KV cache.") kv_shape = self._dec_expected["past_0_key"] cache = self.get_cache_class().create(self.config, kv_shape, self._kv_dtype) cache.reset() @@ -316,7 +330,9 @@ def forward( encoder_outputs = self._encoder(input_ids=input_ids, **model_kwargs) if encoder_outputs is None: raise ValueError("Either encoder_outputs or input_ids required") - enc_h = encoder_outputs["last_hidden_state"] + # The encoder wrapper always returns a dict-like BaseModelOutput; the tuple + # arm of the annotation exists only for GenerationMixin signature compat. + enc_h = cast("BaseModelOutput", encoder_outputs)["last_hidden_state"] # Resolve or create cache (subclasses override get_cache_class). cache = self._resolve_cache(past_key_values) diff --git a/src/winml/modelkit/models/winml/zero_shot_image_classification.py b/src/winml/modelkit/models/winml/zero_shot_image_classification.py index 9ab089c45..ba40a295d 100644 --- a/src/winml/modelkit/models/winml/zero_shot_image_classification.py +++ b/src/winml/modelkit/models/winml/zero_shot_image_classification.py @@ -12,7 +12,7 @@ import logging from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any, ClassVar, cast import numpy as np import torch @@ -101,7 +101,12 @@ def forward( def _preprocess_vision(self, pixel_values: torch.Tensor | None) -> dict[str, np.ndarray]: """Torch→numpy via the sub-model's formatter.""" - return self.sub_models["image-encoder"]._format_inputs(pixel_values=pixel_values) + # sub_models values are Any (heterogeneous WinML models); _format_inputs + # returns a {name: ndarray} feed dict. + return cast( + "dict[str, np.ndarray]", + self.sub_models["image-encoder"]._format_inputs(pixel_values=pixel_values), + ) def _run_vision(self, inputs: dict[str, np.ndarray]) -> torch.Tensor: """Run vision encoder over ``M`` images, batching per the ONNX's fixed batch dim."""