|
18 | 18 | # License: BSD 3-Clause |
19 | 19 | from __future__ import annotations |
20 | 20 |
|
21 | | -from typing import Any, Callable, Dict |
22 | | - |
23 | 21 | from . import ( |
24 | 22 | _api_calls, |
25 | 23 | config, |
26 | 24 | datasets, |
| 25 | + dispatchers, |
27 | 26 | evaluations, |
28 | 27 | exceptions, |
29 | 28 | extensions, |
|
36 | 35 | ) |
37 | 36 | from .__version__ import __version__ |
38 | 37 | from .datasets import OpenMLDataFeature, OpenMLDataset |
| 38 | +from .dispatchers import get, list_all |
39 | 39 | from .evaluations import OpenMLEvaluation |
40 | 40 | from .flows import OpenMLFlow |
41 | 41 | from .runs import OpenMLRun |
|
51 | 51 | OpenMLTask, |
52 | 52 | ) |
53 | 53 |
|
54 | | -ListDispatcher = Dict[str, Callable[..., Any]] |
55 | | -GetDispatcher = Dict[str, Callable[..., Any]] |
56 | | - |
57 | | -_LIST_DISPATCH: ListDispatcher = { |
58 | | - "dataset": datasets.functions.list_datasets, |
59 | | - "task": tasks.functions.list_tasks, |
60 | | - "flow": flows.functions.list_flows, |
61 | | - "run": runs.functions.list_runs, |
62 | | -} |
63 | | - |
64 | | -_GET_DISPATCH: GetDispatcher = { |
65 | | - "dataset": datasets.functions.get_dataset, |
66 | | - "task": tasks.functions.get_task, |
67 | | - "flow": flows.functions.get_flow, |
68 | | - "run": runs.functions.get_run, |
69 | | -} |
70 | | - |
71 | | - |
72 | | -def list_all(object_type: str, /, **kwargs: Any) -> Any: |
73 | | - """List OpenML objects by type (e.g., datasets, tasks, flows, runs). |
74 | | -
|
75 | | - This is a convenience dispatcher that forwards to the existing type-specific |
76 | | - ``list_*`` functions. Existing imports remain available for backward compatibility. |
77 | | -
|
78 | | - Parameters |
79 | | - ---------- |
80 | | - object_type : str |
81 | | - The type of object to list. Must be one of 'dataset', 'task', 'flow', 'run'. |
82 | | - **kwargs : Any |
83 | | - Additional arguments passed to the underlying list function. |
84 | | -
|
85 | | - Returns |
86 | | - ------- |
87 | | - Any |
88 | | - The result from the type-specific list function (typically a DataFrame). |
89 | | -
|
90 | | - Raises |
91 | | - ------ |
92 | | - ValueError |
93 | | - If object_type is not one of the supported types. |
94 | | - """ |
95 | | - if not isinstance(object_type, str): |
96 | | - raise TypeError(f"object_type must be a string, got {type(object_type).__name__}") |
97 | | - |
98 | | - func = _LIST_DISPATCH.get(object_type.lower()) |
99 | | - if func is None: |
100 | | - valid_types = ", ".join(repr(k) for k in _LIST_DISPATCH) |
101 | | - raise ValueError( |
102 | | - f"Unsupported object_type {object_type!r}; expected one of {valid_types}.", |
103 | | - ) |
104 | | - |
105 | | - return func(**kwargs) |
106 | | - |
107 | | - |
108 | | -def get(identifier: int | str, *, object_type: str = "dataset", **kwargs: Any) -> Any: |
109 | | - """Get an OpenML object by identifier. |
110 | | -
|
111 | | - Parameters |
112 | | - ---------- |
113 | | - identifier : int | str |
114 | | - The ID or name of the object to retrieve. |
115 | | - object_type : str, default="dataset" |
116 | | - The type of object to get. Must be one of 'dataset', 'task', 'flow', 'run'. |
117 | | - **kwargs : Any |
118 | | - Additional arguments passed to the underlying get function. |
119 | | -
|
120 | | - Returns |
121 | | - ------- |
122 | | - Any |
123 | | - The requested OpenML object. |
124 | | -
|
125 | | - Raises |
126 | | - ------ |
127 | | - ValueError |
128 | | - If object_type is not one of the supported types. |
129 | | -
|
130 | | - Examples |
131 | | - -------- |
132 | | - >>> openml.get(61) # Get dataset 61 (default object_type="dataset") |
133 | | - >>> openml.get("Fashion-MNIST") # Get dataset by name |
134 | | - >>> openml.get(31, object_type="task") # Get task 31 |
135 | | - >>> openml.get(10, object_type="flow") # Get flow 10 |
136 | | - >>> openml.get(20, object_type="run") # Get run 20 |
137 | | - """ |
138 | | - if not isinstance(object_type, str): |
139 | | - raise TypeError(f"object_type must be a string, got {type(object_type).__name__}") |
140 | | - |
141 | | - func = _GET_DISPATCH.get(object_type.lower()) |
142 | | - if func is None: |
143 | | - valid_types = ", ".join(repr(k) for k in _GET_DISPATCH) |
144 | | - raise ValueError( |
145 | | - f"Unsupported object_type {object_type!r}; expected one of {valid_types}.", |
146 | | - ) |
147 | | - |
148 | | - return func(identifier, **kwargs) |
149 | | - |
150 | 54 |
|
151 | 55 | def populate_cache( |
152 | 56 | task_ids: list[int] | None = None, |
@@ -206,6 +110,7 @@ def populate_cache( |
206 | 110 | "OpenMLStudy", |
207 | 111 | "OpenMLBenchmarkSuite", |
208 | 112 | "datasets", |
| 113 | + "dispatchers", |
209 | 114 | "evaluations", |
210 | 115 | "exceptions", |
211 | 116 | "extensions", |
|
0 commit comments