Skip to content

Commit fae81f9

Browse files
committed
Added category handling
tod
1 parent 1153d0b commit fae81f9

File tree

8 files changed

+119
-25
lines changed

8 files changed

+119
-25
lines changed

src/superannotate/lib/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def setup_logging(level=DEFAULT_LOGGING_LEVEL, file_path=LOG_FILE_LOCATION):
179179
"Tokenization",
180180
"ImageAutoAssignEnable",
181181
"TemplateState",
182+
"CategorizeItems",
182183
]
183184

184185
__alL__ = (

src/superannotate/lib/core/service_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ class ProjectResponse(ServiceResponse):
226226
res_data: entities.ProjectEntity = None
227227

228228

229+
class ListCategoryResponse(ServiceResponse):
230+
res_data: List[entities.CategoryEntity] = None
231+
232+
229233
class WorkflowResponse(ServiceResponse):
230234
res_data: entities.WorkflowEntity = None
231235

src/superannotate/lib/core/serviceproviders.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
from lib.core import entities
1111
from lib.core.conditions import Condition
12-
from lib.core.entities import CategoryEntity
1312
from lib.core.jsx_conditions import Query
1413
from lib.core.reporter import Reporter
1514
from lib.core.service_types import AnnotationClassListResponse
1615
from lib.core.service_types import FolderListResponse
1716
from lib.core.service_types import FolderResponse
1817
from lib.core.service_types import IntegrationListResponse
18+
from lib.core.service_types import ListCategoryResponse
1919
from lib.core.service_types import ProjectListResponse
2020
from lib.core.service_types import ProjectResponse
2121
from lib.core.service_types import ServiceResponse
@@ -89,13 +89,13 @@ def list_workflow_roles(self, project_id: int, workflow_id: int):
8989
raise NotImplementedError
9090

9191
@abstractmethod
92-
def list_project_categories(self, project_id: int) -> List[entities.CategoryEntity]:
92+
def list_project_categories(self, project_id: int) -> ListCategoryResponse:
9393
raise NotImplementedError
9494

9595
@abstractmethod
9696
def create_project_categories(
97-
self, project_id: int, categories: List[CategoryEntity]
98-
):
97+
self, project_id: int, categories: List[str]
98+
) -> ServiceResponse:
9999
raise NotImplementedError
100100

101101

