diff --git a/pyproject.toml b/pyproject.toml index c5c422e..e0b9548 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ classifiers = [ dependencies = [ "azure-ai-ml", "azure-identity", - "loguru", "rich", "tomli>=1.1.0; python_version < '3.11'", "typer>=0.16.0", diff --git a/src/submit_aml/__main__.py b/src/submit_aml/__main__.py index 4e4699f..e7c4bc7 100644 --- a/src/submit_aml/__main__.py +++ b/src/submit_aml/__main__.py @@ -5,14 +5,13 @@ from typing import Optional import typer -from rich.console import Console from .aml import CredentialType from .aml import submit_to_aml from .command import get_sweep_inputs_from_args from .config import get_default from .environment import parse_key_value_pairs -from .logger import logger +from .errors import report_exception PANEL_AZURE = "Azure" PANEL_COMMAND = "Command" @@ -439,10 +438,9 @@ def submit( wait_for_completion=stream_logs, workspace_name=workspace_name, ) - except Exception: - logger.critical("Failed to submit job to Azure ML. Reason:") - console = Console() - console.print_exception() + except Exception as exc: + report_exception(exc, message="The Azure ML job could not be submitted.") + raise typer.Exit(code=1) from exc if __name__ == "__main__": diff --git a/src/submit_aml/aml.py b/src/submit_aml/aml.py index a914c93..c269bab 100644 --- a/src/submit_aml/aml.py +++ b/src/submit_aml/aml.py @@ -14,7 +14,9 @@ from azure.ai.ml.sweep import Choice from azure.identity import AzureCliCredential from azure.identity import ManagedIdentityCredential -from rich.console import Console +from rich import box +from rich.table import Table +from rich.text import Text from .command import TypeServices from .command import add_service_for_debugging @@ -32,6 +34,8 @@ from .environment import add_profiler_env_variables from .environment import infer_environment from .environment import log_environment_variables +from .logger import console +from .logger import indent from .logger import logger from .logger import suppress_azure_warnings from .paths import get_cwd @@ -40,6 +44,57 @@ TypeInputsDict = dict[str, Input | Choice] _MAX_SWEEP_DESCRIPTION_LENGTH = 511 +_NUMBER_WORDS = { + 0: "zero", + 1: "one", + 2: "two", + 3: "three", + 4: "four", + 5: "five", + 6: "six", + 7: "seven", + 8: "eight", + 9: "nine", +} + + +def _spell_number(number: int) -> str: + """Return a natural-language representation of a small non-negative number. + + Numbers below ten are spelled out; larger numbers are returned as digits. + + Examples: + >>> _spell_number(1) + 'one' + >>> _spell_number(9) + 'nine' + >>> _spell_number(10) + '10' + + Args: + number: The number to render. + """ + return _NUMBER_WORDS.get(number, str(number)) + + +def _count_noun(number: int, noun: str) -> str: + """Return a spelled-out count followed by a correctly pluralised noun. + + Examples: + >>> _count_noun(1, "node") + 'one node' + >>> _count_noun(2, "node") + 'two nodes' + >>> _count_noun(4, "GPU") + 'four GPUs' + + Args: + number: The count. + noun: The singular noun to pluralise when ``number`` is not one. + """ + plural = noun if number == 1 else f"{noun}s" + return f"{_spell_number(number)} {plural}" + class CredentialType(str, Enum): """Credential type used to authenticate with Azure ML.""" @@ -175,15 +230,15 @@ def setup( if num_gpus is None: instance_count = num_nodes distribution = MpiDistribution() - logger.info(f'Using "MPI" distribution with {num_nodes} nodes.') + logger.info(f'Using "MPI" distribution with {_count_noun(num_nodes, "node")}.') else: instance_count = num_nodes distribution = PyTorchDistribution( process_count_per_instance=num_gpus, ) logger.info( - f'Using "PyTorch" distribution with {num_nodes} nodes and ' - f"{num_gpus} GPUs per node." + f'Using "PyTorch" distribution with {_count_noun(num_nodes, "node")}' + f" and {_count_noun(num_gpus, 'GPU')} per node." ) experiment_name = _sanitize_experiment_name(experiment_name) @@ -231,22 +286,12 @@ def _submit( start_msg = "Submitting job to Azure Machine Learning..." end_msg = "Job submitted successfully" - with report_time(start_msg, end_msg): + # The SDK renders its own tqdm upload progress bar, so we don't show a + # spinner here (two live displays would clash). + with report_time(start_msg, end_msg, spinner=False): returned_job = ml_client.create_or_update(command_job) - logger.info(f'Run ID: "{returned_job.name}"') - - if returned_job.display_name is not None: - logger.info(f'Display name: "{returned_job.display_name}"') - - logger.info("Studio URL:") - assert returned_job.services is not None - url = returned_job.services["Studio"].endpoint - # Log the run URL. We use this instead of the logger so the URL is clickable and - # not split over multiple lines. - # See https://github.com/Textualize/rich/issues/886#issuecomment-756406589 - # for more details. - Console().print(url, style=f"link {url}") + _print_job_summary(returned_job) if wait_for_completion: logger.info("Starting logs streaming...") @@ -256,6 +301,38 @@ def _submit( return returned_job +def _print_job_summary(job: Job) -> None: + """Print a summary table for a submitted job. + + Args: + job: The job returned by Azure ML after submission. + """ + assert job.services is not None + url = job.services["Studio"].endpoint + assert url is not None + + table = Table( + box=box.ROUNDED, + show_header=False, + title="Job submitted", + title_style="bold green", + title_justify="left", + ) + table.add_column(style="bold cyan", justify="right") + table.add_column(overflow="fold") + + table.add_row("Run ID", job.name) + if job.display_name is not None: + table.add_row("Display name", job.display_name) + if job.experiment_name is not None: + table.add_row("Experiment", job.experiment_name) + # Render the URL as a hyperlink so it stays clickable even if it wraps. + # See https://github.com/Textualize/rich/issues/886#issuecomment-756406589. + table.add_row("Studio URL", Text(url, style=f"link {url}")) + + console.print(table) + + def submit_to_aml( *, aml_environment: str | None = None, @@ -365,41 +442,45 @@ def submit_to_aml( "Conda environments manage their own" " dependencies." ) - ( - source_dir, - project_dir, - script_path, - ml_client, - description, - instance_count, - distribution, - experiment_name, - ) = setup( - source_dir, - project_dir, - script_path, - subscription_id, - resource_group, - workspace_name, - description, - num_gpus, - num_nodes, - experiment_name, - credential_type=credential_type, - ) + logger.info("Configuring job...") + with indent(): + ( + source_dir, + project_dir, + script_path, + ml_client, + description, + instance_count, + distribution, + experiment_name, + ) = setup( + source_dir, + project_dir, + script_path, + subscription_id, + resource_group, + workspace_name, + description, + num_gpus, + num_nodes, + experiment_name, + credential_type=credential_type, + ) - environment = infer_environment( - ml_client=ml_client, - project_dir=project_dir, - base_docker_image=base_docker_image, - dependency_groups=dependency_groups, - optional_dependencies=optional_dependencies, - aml_environment=aml_environment, - build_docker_context=build_docker_context, - conda_env_file=conda_env_file, - docker_run=docker_run, - dry_run=dry_run, - ) + logger.info("Preparing environment...") + with indent(): + environment = infer_environment( + ml_client=ml_client, + project_dir=project_dir, + base_docker_image=base_docker_image, + dependency_groups=dependency_groups, + optional_dependencies=optional_dependencies, + aml_environment=aml_environment, + build_docker_context=build_docker_context, + conda_env_file=conda_env_file, + docker_run=docker_run, + dry_run=dry_run, + ) if only_environment: msg = "The environment build has been submitted. No job will be submitted." logger.warning(msg) diff --git a/src/submit_aml/data.py b/src/submit_aml/data.py index b52685c..410568b 100644 --- a/src/submit_aml/data.py +++ b/src/submit_aml/data.py @@ -11,6 +11,7 @@ from azure.ai.ml.exceptions import MlException from .logger import logger +from .progress import report_time TypeInputsDict = dict[str, Input | SweepDistribution] TypeOptionalStrList = list[str] | None @@ -208,17 +209,19 @@ def _get_data_assets( else: kwargs = {"version": version} - logger.info(f'Retrieving data asset "{path}"...') - try: - data = ml_client.data.get(name=path, **kwargs) - except MlException as e: - msg = ( - "Error getting data asset with" - f' name "{path}"' - f' and version "{version}"' - ) - raise ValueError(msg) from e - logger.success(f'Found data asset with path "{path}"') + with report_time( + f'Retrieving data asset "{path}"...', + f'Retrieved data asset "{path}"', + ): + try: + data = ml_client.data.get(name=path, **kwargs) + except MlException as e: + msg = ( + "Error getting data asset with" + f' name "{path}"' + f' and version "{version}"' + ) + raise ValueError(msg) from e inputs[alias] = Input( path=data.id, mode=mode, diff --git a/src/submit_aml/environment.py b/src/submit_aml/environment.py index c3e68ee..3fb9ae5 100644 --- a/src/submit_aml/environment.py +++ b/src/submit_aml/environment.py @@ -13,6 +13,7 @@ from .defaults import DEFAULT_DOCKER_IMAGE from .defaults import DEFAULT_UV_SYNC_COMMAND from .logger import logger +from .progress import report_time def parse_key_value_pairs( @@ -393,17 +394,20 @@ def _register_environment(ml_client: MLClient, environment: Environment) -> Envi kwargs = {"label": "latest"} else: kwargs = {"version": environment.version} - logger.info( - f'Checking if environment "{environment.name}" ({kwargs}) exists...' - ) - env = ml_client.environments.get(environment.name, **kwargs) + start_msg = f'Checking if environment "{environment.name}" ({kwargs}) exists...' + with report_time(start_msg, "Environment lookup complete"): + env = ml_client.environments.get(environment.name, **kwargs) msg = ( f'Found a registered environment with name "{environment.name}"' f' and version "{env.version}"' ) logger.info(msg) except ResourceNotFoundError: - logger.info("Environment not found. Registering a new one...") - env = ml_client.environments.create_or_update(environment) + with report_time( + "Environment not found. Registering a new one...", + "Environment registered", + spinner=False, + ): + env = ml_client.environments.create_or_update(environment) logger.info(f'Registered environment: "{env.name}" (version: "{env.version}")') return env diff --git a/src/submit_aml/errors.py b/src/submit_aml/errors.py new file mode 100644 index 0000000..943b8ca --- /dev/null +++ b/src/submit_aml/errors.py @@ -0,0 +1,92 @@ +import sys +import tempfile +import traceback +from datetime import datetime +from datetime import timezone +from pathlib import Path + +from rich.panel import Panel +from rich.text import Text + +from .logger import console +from .logger import logger + + +def _summarize_exception(exc: BaseException) -> str: + """Return a concise one-line summary of an exception. + + The summary combines the exception class name with the first non-empty line + of its message, so verbose multi-line errors are reduced to their headline. + + Args: + exc: The exception to summarise. + + Returns: + A single-line summary string. + """ + message = str(exc).strip() + if not message: + return exc.__class__.__name__ + first_line = message.splitlines()[0].strip() + return f"{exc.__class__.__name__}: {first_line}" + + +def write_traceback_log(exc: BaseException) -> Path: + """Write the full traceback of an exception to a temporary log file. + + The log also records the command line and a timestamp to make the file + useful for debugging and bug reports. + + Args: + exc: The exception whose traceback should be saved. + + Returns: + The path to the created log file. + """ + handle = tempfile.NamedTemporaryFile( + mode="w", + prefix="submit-aml-", + suffix=".log", + delete=False, + encoding="utf-8", + ) + with handle: + handle.write(f"Command: {' '.join(sys.argv)}\n") + handle.write(f"Timestamp: {datetime.now(timezone.utc).isoformat()}\n\n") + handle.write( + "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + ) + return Path(handle.name) + + +def report_exception(exc: BaseException, *, message: str) -> Path: + """Report an exception to the user without dumping the full traceback. + + A concise, pretty error panel is printed together with the path to a + temporary log file containing the complete traceback. + + Args: + exc: The exception to report. + message: A short, human-readable description of what failed. + + Returns: + The path to the log file holding the full traceback. + """ + log_path = write_traceback_log(exc) + + body = Text() + body.append(message, style="bold") + body.append("\n\n") + body.append(_summarize_exception(exc), style="red") + + console.print() + console.print( + Panel( + body, + title="Error", + title_align="left", + border_style="red", + ) + ) + logger.info(f"Full traceback written to {log_path}") + return log_path diff --git a/src/submit_aml/logger.py b/src/submit_aml/logger.py index 101fbf4..72521a3 100644 --- a/src/submit_aml/logger.py +++ b/src/submit_aml/logger.py @@ -1,12 +1,182 @@ +import contextvars import logging +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any -from loguru import logger -from rich.logging import RichHandler +from rich.console import Console +from rich.highlighter import ReprHighlighter +from rich.text import Text from .defaults import DEFAULT_LOGGERS_TO_SUPPRESS -# https://github.com/Textualize/rich/issues/163#issuecomment-661023060 -logger.configure(handlers=[{"sink": RichHandler(), "format": "{message}"}]) +console = Console() +"""Shared rich console used for all console output. + +Reusing a single console instance lets log lines render cleanly above an active +spinner (rich requires the same console for both the live display and any other +output). +""" + +_highlighter = ReprHighlighter() +"""Rich highlighter that colours numbers, paths, quoted strings, URLs, etc.""" + +_INDENT = " " +"""String used for a single level of indentation.""" + +_depth: contextvars.ContextVar[int] = contextvars.ContextVar( + "submit_aml_log_depth", + default=0, +) +"""Current logging depth, used to indent nested output.""" + +# Glyph and rich style for each log level. The glyphs are flat (non-emoji) +# symbols so they share a consistent look and only the colour varies. +_LEVEL_STYLES: dict[str, tuple[str, str]] = { + "DEBUG": ("·", "dim"), + "INFO": ("•", "cyan"), + "SUCCESS": ("✓", "green"), + "WARNING": ("▲", "yellow"), + "ERROR": ("✗", "red"), + "CRITICAL": ("✗", "bold red"), +} + +# Levels for which the whole message (not just the glyph) is coloured. +_COLOURED_MESSAGE_LEVELS = frozenset({"DEBUG", "WARNING", "ERROR", "CRITICAL"}) + + +@contextmanager +def indent() -> Iterator[None]: + """Increase the logging depth within the context. + + Any output emitted through [`logger`][submit_aml.logger.logger] (or a + spinner) while this context is active is indented one extra level. The + previous depth is restored on exit, even if an exception is raised. + + Yields: + ``None``. + """ + token = _depth.set(_depth.get() + 1) + try: + yield + finally: + _depth.reset(token) + + +def get_depth() -> int: + """Return the current logging depth. + + Returns: + The number of active indentation levels. + """ + return _depth.get() + + +def format_log_line( + level_name: str, + message: str, + depth: int, + *, + highlight: bool = True, + width: int | None = None, +) -> Text: + """Render a log message as indented, glyph-prefixed rich text. + + The message is prefixed with ``depth`` levels of indentation and a coloured + glyph for the level. Long or multi-line messages are wrapped with a hanging + indent so continuation lines align under the message. For informational + levels the message is passed through rich's highlighter so numbers, paths, + quoted strings and URLs are colourised. + + Args: + level_name: Name of the log level (e.g. ``"INFO"``). + message: The (already interpolated) message to render. + depth: Number of indentation levels to apply. + highlight: Whether to apply rich highlighting to the message. Ignored + for levels whose whole message is already coloured. + width: Total width available for the rendered line. Defaults to the + shared console width. + + Returns: + A [`rich.text.Text`][] instance ready to be printed. + """ + glyph, style = _LEVEL_STYLES.get(level_name, ("•", "")) + text_style = style if level_name in _COLOURED_MESSAGE_LEVELS else "" + + indentation = _INDENT * depth + prefix_width = len(indentation) + len(glyph) + 1 + continuation = " " * prefix_width + + message_text = Text(message, style=text_style) + if highlight and not text_style: + _highlighter.highlight(message_text) + + if width is None: + width = console.size.width + wrap_width = max(width - prefix_width, 1) + lines = message_text.wrap(console, wrap_width) + + text = Text() + for index, line in enumerate(lines): + if index > 0: + text.append("\n") + if index == 0: + text.append(indentation) + text.append(glyph, style=style) + text.append(" ") + else: + text.append(continuation) + text.append_text(line) + return text + + +class Logger: + """Minimal rich-backed logger with a loguru-compatible interface. + + Each level method accepts an optional set of positional and keyword + arguments. When any are given, the message is formatted with + ``message.format(*args, **kwargs)``; otherwise it is used verbatim, so + f-string messages containing literal braces are left untouched. + """ + + def _log( + self, + level_name: str, + message: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> None: + if args or kwargs: + message = message.format(*args, **kwargs) + console.print(format_log_line(level_name, message, _depth.get())) + + def debug(self, message: str, *args: Any, **kwargs: Any) -> None: + """Log a debug message.""" + self._log("DEBUG", message, args, kwargs) + + def info(self, message: str, *args: Any, **kwargs: Any) -> None: + """Log an informational message.""" + self._log("INFO", message, args, kwargs) + + def success(self, message: str, *args: Any, **kwargs: Any) -> None: + """Log a success message.""" + self._log("SUCCESS", message, args, kwargs) + + def warning(self, message: str, *args: Any, **kwargs: Any) -> None: + """Log a warning message.""" + self._log("WARNING", message, args, kwargs) + + def error(self, message: str, *args: Any, **kwargs: Any) -> None: + """Log an error message.""" + self._log("ERROR", message, args, kwargs) + + def critical(self, message: str, *args: Any, **kwargs: Any) -> None: + """Log a critical message.""" + self._log("CRITICAL", message, args, kwargs) + + +logger = Logger() +"""Module-level logger singleton used throughout the package.""" def suppress_azure_warnings(modules: list[str] | None = None) -> None: diff --git a/src/submit_aml/progress.py b/src/submit_aml/progress.py index ca5b2b5..369b10c 100644 --- a/src/submit_aml/progress.py +++ b/src/submit_aml/progress.py @@ -1,12 +1,22 @@ import time +from collections.abc import Iterator from contextlib import contextmanager -from loguru import logger from rich.progress import Progress from rich.progress import SpinnerColumn from rich.progress import TextColumn from rich.progress import TimeElapsedColumn +from .logger import _INDENT +from .logger import console +from .logger import get_depth +from .logger import indent +from .logger import logger + +# Whether a spinner is currently being displayed. Rich allows only one live +# display per console at a time, so nested spinners fall back to plain logging. +_spinner_active = False + class BarlessProgress(Progress): """A Rich progress display with a spinner and elapsed time, but no progress bar.""" @@ -17,47 +27,59 @@ def __init__(self, *args, **kwargs): TextColumn("[progress.description]{task.description}"), TimeElapsedColumn(), ] - super().__init__(*columns, *args, **kwargs) + super().__init__(*columns, *args, console=console, **kwargs) @contextmanager -def report_time_fancy(start_msg: str, end_msg: str): - """Context manager that shows a Rich progress bar and logs elapsed time. - - Displays a spinner with ``start_msg`` while the block executes, then logs - ``end_msg`` together with the elapsed time on completion. +def report_time( + start_msg: str, + end_msg: str, + *, + spinner: bool = True, +) -> Iterator[None]: + """Show a spinner while a block runs and report the elapsed time. + + While the block executes, a spinner with a live elapsed-time counter is + displayed next to ``start_msg`` so the user gets feedback that work is in + progress. Output emitted inside the block is indented one level and rendered + above the spinner. When the block completes, the spinner is cleared and + ``end_msg`` is logged together with the elapsed time. + + The spinner is skipped (``start_msg`` is logged as a plain header instead) + when ``spinner`` is ``False``, when the console is not a terminal, or when a + spinner is already active (rich allows only one live display per console). + Disable the spinner for operations that render their own progress output + (e.g. uploads), so the two live displays do not clash. Args: - start_msg: Message shown in the progress bar during execution. + start_msg: Message shown next to the spinner during execution. end_msg: Message logged after the block completes. - """ - begin = time.time() - with BarlessProgress() as progress: - task = progress.add_task(start_msg, total=1) - yield - progress.update(task, advance=1) - end = time.time() - delta = _natural_delta(end - begin) - logger.success(f"{end_msg} in {delta}!") - + spinner: Whether to display a spinner. Set to ``False`` for operations + that print their own progress. -@contextmanager -def report_time(start_msg: str, end_msg: str): - """Context manager that logs start/end messages with elapsed time. - - Logs ``start_msg`` before the block executes and ``end_msg`` together with - the elapsed time after it completes. - - Args: - start_msg: Message logged before execution begins. - end_msg: Message logged after the block completes. + Yields: + ``None``. """ + global _spinner_active begin = time.time() - logger.info(start_msg) - yield - end = time.time() - delta = _natural_delta(end - begin) - logger.success(f"{end_msg} in {delta}!") + + if not spinner or _spinner_active or not console.is_terminal: + logger.info(start_msg) + with indent(): + yield + else: + description = f"{_INDENT * get_depth()}{start_msg}" + _spinner_active = True + try: + with BarlessProgress(transient=True) as progress: + progress.add_task(description, total=None) + with indent(): + yield + finally: + _spinner_active = False + + delta = _natural_delta(time.time() - begin) + logger.success(f"{end_msg} in {delta}.") def _natural_delta(delta_seconds: float) -> str: diff --git a/tests/test_environment.py b/tests/test_environment.py index b77de84..3053e0f 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -56,5 +56,5 @@ def test_check_has_patch_without_patch( """A warning is logged when the patch component is missing.""" pv = tmp_path / ".python-version" pv.write_text("3.12\n") - # _check_has_patch logs a warning via loguru; it should not raise. + # _check_has_patch logs a warning; it should not raise. _check_has_patch(pv)