diff --git a/openml/__init__.py b/openml/__init__.py index c49505eb9..e23316d4d 100644 --- a/openml/__init__.py +++ b/openml/__init__.py @@ -18,9 +18,11 @@ # License: BSD 3-Clause from __future__ import annotations +from typing import TYPE_CHECKING + from . import ( _api_calls, - config, + config as _config_module, datasets, evaluations, exceptions, @@ -49,6 +51,11 @@ OpenMLTask, ) +if TYPE_CHECKING: + from .config import OpenMLConfigManager + +config: OpenMLConfigManager = _config_module._config + def populate_cache( task_ids: list[int] | None = None, diff --git a/openml/_api_calls.py b/openml/_api_calls.py index 81296b3da..a72da1b8c 100644 --- a/openml/_api_calls.py +++ b/openml/_api_calls.py @@ -12,7 +12,7 @@ import xml import zipfile from pathlib import Path -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, cast import minio import requests @@ -71,7 +71,7 @@ def resolve_env_proxies(url: str) -> str | None: def _create_url_from_endpoint(endpoint: str) -> str: - url = config.server + url = cast(str, config.server) if not url.endswith("/"): url += "/" url += endpoint @@ -301,7 +301,8 @@ def _file_id_to_url(file_id: int, filename: str | None = None) -> str: Presents the URL how to download a given file id filename is optional """ - openml_url = config.server.split("/api/") + openml_server = cast(str, config.server) + openml_url = openml_server.split("/api/") url = openml_url[0] + f"/data/download/{file_id!s}" if filename is not None: url += "/" + filename diff --git a/openml/cli.py b/openml/cli.py index d0a46e498..c1363ea74 100644 --- a/openml/cli.py +++ b/openml/cli.py @@ -9,6 +9,8 @@ from typing import Callable from urllib.parse import urlparse +from attr import fields + from openml import config @@ -339,7 +341,9 @@ def main() -> None: "'https://openml.github.io/openml-python/main/usage.html#configuration'.", ) - configurable_fields = [f for f in config._defaults if f not in ["max_retries"]] + configurable_fields = [ + f.name for f in fields(config.OpenMLConfig) if f.name not in ["max_retries"] + ] parser_configure.add_argument( "field", diff --git a/openml/config.py b/openml/config.py index cf66a6346..2ecb3c64f 100644 --- a/openml/config.py +++ b/openml/config.py @@ -1,6 +1,7 @@ """Store module level information like the API key, cache directory and the server""" # License: BSD 3-Clause +# ruff: noqa: PLW0603 from __future__ import annotations import configparser @@ -11,127 +12,33 @@ import shutil import warnings from contextlib import contextmanager +from dataclasses import dataclass, field, replace from io import StringIO from pathlib import Path from typing import Any, Iterator, cast -from typing_extensions import Literal, TypedDict +from typing_extensions import Literal from urllib.parse import urlparse logger = logging.getLogger(__name__) openml_logger = logging.getLogger("openml") -console_handler: logging.StreamHandler | None = None -file_handler: logging.handlers.RotatingFileHandler | None = None - -OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR" -OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET" -_TEST_SERVER_NORMAL_USER_KEY = "normaluser" - - -class _Config(TypedDict): - apikey: str - server: str - cachedir: Path - avoid_duplicate_runs: bool - retry_policy: Literal["human", "robot"] - connection_n_retries: int - show_progress: bool - - -def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002 - """Creates but does not attach the log handlers.""" - global console_handler, file_handler # noqa: PLW0603 - if console_handler is not None or file_handler is not None: - logger.debug("Requested to create log handlers, but they are already created.") - return - - message_format = "[%(levelname)s] [%(asctime)s:%(name)s] %(message)s" - output_formatter = logging.Formatter(message_format, datefmt="%H:%M:%S") - - console_handler = logging.StreamHandler() - console_handler.setFormatter(output_formatter) - - if create_file_handler: - one_mb = 2**20 - log_path = _root_cache_directory / "openml_python.log" - file_handler = logging.handlers.RotatingFileHandler( - log_path, - maxBytes=one_mb, - backupCount=1, - delay=True, - ) - file_handler.setFormatter(output_formatter) - - -def _convert_log_levels(log_level: int) -> tuple[int, int]: - """Converts a log level that's either defined by OpenML/Python to both specifications.""" - # OpenML verbosity level don't match Python values directly: - openml_to_python = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} - python_to_openml = { - logging.DEBUG: 2, - logging.INFO: 1, - logging.WARNING: 0, - logging.CRITICAL: 0, - logging.ERROR: 0, - } - # Because the dictionaries share no keys, we use `get` to convert as necessary: - openml_level = python_to_openml.get(log_level, log_level) - python_level = openml_to_python.get(log_level, log_level) - return openml_level, python_level - - -def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> None: - """Set handler log level, register it if needed, save setting to config file if specified.""" - _oml_level, py_level = _convert_log_levels(log_level) - handler.setLevel(py_level) - - if openml_logger.level > py_level or openml_logger.level == logging.NOTSET: - openml_logger.setLevel(py_level) - - if handler not in openml_logger.handlers: - openml_logger.addHandler(handler) - - -def set_console_log_level(console_output_level: int) -> None: - """Set console output to the desired level and register it with openml logger if needed.""" - global console_handler # noqa: PLW0602 - assert console_handler is not None - _set_level_register_and_store(console_handler, console_output_level) - - -def set_file_log_level(file_output_level: int) -> None: - """Set file output to the desired level and register it with openml logger if needed.""" - global file_handler # noqa: PLW0602 - assert file_handler is not None - _set_level_register_and_store(file_handler, file_output_level) - - -# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards) -_user_path = Path("~").expanduser().absolute() def _resolve_default_cache_dir() -> Path: - user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR) + user_defined_cache_dir = os.environ.get("OPENML_CACHE_DIR") if user_defined_cache_dir is not None: return Path(user_defined_cache_dir) if platform.system().lower() != "linux": - return _user_path / ".openml" + return Path("~", ".openml") xdg_cache_home = os.environ.get("XDG_CACHE_HOME") if xdg_cache_home is None: return Path("~", ".cache", "openml") - # This is the proper XDG_CACHE_HOME directory, but - # we unfortunately had a problem where we used XDG_CACHE_HOME/org, - # we check heuristically if this old directory still exists and issue - # a warning if it does. There's too much data to move to do this for the user. - - # The new cache directory exists cache_dir = Path(xdg_cache_home) / "openml" if cache_dir.exists(): return cache_dir - # The old cache directory *does not* exist heuristic_dir_for_backwards_compat = Path(xdg_cache_home) / "org" / "openml" if not heuristic_dir_for_backwards_compat.exists(): return cache_dir @@ -147,378 +54,414 @@ def _resolve_default_cache_dir() -> Path: return Path(xdg_cache_home) -_defaults: _Config = { - "apikey": "", - "server": "https://www.openml.org/api/v1/xml", - "cachedir": _resolve_default_cache_dir(), - "avoid_duplicate_runs": False, - "retry_policy": "human", - "connection_n_retries": 5, - "show_progress": False, -} +@dataclass +class OpenMLConfig: + """Dataclass storing the OpenML configuration.""" -# Default values are actually added here in the _setup() function which is -# called at the end of this module -server = _defaults["server"] + apikey: str = "" + server: str = "https://www.openml.org/api/v1/xml" + cachedir: Path = field(default_factory=_resolve_default_cache_dir) + avoid_duplicate_runs: bool = False + retry_policy: Literal["human", "robot"] = "human" + connection_n_retries: int = 5 + show_progress: bool = False + def __setattr__(self, name: str, value: Any) -> None: + if name == "apikey" and value is not None and not isinstance(value, str): + raise ValueError("apikey must be a string or None") -def get_server_base_url() -> str: - """Return the base URL of the currently configured server. + super().__setattr__(name, value) - Turns ``"https://api.openml.org/api/v1/xml"`` in ``"https://www.openml.org/"`` - and ``"https://test.openml.org/api/v1/xml"`` in ``"https://test.openml.org/"`` - Returns - ------- - str - """ - domain, path = server.split("/api", maxsplit=1) - return domain.replace("api", "www") +class OpenMLConfigManager: + """The OpenMLConfigManager manages the configuration of the openml-python package.""" + def __init__(self) -> None: + self.console_handler: logging.StreamHandler | None = None + self.file_handler: logging.handlers.RotatingFileHandler | None = None -apikey: str = _defaults["apikey"] -show_progress: bool = _defaults["show_progress"] -# The current cache directory (without the server name) -_root_cache_directory: Path = Path(_defaults["cachedir"]) -avoid_duplicate_runs = _defaults["avoid_duplicate_runs"] + self.OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR" + self.OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET" + self._TEST_SERVER_NORMAL_USER_KEY = "normaluser" -retry_policy: Literal["human", "robot"] = _defaults["retry_policy"] -connection_n_retries: int = _defaults["connection_n_retries"] + self._user_path = Path("~").expanduser().absolute() + self._config: OpenMLConfig = OpenMLConfig() + self._root_cache_directory: Path = self._config.cachedir -def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = None) -> None: - global retry_policy # noqa: PLW0603 - global connection_n_retries # noqa: PLW0603 - default_retries_by_policy = {"human": 5, "robot": 50} + self.logger = logger + self.openml_logger = openml_logger - if value not in default_retries_by_policy: - raise ValueError( - f"Detected retry_policy '{value}' but must be one of " - f"{list(default_retries_by_policy.keys())}", - ) - if n_retries is not None and not isinstance(n_retries, int): - raise TypeError(f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`.") + self._examples = self.ConfigurationForExamples(self) - if isinstance(n_retries, int) and n_retries < 1: - raise ValueError(f"`n_retries` is '{n_retries}' but must be positive.") + self._setup() - retry_policy = value - connection_n_retries = default_retries_by_policy[value] if n_retries is None else n_retries + def __getattr__(self, name: str) -> Any: + if hasattr(self._config, name): + return getattr(self._config, name) + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + _FIELDS = { # noqa: RUF012 + "apikey", + "server", + "cachedir", + "avoid_duplicate_runs", + "retry_policy", + "connection_n_retries", + "show_progress", + } -class ConfigurationForExamples: - """Allows easy switching to and from a test configuration, used for examples.""" + def __setattr__(self, name: str, value: Any) -> None: + # during __init__ before _config exists + if name in { + "_config", + "_root_cache_directory", + "console_handler", + "file_handler", + "logger", + "openml_logger", + "_examples", + "OPENML_CACHE_DIR_ENV_VAR", + "OPENML_SKIP_PARQUET_ENV_VAR", + "_TEST_SERVER_NORMAL_USER_KEY", + "_user_path", + }: + return object.__setattr__(self, name, value) + + if name in self._FIELDS: + # write into dataclass, not manager (prevents shadowing) + if name == "cachedir": + object.__setattr__(self, "_root_cache_directory", Path(value)) + object.__setattr__(self, "_config", replace(self._config, **{name: value})) + return None + + object.__setattr__(self, name, value) + return None + + def _create_log_handlers(self, create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002 + if self.console_handler is not None or self.file_handler is not None: + self.logger.debug("Requested to create log handlers, but they are already created.") + return - _last_used_server = None - _last_used_key = None - _start_last_called = False - _test_server = "https://test.openml.org/api/v1/xml" - _test_apikey = _TEST_SERVER_NORMAL_USER_KEY + message_format = "[%(levelname)s] [%(asctime)s:%(name)s] %(message)s" + output_formatter = logging.Formatter(message_format, datefmt="%H:%M:%S") - @classmethod - def start_using_configuration_for_example(cls) -> None: - """Sets the configuration to connect to the test server with valid apikey. + self.console_handler = logging.StreamHandler() + self.console_handler.setFormatter(output_formatter) - To configuration as was before this call is stored, and can be recovered - by using the `stop_use_example_configuration` method. - """ - global server # noqa: PLW0603 - global apikey # noqa: PLW0603 + if create_file_handler: + one_mb = 2**20 + log_path = self._root_cache_directory / "openml_python.log" + self.file_handler = logging.handlers.RotatingFileHandler( + log_path, + maxBytes=one_mb, + backupCount=1, + delay=True, + ) + self.file_handler.setFormatter(output_formatter) + + def _convert_log_levels(self, log_level: int) -> tuple[int, int]: + openml_to_python = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + python_to_openml = { + logging.DEBUG: 2, + logging.INFO: 1, + logging.WARNING: 0, + logging.CRITICAL: 0, + logging.ERROR: 0, + } + openml_level = python_to_openml.get(log_level, log_level) + python_level = openml_to_python.get(log_level, log_level) + return openml_level, python_level + + def _set_level_register_and_store(self, handler: logging.Handler, log_level: int) -> None: + _oml_level, py_level = self._convert_log_levels(log_level) + handler.setLevel(py_level) + + if self.openml_logger.level > py_level or self.openml_logger.level == logging.NOTSET: + self.openml_logger.setLevel(py_level) + + if handler not in self.openml_logger.handlers: + self.openml_logger.addHandler(handler) + + def set_console_log_level(self, console_output_level: int) -> None: + """Set the log level for console output.""" + assert self.console_handler is not None + self._set_level_register_and_store(self.console_handler, console_output_level) + + def set_file_log_level(self, file_output_level: int) -> None: + """Set the log level for file output.""" + assert self.file_handler is not None + self._set_level_register_and_store(self.file_handler, file_output_level) + + def get_server_base_url(self) -> str: + """Get the base URL of the OpenML server (i.e., without /api).""" + domain, _ = self._config.server.split("/api", maxsplit=1) + return domain.replace("api", "www") + + def set_retry_policy( + self, value: Literal["human", "robot"], n_retries: int | None = None + ) -> None: + """Set the retry policy for server connections.""" + default_retries_by_policy = {"human": 5, "robot": 50} + + if value not in default_retries_by_policy: + raise ValueError( + f"Detected retry_policy '{value}' but must be one of " + f"{list(default_retries_by_policy.keys())}", + ) + if n_retries is not None and not isinstance(n_retries, int): + raise TypeError( + f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`." + ) - if cls._start_last_called and server == cls._test_server and apikey == cls._test_apikey: - # Method is called more than once in a row without modifying the server or apikey. - # We don't want to save the current test configuration as a last used configuration. - return + if isinstance(n_retries, int) and n_retries < 1: + raise ValueError(f"`n_retries` is '{n_retries}' but must be positive.") - cls._last_used_server = server - cls._last_used_key = apikey - cls._start_last_called = True - - # Test server key for examples - server = cls._test_server - apikey = cls._test_apikey - warnings.warn( - f"Switching to the test server {server} to not upload results to the live server. " - "Using the test server may result in reduced performance of the API!", - stacklevel=2, + self._config = replace( + self._config, + retry_policy=value, + connection_n_retries=( + default_retries_by_policy[value] if n_retries is None else n_retries + ), ) - @classmethod - def stop_using_configuration_for_example(cls) -> None: - """Return to configuration as it was before `start_use_example_configuration`.""" - if not cls._start_last_called: - # We don't want to allow this because it will (likely) result in the `server` and - # `apikey` variables being set to None. - raise RuntimeError( - "`stop_use_example_configuration` called without a saved config." - "`start_use_example_configuration` must be called first.", + def _handle_xdg_config_home_backwards_compatibility(self, xdg_home: str) -> Path: + config_dir = Path(xdg_home) / "openml" + + backwards_compat_config_file = Path(xdg_home) / "config" + if not backwards_compat_config_file.exists(): + return config_dir + + try: + self._parse_config(backwards_compat_config_file) + except Exception: # noqa: BLE001 + return config_dir + + correct_config_location = config_dir / "config" + try: + shutil.copy(backwards_compat_config_file, correct_config_location) + self.openml_logger.warning( + "An openml configuration file was found at the old location " + f"at {backwards_compat_config_file}. We have copied it to the new " + f"location at {correct_config_location}. " + "\nTo silence this warning please verify that the configuration file " + f"at {correct_config_location} is correct and delete the file at " + f"{backwards_compat_config_file}." ) - - global server # noqa: PLW0603 - global apikey # noqa: PLW0603 - - server = cast(str, cls._last_used_server) - apikey = cast(str, cls._last_used_key) - cls._start_last_called = False - - -def _handle_xdg_config_home_backwards_compatibility( - xdg_home: str, -) -> Path: - # NOTE(eddiebergman): A previous bug results in the config - # file being located at `${XDG_CONFIG_HOME}/config` instead - # of `${XDG_CONFIG_HOME}/openml/config`. As to maintain backwards - # compatibility, where users may already may have had a configuration, - # we copy it over an issue a warning until it's deleted. - # As a heurisitic to ensure that it's "our" config file, we try parse it first. - config_dir = Path(xdg_home) / "openml" - - backwards_compat_config_file = Path(xdg_home) / "config" - if not backwards_compat_config_file.exists(): - return config_dir - - # If it errors, that's a good sign it's not ours and we can - # safely ignore it, jumping out of this block. This is a heurisitc - try: - _parse_config(backwards_compat_config_file) - except Exception: # noqa: BLE001 - return config_dir - - # Looks like it's ours, lets try copy it to the correct place - correct_config_location = config_dir / "config" - try: - # We copy and return the new copied location - shutil.copy(backwards_compat_config_file, correct_config_location) - openml_logger.warning( - "An openml configuration file was found at the old location " - f"at {backwards_compat_config_file}. We have copied it to the new " - f"location at {correct_config_location}. " - "\nTo silence this warning please verify that the configuration file " - f"at {correct_config_location} is correct and delete the file at " - f"{backwards_compat_config_file}." - ) - return config_dir - except Exception as e: # noqa: BLE001 - # We failed to copy and its ours, return the old one. - openml_logger.warning( - "While attempting to perform a backwards compatible fix, we " - f"failed to copy the openml config file at " - f"{backwards_compat_config_file}' to {correct_config_location}" - f"\n{type(e)}: {e}", - "\n\nTo silence this warning, please copy the file " - "to the new location and delete the old file at " - f"{backwards_compat_config_file}.", - ) - return backwards_compat_config_file - - -def determine_config_file_path() -> Path: - if platform.system().lower() == "linux": - xdg_home = os.environ.get("XDG_CONFIG_HOME") - if xdg_home is not None: - config_dir = _handle_xdg_config_home_backwards_compatibility(xdg_home) + return config_dir + except Exception as e: # noqa: BLE001 + self.openml_logger.warning( + "While attempting to perform a backwards compatible fix, we " + f"failed to copy the openml config file at " + f"{backwards_compat_config_file}' to {correct_config_location}" + f"\n{type(e)}: {e}", + "\n\nTo silence this warning, please copy the file " + "to the new location and delete the old file at " + f"{backwards_compat_config_file}.", + ) + return backwards_compat_config_file + + def determine_config_file_path(self) -> Path: + """Determine the path to the openml configuration file.""" + if platform.system().lower() == "linux": + xdg_home = os.environ.get("XDG_CONFIG_HOME") + if xdg_home is not None: + config_dir = self._handle_xdg_config_home_backwards_compatibility(xdg_home) + else: + config_dir = Path("~", ".config", "openml") else: - config_dir = Path("~", ".config", "openml") - else: - config_dir = Path("~") / ".openml" - - # Still use os.path.expanduser to trigger the mock in the unit test - config_dir = Path(config_dir).expanduser().resolve() - return config_dir / "config" - - -def _setup(config: _Config | None = None) -> None: - """Setup openml package. Called on first import. - - Reads the config file and sets up apikey, server, cache appropriately. - key and server can be set by the user simply using - openml.config.apikey = THEIRKEY - openml.config.server = SOMESERVER - We could also make it a property but that's less clear. - """ - global apikey # noqa: PLW0603 - global server # noqa: PLW0603 - global _root_cache_directory # noqa: PLW0603 - global avoid_duplicate_runs # noqa: PLW0603 - global show_progress # noqa: PLW0603 - - config_file = determine_config_file_path() - config_dir = config_file.parent - - # read config file, create directory for config file - try: - if not config_dir.exists(): - config_dir.mkdir(exist_ok=True, parents=True) - except PermissionError: - openml_logger.warning( - f"No permission to create OpenML directory at {config_dir}!" - " This can result in OpenML-Python not working properly." - ) - - if config is None: - config = _parse_config(config_file) - - avoid_duplicate_runs = config["avoid_duplicate_runs"] - apikey = config["apikey"] - server = config["server"] - show_progress = config["show_progress"] - n_retries = int(config["connection_n_retries"]) - - set_retry_policy(config["retry_policy"], n_retries) + config_dir = Path("~") / ".openml" + + config_dir = Path(config_dir).expanduser().resolve() + return config_dir / "config" + + def _parse_config(self, config_file: str | Path) -> dict[str, Any]: + config_file = Path(config_file) + config = configparser.RawConfigParser(defaults=OpenMLConfig().__dict__) # type: ignore + + config_file_ = StringIO() + config_file_.write("[FAKE_SECTION]\n") + try: + with config_file.open("r") as fh: + for line in fh: + config_file_.write(line) + except FileNotFoundError: + self.logger.info( + "No config file found at %s, using default configuration.", config_file + ) + except OSError as e: + self.logger.info("Error opening file %s: %s", config_file, e.args[0]) + config_file_.seek(0) + config.read_file(config_file_) + configuration = dict(config.items("FAKE_SECTION")) + for boolean_field in ["avoid_duplicate_runs", "show_progress"]: + if isinstance(config["FAKE_SECTION"][boolean_field], str): + configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore + return configuration # type: ignore + + def start_using_configuration_for_example(self) -> None: + """Sets the configuration to connect to the test server with valid apikey.""" + return self._examples.start_using_configuration_for_example() + + def stop_using_configuration_for_example(self) -> None: + """Store the configuration as it was before `start_use_example_configuration`.""" + return self._examples.stop_using_configuration_for_example() + + def _setup(self, config: dict[str, Any] | None = None) -> None: + config_file = self.determine_config_file_path() + config_dir = config_file.parent + + try: + if not config_dir.exists(): + config_dir.mkdir(exist_ok=True, parents=True) + except PermissionError: + self.openml_logger.warning( + f"No permission to create OpenML directory at {config_dir}!" + " This can result in OpenML-Python not working properly." + ) - user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR) - if user_defined_cache_dir is not None: - short_cache_dir = Path(user_defined_cache_dir) - else: - short_cache_dir = Path(config["cachedir"]) - _root_cache_directory = short_cache_dir.expanduser().resolve() - - try: - cache_exists = _root_cache_directory.exists() - # create the cache subdirectory - if not cache_exists: - _root_cache_directory.mkdir(exist_ok=True, parents=True) - _create_log_handlers() - except PermissionError: - openml_logger.warning( - f"No permission to create OpenML directory at {_root_cache_directory}!" - " This can result in OpenML-Python not working properly." + if config is None: + config = self._parse_config(config_file) + + self._config = replace( + self._config, + apikey=config["apikey"], + server=config["server"], + show_progress=config["show_progress"], + avoid_duplicate_runs=config["avoid_duplicate_runs"], + retry_policy=config["retry_policy"], + connection_n_retries=int(config["connection_n_retries"]), ) - _create_log_handlers(create_file_handler=False) - - -def set_field_in_config_file(field: str, value: Any) -> None: - """Overwrites the `field` in the configuration file with the new `value`.""" - if field not in _defaults: - raise ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.") - - # TODO(eddiebergman): This use of globals has gone too far - globals()[field] = value - config_file = determine_config_file_path() - config = _parse_config(config_file) - with config_file.open("w") as fh: - for f in _defaults: - # We can't blindly set all values based on globals() because when the user - # sets it through config.FIELD it should not be stored to file. - # There doesn't seem to be a way to avoid writing defaults to file with configparser, - # because it is impossible to distinguish from an explicitly set value that matches - # the default value, to one that was set to its default because it was omitted. - value = globals()[f] if f == field else config.get(f) # type: ignore - if value is not None: - fh.write(f"{f} = {value}\n") - - -def _parse_config(config_file: str | Path) -> _Config: - """Parse the config file, set up defaults.""" - config_file = Path(config_file) - config = configparser.RawConfigParser(defaults=_defaults) # type: ignore - - # The ConfigParser requires a [SECTION_HEADER], which we do not expect in our config file. - # Cheat the ConfigParser module by adding a fake section header - config_file_ = StringIO() - config_file_.write("[FAKE_SECTION]\n") - try: - with config_file.open("r") as fh: - for line in fh: - config_file_.write(line) - except FileNotFoundError: - logger.info("No config file found at %s, using default configuration.", config_file) - except OSError as e: - logger.info("Error opening file %s: %s", config_file, e.args[0]) - config_file_.seek(0) - config.read_file(config_file_) - configuration = dict(config.items("FAKE_SECTION")) - for boolean_field in ["avoid_duplicate_runs", "show_progress"]: - if isinstance(config["FAKE_SECTION"][boolean_field], str): - configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore - return configuration # type: ignore - - -def get_config_as_dict() -> _Config: - return { - "apikey": apikey, - "server": server, - "cachedir": _root_cache_directory, - "avoid_duplicate_runs": avoid_duplicate_runs, - "connection_n_retries": connection_n_retries, - "retry_policy": retry_policy, - "show_progress": show_progress, - } - - -# NOTE: For backwards compatibility, we keep the `str` -def get_cache_directory() -> str: - """Get the current cache directory. - - This gets the cache directory for the current server relative - to the root cache directory that can be set via - ``set_root_cache_directory()``. The cache directory is the - ``root_cache_directory`` with additional information on which - subdirectory to use based on the server name. By default it is - ``root_cache_directory / org / openml / www`` for the standard - OpenML.org server and is defined as - ``root_cache_directory / top-level domain / second-level domain / - hostname`` - ``` - - Returns - ------- - cachedir : string - The current cache directory. - - """ - url_suffix = urlparse(server).netloc - reversed_url_suffix = os.sep.join(url_suffix.split(".")[::-1]) # noqa: PTH118 - return os.path.join(_root_cache_directory, reversed_url_suffix) # noqa: PTH118 + self.set_retry_policy(config["retry_policy"], self._config.connection_n_retries) -def set_root_cache_directory(root_cache_directory: str | Path) -> None: - """Set module-wide base cache directory. - - Sets the root cache directory, wherin the cache directories are - created to store content from different OpenML servers. For example, - by default, cached data for the standard OpenML.org server is stored - at ``root_cache_directory / org / openml / www``, and the general - pattern is ``root_cache_directory / top-level domain / second-level - domain / hostname``. - - Parameters - ---------- - root_cache_directory : string - Path to use as cache directory. - - See Also - -------- - get_cache_directory - """ - global _root_cache_directory # noqa: PLW0603 - _root_cache_directory = Path(root_cache_directory) - - -start_using_configuration_for_example = ( - ConfigurationForExamples.start_using_configuration_for_example -) -stop_using_configuration_for_example = ConfigurationForExamples.stop_using_configuration_for_example - + user_defined_cache_dir = os.environ.get(self.OPENML_CACHE_DIR_ENV_VAR) + if user_defined_cache_dir is not None: + short_cache_dir = Path(user_defined_cache_dir) + else: + short_cache_dir = Path(config["cachedir"]) + + self._root_cache_directory = short_cache_dir.expanduser().resolve() + self._config = replace(self._config, cachedir=self._root_cache_directory) + + try: + cache_exists = self._root_cache_directory.exists() + if not cache_exists: + self._root_cache_directory.mkdir(exist_ok=True, parents=True) + self._create_log_handlers() + except PermissionError: + self.openml_logger.warning( + f"No permission to create OpenML directory at {self._root_cache_directory}!" + " This can result in OpenML-Python not working properly." + ) + self._create_log_handlers(create_file_handler=False) + + def set_field_in_config_file(self, field: str, value: Any) -> None: + """Set a field in the configuration file.""" + if not hasattr(OpenMLConfig(), field): + raise ValueError( + f"Field '{field}' is not valid and must be one of " + f"'{OpenMLConfig().__dict__.keys()}'." + ) -@contextmanager -def overwrite_config_context(config: dict[str, Any]) -> Iterator[_Config]: - """A context manager to temporarily override variables in the configuration.""" - existing_config = get_config_as_dict() - merged_config = {**existing_config, **config} + self._config = replace(self._config, **{field: value}) + config_file = self.determine_config_file_path() + existing = self._parse_config(config_file) + with config_file.open("w") as fh: + for f in OpenMLConfig().__dict__: + v = value if f == field else existing.get(f) + if v is not None: + fh.write(f"{f} = {v}\n") + + def get_config_as_dict(self) -> dict[str, Any]: + """Get the current configuration as a dictionary.""" + return self._config.__dict__.copy() + + def get_cache_directory(self) -> str: + """Get the cache directory for the current server.""" + url_suffix = urlparse(self._config.server).netloc + reversed_url_suffix = os.sep.join(url_suffix.split(".")[::-1]) # noqa: PTH118 + return os.path.join(self._root_cache_directory, reversed_url_suffix) # noqa: PTH118 + + def set_root_cache_directory(self, root_cache_directory: str | Path) -> None: + """Set the root cache directory.""" + self._root_cache_directory = Path(root_cache_directory) + self._config = replace(self._config, cachedir=self._root_cache_directory) + + @contextmanager + def overwrite_config_context(self, config: dict[str, Any]) -> Iterator[dict[str, Any]]: + """Overwrite the current configuration within a context manager.""" + existing_config = self.get_config_as_dict() + merged_config = {**existing_config, **config} + + self._setup(merged_config) + yield merged_config + self._setup(existing_config) + + class ConfigurationForExamples: + """Allows easy switching to and from a test configuration, used for examples.""" + + _last_used_server = None + _last_used_key = None + _start_last_called = False + + def __init__(self, manager: OpenMLConfigManager): + self._manager = manager + self._test_apikey = manager._TEST_SERVER_NORMAL_USER_KEY + self._test_server = "https://test.openml.org/api/v1/xml" + + def start_using_configuration_for_example(self) -> None: + """Sets the configuration to connect to the test server with valid apikey. + + To configuration as was before this call is stored, and can be recovered + by using the `stop_use_example_configuration` method. + """ + if ( + self._start_last_called + and self._manager._config.server == self._test_server + and self._manager._config.apikey == self._test_apikey + ): + # Method is called more than once in a row without modifying the server or apikey. + # We don't want to save the current test configuration as a last used configuration. + return + + self._last_used_server = self._manager._config.server + self._last_used_key = self._manager._config.apikey + self._start_last_called = True + + # Test server key for examples + self._manager._config = replace( + self._manager._config, + server=self._test_server, + apikey=self._test_apikey, + ) + warnings.warn( + f"Switching to the test server {self._test_server} to not upload results to " + "the live server. Using the test server may result in reduced performance of the " + "API!", + stacklevel=2, + ) - _setup(merged_config) # type: ignore - yield merged_config # type: ignore + def stop_using_configuration_for_example(self) -> None: + """Return to configuration as it was before `start_use_example_configuration`.""" + if not self._start_last_called: + # We don't want to allow this because it will (likely) result in the `server` and + # `apikey` variables being set to None. + raise RuntimeError( + "`stop_use_example_configuration` called without a saved config." + "`start_use_example_configuration` must be called first.", + ) + + self._manager._config = replace( + self._manager._config, + server=cast("str", self._last_used_server), + apikey=cast("str", self._last_used_key), + ) + self._start_last_called = False - _setup(existing_config) +_config = OpenMLConfigManager() -__all__ = [ - "get_cache_directory", - "set_root_cache_directory", - "start_using_configuration_for_example", - "stop_using_configuration_for_example", - "get_config_as_dict", -] -_setup() +def __getattr__(name: str) -> Any: + return getattr(_config, name) diff --git a/openml/runs/functions.py b/openml/runs/functions.py index 666b75c37..573d91576 100644 --- a/openml/runs/functions.py +++ b/openml/runs/functions.py @@ -18,7 +18,6 @@ import openml import openml._api_calls import openml.utils -from openml import config from openml.exceptions import ( OpenMLCacheException, OpenMLRunsExistError, @@ -107,7 +106,7 @@ def run_model_on_task( # noqa: PLR0913 """ if avoid_duplicate_runs is None: avoid_duplicate_runs = openml.config.avoid_duplicate_runs - if avoid_duplicate_runs and not config.apikey: + if avoid_duplicate_runs and not openml.config.apikey: warnings.warn( "avoid_duplicate_runs is set to True, but no API key is set. " "Please set your API key in the OpenML configuration file, see" @@ -336,7 +335,7 @@ def run_flow_on_task( # noqa: C901, PLR0912, PLR0915, PLR0913 message = f"Executed Task {task.task_id} with Flow id:{run.flow_id}" else: message = f"Executed Task {task.task_id} on local Flow with name {flow.name}." - config.logger.info(message) + openml.config.logger.info(message) return run @@ -528,7 +527,7 @@ def _run_task_get_arffcontent( # noqa: PLR0915, PLR0912, C901 # The forked child process may not copy the configuration state of OpenML from the parent. # Current configuration setup needs to be copied and passed to the child processes. - _config = config.get_config_as_dict() + _config = openml.config.get_config_as_dict() # Execute runs in parallel # assuming the same number of tasks as workers (n_jobs), the total compute time for this # statement will be similar to the slowest run @@ -733,7 +732,7 @@ def _run_task_get_arffcontent_parallel_helper( # noqa: PLR0913 """ # Sets up the OpenML instantiated in the child process to match that of the parent's # if configuration=None, loads the default - config._setup(configuration) + openml.config._setup(configuration) train_indices, test_indices = task.get_train_test_split_indices( repeat=rep_no, @@ -757,7 +756,7 @@ def _run_task_get_arffcontent_parallel_helper( # noqa: PLR0913 else: raise NotImplementedError(task.task_type) - config.logger.info( + openml.config.logger.info( f"Going to run model {model!s} on " f"dataset {openml.datasets.get_dataset(task.dataset_id).name} " f"for repeat {rep_no} fold {fold_no} sample {sample_no}" diff --git a/openml/setups/functions.py b/openml/setups/functions.py index 374911901..90dd73c06 100644 --- a/openml/setups/functions.py +++ b/openml/setups/functions.py @@ -14,7 +14,6 @@ import openml import openml.exceptions import openml.utils -from openml import config from openml.flows import OpenMLFlow, flow_exists from .setup import OpenMLParameter, OpenMLSetup @@ -84,7 +83,7 @@ def _get_cached_setup(setup_id: int) -> OpenMLSetup: OpenMLCacheException If the setup file for the given setup ID is not cached. """ - cache_dir = Path(config.get_cache_directory()) + cache_dir = Path(openml.config.get_cache_directory()) setup_cache_dir = cache_dir / "setups" / str(setup_id) try: setup_file = setup_cache_dir / "description.xml" @@ -112,7 +111,7 @@ def get_setup(setup_id: int) -> OpenMLSetup: ------- OpenMLSetup (an initialized openml setup object) """ - setup_dir = Path(config.get_cache_directory()) / "setups" / str(setup_id) + setup_dir = Path(openml.config.get_cache_directory()) / "setups" / str(setup_id) setup_dir.mkdir(exist_ok=True, parents=True) setup_file = setup_dir / "description.xml" diff --git a/openml/tasks/task.py b/openml/tasks/task.py index 395b52482..304bab544 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Any, Sequence from typing_extensions import TypedDict +import openml import openml._api_calls -import openml.config from openml import datasets from openml.base import OpenMLBase from openml.utils import _create_cache_directory_for_id diff --git a/openml/utils.py b/openml/utils.py index 7e72e7aee..f4a78fa44 100644 --- a/openml/utils.py +++ b/openml/utils.py @@ -18,8 +18,6 @@ import openml._api_calls import openml.exceptions -from . import config - # Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles if TYPE_CHECKING: from openml.base import OpenMLBase @@ -328,7 +326,7 @@ def _list_all( # noqa: C901 def _get_cache_dir_for_key(key: str) -> Path: - return Path(config.get_cache_directory()) / key + return Path(openml.config.get_cache_directory()) / key def _create_cache_directory(key: str) -> Path: @@ -428,7 +426,7 @@ def safe_func(*args: P.args, **kwargs: P.kwargs) -> R: def _create_lockfiles_dir() -> Path: - path = Path(config.get_cache_directory()) / "locks" + path = Path(openml.config.get_cache_directory()) / "locks" # TODO(eddiebergman): Not sure why this is allowed to error and ignore??? with contextlib.suppress(OSError): path.mkdir(exist_ok=True, parents=True) diff --git a/tests/test_openml/test_config.py b/tests/test_openml/test_config.py index 7ef223504..282838414 100644 --- a/tests/test_openml/test_config.py +++ b/tests/test_openml/test_config.py @@ -46,7 +46,7 @@ class TestConfig(openml.testing.TestBase): def test_non_writable_home(self, log_handler_mock, warnings_mock): with tempfile.TemporaryDirectory(dir=self.workdir) as td: os.chmod(td, 0o444) - _dd = copy(openml.config._defaults) + _dd = copy(openml.config.OpenMLConfig().__dict__) _dd["cachedir"] = Path(td) / "something-else" openml.config._setup(_dd) @@ -127,7 +127,6 @@ def test_switch_from_example_configuration(self): openml.config.start_using_configuration_for_example() openml.config.stop_using_configuration_for_example() - assert openml.config.apikey == TestBase.user_key assert openml.config.server == self.production_server