@@ -362,7 +362,9 @@ def delete_multiple(
362362
raise NotImplementedError
363363

364364
@abstractmethod
365-
def bulk_attach_categories(self, project_id: int, folder_id: int, item_category_map: Dict[int, int]) -> bool:
365+
def bulk_attach_categories(
366+
self, project_id: int, folder_id: int, item_category_map: Dict[int, int]
367+
) -> bool:
366368
raise NotImplementedError
367369

368370

src/superannotate/lib/core/usecases/annotations.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import time
1111
import traceback
1212
from collections import defaultdict
13+
from contextlib import suppress
1314
from dataclasses import dataclass
1415
from itertools import islice
1516
from operator import itemgetter
@@ -1846,6 +1847,7 @@ def __init__(
18461847
self._transform_version = (
18471848
"llmJsonV2" if transform_version is None else transform_version
18481849
)
1850+
self._category_name_to_id_map = {}
18491851

18501852
@property
18511853
def files_queue(self):
@@ -1994,11 +1996,12 @@ def serialize_folder_name(val):
19941996

19951997
def execute(self):
19961998
if self.is_valid():
1999+
# TODO check categories status in the project
2000+
skip_categorization = False
19972001
serialized_original_folder_map = {}
19982002
failed, skipped, uploaded = [], [], []
1999-
distributed_items: Dict[str, Dict[str, Any]] = defaultdict(
2000-
dict
2001-
) # folder_id -> item_name -> annotation
2003+
# folder_id -> item_name -> annotation
2004+
distributed_items: Dict[str, Dict[str, Any]] = defaultdict(dict)
20022005
valid_items_count = 0
20032006
for annotation in self._annotations:
20042007
if self._validate_json(annotation):
@@ -2068,6 +2071,17 @@ def execute(self):
20682071
{i.item.name for i in items_to_upload}
20692072
- set(failed_annotations).union(skipped)
20702073
)
2074+
if not skip_categorization:
2075+
item_id_category_map = {}
2076+
for item_name in uploaded_annotations:
2077+
category = name_annotation_map[item_name]["metadata"].get(
2078+
"item_category"
2079+
)
2080+
if category:
2081+
item_id_category_map[name_item_map[item_name].id] = category
2082+
self._attach_categories(
2083+
folder_id=folder.id, item_id_category_map=item_id_category_map
2084+
)
20712085
workflow = self._service_provider.work_management.get_workflow(
20722086
self._project.workflow_id
20732087
)
@@ -2096,3 +2110,44 @@ def execute(self):
20962110
"skipped": skipped,
20972111
}
20982112
return self._response
2113+
2114+
def _attach_categories(self, folder_id: int, item_id_category_map: Dict[int, str]):
2115+
categories_to_create: List[str] = []
2116+
item_id_category_id_map: Dict[int, int] = {}
2117+
if not self._category_name_to_id_map:
2118+
response = self._service_provider.work_management.list_project_categories(
2119+
self._project.id
2120+
)
2121+
response.raise_for_status()
2122+
categories = response.data
2123+
self._category_name_to_id_map = {c.name: c.id for c in categories}
2124+
for item_id in list(item_id_category_map.keys()):
2125+
category_name = item_id_category_map[item_id]
2126+
if category_name not in self._category_name_to_id_map:
2127+
categories_to_create.append(category_name)
2128+
else:
2129+
item_id_category_id_map[item_id] = self._category_name_to_id_map[
2130+
category_name
2131+
]
2132+
item_id_category_map.pop(item_id)
2133+
2134+
if categories_to_create:
2135+
_categories = (
2136+
self._service_provider.work_management.create_project_categories(
2137+
project_id=self._project.id,
2138+
categories=categories_to_create,
2139+
).data["data"]
2140+
)
2141+
for c in _categories:
2142+
self._category_name_to_id_map[c["name"]] = c["id"]
2143+
for item_id, category_name in item_id_category_map.items():
2144+
with suppress(KeyError):
2145+
item_id_category_id_map[item_id] = self._category_name_to_id_map[
2146+
category_name
2147+
]
2148+
if item_id_category_id_map:
2149+
self._service_provider.items.bulk_attach_categories(
2150+
project_id=self._project.id,
2151+
folder_id=folder_id,
2152+
item_category_map=item_id_category_id_map,
2153+
)

src/superannotate/lib/infrastructure/services/item.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ItemService(BaseItemService):
2121
URL_DELETE_ITEMS = "image/delete/images"
2222
URL_SET_APPROVAL_STATUSES = "/items/bulk/change"
2323
URL_COPY_MOVE_MULTIPLE = "images/copy-move-images-folders"
24-
URL_ATTACH_CATEGORIES #
24+
URL_ATTACH_CATEGORIES = "items/bulk/setcategory"
2525

2626
def update(self, project: entities.ProjectEntity, item: entities.BaseItemEntity):
2727
return self.client.request(
@@ -214,12 +214,20 @@ def delete_multiple(self, project: entities.ProjectEntity, item_ids: List[int]):
214214
data={"image_ids": item_ids},
215215
)
216216

217-
218-
def bulk_attach_categories(self, project_id: int, folder_id: int, item_category_map: Dict[int, int]) -> bool:
219-
params = {
220-
"project_id": project_id,
221-
"folder_id": folder_id
222-
}
217+
def bulk_attach_categories(
218+
self, project_id: int, folder_id: int, item_category_map: Dict[int, int]
219+
) -> bool:
220+
params = {"project_id": project_id, "folder_id": folder_id}
223221
response = self.client.request(
224-
self.
225-
)
222+
self.URL_ATTACH_CATEGORIES,
223+
"post",
224+
params=params,
225+
data={
226+
"bulk": [
227+
{"item_id": item_id, "categories": [category]}
228+
for item_id, category in item_category_map.items()
229+
]
230+
},
231+
)
232+
response.raise_for_status()
233+
return response.ok

src/superannotate/lib/infrastructure/services/work_management.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from lib.core.jsx_conditions import Filter
99
from lib.core.jsx_conditions import OperatorEnum
1010
from lib.core.jsx_conditions import Query
11+
from lib.core.service_types import ListCategoryResponse
12+
from lib.core.service_types import ServiceResponse
1113
from lib.core.serviceproviders import BaseWorkManagementService
1214

1315

@@ -26,7 +28,7 @@ def _generate_context(**kwargs):
2628
encoded_context = base64.b64encode(json.dumps(kwargs).encode("utf-8"))
2729
return encoded_context.decode("utf-8")
2830

