Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ classifiers = [
dependencies = [
"azure-ai-ml",
"azure-identity",
"loguru",
"rich",
"tomli>=1.1.0; python_version < '3.11'",
"typer>=0.16.0",
Expand Down
10 changes: 4 additions & 6 deletions src/submit_aml/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from typing import Optional

import typer
from rich.console import Console

from .aml import CredentialType
from .aml import submit_to_aml
from .command import get_sweep_inputs_from_args
from .config import get_default
from .environment import parse_key_value_pairs
from .logger import logger
from .errors import report_exception

PANEL_AZURE = "Azure"
PANEL_COMMAND = "Command"
Expand Down Expand Up @@ -439,10 +438,9 @@ def submit(
wait_for_completion=stream_logs,
workspace_name=workspace_name,
)
except Exception:
logger.critical("Failed to submit job to Azure ML. Reason:")
console = Console()
console.print_exception()
except Exception as exc:
report_exception(exc, message="The Azure ML job could not be submitted.")
raise typer.Exit(code=1) from exc


if __name__ == "__main__":
Expand Down
185 changes: 133 additions & 52 deletions src/submit_aml/aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from azure.ai.ml.sweep import Choice
from azure.identity import AzureCliCredential
from azure.identity import ManagedIdentityCredential
from rich.console import Console
from rich import box
from rich.table import Table
from rich.text import Text

from .command import TypeServices
from .command import add_service_for_debugging
Expand All @@ -32,6 +34,8 @@
from .environment import add_profiler_env_variables
from .environment import infer_environment
from .environment import log_environment_variables
from .logger import console
from .logger import indent
from .logger import logger
from .logger import suppress_azure_warnings
from .paths import get_cwd
Expand All @@ -40,6 +44,57 @@
TypeInputsDict = dict[str, Input | Choice]
_MAX_SWEEP_DESCRIPTION_LENGTH = 511

_NUMBER_WORDS = {
0: "zero",
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
6: "six",
7: "seven",
8: "eight",
9: "nine",
}


def _spell_number(number: int) -> str:
"""Return a natural-language representation of a small non-negative number.

Numbers below ten are spelled out; larger numbers are returned as digits.

Examples:
>>> _spell_number(1)
'one'
>>> _spell_number(9)
'nine'
>>> _spell_number(10)
'10'

Args:
number: The number to render.
"""
return _NUMBER_WORDS.get(number, str(number))


def _count_noun(number: int, noun: str) -> str:
"""Return a spelled-out count followed by a correctly pluralised noun.

Examples:
>>> _count_noun(1, "node")
'one node'
>>> _count_noun(2, "node")
'two nodes'
>>> _count_noun(4, "GPU")
'four GPUs'

Args:
number: The count.
noun: The singular noun to pluralise when ``number`` is not one.
"""
plural = noun if number == 1 else f"{noun}s"
return f"{_spell_number(number)} {plural}"


class CredentialType(str, Enum):
"""Credential type used to authenticate with Azure ML."""
Expand Down Expand Up @@ -175,15 +230,15 @@ def setup(
if num_gpus is None:
instance_count = num_nodes
distribution = MpiDistribution()
logger.info(f'Using "MPI" distribution with {num_nodes} nodes.')
logger.info(f'Using "MPI" distribution with {_count_noun(num_nodes, "node")}.')
else:
instance_count = num_nodes
distribution = PyTorchDistribution(
process_count_per_instance=num_gpus,
)
logger.info(
f'Using "PyTorch" distribution with {num_nodes} nodes and '
f"{num_gpus} GPUs per node."
f'Using "PyTorch" distribution with {_count_noun(num_nodes, "node")}'
f" and {_count_noun(num_gpus, 'GPU')} per node."
)

experiment_name = _sanitize_experiment_name(experiment_name)
Expand Down Expand Up @@ -231,22 +286,12 @@ def _submit(

start_msg = "Submitting job to Azure Machine Learning..."
end_msg = "Job submitted successfully"
with report_time(start_msg, end_msg):
# The SDK renders its own tqdm upload progress bar, so we don't show a
# spinner here (two live displays would clash).
with report_time(start_msg, end_msg, spinner=False):
returned_job = ml_client.create_or_update(command_job)

logger.info(f'Run ID: "{returned_job.name}"')

if returned_job.display_name is not None:
logger.info(f'Display name: "{returned_job.display_name}"')

logger.info("Studio URL:")
assert returned_job.services is not None
url = returned_job.services["Studio"].endpoint
# Log the run URL. We use this instead of the logger so the URL is clickable and
# not split over multiple lines.
# See https://github.com/Textualize/rich/issues/886#issuecomment-756406589
# for more details.
Console().print(url, style=f"link {url}")
_print_job_summary(returned_job)

if wait_for_completion:
logger.info("Starting logs streaming...")
Expand All @@ -256,6 +301,38 @@ def _submit(
return returned_job


def _print_job_summary(job: Job) -> None:
"""Print a summary table for a submitted job.

Args:
job: The job returned by Azure ML after submission.
"""
assert job.services is not None
url = job.services["Studio"].endpoint
assert url is not None

table = Table(
box=box.ROUNDED,
show_header=False,
title="Job submitted",
title_style="bold green",
title_justify="left",
)
table.add_column(style="bold cyan", justify="right")
table.add_column(overflow="fold")

table.add_row("Run ID", job.name)
if job.display_name is not None:
table.add_row("Display name", job.display_name)
if job.experiment_name is not None:
table.add_row("Experiment", job.experiment_name)
# Render the URL as a hyperlink so it stays clickable even if it wraps.
# See https://github.com/Textualize/rich/issues/886#issuecomment-756406589.
table.add_row("Studio URL", Text(url, style=f"link {url}"))

Comment thread
fepegar marked this conversation as resolved.
console.print(table)


def submit_to_aml(
*,
aml_environment: str | None = None,
Expand Down Expand Up @@ -365,41 +442,45 @@ def submit_to_aml(
"Conda environments manage their own"
" dependencies."
)
(
source_dir,
project_dir,
script_path,
ml_client,
description,
instance_count,
distribution,
experiment_name,
) = setup(
source_dir,
project_dir,
script_path,
subscription_id,
resource_group,
workspace_name,
description,
num_gpus,
num_nodes,
experiment_name,
credential_type=credential_type,
)
logger.info("Configuring job...")
with indent():
(
source_dir,
project_dir,
script_path,
ml_client,
description,
instance_count,
distribution,
experiment_name,
) = setup(
source_dir,
project_dir,
script_path,
subscription_id,
resource_group,
workspace_name,
description,
num_gpus,
num_nodes,
experiment_name,
credential_type=credential_type,
)

environment = infer_environment(
ml_client=ml_client,
project_dir=project_dir,
base_docker_image=base_docker_image,
dependency_groups=dependency_groups,
optional_dependencies=optional_dependencies,
aml_environment=aml_environment,
build_docker_context=build_docker_context,
conda_env_file=conda_env_file,
docker_run=docker_run,
dry_run=dry_run,
)
logger.info("Preparing environment...")
with indent():
environment = infer_environment(
ml_client=ml_client,
project_dir=project_dir,
base_docker_image=base_docker_image,
dependency_groups=dependency_groups,
optional_dependencies=optional_dependencies,
aml_environment=aml_environment,
build_docker_context=build_docker_context,
conda_env_file=conda_env_file,
docker_run=docker_run,
dry_run=dry_run,
)
if only_environment:
msg = "The environment build has been submitted. No job will be submitted."
logger.warning(msg)
Expand Down
25 changes: 14 additions & 11 deletions src/submit_aml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from azure.ai.ml.exceptions import MlException

from .logger import logger
from .progress import report_time

TypeInputsDict = dict[str, Input | SweepDistribution]
TypeOptionalStrList = list[str] | None
Expand Down Expand Up @@ -208,17 +209,19 @@ def _get_data_assets(
else:
kwargs = {"version": version}

logger.info(f'Retrieving data asset "{path}"...')
try:
data = ml_client.data.get(name=path, **kwargs)
except MlException as e:
msg = (
"Error getting data asset with"
f' name "{path}"'
f' and version "{version}"'
)
raise ValueError(msg) from e
logger.success(f'Found data asset with path "{path}"')
with report_time(
f'Retrieving data asset "{path}"...',
f'Retrieved data asset "{path}"',
):
try:
data = ml_client.data.get(name=path, **kwargs)
except MlException as e:
msg = (
"Error getting data asset with"
f' name "{path}"'
f' and version "{version}"'
)
raise ValueError(msg) from e
inputs[alias] = Input(
path=data.id,
mode=mode,
Expand Down
16 changes: 10 additions & 6 deletions src/submit_aml/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .defaults import DEFAULT_DOCKER_IMAGE
from .defaults import DEFAULT_UV_SYNC_COMMAND
from .logger import logger
from .progress import report_time


def parse_key_value_pairs(
Expand Down Expand Up @@ -393,17 +394,20 @@ def _register_environment(ml_client: MLClient, environment: Environment) -> Envi
kwargs = {"label": "latest"}
else:
kwargs = {"version": environment.version}
logger.info(
f'Checking if environment "{environment.name}" ({kwargs}) exists...'
)
env = ml_client.environments.get(environment.name, **kwargs)
start_msg = f'Checking if environment "{environment.name}" ({kwargs}) exists...'
with report_time(start_msg, "Environment lookup complete"):
env = ml_client.environments.get(environment.name, **kwargs)
msg = (
f'Found a registered environment with name "{environment.name}"'
f' and version "{env.version}"'
)
logger.info(msg)
except ResourceNotFoundError:
logger.info("Environment not found. Registering a new one...")
env = ml_client.environments.create_or_update(environment)
with report_time(
"Environment not found. Registering a new one...",
"Environment registered",
spinner=False,
):
env = ml_client.environments.create_or_update(environment)
logger.info(f'Registered environment: "{env.name}" (version: "{env.version}")')
return env
Loading
Loading