Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ objects.
--mount "checkpoint=job_dir:my-training-job:models/best.pth"
```

Mount a raw folder on a registered datastore:

```bash
submit-aml \
--script evaluate.py \
--mount "ref_dir=mydatastore/exports/reference"
```

=== "Python"

```python
Expand All @@ -215,6 +223,12 @@ objects.
script_path="evaluate.py",
datasets_mount=["checkpoint=job_dir:my-training-job:models/best.pth"],
)

# Mount a raw folder on a registered datastore
submit_to_aml(
script_path="evaluate.py",
datasets_mount=["ref_dir=mydatastore/exports/reference"],
)
```

Configure an output datastore:
Expand Down
7 changes: 6 additions & 1 deletion src/submit_aml/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def submit(
" dataset, the argument should take the form: alias, name and version"
" of the dataset; for example: 'vindr_dir=VINDR-CXR-V2:1'."
" If the version is omitted, the last one will be used."
" To download a raw folder on a registered datastore, the argument should"
" take the form 'alias=datastore/path/to/dir'; for example:"
" 'ref_dir=mydatastore/exports/reference'."
" To download the output folder of a previous job, the argument should take"
" the form 'alias=job_dir:<job_id>:<path/in/job/outputs>'; for example:"
" 'checkpoint=job_dir:crusty_hat_43s6lmvb25:outputs/checkpoint-10000'."
Expand All @@ -160,7 +163,9 @@ def submit(
"-m",
help=(
"Azure ML dataset or job output folder to mount."
" For an Azure ML dataset, the alias, name and version should be provided"
" For an Azure ML dataset, the alias, name and version should be provided;"
" for a raw datastore folder, the alias, datastore and path should be"
" provided (e.g. 'ref_dir=mydatastore/exports/reference');"
" while for a job output folder, the alias, job ID and path in the job"
" outputs should be provided. See the --download option for more"
" information."
Expand Down
58 changes: 51 additions & 7 deletions src/submit_aml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def _extract_alias_path_version(string: str) -> tuple[str, str, str | None]:
sys.exit(1)


def _datastore_uri(datastore: str, path: str) -> str:
"""Build an Azure ML datastore URI for a folder on a datastore.

Args:
datastore: Name of the registered datastore.
path: Folder path within the datastore.

Returns:
An `azureml://` URI of the form
`azureml://datastores/<datastore>/paths/<path>`.
"""
return f"azureml://datastores/{datastore}/paths/{path}"


def _extract_alias_datastore_path(string: str) -> tuple[str, str, str]:
"""Get alias, datastore name and folder path from a string.

Expand Down Expand Up @@ -126,6 +140,24 @@ def _is_alias_job_path_string(string: str) -> bool:
return False


def _is_alias_datastore_path_string(string: str) -> bool:
"""Return True if the string refers to a raw datastore-path folder.

A datastore-path string has the form `'alias=datastore/folder'`. It is
distinguished from a data-asset name (`'alias=name[:version]'`) by the
presence of a `/` in the right-hand side, and from a job-output directory
by not starting with `job_dir:`.

This is intentionally a pure string check: `_extract_alias_datastore_path`
calls `sys.exit(1)` on a non-match rather than raising, so it cannot be
wrapped in try/except the way `_is_alias_job_path_string` is.
"""
if "=" not in string:
return False
_, rhs = string.split("=", 1)
return "/" in rhs and not rhs.startswith("job_dir:")


