|
11 | 11 | __author__ = ["fkiraly"] |
12 | 12 | # all_objects is based on the sklearn utility all_estimators |
13 | 13 |
|
| 14 | +from inspect import isclass |
14 | 15 | from pathlib import Path |
15 | 16 |
|
16 | 17 | from skbase.lookup import all_objects as _all_objects |
@@ -133,25 +134,39 @@ def all_objects( |
133 | 134 | result = [] |
134 | 135 | ROOT = str(Path(__file__).parent.parent) # package root directory |
135 | 136 |
|
136 | | - if isinstance(filter_tags, str): |
137 | | - filter_tags = {filter_tags: True} |
138 | | - filter_tags = filter_tags.copy() if filter_tags else None |
139 | | - |
140 | | - if object_types: |
141 | | - if filter_tags and "object_type" not in filter_tags.keys(): |
142 | | - object_tag_filter = {"object_type": object_types} |
143 | | - elif filter_tags: |
144 | | - filter_tags_filter = filter_tags.get("object_type", []) |
145 | | - if isinstance(object_types, str): |
146 | | - object_types = [object_types] |
147 | | - object_tag_update = {"object_type": object_types + filter_tags_filter} |
148 | | - filter_tags.update(object_tag_update) |
| 137 | + def _coerce_to_str(obj): |
| 138 | + if isinstance(obj, (list, tuple)): |
| 139 | + return [_coerce_to_str(o) for o in obj] |
| 140 | + if isclass(obj): |
| 141 | + obj = obj.get_tag("object_type") |
| 142 | + return obj |
| 143 | + |
| 144 | + def _coerce_to_list_of_str(obj): |
| 145 | + obj = _coerce_to_str(obj) |
| 146 | + if isinstance(obj, str): |
| 147 | + return [obj] |
| 148 | + return obj |
| 149 | + |
| 150 | + if object_types is not None: |
| 151 | + object_types = _coerce_to_list_of_str(object_types) |
| 152 | + object_types = list(set(object_types)) |
| 153 | + |
| 154 | + if object_types is not None: |
| 155 | + if filter_tags is None: |
| 156 | + filter_tags = {} |
| 157 | + elif isinstance(filter_tags, str): |
| 158 | + filter_tags = {filter_tags: True} |
149 | 159 | else: |
150 | | - object_tag_filter = {"object_type": object_types} |
151 | | - if filter_tags: |
152 | | - filter_tags.update(object_tag_filter) |
| 160 | + filter_tags = filter_tags.copy() |
| 161 | + |
| 162 | + if "object_type" in filter_tags: |
| 163 | + obj_field = filter_tags["object_type"] |
| 164 | + obj_field = _coerce_to_list_of_str(obj_field) |
| 165 | + obj_field = obj_field + object_types |
153 | 166 | else: |
154 | | - filter_tags = object_tag_filter |
| 167 | + obj_field = object_types |
| 168 | + |
| 169 | + filter_tags["object_type"] = obj_field |
155 | 170 |
|
156 | 171 | result = _all_objects( |
157 | 172 | object_types=[_BaseObject], |
|
0 commit comments