Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/winml/modelkit/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,20 @@ def eval(
cfg.precision,
)

# The build-pipeline flags only take effect when eval rebuilds the model.
# With a pre-built ONNX path and skip_build (the default), they are no-ops
# forwarded to from_onnx, so warn the user that they were ignored — mirrors
# the --precision warning above. Shared with perf via utils/cli.py.
build_flags_warning = cli_utils.ignored_build_flags_warning(
skip_build_onnx=cfg.model_path is not None and cfg.skip_build,
quant=cfg.quant,
optimize=cfg.optimize,
analyze=cfg.analyze,
max_optim_iterations=cfg.max_optim_iterations,
)
if build_flags_warning:
logger.warning(build_flags_warning)

logger.debug("Effective eval config: %s", cfg.to_dict())

json_mode = output_format == "json"
Expand Down
12 changes: 12 additions & 0 deletions src/winml/modelkit/commands/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,18 @@ def perf(
"pre-exported ONNX files (shapes are baked into the model)."
)
config.shape_config = None
# Build-pipeline flags are forwarded to from_onnx but no-op when the
# build is skipped (the default). Warn so the silent no-op is visible
# — shared detection with eval via utils/cli.py.
build_flags_warning = cli_utils.ignored_build_flags_warning(
skip_build_onnx=skip_build,
quant=quant,
optimize=optimize,
analyze=analyze,
max_optim_iterations=max_optim_iterations,
)
if build_flags_warning:
console.print(f"[yellow]Warning:[/yellow] {build_flags_warning}")
console.print(f"[dim]Benchmarking ONNX:[/dim] {model_path}")
else:
if precision != "auto":
Expand Down
16 changes: 0 additions & 16 deletions src/winml/modelkit/optim/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,6 @@ def _convert_to_kwargs(config: dict[str, Any], all_caps: dict[str, Any]) -> dict
return result


def _hack_inject_quant_preprocess_metadata(model: onnx.ModelProto) -> None:
"""Inject metadata that signals pre-processing was done.

Suppresses the ORT quantization warning:
'Please consider to run pre-processing before quantization.'
"""
metadata = {"onnx.quant.pre_process": "onnxruntime.quant"}
if model.metadata_props:
for prop in model.metadata_props:
metadata[prop.key] = prop.value
onnx.helper.set_model_props(model, metadata)


