Skip to content

Commit a669134

Browse files
committed
Update _lookup.py
1 parent ef37f55 commit a669134

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

pytorch_forecasting/_registry/_lookup.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
__author__ = ["fkiraly"]
1212
# all_objects is based on the sklearn utility all_estimators
1313

14+
from inspect import isclass
1415
from pathlib import Path
1516

1617
from skbase.lookup import all_objects as _all_objects
@@ -133,25 +134,39 @@ def all_objects(
133134
result = []
134135
ROOT = str(Path(__file__).parent.parent) # package root directory
135136

136-
if isinstance(filter_tags, str):
137-
filter_tags = {filter_tags: True}
138-
filter_tags = filter_tags.copy() if filter_tags else None
139-
140-
if object_types:
141-
if filter_tags and "object_type" not in filter_tags.keys():
142-
object_tag_filter = {"object_type": object_types}
143-
elif filter_tags:
144-
filter_tags_filter = filter_tags.get("object_type", [])
145-
if isinstance(object_types, str):
146-
object_types = [object_types]
147-
object_tag_update = {"object_type": object_types + filter_tags_filter}
148-
filter_tags.update(object_tag_update)
137+
def _coerce_to_str(obj):
138+
if isinstance(obj, (list, tuple)):
139+
return [_coerce_to_str(o) for o in obj]
140+
if isclass(obj):
141+
obj = obj.get_tag("object_type")
142+
return obj
143+
144+
def _coerce_to_list_of_str(obj):
145+
obj = _coerce_to_str(obj)
146+
if isinstance(obj, str):
147+
return [obj]
148+
return obj
149+
150+
if object_types is not None:
151+
object_types = _coerce_to_list_of_str(object_types)
152+
object_types = list(set(object_types))
153+
154+
if object_types is not None:
155+
if filter_tags is None:
156+
filter_tags = {}
157+
elif isinstance(filter_tags, str):
158+
filter_tags = {filter_tags: True}
149159
else:
150-
object_tag_filter = {"object_type": object_types}
151-
if filter_tags:
152-
filter_tags.update(object_tag_filter)
160+
filter_tags = filter_tags.copy()
161+
162+
if "object_type" in filter_tags:
163+
obj_field = filter_tags["object_type"]
164+
obj_field = _coerce_to_list_of_str(obj_field)
165+
obj_field = obj_field + object_types
153166
else:
154-
filter_tags = object_tag_filter
167+
obj_field = object_types
168+
169+
filter_tags["object_type"] = obj_field
155170

156171
result = _all_objects(
157172
object_types=[_BaseObject],

0 commit comments

Comments
 (0)