Skip to content

Commit b10eb12

Browse files
committed
move utills function to a separate dispathcer file
1 parent decd5e4 commit b10eb12

File tree

3 files changed

+115
-103
lines changed

3 files changed

+115
-103
lines changed

openml/__init__.py

Lines changed: 3 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
# License: BSD 3-Clause
1919
from __future__ import annotations
2020

21-
from typing import Any, Callable, Dict
22-
2321
from . import (
2422
_api_calls,
2523
config,
2624
datasets,
25+
dispatchers,
2726
evaluations,
2827
exceptions,
2928
extensions,
@@ -36,6 +35,7 @@
3635
)
3736
from .__version__ import __version__
3837
from .datasets import OpenMLDataFeature, OpenMLDataset
38+
from .dispatchers import get, list_all
3939
from .evaluations import OpenMLEvaluation
4040
from .flows import OpenMLFlow
4141
from .runs import OpenMLRun
@@ -51,102 +51,6 @@
5151
OpenMLTask,
5252
)
5353

54-
ListDispatcher = Dict[str, Callable[..., Any]]
55-
GetDispatcher = Dict[str, Callable[..., Any]]
56-
57-
_LIST_DISPATCH: ListDispatcher = {
58-
"dataset": datasets.functions.list_datasets,
59-
"task": tasks.functions.list_tasks,
60-
"flow": flows.functions.list_flows,
61-
"run": runs.functions.list_runs,
62-
}
63-
64-
_GET_DISPATCH: GetDispatcher = {
65-
"dataset": datasets.functions.get_dataset,
66-
"task": tasks.functions.get_task,
67-
"flow": flows.functions.get_flow,
68-
"run": runs.functions.get_run,
69-
}
70-
71-
72-
def list_all(object_type: str, /, **kwargs: Any) -> Any:
73-
"""List OpenML objects by type (e.g., datasets, tasks, flows, runs).
74-
75-
This is a convenience dispatcher that forwards to the existing type-specific
76-
``list_*`` functions. Existing imports remain available for backward compatibility.
77-
78-
Parameters
79-
----------
80-
object_type : str
81-
The type of object to list. Must be one of 'dataset', 'task', 'flow', 'run'.
82-
**kwargs : Any
83-
Additional arguments passed to the underlying list function.
84-
85-
Returns
86-
-------
87-
Any
88-
The result from the type-specific list function (typically a DataFrame).
89-
90-
Raises
91-
------
92-
ValueError
93-
If object_type is not one of the supported types.
94-
"""
95-
if not isinstance(object_type, str):
96-
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")
97-
98-
func = _LIST_DISPATCH.get(object_type.lower())
99-
if func is None:
100-
valid_types = ", ".join(repr(k) for k in _LIST_DISPATCH)
101-
raise ValueError(
102-
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
103-
)
104-
105-
return func(**kwargs)
106-
107-
108-
def get(identifier: int | str, *, object_type: str = "dataset", **kwargs: Any) -> Any:
109-
"""Get an OpenML object by identifier.
110-
111-
Parameters
112-
----------
113-
identifier : int | str
114-
The ID or name of the object to retrieve.
115-
object_type : str, default="dataset"
116-
The type of object to get. Must be one of 'dataset', 'task', 'flow', 'run'.
117-
**kwargs : Any
118-
Additional arguments passed to the underlying get function.
119-
120-
Returns
121-
-------
122-
Any
123-
The requested OpenML object.
124-
125-
Raises
126-
------
127-
ValueError
128-
If object_type is not one of the supported types.
129-
130-
Examples
131-
--------
132-
>>> openml.get(61) # Get dataset 61 (default object_type="dataset")
133-
>>> openml.get("Fashion-MNIST") # Get dataset by name
134-
>>> openml.get(31, object_type="task") # Get task 31
135-
>>> openml.get(10, object_type="flow") # Get flow 10
136-
>>> openml.get(20, object_type="run") # Get run 20
137-
"""
138-
if not isinstance(object_type, str):
139-
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")
140-
141-
func = _GET_DISPATCH.get(object_type.lower())
142-
if func is None:
143-
valid_types = ", ".join(repr(k) for k in _GET_DISPATCH)
144-
raise ValueError(
145-
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
146-
)
147-
148-
return func(identifier, **kwargs)
149-
15054