def optimize_onnx(
model: str | Path | onnx.ModelProto,
output: str | Path | None = None,
Expand Down Expand Up @@ -272,9 +259,6 @@ def optimize_onnx(
optimized_model = optimizer.optimize(loaded_model, **optimizer_kwargs)
optimized_model = optimizer.optimize(optimized_model, **optimizer_kwargs)

# Step 9.5: Inject quant pre-processing metadata to suppress ORT warning
_hack_inject_quant_preprocess_metadata(optimized_model)

# Step 10: Save if output path provided
if output is not None:
output_path = Path(output)
Expand Down
20 changes: 13 additions & 7 deletions src/winml/modelkit/quant/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,20 @@ def quantize_onnx(
extra_options=extra_options,
)

# Step 2: Capture metadata before ORT quantization (it rebuilds the graph)
# Step 2: Load the input model, capture its metadata snapshot (ORT
# rebuilds the graph during quantization, so we restore afterwards),
# and tag it as pre-processed so quantize_static() does not emit the
# "run pre-processing before quantization" warning. We hand this
# in-memory ModelProto to ORT directly rather than mutating the user's
# input file on disk.
from onnxruntime.quantization.quant_utils import add_pre_process_metadata

from ..onnx import capture_metadata, load_onnx, restore_metadata, save_onnx
from .qdq_fix import fix_qdq_dtype_info

pre_quant_model = load_onnx(model_path, load_weights=False, validate=False)
metadata_snapshot = capture_metadata(pre_quant_model)
del pre_quant_model
input_model = load_onnx(model_path, validate=False)
metadata_snapshot = capture_metadata(input_model)
add_pre_process_metadata(input_model)

# Step 3: Apply quantization
if use_external_data:
Expand All @@ -171,9 +178,8 @@ def quantize_onnx(
# output directory rather than the process CWD. Without this, a stale
# .onnx.data sidecar in the process CWD from a previous build triggers
# a false-positive FileExistsError even when the output dir is clean.
# Use absolute paths so the chdir does not break relative input/output
# Use an absolute output path so the chdir does not break its
# resolution. output_path.parent is guaranteed to exist (caller mkdir).
abs_model_input = str(Path(model_path).resolve())
abs_model_output = str(Path(output_path).resolve())
# Remove stale output artifacts from a previous build. ORT/onnx refuse
# to overwrite an existing external-data sidecar (e.g. quantized.onnx.data),
Expand All @@ -187,7 +193,7 @@ def quantize_onnx(
try:
os.chdir(output_path.parent)
quantize(
model_input=abs_model_input,
model_input=input_model,
model_output=abs_model_output,
quant_config=qdq_config,
)
Expand Down
52 changes: 50 additions & 2 deletions src/winml/modelkit/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,9 @@ def max_optim_iterations_option(optional_message: str | None = None) -> Callable
Returns:
Decorator function.
"""
base_help = "Maximum autoconf re-optimization rounds (default: 3). --no-analyze sets this to 0."
base_help = "Maximum autoconf re-optimization rounds (default: 3). --no-analyze sets this to 0"
if optional_message:
base_help = f"{base_help} {optional_message}"
base_help = f"{base_help}. {optional_message}"
return click.option(
"--max-optim-iterations",
"max_optim_iterations",
Expand Down Expand Up @@ -676,6 +676,54 @@ def build_pipeline_extra_kwargs(
return extra


def ignored_build_flags_warning(
*,
skip_build_onnx: bool,
quant: bool = True,
optimize: bool = True,
analyze: bool = True,
max_optim_iterations: int | None = None,
) -> str | None:
"""Build a warning for build-pipeline flags that are no-ops on a pre-built ONNX.

Commands that accept a pre-built ``.onnx`` input (``eval``, ``perf``) forward
``--no-quant``/``--no-optimize``/``--no-analyze``/``--max-optim-iterations`` to
``from_onnx``, but with ``skip_build`` (the default) no build runs, so those
toggles silently take no effect. This returns a message naming the flags the
user actually set (or ``None`` when nothing was set or a build will run), so
callers can surface it through their own logger/console — mirroring the
``--precision``-ignored warning.

Args:
skip_build_onnx: True when the input is a pre-built ONNX *and* the build
is skipped (the precondition under which the flags are no-ops).
quant/optimize/analyze: Enabled-semantics toggles (False = user passed
the ``--no-*`` form).
max_optim_iterations: Explicit value, or ``None`` when left at default.

Returns:
Warning message, or ``None`` if no ignored flags apply.
"""
if not skip_build_onnx:
return None
ignored = [
flag
for flag, was_set in (
("--no-quant", not quant),
("--no-optimize", not optimize),
("--no-analyze", not analyze),
("--max-optim-iterations", max_optim_iterations is not None),
)
if was_set
]
if not ignored:
return None
return (
f"{', '.join(ignored)} ignored for pre-built ONNX inputs "
"(no build runs; pass --no-skip-build to rebuild)."
)


def allow_unsupported_nodes_option(optional_message: str | None = None) -> Callable[[F], F]:
"""Add shared --allow-unsupported-nodes option to a Click command.

Expand Down
84 changes: 84 additions & 0 deletions tests/unit/commands/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,90 @@ def test_default_dataset_logs_warning(
), f"expected warning not found in {msgs!r}"


# ---------------------------------------------------------------------------
# Build-pipeline flags ignored for pre-built ONNX inputs
# ---------------------------------------------------------------------------


class TestPrebuiltOnnxIgnoredBuildFlags:
"""A pre-built ONNX path with skip_build (the default) makes the build
flags (--no-quant/--no-optimize/--no-analyze/--max-optim-iterations)
no-ops, so the command warns they were ignored."""

@staticmethod
def _run(runner: CliRunner, args: list[str]):
"""Invoke eval with ``evaluate`` stubbed so only the CLI front-half
(config resolution + warnings) runs. Returns the CliRunner result.

``commands.eval`` imports ``evaluate`` lazily via ``from ..eval import
evaluate``, so the stub is installed on the ``winml.modelkit.eval``
package where that import resolves it."""
from winml.modelkit.commands.eval import eval as eval_cmd

with (
patch("winml.modelkit.eval.evaluate", return_value=object()),
patch("winml.modelkit.commands.eval._write_and_display", return_value=None),
):
return runner.invoke(eval_cmd, args, obj={"debug": False})

def test_no_quant_warns_for_prebuilt_onnx(
self,
runner: CliRunner,
onnx_file,
caplog,
):
import logging as _logging

with caplog.at_level(_logging.WARNING, logger="winml.modelkit.commands.eval"):
result = self._run(
runner,
[
"-m",
str(onnx_file),
"--model-id",
"some/model",
"--task",
"image-classification",
"--no-quant",
"--max-optim-iterations",
"5",
],
)
assert result.exit_code == 0, result.output
msgs = [r.getMessage() for r in caplog.records]
assert any(
"--no-quant" in m and "--max-optim-iterations" in m and "pre-built ONNX" in m
for m in msgs
), f"expected ignored-build-flags warning not found in {msgs!r}"

def test_no_warning_when_flags_left_default(
self,
runner: CliRunner,
onnx_file,
caplog,
):
"""Default build flags emit no ignored-flags warning."""
import logging as _logging

with caplog.at_level(_logging.WARNING, logger="winml.modelkit.commands.eval"):
result = self._run(
runner,
[
"-m",
str(onnx_file),
"--model-id",
"some/model",
"--task",
"image-classification",
],
)
assert result.exit_code == 0, result.output
msgs = [r.getMessage() for r in caplog.records]
assert not any("ignored for pre-built ONNX inputs (no build runs" in m for m in msgs), (
f"unexpected ignored-build-flags warning in {msgs!r}"
)


# ---------------------------------------------------------------------------
# --format json
# ---------------------------------------------------------------------------
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/commands/test_perf_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,72 @@ def capture_config(config: BenchmarkConfig) -> MagicMock:
assert "Benchmarking ONNX" in result.output
assert captured["config"].shape_config is None

def test_cli_onnx_warns_ignored_build_flags(self, runner: CliRunner, tmp_path: Path) -> None:
"""Build-pipeline flags are no-ops for a pre-built ONNX with skip_build,
so the CLI surfaces a warning naming the flags the user set."""
onnx_file = tmp_path / "model.onnx"
onnx_file.write_bytes(b"fake onnx")

def capture_config(_config: BenchmarkConfig) -> MagicMock:
mock = MagicMock()
mock.run.return_value = MagicMock()
return mock

with (
patch(
"winml.modelkit.commands.perf.PerfBenchmark",
side_effect=capture_config,
),
patch("winml.modelkit.commands.perf.display_console_report"),
patch("winml.modelkit.commands.perf.write_json_report"),
):
result = runner.invoke(
perf,
[
"-m",
str(onnx_file),
"--no-quant",
"--no-optimize",
"-o",
str(tmp_path / "out.json"),
],
obj={},
)

assert result.exit_code == 0, result.output
assert "--no-quant" in result.output
assert "--no-optimize" in result.output
assert "pre-built ONNX" in result.output

def test_cli_onnx_no_build_flag_warning_at_defaults(
self, runner: CliRunner, tmp_path: Path
) -> None:
"""No ignored-build-flags warning when the flags are left at defaults."""
onnx_file = tmp_path / "model.onnx"
onnx_file.write_bytes(b"fake onnx")

def capture_config(_config: BenchmarkConfig) -> MagicMock:
mock = MagicMock()
mock.run.return_value = MagicMock()
return mock

with (
patch(
"winml.modelkit.commands.perf.PerfBenchmark",
side_effect=capture_config,
),
patch("winml.modelkit.commands.perf.display_console_report"),
patch("winml.modelkit.commands.perf.write_json_report"),
):
result = runner.invoke(
perf,
["-m", str(onnx_file), "-o", str(tmp_path / "out.json")],
obj={},
)

assert result.exit_code == 0, result.output
assert "ignored for pre-built ONNX inputs (no build runs" not in result.output

def test_cli_onnx_not_found_error(self, runner: CliRunner, tmp_path: Path) -> None:
"""CLI with non-existent .onnx file should raise FileNotFoundError."""
missing = tmp_path / "missing.onnx"
Expand Down
19 changes: 16 additions & 3 deletions tests/unit/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class _FakeOrtQuantizationModule(ModuleType):
QuantType: Any
get_qdq_config: Any
quantize: Any
quant_utils: ModuleType


class _FakeOrtQuantUtilsModule(ModuleType):
add_pre_process_metadata: Any


class _FakeOnnxModule(ModuleType):
Expand Down Expand Up @@ -74,9 +79,13 @@ def _install_fake_ort_quantization(monkeypatch: pytest.MonkeyPatch, *, quantize_
)
quant_module.get_qdq_config = lambda **_: SimpleNamespace(use_external_data_format=False)
quant_module.quantize = quantize_impl
quant_utils_module = _FakeOrtQuantUtilsModule("onnxruntime.quantization.quant_utils")
quant_utils_module.add_pre_process_metadata = lambda _model: None
quant_module.quant_utils = quant_utils_module
ort_module.quantization = quant_module
monkeypatch.setitem(sys.modules, "onnxruntime", ort_module)
monkeypatch.setitem(sys.modules, "onnxruntime.quantization", quant_module)
monkeypatch.setitem(sys.modules, "onnxruntime.quantization.quant_utils", quant_utils_module)


def test_quantize_onnx_removes_only_exact_external_data_sidecar(
Expand All @@ -92,8 +101,12 @@ def test_quantize_onnx_removes_only_exact_external_data_sidecar(
exact_sidecar.write_text("stale")
extra_suffix_sidecar.write_text("keep")

def fake_quantize(*, model_input: str, model_output: str, quant_config) -> None:
assert model_input == str(model_path.resolve())
# The quantizer hands the in-memory input model (not the path) to ORT so it
# can tag it as pre-processed without mutating the user's input file.
input_model = SimpleNamespace()

def fake_quantize(*, model_input, model_output: str, quant_config) -> None:
assert model_input is input_model
assert model_output == str(output_path.resolve())
assert quant_config.use_external_data_format is True
assert not exact_sidecar.exists()
Expand All @@ -106,7 +119,7 @@ def fake_quantize(*, model_input: str, model_output: str, quant_config) -> None:
quantized_model = SimpleNamespace(
graph=SimpleNamespace(node=[SimpleNamespace(op_type="QuantizeLinear")])
)
load_results = [SimpleNamespace(), quantized_model]
load_results = [input_model, quantized_model]
fake_onnx_module.capture_metadata = lambda _model: SimpleNamespace(node_count=0)
fake_onnx_module.load_onnx = lambda *_args, **_kwargs: load_results.pop(0)
fake_onnx_module.restore_metadata = lambda *_args, **_kwargs: None
Expand Down
Loading
Loading