Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/submit_aml/aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,14 @@ def submit_to_aml(
return None

# Build command that will be run
if project_dir != source_dir:
relative_project_dir = project_dir.relative_to(source_dir)
if command_prefix.startswith("uv run") and project_dir != source_dir:
try:
relative_project_dir = project_dir.relative_to(source_dir)
except ValueError as exc:
raise ValueError(
f"The project directory '{project_dir}' must be inside the source"
f" directory '{source_dir}' to append the uv --project flag."
) from exc
command_prefix += f" --project {relative_project_dir}"

if services is None:
Expand Down
89 changes: 89 additions & 0 deletions tests/test_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from __future__ import annotations

from pathlib import Path
from unittest.mock import patch

import pytest

from submit_aml.aml import CredentialType
from submit_aml.aml import _sanitize_experiment_name
from submit_aml.aml import get_client
from submit_aml.aml import submit_to_aml


def test_sanitize_none_returns_none() -> None:
Expand Down Expand Up @@ -56,3 +58,90 @@ def test_get_client_msi_uses_managed_identity(
"""CredentialType.MANAGED_IDENTITY uses ManagedIdentityCredential."""
get_client("sub", "rg", "ws", credential_type=CredentialType.MANAGED_IDENTITY)
mock_msi_cred.assert_called_once() # type: ignore[union-attr]


@pytest.mark.parametrize(
("command_prefix", "expected_in_command"),
[
("uv run", "--project subproject"),
("python", None),
],
ids=["uv-run-appends-project", "non-uv-skips-project"],
)
@patch("submit_aml.aml._submit")
@patch("submit_aml.aml.instantiate_command")
@patch("submit_aml.aml.infer_environment")
@patch("submit_aml.aml.setup")
def test_project_flag_only_appended_for_uv_run(
mock_setup: object,
mock_infer_env: object,
mock_instantiate: object,
mock_submit: object,
command_prefix: str,
expected_in_command: str | None,
) -> None:
"""``--project`` is appended only for ``uv run`` prefixes."""
source_dir = Path("/repo")
project_dir = source_dir / "subproject"
mock_setup.return_value = ( # type: ignore[attr-defined]
source_dir,
project_dir,
"run.py",
object(), # ml_client
"description",
1, # instance_count
None, # distribution
"experiment",
)

submit_to_aml(
command_prefix=command_prefix,
compute_target="cpu-cluster",
script_path="run.py",
subscription_id="sub",
resource_group="rg",
workspace_name="ws",
dry_run=True,
)

command = mock_instantiate.call_args.kwargs["command"] # type: ignore[attr-defined]
if expected_in_command is None:
assert "--project" not in command
else:
assert expected_in_command in command


@patch("submit_aml.aml._submit")
@patch("submit_aml.aml.instantiate_command")
@patch("submit_aml.aml.infer_environment")
@patch("submit_aml.aml.setup")
def test_uv_run_prefix_raises_when_project_not_under_source(
mock_setup: object,
mock_infer_env: object,
mock_instantiate: object,
mock_submit: object,
) -> None:
"""A clear ValueError is raised when project_dir is not inside source_dir."""
source_dir = Path("/repo")
project_dir = Path("/other/project")
mock_setup.return_value = ( # type: ignore[attr-defined]
source_dir,
project_dir,
"run.py",
object(), # ml_client
"description",
1, # instance_count
None, # distribution
"experiment",
)

with pytest.raises(ValueError, match="must be inside the source directory"):
submit_to_aml(
command_prefix="uv run",
compute_target="cpu-cluster",
script_path="run.py",
subscription_id="sub",
resource_group="rg",
workspace_name="ws",
dry_run=True,
)
Loading