29-
def list_project_categories(self, project_id: int) -> List[CategoryEntity]:
31+
def list_project_categories(self, project_id: int) -> ListCategoryResponse:
3032
return self.client.paginate(
3133
self.URL_LIST_CATEGORIES,
3234
item_type=CategoryEntity,
@@ -39,15 +41,21 @@ def list_project_categories(self, project_id: int) -> List[CategoryEntity]:
3941
)
4042

4143
def create_project_categories(
42-
self, project_id: int, categories: List[CategoryEntity]
43-
):
44+
self, project_id: int, categories: List[str]
45+
) -> ServiceResponse:
4446
response = self.client.request(
45-
"post",
47+
method="post",
4648
url=self.URL_CREATE_CATEGORIES,
47-
pararms={"project_id": project_id},
48-
data={"bulk": [i["name"] for i in categories]},
49+
params={"project_id": project_id},
50+
data={"bulk": [{"name": i} for i in categories]},
51+
headers={
52+
"x-sa-entity-context": self._generate_context(
53+
team_id=self.client.team_id, project_id=project_id
54+
),
55+
},
4956
)
5057
response.raise_for_status()
58+
return response
5159

5260
def get_workflow(self, pk: int) -> WorkflowEntity:
5361
response = self.list_workflows(Filter("id", pk, OperatorEnum.EQ))
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{"metadata": {"folder_name": "test_folder", "name": "0123456789101112", "item_category": "category1"}, ",data": {"component_id_0": {"value": ["Partially complete, needs review"]}, "component_id_1": {"value": "I am a text input 001"}, "component_id_2": {"value": 11}}}
2+
{"metadata": {"folder_name": " test_Folder", "name": "item_002", "item_category": "category2"}, "data": {"component_id_0": {"value": ["Partially complete, needs review", "Incomplete"]}, "component_id_1": {"value": "I am a text input 002"}, "component_id_2": {"value": 33}}}
3+
{"metadata": {"folder_name": "Test_Folder ", "name": "item_003", "item_category": "category3"}, "data": {"component_id_0": {"value": ["Partially complete, needs review", "Incomplete"]}, "component_id_1": {"value": "I am a text input 003"}, "component_id_2": {"value": 33}}}

tests/integration/annotations/test_upload_annotations.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ class MultiModalUploadAnnotations(BaseTestCase):
143143
JSONL_ANNOTATIONS_PATH = os.path.join(
144144
DATA_SET_PATH, "multimodal/annotations/jsonl/form1.jsonl"
145145
)
146+
JSONL_ANNOTATIONS_WITH_CATEGORIES_PATH = os.path.join(
147+
DATA_SET_PATH, "multimodal/annotations/jsonl/form1_with_categories.jsonl"
148+
)
146149
CLASSES_TEMPLATE_PATH = os.path.join(
147150
Path(__file__).parent.parent.parent,
148151
"data_set/editor_templates/from1_classes.json",
@@ -154,10 +157,13 @@ def setUp(self, *args, **kwargs):
154157
self.PROJECT_NAME,
155158
self.PROJECT_DESCRIPTION,
156159
"Multimodal",
157-
settings=[{"attribute": "TemplateState", "value": 1}],
160+
settings=[
161+
{"attribute": "CategorizeItems", "value": 1},
162+
{"attribute": "TemplateState", "value": 1},
163+
],
158164
)
159165
project = sa.controller.get_project(self.PROJECT_NAME)
160-
time.sleep(5)
166+
time.sleep(2)
161167
with open(self.EDITOR_TEMPLATE_PATH) as f:
162168
res = sa.controller.service_provider.projects.attach_editor_template(
163169
sa.controller.team, project, template=json.load(f)
@@ -195,3 +201,10 @@ def test_error_upload_from_folder_to_folder_(self):
195201
"You can't include a folder when uploading from within a folder.",
196202
):
197203
sa.upload_annotations(f"{self.PROJECT_NAME}/tmp", annotations=data)
204+
205+
def test_upload_with_categories(self):
206+
with open(self.JSONL_ANNOTATIONS_WITH_CATEGORIES_PATH) as f:
207+
data = [json.loads(line) for line in f]
208+
sa.upload_annotations(f"{self.PROJECT_NAME}", annotations=data)
209+
annotations = sa.get_annotations(f"{self.PROJECT_NAME}/test_folder")
210+
assert len(annotations) == 3

0 commit comments

Comments
 (0)