Skip to content

Commit d78bf5d

Browse files
committed
Update _lookup.py
1 parent a669134 commit d78bf5d

File tree

1 file changed

+60
-35
lines changed

1 file changed

+60
-35
lines changed

pytorch_forecasting/_registry/_lookup.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,44 +40,64 @@ def all_objects(
4040
----------
4141
object_types: str, list of str, optional (default=None)
4242
Which kind of objects should be returned.
43-
if None, no filter is applied and all objects are returned.
44-
if str or list of str, strings define scitypes specified in search
45-
only objects that are of (at least) one of the scitypes are returned
46-
possible str values are entries of registry.BASE_CLASS_REGISTER (first col)
47-
for instance 'regrssor_proba', 'distribution, 'metric'
4843
49-
return_names: bool, optional (default=True)
44+
* if None, no filter is applied and all objects are returned.
45+
* if str or list of str, strings define scitypes specified in search
46+
only objects that are of (at least) one of the scitypes are returned
5047
51-
if True, estimator class name is included in the ``all_objects``
52-
return in the order: name, estimator class, optional tags, either as
53-
a tuple or as pandas.DataFrame columns
48+
return_names: bool, optional (default=True)
5449
55-
if False, estimator class name is removed from the ``all_objects`` return.
50+
* if True, estimator class name is included in the ``all_objects``
51+
return in the order: name, estimator class, optional tags, either as
52+
a tuple or as pandas.DataFrame columns
53+
* if False, estimator class name is removed from the ``all_objects`` return.
5654
57-
filter_tags: dict of (str or list of str), optional (default=None)
55+
filter_tags: dict of (str or list of str or re.Pattern), optional (default=None)
5856
For a list of valid tag strings, use the registry.all_tags utility.
5957
60-
``filter_tags`` subsets the returned estimators as follows:
58+
``filter_tags`` subsets the returned objects as follows:
6159
6260
* each key/value pair is statement in "and"/conjunction
6361
* key is tag name to sub-set on
6462
* value str or list of string are tag values
6563
* condition is "key must be equal to value, or in set(value)"
6664
67-
exclude_estimators: str, list of str, optional (default=None)
68-
Names of estimators to exclude.
65+
In detail, he return will be filtered to keep exactly the classes
66+
where tags satisfy all the filter conditions specified by ``filter_tags``.
67+
Filter conditions are as follows, for ``tag_name: search_value`` pairs in
68+
the ``filter_tags`` dict, applied to a class ``klass``:
69+
70+
- If ``klass`` does not have a tag with name ``tag_name``, it is excluded.
71+
Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``.
72+
- If ``search_value`` is a string, and ``tag_value`` is a string,
73+
the filter condition is that ``search_value`` must match the tag value.
74+
- If ``search_value`` is a string, and ``tag_value`` is a list,
75+
the filter condition is that ``search_value`` is contained in ``tag_value``.
76+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string,
77+
the filter condition is that ``search_value.fullmatch(tag_value)``
78+
is true, i.e., the regex matches the tag value.
79+
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list,
80+
the filter condition is that at least one element of ``tag_value``
81+
matches the regex.
82+
- If ``search_value`` is iterable, then the filter condition is that
83+
at least one element of ``search_value`` satisfies the above conditions,
84+
applied to ``tag_value``.
85+
86+
Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0.
87+
88+
exclude_objects: str, list of str, optional (default=None)
89+
Names of objects to exclude.
6990
7091
as_dataframe: bool, optional (default=False)
7192
72-
True: ``all_objects`` will return a pandas.DataFrame with named
73-
columns for all of the attributes being returned.
74-
75-
False: ``all_objects`` will return a list (either a list of
76-
estimators or a list of tuples, see Returns)
93+
* True: ``all_objects`` will return a ``pandas.DataFrame`` with named
94+
columns for all of the attributes being returned.
95+
* False: ``all_objects`` will return a list (either a list of
96+
objects or a list of tuples, see Returns)
7797
7898
return_tags: str or list of str, optional (default=None)
7999
Names of tags to fetch and return each estimator's value of.
80-
For a list of valid tag strings, use the registry.all_tags utility.
100+
For a list of valid tag strings, use the ``registry.all_tags`` utility.
81101
if str or list of str,
82102
the tag values named in return_tags will be fetched for each
83103
estimator and will be appended as either columns or tuple entries.
@@ -88,27 +108,32 @@ def all_objects(
88108
Returns
89109
-------
90110
all_objects will return one of the following:
91-
1. list of objects, if return_names=False, and return_tags is None
92-
2. list of tuples (optional object name, class, ~optional object
93-
tags), if return_names=True or return_tags is not None.
94-
3. pandas.DataFrame if as_dataframe = True
111+
112+
1. list of objects, if ``return_names=False``, and ``return_tags`` is None
113+
114+
2. list of tuples (optional estimator name, class, optional estimator
115+
tags), if ``return_names=True`` or ``return_tags`` is not ``None``.
116+
117+
3. ``pandas.DataFrame`` if ``as_dataframe = True``
118+
95119
if list of objects:
96120
entries are objects matching the query,
97-
in alphabetical order of object name
121+
in alphabetical order of estimator name
122+
98123
if list of tuples:
99-
list of (optional object name, object, optional object
100-
tags) matching the query, in alphabetical order of object name,
124+
list of (optional estimator name, estimator, optional estimator
125+
tags) matching the query, in alphabetical order of estimator name,
101126
where
102-
``name`` is the object name as string, and is an
103-
optional return
104-
``object`` is the actual object
105-
``tags`` are the object's values for each tag in return_tags
106-
and is an optional return.
107-
if dataframe:
108-
all_objects will return a pandas.DataFrame.
127+
``name`` is the estimator name as string, and is an
128+
optional return
129+
``estimator`` is the actual estimator
130+
``tags`` are the estimator's values for each tag in return_tags
131+
and is an optional return.
132+
133+
if ``DataFrame``:
109134
column names represent the attributes contained in each column.
110135
"objects" will be the name of the column of objects, "names"
111-
will be the name of the column of object class names and the string(s)
136+
will be the name of the column of estimator class names and the string(s)
112137
passed in return_tags will serve as column names for all columns of
113138
tags that were optionally requested.
114139

0 commit comments

Comments
 (0)