diff --git a/docs/examples.md b/docs/examples.md index 26e671e..e174913 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -168,32 +168,40 @@ Run a grid sweep over hyperparameters: Datasets are passed to the job as [`Input`](https://learn.microsoft.com/en-us/python/api/azure-ai-ml/azure.ai.ml.input?view=azure-python) -objects. +objects. There is one flag per source type (registered data asset, datastore +folder, or previous job output), in either `--mount-*` or `--download-*` form. === "CLI" - Mount a dataset: + Mount or download a registered data asset: ```bash submit-aml \ --script train.py \ - --mount "data=MY-DATASET:2" + --mount-asset "data=MY-DATASET:2" ``` - Download a dataset: + ```bash + submit-aml \ + --script train.py \ + --download-asset "data=MY-DATASET" + ``` + + Mount a folder directly from a datastore (no data-asset registration + required): ```bash submit-aml \ --script train.py \ - --download "data=MY-DATASET" + --mount-datastore "ref=mystore/exports/reference" ``` - Use outputs from a previous job: + Use the outputs of a previous job: ```bash submit-aml \ --script evaluate.py \ - --mount "checkpoint=job_dir:my-training-job:models/best.pth" + --mount-job "checkpoint=my-training-job:models/best.pth" ``` === "Python" @@ -201,30 +209,49 @@ objects. ```python submit_to_aml( script_path="train.py", - datasets_mount=["data=MY-DATASET:2"], + mount_asset=["data=MY-DATASET:2"], ) # Or download instead of mount submit_to_aml( script_path="train.py", - datasets_download=["data=MY-DATASET"], + download_asset=["data=MY-DATASET"], + ) + + # Mount a datastore folder directly + submit_to_aml( + script_path="train.py", + mount_datastore=["ref=mystore/exports/reference"], ) # Use outputs from a previous job submit_to_aml( script_path="evaluate.py", - datasets_mount=["checkpoint=job_dir:my-training-job:models/best.pth"], + mount_job=["checkpoint=my-training-job:models/best.pth"], ) ``` -Configure an output datastore: +!!! note "Deprecated flags" + + The `--mount`, `--download` and `--output` flags (and their + `datasets_mount`, `datasets_download` and `datasets_output` Python + equivalents) are deprecated in favour of the explicit per-source flags + above. They still work but emit a deprecation warning. + +Write outputs to a datastore folder, or register them as a data asset: === "CLI" ```bash submit-aml \ --script train.py \ - --output "results=mydatastore/experiment-outputs" + --output-datastore "results=mydatastore/experiment-outputs" + ``` + + ```bash + submit-aml \ + --script train.py \ + --output-asset "results=my-experiment-results" ``` === "Python" @@ -232,7 +259,13 @@ Configure an output datastore: ```python submit_to_aml( script_path="train.py", - datasets_output=["results=mydatastore/experiment-outputs"], + output_datastore=["results=mydatastore/experiment-outputs"], + ) + + # Or register the outputs as a data asset + submit_to_aml( + script_path="train.py", + output_asset=["results=my-experiment-results"], ) ``` @@ -404,4 +437,3 @@ Submit and wait for the job to complete, streaming logs: ```python submit_to_aml(script_path="train.py", wait_for_completion=True) ``` - diff --git a/src/submit_aml/__main__.py b/src/submit_aml/__main__.py index e7c4bc7..35ed04f 100644 --- a/src/submit_aml/__main__.py +++ b/src/submit_aml/__main__.py @@ -141,13 +141,18 @@ def submit( "--download", "-d", help=( - "Azure ML dataset or job output folder to download. To download an Azure ML" - " dataset, the argument should take the form: alias, name and version" - " of the dataset; for example: 'vindr_dir=VINDR-CXR-V2:1'." - " If the version is omitted, the last one will be used." - " To download the output folder of a previous job, the argument should take" - " the form 'alias=job_dir::'; for example:" - " 'checkpoint=job_dir:crusty_hat_43s6lmvb25:outputs/checkpoint-10000'." + "[DEPRECATED] Use --download-asset, --download-datastore or" + " --download-job instead. Azure ML dataset, datastore folder or job" + " output folder to download. To download an Azure ML dataset, the" + " argument should take the form: alias, name and version of the" + " dataset; for example: 'vindr_dir=VINDR-CXR-V2:1'. If the version is" + " omitted, the last one will be used. To download a datastore folder," + " use 'alias=datastore/folder'. To download the output folder of a" + " previous job, prefer --download-job; on this deprecated flag use" + " the 'alias=job_dir::' form, for example" + " 'checkpoint=job_dir:crusty_hat_43s6lmvb25:outputs/checkpoint-10000'" + " (the bare 'alias=:' form is only recognised as a job" + " when contains a '/', otherwise it is read as a data asset)." " The alias can be used to pass input datasets to the script, e.g.," r" '${{inputs.vindr_dir}}' or '${{inputs.checkpoint}}'." " This option can be used multiple times." @@ -159,21 +164,88 @@ def submit( "--mount", "-m", help=( - "Azure ML dataset or job output folder to mount." - " For an Azure ML dataset, the alias, name and version should be provided" - " while for a job output folder, the alias, job ID and path in the job" + "[DEPRECATED] Use --mount-asset, --mount-datastore or --mount-job" + " instead. Azure ML dataset, datastore folder or job output folder to" + " mount. For an Azure ML dataset, the alias, name and version should be" + " provided; for a datastore folder, use 'alias=datastore/folder'; while" + " for a job output folder, the alias, job ID and path in the job" " outputs should be provided. See the --download option for more" " information." ), rich_help_panel=PANEL_DATA, ), + mount_asset: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--mount-asset", + help=( + "Registered Azure ML data asset to mount, expressed as" + ' "alias=name[:version]". For example: "vindr_dir=VINDR-CXR-V2:1".' + " If the version is omitted, the latest one is used." + r" Pass it to the script with '${{inputs.vindr_dir}}'." + " This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), + download_asset: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--download-asset", + help=( + "Registered Azure ML data asset to download. Same format as" + " --mount-asset. This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), + mount_datastore: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--mount-datastore", + help=( + "Datastore folder to mount, expressed as" + ' "alias=datastore/path/to/folder".' + ' For example: "ref=mystore/exports/reference".' + r" Pass it to the script with '${{inputs.ref}}'." + " This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), + download_datastore: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--download-datastore", + help=( + "Datastore folder to download. Same format as --mount-datastore." + " This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), + mount_job: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--mount-job", + help=( + "Output of a previous job to mount, expressed as" + ' "alias=:". The path may point at any' + " run artifact, not just files under outputs/." + ' For example: "checkpoint=crusty_hat_43s6lmvb25:models/best.pth".' + r" Pass it to the script with '${{inputs.checkpoint}}'." + " This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), + download_job: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--download-job", + help=( + "Output of a previous job to download. Same format as" + " --mount-job. This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), output: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 None, "--output", "-o", help=( - "Alias, datastore and path to folder into which outputs will be written," - ' expressed as "alias=datastore/path/to/dir".' + "[DEPRECATED] Use --output-datastore or --output-asset instead." + " Alias, datastore and path to folder into which outputs will be" + ' written, expressed as "alias=datastore/path/to/dir".' ' For example: "out_dir=mydatastore/my_dataset".' " The alias can be used to pass outputs to the script, e.g.," r' "${{outputs.out_dir}}".' @@ -182,6 +254,32 @@ def submit( ), rich_help_panel=PANEL_DATA, ), + output_datastore: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--output-datastore", + help=( + "Datastore folder into which outputs will be written, expressed as" + ' "alias=datastore/path/to/dir".' + ' For example: "out_dir=mydatastore/my_dataset".' + r" Pass it to the script with '${{outputs.out_dir}}'." + " This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), + output_asset: Optional[List[str]] = typer.Option( # noqa: UP006, UP007 + None, + "--output-asset", + help=( + "Register the outputs as an Azure ML data asset, expressed as" + ' "alias=name[:version]". For example: "out_dir=my-results".' + " The blobs are written to the workspace's default datastore and" + " registered as a data asset; if the version is omitted, Azure ML" + " auto-increments it." + r" Pass it to the script with '${{outputs.out_dir}}'." + " This option can be used multiple times." + ), + rich_help_panel=PANEL_DATA, + ), command_prefix: str = typer.Option( get_default("command_prefix"), help="Prefix to prepend to the command. For example, `uv run`.", @@ -408,6 +506,14 @@ def submit( datasets_download=datasets_download, datasets_mount=datasets_mount, datasets_output=output, + mount_asset=mount_asset, + download_asset=download_asset, + mount_datastore=mount_datastore, + download_datastore=download_datastore, + mount_job=mount_job, + download_job=download_job, + output_datastore=output_datastore, + output_asset=output_asset, debug=debug, dependency_groups=dependency_groups, description=description, diff --git a/src/submit_aml/aml.py b/src/submit_aml/aml.py index c269bab..a38de21 100644 --- a/src/submit_aml/aml.py +++ b/src/submit_aml/aml.py @@ -345,6 +345,14 @@ def submit_to_aml( datasets_download: TypeOptionalStrList = None, datasets_mount: TypeOptionalStrList = None, datasets_output: TypeOptionalStrList = None, + mount_asset: TypeOptionalStrList = None, + download_asset: TypeOptionalStrList = None, + mount_datastore: TypeOptionalStrList = None, + download_datastore: TypeOptionalStrList = None, + mount_job: TypeOptionalStrList = None, + download_job: TypeOptionalStrList = None, + output_datastore: TypeOptionalStrList = None, + output_asset: TypeOptionalStrList = None, debug: bool = False, dependency_groups: list[str] | None = None, description: str | None = None, @@ -512,8 +520,22 @@ def submit_to_aml( add_service_for_tensorboard(services, tensorboard_dir) # Data - inputs = build_command_inputs(ml_client, datasets_download, datasets_mount) - outputs = build_command_outputs(datasets_output) + inputs = build_command_inputs( + ml_client, + mount_asset=mount_asset, + download_asset=download_asset, + mount_datastore=mount_datastore, + download_datastore=download_datastore, + mount_job=mount_job, + download_job=download_job, + legacy_mount=datasets_mount, + legacy_download=datasets_download, + ) + outputs = build_command_outputs( + output_datastore=output_datastore, + output_asset=output_asset, + legacy_output=datasets_output, + ) # Sweep jobs is_sweep = sweep_inputs is not None and len(sweep_inputs) > 0 diff --git a/src/submit_aml/data.py b/src/submit_aml/data.py index 410568b..ae5f325 100644 --- a/src/submit_aml/data.py +++ b/src/submit_aml/data.py @@ -2,10 +2,13 @@ import re import sys +import warnings +from typing import TypeVar from azure.ai.ml import Input from azure.ai.ml import MLClient from azure.ai.ml import Output +from azure.ai.ml.constants import AssetTypes from azure.ai.ml.constants import InputOutputModes from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution from azure.ai.ml.exceptions import MlException @@ -15,6 +18,24 @@ TypeInputsDict = dict[str, Input | SweepDistribution] TypeOptionalStrList = list[str] | None +_MappingValue = TypeVar("_MappingValue") + + +def _datastore_uri(datastore: str, path: str) -> str: + """Build an Azure ML datastore URI for a folder. + + Args: + datastore: Name of the datastore. + path: Path to the folder within the datastore. + + Returns: + A URI of the form `azureml://datastores//paths/`. + + Examples: + >>> _datastore_uri('mystore', 'exports/reference') + 'azureml://datastores/mystore/paths/exports/reference' + """ + return f"azureml://datastores/{datastore}/paths/{path}" def _extract_alias_path_version(string: str) -> tuple[str, str, str | None]: @@ -27,12 +48,9 @@ def _extract_alias_path_version(string: str) -> tuple[str, str, str | None]: Tuple of alias, path, and version (which may be None if version is not provided). - Raises: - ValueError: If the string is not of the expected format. - Examples: >>> _extract_alias_path_version('my_data=MIMIC-CXR-V2:2') - ('my_data', 'MIMIC-CXR-V2', 2) + ('my_data', 'MIMIC-CXR-V2', '2') >>> _extract_alias_path_version('my_data=MIMIC-CXR-V2') ('my_data', 'MIMIC-CXR-V2', None) """ @@ -64,11 +82,8 @@ def _extract_alias_datastore_path(string: str) -> tuple[str, str, str]: Returns: Tuple of alias, datastore and folder. - Raises: - ValueError: If the string is not of the expected format. - Examples: - >>> get_alias_datastore_path('my_data=inereyedata/output_dataset') + >>> _extract_alias_datastore_path('my_data=inereyedata/output_dataset') ('my_data', 'inereyedata', 'output_dataset') """ pattern = r"(?P[^=]+)=(?P[^/]+)/(?P.+)" @@ -84,146 +99,471 @@ def _extract_alias_datastore_path(string: str) -> tuple[str, str, str]: def _extract_alias_job_path(string: str) -> tuple[str, str, str]: - """Get alias, job ID, and path from a job directory string. + """Get alias, job ID, and path from a job output string. Args: - string: String of the form `'alias=job_dir::'`. + string: String of the form `'alias=:'`. Returns: Tuple of alias, job_id, and path. - Raises: - ValueError: If the string is not of the expected format. - Examples: - >>> _extract_alias_job_path('checkpoint=job_dir:my_job_123:models/best.pth') + >>> _extract_alias_job_path('checkpoint=my_job_123:models/best.pth') ('checkpoint', 'my_job_123', 'models/best.pth') """ - pattern = r"(?P[^=]+)=job_dir:(?P[^:]+):(?P.+)" + pattern = r"(?P[^=]+)=(?P[^:]+):(?P.+)" match = re.match(pattern, string) if match is None: message = ( - f'Invalid job directory string: "{string}".' - ' Expected format: "alias=job_dir:job_id:path".' + f'Invalid job output string: "{string}".' + ' Expected format: "alias=job_id:path".' ) - raise ValueError(message) + logger.error(message) + sys.exit(1) + if match.group("job_id") == "job_dir": + message = ( + f'Invalid job output string: "{string}".' + ' The "job_dir:" prefix is no longer used with the job flags;' + ' use "alias=job_id:path" instead.' + ) + logger.error(message) + sys.exit(1) return match.group("alias"), match.group("job_id"), match.group("path") -def _is_alias_path_version_string(string: str) -> bool: - try: - _extract_alias_path_version(string) - return True - except ValueError: - return False +def _input_from_asset( + ml_client: MLClient, + string: str, + mode: str, +) -> tuple[str, Input]: + """Build an `Input` from a registered data asset string. + Args: + ml_client: Client used to resolve the data asset. + string: String of the form `'alias=name[:version]'`. + mode: Either `InputOutputModes.DOWNLOAD` or `InputOutputModes.MOUNT`. -def _is_alias_job_path_string(string: str) -> bool: - try: - _extract_alias_job_path(string) - return True - except ValueError: - return False + Returns: + Tuple of alias and the resolved `Input`. + Raises: + ValueError: If the data asset cannot be retrieved. + """ + alias, path, version = _extract_alias_path_version(string) + + if version is None: + kwargs = {"label": "latest"} + else: + kwargs = {"version": version} + + with report_time( + f'Retrieving data asset "{path}"...', + f'Retrieved data asset "{path}"', + ): + try: + data = ml_client.data.get(name=path, **kwargs) + except MlException as e: + version_desc = "latest" if version is None else version + msg = ( + f'Error getting data asset with name "{path}"' + f' and version "{version_desc}"' + ) + raise ValueError(msg) from e + return alias, Input(path=data.id, mode=mode) -def build_command_inputs( - ml_client: MLClient, - strings_download: list[str] | None, - strings_mount: list[str] | None, -) -> TypeInputsDict: - """Get dictionaries data assets to be mounted or downloaded. + +def _input_from_datastore(string: str, mode: str) -> tuple[str, Input]: + """Build an `Input` from a datastore-path string. Args: - strings_download: List of strings of the form `'alias=path:version'` to - be downloaded. If `None`, no data assets will be downloaded. - strings_mount: List of strings of the form `'alias=path:version'` to - be mounted. If `None`, no data assets will be mounted. + string: String of the form `'alias=datastore/folder'`. + mode: Either `InputOutputModes.DOWNLOAD` or `InputOutputModes.MOUNT`. + + Returns: + Tuple of alias and the resulting `Input`. """ - strings_download = [] if strings_download is None else strings_download - strings_mount = [] if strings_mount is None else strings_mount - datasets_download = _get_data_assets( - ml_client, - strings_download, - InputOutputModes.DOWNLOAD, + alias, datastore, folder = _extract_alias_datastore_path(string) + azureml_path = _datastore_uri(datastore, folder) + logger.info(f'Using datastore path "{azureml_path}"...') + return alias, Input(path=azureml_path, mode=mode) + + +def _input_from_job(string: str, mode: str) -> tuple[str, Input]: + """Build an `Input` from a previous job's output string. + + Args: + string: String of the form `'alias=:'`. + mode: Either `InputOutputModes.DOWNLOAD` or `InputOutputModes.MOUNT`. + + Returns: + Tuple of alias and the resulting `Input`. + """ + alias, job_id, path = _extract_alias_job_path(string) + azureml_path = _datastore_uri( + "workspaceartifactstore", + f"ExperimentRun/dcid.{job_id}/{path}", ) - datasets_mount = _get_data_assets( - ml_client, - strings_mount, - InputOutputModes.MOUNT, + logger.info(f'Using job output path "{azureml_path}"...') + return alias, Input(path=azureml_path, mode=mode) + + +def _output_from_datastore(string: str) -> tuple[str, Output]: + """Build an `Output` that writes to a datastore folder. + + Args: + string: String of the form `'alias=datastore/folder'`. + + Returns: + Tuple of alias and the resulting `Output`. + """ + alias, datastore, folder = _extract_alias_datastore_path(string) + return alias, Output(path=_datastore_uri(datastore, folder)) + + +def _output_from_asset(string: str) -> tuple[str, Output]: + """Build an `Output` that registers a data asset. + + The blobs are written to the workspace's default datastore at an + Azure ML-managed location and registered as a data asset. + + Args: + string: String of the form `'alias=name[:version]'`. If the version is + omitted, Azure ML auto-increments it. + + Returns: + Tuple of alias and the resulting `Output`. + """ + alias, name, version = _extract_alias_path_version(string) + output = Output(type=AssetTypes.URI_FOLDER, name=name, version=version) + return alias, output + + +# Removal plan for the deprecated data flags (--mount/-m, --download/-d, +# --output/-o), superseded by the explicit-source flags (--{mount,download}- +# {asset,datastore,job} and --output-{datastore,asset}): +# +# 1. Now (1.x): both flag sets work. The legacy flags carry a [DEPRECATED] +# marker in --help and emit a deprecation warning at runtime (via +# `_warn_legacy_input` / `_warn_legacy_output`). This is the grace period +# in which users migrate. +# 2. Before removal: once downstream callers have migrated (grep the known +# consumer repos / run scripts for `--mount`, `--download`, `--output`, +# `-m `, `-d `, `-o ` and the `datasets_{mount,download,output}` kwargs of +# `submit_to_aml`), and the deprecation has shipped in at least one +# tagged release, schedule removal for the next MAJOR version (2.0.0) per +# semver, since dropping a CLI flag is a breaking change. +# 3. At removal (2.0.0): delete the `datasets_download`/`datasets_mount`/ +# `output` typer.Options in `__main__.py`, drop the matching +# `submit_to_aml` parameters and the `legacy_*` parameters and branches in +# `build_command_inputs`/`build_command_outputs`, delete the +# `_legacy_*` helpers and the `_warn_legacy_*` warning helpers, and note +# the breaking change in the changelog. +# +# Until step 3, keep the legacy flags VISIBLE in --help (the [DEPRECATED] +# marker is how users discover the migration path); only hide them as an +# optional last step in a release immediately preceding removal. + +# Replacement input flag (CLI, Python parameter) for each legacy input flag +# base and classified source type, used to tailor the deprecation warning. +_LEGACY_INPUT_FLAGS = { + "mount": ("--mount", "datasets_mount"), + "download": ("--download", "datasets_download"), +} + + +def _warn_legacy_input( + flag_base: str, + source: str, + old_value: str, + new_value: str, + stacklevel: int = 2, +) -> None: + """Warn that a legacy input flag is deprecated, naming the exact fix. + + Emits both a human-facing log line (for CLI users) and a Python + `DeprecationWarning` (for callers of the library API, e.g. `submit_to_aml` + or `build_command_inputs`). + + Args: + flag_base: The legacy flag base, either `'mount'` or `'download'`. + source: The classified source type (`'asset'`, `'datastore'`, or + `'job'`), used to pick the per-source replacement flag. + old_value: The legacy `alias=value` string the user passed. + new_value: The value to use with the replacement flag (identical to + `old_value` except for job values, which drop the `job_dir:` + prefix). + stacklevel: Stack level for the `DeprecationWarning`, so it points at + the API caller rather than this helper. + """ + old_cli, old_param = _LEGACY_INPUT_FLAGS[flag_base] + new_cli = f"--{flag_base}-{source}" + new_param = f"{flag_base}_{source}" + # Both are always emitted: the log line is phrased for CLI users (it names + # flags) and the DeprecationWarning for Python API users (it names + # parameters). + cli_message = ( + f"{old_cli} is deprecated and will be removed in a future release." + f" Replace '{old_cli} {old_value}' with '{new_cli} {new_value}'." ) - return {**datasets_download, **datasets_mount} + api_message = ( + f"The '{old_param}' parameter is deprecated and will be removed in a" + f" future release. Pass [{new_value!r}] to '{new_param}' instead." + ) + logger.warning(cli_message) + warnings.warn(api_message, DeprecationWarning, stacklevel=stacklevel) -def build_command_outputs( - strings_upload: list[str] | None, -) -> dict[str, Output]: - """Get outputs for command. +def _warn_legacy_output(old_value: str, stacklevel: int = 2) -> None: + """Warn that a legacy `--output` value is deprecated, naming the exact fix. + + Emits a human-facing log line naming the replacement CLI flag (for CLI + users) and a Python `DeprecationWarning` naming the replacement parameter + (for callers of the library API, e.g. `submit_to_aml` or + `build_command_outputs`). Args: - strings_upload: List of strings of the form `'alias=datastore/path/to/dir'` to - be uploaded. If `None`, no outputs will be returned. + old_value: The legacy `alias=datastore/folder` string the user passed. + stacklevel: Stack level for the `DeprecationWarning`, so it points at + the API caller rather than this helper. """ - strings_upload = [] if strings_upload is None else strings_upload - outputs_dict = {} - for string in strings_upload: - alias, datastore, path = _extract_alias_datastore_path(string) - output = Output( - path=f"azureml://datastores/{datastore}/paths/{path}", - ) - outputs_dict[alias] = output - return outputs_dict + # Both are always emitted: the log line is phrased for CLI users (it names + # flags) and the DeprecationWarning for Python API users (it names + # parameters). + cli_message = ( + "--output is deprecated and will be removed in a future release." + f" Replace '--output {old_value}' with '--output-datastore {old_value}'." + ) + api_message = ( + "The 'datasets_output' parameter is deprecated and will be removed in a" + f" future release. Pass [{old_value!r}] to 'output_datastore' instead." + ) + logger.warning(cli_message) + warnings.warn(api_message, DeprecationWarning, stacklevel=stacklevel) -def _get_data_assets( +def _classify_legacy_input(string: str) -> str: + """Classify a legacy `--mount`/`--download` value by source type. + + Args: + string: A legacy dataset string. + + Returns: + One of `'job'`, `'datastore'`, or `'asset'`. + + A right-hand side that starts with `job_dir:`, or that has a `:` before its + first `/` (the new `job_id:path` form), is a job output. A `/` that comes + before any `:` signals a datastore folder. Anything else (a bare `name` or + `name:version`) is a data asset. + + Examples: + >>> _classify_legacy_input('ckpt=job_dir:job123:out/best.pth') + 'job' + >>> _classify_legacy_input('ckpt=job123:out/best.pth') + 'job' + >>> _classify_legacy_input('ref=mystore/exports/reference') + 'datastore' + >>> _classify_legacy_input('data=MY-DATASET:2') + 'asset' + """ + if "=" not in string: + return "asset" + _, rhs = string.split("=", 1) + if rhs.startswith("job_dir:"): + return "job" + slash = rhs.find("/") + if slash == -1: + return "asset" + colon = rhs.find(":") + if colon != -1 and colon < slash: + return "job" + return "datastore" + + +def _legacy_input( ml_client: MLClient, - datasets: list[str], + string: str, mode: str, -) -> dict[str, Input]: - """Get data assets from Azure ML. + flag_base: str, +) -> tuple[str, Input]: + """Route a legacy `--mount`/`--download` value to the right builder. + + Emits a deprecation warning naming the exact replacement flag for the + value's classified source type. Args: - datasets: List of strings of the form `'alias=path:version'` or - `'alias=job_dir::'`. + ml_client: Client used to resolve data assets. + string: A legacy dataset string. mode: Either `InputOutputModes.DOWNLOAD` or `InputOutputModes.MOUNT`. + flag_base: The legacy flag base, either `'mount'` or `'download'`, + used to tailor the deprecation warning. + + Returns: + Tuple of alias and the resulting `Input`. + """ + source = _classify_legacy_input(string) + # stacklevel=4: warnings.warn -> _warn_legacy_input -> _legacy_input -> + # build_command_inputs, so the DeprecationWarning points at the API caller. + if source == "job": + # Translate the old "alias=job_dir::" form to the new one. + translated = string.replace("=job_dir:", "=", 1) + _warn_legacy_input(flag_base, "job", string, translated, stacklevel=4) + return _input_from_job(translated, mode) + if source == "datastore": + _warn_legacy_input(flag_base, "datastore", string, string, stacklevel=4) + return _input_from_datastore(string, mode) + _warn_legacy_input(flag_base, "asset", string, string, stacklevel=4) + return _input_from_asset(ml_client, string, mode) + + +def _assign_unique( + mapping: dict[str, _MappingValue], + alias: str, + value: _MappingValue, + *, + kind: str, +) -> None: + """Insert `alias -> value`, raising if the alias is already present. + + Each alias becomes a single `${{inputs.}}` / `${{outputs.}}` + reference, so it must be unique. Reusing an alias (across modes, source + types, or flags) is a user error rather than something to silently resolve. + + Args: + mapping: The inputs or outputs dict being built. + alias: The alias key to insert. + value: The `Input` or `Output` to store. + kind: Either `'input'` or `'output'`, used in the error message. + + Raises: + ValueError: If `alias` is already present in `mapping`. + """ + if alias in mapping: + msg = f"Duplicate {kind} alias {alias!r}: each alias must be unique." + raise ValueError(msg) + mapping[alias] = value + + +def build_command_inputs( + ml_client: MLClient, + legacy_download: list[str] | None = None, + legacy_mount: list[str] | None = None, + *, + mount_asset: list[str] | None = None, + download_asset: list[str] | None = None, + mount_datastore: list[str] | None = None, + download_datastore: list[str] | None = None, + mount_job: list[str] | None = None, + download_job: list[str] | None = None, +) -> TypeInputsDict: + """Build the inputs dictionary for a command job. + + `legacy_download` / `legacy_mount` are kept as the first positional + parameters (in that order) so existing positional callers from before the + explicit-source flags, i.e. `build_command_inputs(client, downloads, + mounts)`, keep working during the 1.x deprecation window. The new + explicit-source flags are keyword-only. + + Args: + ml_client: Client used to resolve data assets. + legacy_download: Deprecated `--download` values, routed by source type. + legacy_mount: Deprecated `--mount` values, routed by source type. + mount_asset: Data assets to mount, as `'alias=name[:version]'`. + download_asset: Data assets to download, as `'alias=name[:version]'`. + mount_datastore: Datastore folders to mount, as `'alias=datastore/folder'`. + download_datastore: Datastore folders to download, as + `'alias=datastore/folder'`. + mount_job: Previous job outputs to mount, as `'alias=:'`. + download_job: Previous job outputs to download, as `'alias=:'`. Returns: Dictionary of `alias: Input` mappings. + + Raises: + ValueError: If the same alias is used more than once across any of the + input flags (an alias maps to a single `${{inputs.}}` + reference, so it must be unique). """ - inputs = {} - for string in datasets: - if _is_alias_job_path_string(string): - # Handle job directory format - alias, job_id, path = _extract_alias_job_path(string) - azureml_path = f"azureml://datastores/workspaceartifactstore/paths/ExperimentRun/dcid.{job_id}/{path}" - logger.info(f'Using job output path "{azureml_path}"...') - inputs[alias] = Input( - path=str(azureml_path), - mode=mode, + inputs: TypeInputsDict = {} + + for string in download_asset or []: + alias, value = _input_from_asset(ml_client, string, InputOutputModes.DOWNLOAD) + _assign_unique(inputs, alias, value, kind="input") + for string in mount_asset or []: + alias, value = _input_from_asset(ml_client, string, InputOutputModes.MOUNT) + _assign_unique(inputs, alias, value, kind="input") + + for string in download_datastore or []: + alias, value = _input_from_datastore(string, InputOutputModes.DOWNLOAD) + _assign_unique(inputs, alias, value, kind="input") + for string in mount_datastore or []: + alias, value = _input_from_datastore(string, InputOutputModes.MOUNT) + _assign_unique(inputs, alias, value, kind="input") + + for string in download_job or []: + alias, value = _input_from_job(string, InputOutputModes.DOWNLOAD) + _assign_unique(inputs, alias, value, kind="input") + for string in mount_job or []: + alias, value = _input_from_job(string, InputOutputModes.MOUNT) + _assign_unique(inputs, alias, value, kind="input") + + if legacy_download: + for string in legacy_download: + alias, value = _legacy_input( + ml_client, string, InputOutputModes.DOWNLOAD, "download" ) - else: - # Handle regular data asset format - alias, path, version = _extract_alias_path_version(string) - - if version is None: - kwargs = {"label": "latest"} - else: - kwargs = {"version": version} - - with report_time( - f'Retrieving data asset "{path}"...', - f'Retrieved data asset "{path}"', - ): - try: - data = ml_client.data.get(name=path, **kwargs) - except MlException as e: - msg = ( - "Error getting data asset with" - f' name "{path}"' - f' and version "{version}"' - ) - raise ValueError(msg) from e - inputs[alias] = Input( - path=data.id, - mode=mode, + _assign_unique(inputs, alias, value, kind="input") + if legacy_mount: + for string in legacy_mount: + alias, value = _legacy_input( + ml_client, string, InputOutputModes.MOUNT, "mount" ) + _assign_unique(inputs, alias, value, kind="input") + return inputs + + +def build_command_outputs( + legacy_output: list[str] | None = None, + *, + output_datastore: list[str] | None = None, + output_asset: list[str] | None = None, +) -> dict[str, Output]: + """Build the outputs dictionary for a command job. + + `legacy_output` is kept as the first positional parameter so existing + positional callers from before the explicit-target flags, i.e. + `build_command_outputs(uploads)`, keep working during the 1.x deprecation + window. The new explicit-target flags are keyword-only. + + Args: + legacy_output: Deprecated `--output` values (datastore folders). + output_datastore: Datastore folders to write to, as + `'alias=datastore/folder'`. + output_asset: Data assets to register, as `'alias=name[:version]'`. + + Returns: + Dictionary of `alias: Output` mappings. + + Raises: + ValueError: If the same alias is used more than once across any of the + output flags (an alias maps to a single `${{outputs.}}` + reference, so it must be unique). + """ + outputs: dict[str, Output] = {} + + for string in output_datastore or []: + alias, value = _output_from_datastore(string) + _assign_unique(outputs, alias, value, kind="output") + for string in output_asset or []: + alias, value = _output_from_asset(string) + _assign_unique(outputs, alias, value, kind="output") + + if legacy_output: + for string in legacy_output: + # stacklevel=3: warnings.warn -> _warn_legacy_output -> + # build_command_outputs, pointing at the API caller. + _warn_legacy_output(string, stacklevel=3) + alias, value = _output_from_datastore(string) + _assign_unique(outputs, alias, value, kind="output") + + return outputs diff --git a/tests/test_data.py b/tests/test_data.py index 8f9faab..381e23a 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,14 +1,49 @@ -"""Tests for data-asset parsing helpers.""" +"""Tests for data-asset parsing helpers and input/output builders.""" from __future__ import annotations +from unittest.mock import Mock +from unittest.mock import patch + import pytest +from azure.ai.ml.constants import InputOutputModes +from azure.ai.ml.exceptions import MlException +from submit_aml.data import _classify_legacy_input +from submit_aml.data import _datastore_uri from submit_aml.data import _extract_alias_datastore_path from submit_aml.data import _extract_alias_job_path from submit_aml.data import _extract_alias_path_version +from submit_aml.data import _input_from_asset +from submit_aml.data import _input_from_datastore +from submit_aml.data import _input_from_job +from submit_aml.data import _output_from_asset +from submit_aml.data import _output_from_datastore +from submit_aml.data import build_command_inputs from submit_aml.data import build_command_outputs + +def _deprecation_log(mock_logger: Mock) -> str: + """Join the raw messages passed to `logger.warning`. + + Asserting on the raw message (rather than `capsys`) avoids the rich + console's ANSI colouring, highlighting, and line wrapping, which otherwise + make multi-token substring checks brittle (and fail under CI's colour mode). + """ + return " ".join(call.args[0] for call in mock_logger.warning.call_args_list) + + +# --------------------------------------------------------------------------- +# _datastore_uri +# --------------------------------------------------------------------------- + + +def test_datastore_uri() -> None: + """A datastore and path are joined into an azureml:// URI.""" + uri = _datastore_uri("mystore", "exports/reference") + assert uri == "azureml://datastores/mystore/paths/exports/reference" + + # --------------------------------------------------------------------------- # _extract_alias_path_version # --------------------------------------------------------------------------- @@ -51,35 +86,359 @@ def test_extract_alias_datastore_path_valid() -> None: def test_extract_alias_job_path_valid() -> None: - """'alias=job_dir:job_id:path' is parsed correctly.""" + """'alias=job_id:path' is parsed correctly (no job_dir: prefix).""" alias, job_id, path = _extract_alias_job_path( - "checkpoint=job_dir:my_job_123:models/best.pth" + "checkpoint=my_job_123:models/best.pth" ) assert alias == "checkpoint" assert job_id == "my_job_123" assert path == "models/best.pth" -def test_extract_alias_job_path_invalid_raises() -> None: - """Strings not matching the job_dir pattern raise ValueError.""" - with pytest.raises(ValueError, match="Invalid job directory"): +def test_extract_alias_job_path_invalid_exits() -> None: + """Strings without a path component exit the process.""" + with pytest.raises(SystemExit): _extract_alias_job_path("bad_format") +def test_extract_alias_job_path_rejects_legacy_prefix() -> None: + """The legacy 'job_dir:' prefix is rejected on the new job flags.""" + with pytest.raises(SystemExit): + _extract_alias_job_path("ckpt=job_dir:my_job_123:models/best.pth") + + # --------------------------------------------------------------------------- -# build_command_outputs +# _classify_legacy_input +# --------------------------------------------------------------------------- + + +def test_classify_legacy_input_job() -> None: + """A job_dir: prefix is classified as a job output.""" + assert _classify_legacy_input("ckpt=job_dir:job123:out/best.pth") == "job" + + +def test_classify_legacy_input_job_new_syntax() -> None: + """A new-style 'job_id:path' value (colon before slash) is a job output.""" + assert _classify_legacy_input("ckpt=job123:outputs/best.pth") == "job" + + +def test_classify_legacy_input_datastore() -> None: + """A slash before any colon signals a datastore path.""" + assert _classify_legacy_input("ref=mystore/exports/reference") == "datastore" + + +def test_classify_legacy_input_datastore_with_colon_in_folder() -> None: + """A colon after the first slash is part of the folder, not a job id.""" + assert _classify_legacy_input("ref=mystore/a:b/c") == "datastore" + + +def test_classify_legacy_input_asset() -> None: + """A plain name[:version] is classified as a data asset.""" + assert _classify_legacy_input("data=MY-DATASET:2") == "asset" + + +def test_classify_legacy_input_missing_equals() -> None: + """A string without '=' falls back to the asset branch.""" + assert _classify_legacy_input("no-equals-here") == "asset" + + +# --------------------------------------------------------------------------- +# input builders +# --------------------------------------------------------------------------- + + +def test_input_from_datastore() -> None: + """A datastore string yields an Input with an azureml:// path and mode.""" + alias, value = _input_from_datastore( + "ref=mystore/exports/reference", + InputOutputModes.MOUNT, + ) + assert alias == "ref" + assert value.path == "azureml://datastores/mystore/paths/exports/reference" + assert value.mode == InputOutputModes.MOUNT + + +def test_input_from_job() -> None: + """A job string yields an Input pointing at the job's run artifacts.""" + alias, value = _input_from_job( + "checkpoint=my_job_123:models/best.pth", + InputOutputModes.DOWNLOAD, + ) + assert alias == "checkpoint" + assert "ExperimentRun/dcid.my_job_123/models/best.pth" in value.path + assert "workspaceartifactstore" in value.path + assert value.mode == InputOutputModes.DOWNLOAD + + +def test_input_from_asset_missing_version_reports_latest() -> None: + """When no version is given, the failure message mentions 'latest'.""" + client = Mock() + client.data.get.side_effect = MlException( + message="boom", no_personal_data_message="boom" + ) + with pytest.raises(ValueError, match='version "latest"'): + _input_from_asset(client, "data=MY-DATASET", InputOutputModes.MOUNT) + + +# --------------------------------------------------------------------------- +# output builders +# --------------------------------------------------------------------------- + + +def test_output_from_datastore() -> None: + """A datastore string yields an Output with an azureml:// path.""" + alias, output = _output_from_datastore("out_dir=mydatastore/my_dataset") + assert alias == "out_dir" + assert output.path == "azureml://datastores/mydatastore/paths/my_dataset" + + +def test_output_from_asset_with_version() -> None: + """An asset string registers an Output with name and version.""" + alias, output = _output_from_asset("out_dir=my-results:3") + assert alias == "out_dir" + assert output.name == "my-results" + assert output.version == "3" + assert output.type == "uri_folder" + + +def test_output_from_asset_without_version() -> None: + """Omitting the version leaves it unset for Azure ML to auto-increment.""" + _, output = _output_from_asset("out_dir=my-results") + assert output.name == "my-results" + assert output.version is None + + # --------------------------------------------------------------------------- +# build_command_inputs +# --------------------------------------------------------------------------- + + +def test_build_command_inputs_empty() -> None: + """No arguments produce an empty dict and never touch the client.""" + client = Mock() + assert build_command_inputs(client) == {} + client.data.get.assert_not_called() + + +def test_build_command_inputs_datastore_and_job_skip_client() -> None: + """Datastore and job inputs are built without calling the client.""" + client = Mock() + inputs = build_command_inputs( + client, + mount_datastore=["ref=mystore/exports/reference"], + download_job=["ckpt=my_job_123:models/best.pth"], + ) + assert set(inputs) == {"ref", "ckpt"} + assert inputs["ref"].mode == InputOutputModes.MOUNT + assert inputs["ckpt"].mode == InputOutputModes.DOWNLOAD + client.data.get.assert_not_called() + +def test_build_command_inputs_asset_calls_client() -> None: + """A data-asset input resolves through the client.""" + client = Mock() + client.data.get.return_value = Mock(id="azureml:resolved-asset:1") + inputs = build_command_inputs(client, mount_asset=["data=MY-DATASET:2"]) + client.data.get.assert_called_once() + assert inputs["data"].path == "azureml:resolved-asset:1" -def test_build_command_outputs_none() -> None: - """None input produces an empty dict.""" - assert build_command_outputs(None) == {} +def test_build_command_inputs_legacy_datastore_routes( + capsys: pytest.CaptureFixture[str], +) -> None: + """Legacy --mount datastore strings route to the datastore builder.""" + client = Mock() + inputs = build_command_inputs( + client, + legacy_mount=["ref=mystore/exports/reference"], + ) + assert inputs["ref"].path == ( + "azureml://datastores/mystore/paths/exports/reference" + ) + client.data.get.assert_not_called() + assert "deprecated" in capsys.readouterr().out.lower() -def test_build_command_outputs_valid() -> None: - """Valid output strings are converted into Output objects.""" + +def test_build_command_inputs_legacy_job_routes() -> None: + """Legacy --download job_dir strings route to the job builder.""" + client = Mock() + inputs = build_command_inputs( + client, + legacy_download=["ckpt=job_dir:my_job_123:models/best.pth"], + ) + assert "ExperimentRun/dcid.my_job_123/models/best.pth" in inputs["ckpt"].path + client.data.get.assert_not_called() + + +def test_build_command_inputs_legacy_job_new_syntax_routes() -> None: + """Legacy values using the new 'job_id:path' form route to the job builder.""" + client = Mock() + inputs = build_command_inputs( + client, + legacy_mount=["ckpt=my_job_123:models/best.pth"], + ) + assert "ExperimentRun/dcid.my_job_123/models/best.pth" in inputs["ckpt"].path + client.data.get.assert_not_called() + + +def test_build_command_inputs_legacy_raises_deprecation_warning() -> None: + """The Python API raises a DeprecationWarning naming the new parameter.""" + client = Mock() + with pytest.warns(DeprecationWarning, match=r"mount_datastore") as record: + build_command_inputs(client, legacy_mount=["ref=mystore/exports/reference"]) + # The Python-facing warning references parameters, not CLI flags, and shows + # a one-element list (the parameters are list[str]). + message = str(record[0].message) + assert "--mount" not in message + assert "['ref=mystore/exports/reference']" in message + + +def test_build_command_inputs_legacy_positional_compat() -> None: + """Positional `(client, downloads, mounts)` calls still work (1.x compat).""" + client = Mock() + inputs = build_command_inputs( + client, + ["dl=mystore/down"], + ["mn=mystore/mount"], + ) + assert inputs["dl"].mode == InputOutputModes.DOWNLOAD + assert inputs["mn"].mode == InputOutputModes.MOUNT + + +def test_build_command_outputs_legacy_positional_compat() -> None: + """Positional `(uploads)` calls still work (1.x compat).""" outputs = build_command_outputs(["out_dir=mydatastore/my_dataset"]) + assert outputs["out_dir"].path == ( + "azureml://datastores/mydatastore/paths/my_dataset" + ) + + +def test_build_command_inputs_legacy_asset_calls_client() -> None: + """Legacy --mount asset strings still resolve through the client.""" + client = Mock() + client.data.get.return_value = Mock(id="azureml:resolved-asset:1") + build_command_inputs(client, legacy_mount=["data=MY-DATASET:2"]) + client.data.get.assert_called_once() + + +def test_build_command_inputs_duplicate_alias_across_modes_raises() -> None: + """The same alias under both download and mount is a hard error.""" + client = Mock() + with pytest.raises(ValueError, match=r"[Dd]uplicate.*alias.*ref"): + build_command_inputs( + client, + download_datastore=["ref=mystore/exports/reference"], + mount_datastore=["ref=mystore/exports/reference"], + ) + + +def test_build_command_inputs_duplicate_alias_same_mode_raises() -> None: + """A repeated alias within a single flag is also a hard error.""" + client = Mock() + with pytest.raises(ValueError, match=r"[Dd]uplicate.*alias.*ref"): + build_command_inputs( + client, + mount_datastore=["ref=mystore/a", "ref=mystore/b"], + ) + + +def test_build_command_inputs_legacy_duplicate_alias_raises() -> None: + """A colliding alias across legacy mount/download raises.""" + client = Mock() + with pytest.raises(ValueError, match=r"[Dd]uplicate.*alias.*ref"): + build_command_inputs( + client, + legacy_download=["ref=mystore/exports/reference"], + legacy_mount=["ref=mystore/exports/reference"], + ) + + +def test_build_command_inputs_legacy_asset_warns_mount_asset() -> None: + """A legacy --mount asset value is told to use --mount-asset specifically.""" + client = Mock() + client.data.get.return_value = Mock(id="azureml:resolved-asset:1") + with patch("submit_aml.data.logger") as mock_logger: + build_command_inputs(client, legacy_mount=["my_alias=data_asset"]) + message = _deprecation_log(mock_logger) + assert "--mount-asset my_alias=data_asset" in message + assert "--mount-datastore" not in message + assert "--mount-job" not in message + # CLI users should not see Python parameter names. + assert "mount_asset" not in message + assert "datasets_mount" not in message + + +def test_build_command_inputs_legacy_datastore_warns_mount_datastore() -> None: + """A legacy --mount datastore value is told to use --mount-datastore.""" + client = Mock() + with patch("submit_aml.data.logger") as mock_logger: + build_command_inputs(client, legacy_mount=["ref=mystore/exports/reference"]) + message = _deprecation_log(mock_logger) + assert "--mount-datastore ref=mystore/exports/reference" in message + + +def test_build_command_inputs_legacy_job_warns_download_job_translated() -> None: + """A legacy --download job value is told to use --download-job, sans prefix.""" + client = Mock() + with patch("submit_aml.data.logger") as mock_logger: + build_command_inputs( + client, + legacy_download=["ckpt=job_dir:my_job_123:models/best.pth"], + ) + message = _deprecation_log(mock_logger) + assert "--download-job ckpt=my_job_123:models/best.pth" in message + # The suggested replacement (after "with") drops the legacy job_dir: prefix. + assert "job_dir:" not in message.split("with", 1)[1] + + +# --------------------------------------------------------------------------- +# build_command_outputs +# --------------------------------------------------------------------------- + + +def test_build_command_outputs_empty() -> None: + """No arguments produce an empty dict.""" + assert build_command_outputs() == {} + + +def test_build_command_outputs_datastore_and_asset() -> None: + """Datastore and asset outputs are both built.""" + outputs = build_command_outputs( + output_datastore=["out_dir=mydatastore/my_dataset"], + output_asset=["asset_dir=my-results:2"], + ) + assert outputs["out_dir"].path == ( + "azureml://datastores/mydatastore/paths/my_dataset" + ) + assert outputs["asset_dir"].name == "my-results" + assert outputs["asset_dir"].version == "2" + + +def test_build_command_outputs_duplicate_alias_raises() -> None: + """The same alias under two output flags is a hard error.""" + with pytest.raises(ValueError, match=r"[Dd]uplicate.*alias.*out_dir"): + build_command_outputs( + output_datastore=["out_dir=mydatastore/my_dataset"], + output_asset=["out_dir=my-results:2"], + ) + + +def test_build_command_outputs_legacy_warns() -> None: + """Legacy --output strings are built and emit a targeted deprecation warning.""" + with patch("submit_aml.data.logger") as mock_logger: + outputs = build_command_outputs( + legacy_output=["out_dir=mydatastore/my_dataset"] + ) assert "out_dir" in outputs - output = outputs["out_dir"] - assert "mydatastore" in output.path - assert "my_dataset" in output.path + message = _deprecation_log(mock_logger) + assert "deprecated" in message.lower() + assert "--output-datastore out_dir=mydatastore/my_dataset" in message + + +def test_build_command_outputs_legacy_raises_deprecation_warning() -> None: + """The Python API raises a DeprecationWarning naming the new parameter.""" + with pytest.warns(DeprecationWarning, match=r"output_datastore") as record: + build_command_outputs(legacy_output=["out_dir=mydatastore/my_dataset"]) + message = str(record[0].message) + assert "--output" not in message + assert "['out_dir=mydatastore/my_dataset']" in message