def build_command_inputs(
ml_client: MLClient,
strings_download: list[str] | None,
Expand All @@ -134,10 +166,13 @@ def build_command_inputs(
"""Get dictionaries data assets to be mounted or downloaded.

Args:
strings_download: List of strings of the form `'alias=path:version'` to
be downloaded. If `None`, no data assets will be downloaded.
strings_mount: List of strings of the form `'alias=path:version'` to
be mounted. If `None`, no data assets will be mounted.
strings_download: List of strings to be downloaded. Each is of the form
`'alias=name[:version]'` (registered data asset),
`'alias=datastore/folder'` (raw datastore path), or
`'alias=job_dir:<job_id>:<path>'` (previous job output).
If `None`, no data assets will be downloaded.
strings_mount: List of strings to be mounted, in the same forms as
`strings_download`. If `None`, no data assets will be mounted.
"""
strings_download = [] if strings_download is None else strings_download
strings_mount = [] if strings_mount is None else strings_mount
Expand Down Expand Up @@ -168,7 +203,7 @@ def build_command_outputs(
for string in strings_upload:
alias, datastore, path = _extract_alias_datastore_path(string)
output = Output(
path=f"azureml://datastores/{datastore}/paths/{path}",
path=_datastore_uri(datastore, path),
)
outputs_dict[alias] = output
return outputs_dict
Expand All @@ -182,8 +217,8 @@ def _get_data_assets(
"""Get data assets from Azure ML.

Args:
datasets: List of strings of the form `'alias=path:version'` or
`'alias=job_dir:<job_id>:<path>'`.
datasets: List of strings of the form `'alias=path:version'`,
`'alias=datastore/folder'`, or `'alias=job_dir:<job_id>:<path>'`.
mode: Either `InputOutputModes.DOWNLOAD` or `InputOutputModes.MOUNT`.

Returns:
Expand All @@ -200,6 +235,15 @@ def _get_data_assets(
path=str(azureml_path),
mode=mode,
)
elif _is_alias_datastore_path_string(string):
# Handle raw datastore-path folder format
alias, datastore, folder = _extract_alias_datastore_path(string)
azureml_path = _datastore_uri(datastore, folder)
logger.info(f'Using datastore path "{azureml_path}"...')
inputs[alias] = Input(
path=azureml_path,
mode=mode,
)
else:
# Handle regular data asset format
alias, path, version = _extract_alias_path_version(string)
Expand Down
71 changes: 71 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from __future__ import annotations

from unittest.mock import MagicMock

import pytest
from azure.ai.ml.constants import InputOutputModes

from submit_aml.data import _extract_alias_datastore_path
from submit_aml.data import _extract_alias_job_path
from submit_aml.data import _extract_alias_path_version
from submit_aml.data import _is_alias_datastore_path_string
from submit_aml.data import build_command_inputs
from submit_aml.data import build_command_outputs

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -83,3 +88,69 @@ def test_build_command_outputs_valid() -> None:
output = outputs["out_dir"]
assert "mydatastore" in output.path
assert "my_dataset" in output.path


# ---------------------------------------------------------------------------
# _is_alias_datastore_path_string
# ---------------------------------------------------------------------------


@pytest.mark.parametrize(
("string", "expected"),
[
("ref=mystore/exports/reference", True),
("ref=mystore/folder", True),
("my_data=MIMIC-CXR-V2", False),
("my_data=MIMIC-CXR-V2:2", False),
("checkpoint=job_dir:my_job_123:models/best.pth", False),
("no_equals_sign", False),
],
)
def test_is_alias_datastore_path_string(string: str, expected: bool) -> None:
"""Only 'alias=datastore/folder' strings are recognised as datastore paths."""
assert _is_alias_datastore_path_string(string) is expected


# ---------------------------------------------------------------------------
# build_command_inputs (datastore-path branch)
# ---------------------------------------------------------------------------


def test_build_command_inputs_datastore_path_mount() -> None:
"""A raw datastore-path string builds an azureml:// Input without AML lookup."""
ml_client = MagicMock()
inputs = build_command_inputs(
ml_client,
strings_download=None,
strings_mount=["ref=mystore/exports/reference"],
)
assert "ref" in inputs
ref = inputs["ref"]
assert ref.path == "azureml://datastores/mystore/paths/exports/reference"
assert ref.mode == InputOutputModes.MOUNT
ml_client.data.get.assert_not_called()


def test_build_command_inputs_datastore_path_download() -> None:
"""The datastore-path branch honours the download mode."""
ml_client = MagicMock()
inputs = build_command_inputs(
ml_client,
strings_download=["ref=mystore/exports/reference"],
strings_mount=None,
)
assert inputs["ref"].mode == InputOutputModes.DOWNLOAD
ml_client.data.get.assert_not_called()


def test_build_command_inputs_data_asset_routes_to_get() -> None:
"""A 'name:version' string still resolves via ml_client.data.get."""
ml_client = MagicMock()
ml_client.data.get.return_value = MagicMock(id="azureml://data-asset-id")
inputs = build_command_inputs(
ml_client,
strings_download=None,
strings_mount=["my_data=MIMIC-CXR-V2:2"],
)
ml_client.data.get.assert_called_once_with(name="MIMIC-CXR-V2", version="2")
assert inputs["my_data"].path == "azureml://data-asset-id"
Loading