15155
def populate_cache(
15256
task_ids: list[int] | None = None,
@@ -206,6 +110,7 @@ def populate_cache(
206110
"OpenMLStudy",
207111
"OpenMLBenchmarkSuite",
208112
"datasets",
113+
"dispatchers",
209114
"evaluations",
210115
"exceptions",
211116
"extensions",

openml/dispatchers.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""OpenML API dispatchers for unified get/list operations."""
2+
3+
# License: BSD 3-Clause
4+
from __future__ import annotations
5+
6+
from typing import Any, Callable, Dict
7+
8+
from .datasets import get_dataset, list_datasets
9+
from .flows import get_flow, list_flows
10+
from .runs import get_run, list_runs
11+
from .tasks import get_task, list_tasks
12+
13+
ListDispatcher = Dict[str, Callable[..., Any]]
14+
GetDispatcher = Dict[str, Callable[..., Any]]
15+
16+
_LIST_DISPATCH: ListDispatcher = {
17+
"dataset": list_datasets,
18+
"task": list_tasks,
19+
"flow": list_flows,
20+
"run": list_runs,
21+
}
22+
23+
_GET_DISPATCH: GetDispatcher = {
24+
"dataset": get_dataset,
25+
"task": get_task,
26+
"flow": get_flow,
27+
"run": get_run,
28+
}
29+
30+
31+
def list_all(object_type: str, /, **kwargs: Any) -> Any:
32+
"""List OpenML objects by type (e.g., datasets, tasks, flows, runs).
33+
34+
This is a convenience dispatcher that forwards to the existing type-specific
35+
``list_*`` functions. Existing imports remain available for backward compatibility.
36+
37+
Parameters
38+
----------
39+
object_type : str
40+
The type of object to list. Must be one of 'dataset', 'task', 'flow', 'run'.
41+
**kwargs : Any
42+
Additional arguments passed to the underlying list function.
43+
44+
Returns
45+
-------
46+
Any
47+
The result from the type-specific list function (typically a DataFrame).
48+
49+
Raises
50+
------
51+
ValueError
52+
If object_type is not one of the supported types.
53+
"""
54+
if not isinstance(object_type, str):
55+
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")
56+
57+
func = _LIST_DISPATCH.get(object_type.lower())
58+
if func is None:
59+
valid_types = ", ".join(repr(k) for k in _LIST_DISPATCH)
60+
raise ValueError(
61+
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
62+
)
63+
64+
return func(**kwargs)
65+
66+
67+
def get(identifier: int | str, *, object_type: str = "dataset", **kwargs: Any) -> Any:
68+
"""Get an OpenML object by identifier.
69+
70+
Parameters
71+
----------
72+
identifier : int | str
73+
The ID or name of the object to retrieve.
74+
object_type : str, default="dataset"
75+
The type of object to get. Must be one of 'dataset', 'task', 'flow', 'run'.
76+
**kwargs : Any
77+
Additional arguments passed to the underlying get function.
78+
79+
Returns
80+
-------
81+
Any
82+
The requested OpenML object.
83+
84+
Raises
85+
------
86+
ValueError
87+
If object_type is not one of the supported types.
88+
89+
Examples
90+
--------
91+
>>> openml.get(61) # Get dataset 61 (default object_type="dataset")
92+
>>> openml.get("Fashion-MNIST") # Get dataset by name
93+
>>> openml.get(31, object_type="task") # Get task 31
94+
>>> openml.get(10, object_type="flow") # Get flow 10
95+
>>> openml.get(20, object_type="run") # Get run 20
96+
"""
97+
if not isinstance(object_type, str):
98+
raise TypeError(f"object_type must be a string, got {type(object_type).__name__}")
99+
100+
func = _GET_DISPATCH.get(object_type.lower())
101+
if func is None:
102+
valid_types = ", ".join(repr(k) for k in _GET_DISPATCH)
103+
raise ValueError(
104+
f"Unsupported object_type {object_type!r}; expected one of {valid_types}.",
105+
)
106+
107+
return func(identifier, **kwargs)

tests/test_openml/test_openml.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_populate_cache(
4747
def test_list_dispatch(self, list_datasets_mock, list_tasks_mock):
4848
# Need to patch after import, so update dispatch dict
4949
with mock.patch.dict(
50-
"openml._LIST_DISPATCH",
50+
"openml.dispatchers._LIST_DISPATCH",
5151
{
5252
"dataset": list_datasets_mock,
5353
"task": list_tasks_mock,
@@ -64,20 +64,20 @@ def test_list_dispatch(self, list_datasets_mock, list_tasks_mock):
6464
def test_get_dispatch(self, get_dataset_mock, get_task_mock):
6565
# Need to patch after import, so update dispatch dict
6666
with mock.patch.dict(
67-
"openml._GET_DISPATCH",
67+
"openml.dispatchers._GET_DISPATCH",
6868
{
6969
"dataset": get_dataset_mock,
7070
"task": get_task_mock,
7171
},
7272
):
73-
openml.get(61)
73+
openml.get(61)
7474
get_dataset_mock.assert_called_with(61)
7575

76-
openml.get("Fashion-MNIST", version=2)
76+
openml.get("Fashion-MNIST", version=2)
7777
get_dataset_mock.assert_called_with("Fashion-MNIST", version=2)
7878

7979
openml.get("Fashion-MNIST")
8080
get_dataset_mock.assert_called_with("Fashion-MNIST")
8181

82-
openml.get(31, object_type="task")
82+
openml.get(31, object_type="task")
8383
get_task_mock.assert_called_with(31)

0 commit comments

Comments
 (0)