|
1 | 1 | from typing import List |
2 | 2 | from typing import Optional |
| 3 | +from typing import Union |
| 4 | + |
| 5 | +from pydantic import Field |
| 6 | +from pydantic import StrictStr |
| 7 | +from pydantic import ValidationError |
| 8 | +from pydantic.error_wrappers import ErrorWrapper |
3 | 9 |
|
4 | 10 | from superannotate_schemas.schemas.base import BaseAttribute |
5 | | -from superannotate_schemas.schemas.base import BaseInstance |
| 11 | +from superannotate_schemas.schemas.base import BaseDocumentInstance |
6 | 12 | from superannotate_schemas.schemas.base import BaseMetadata as Metadata |
7 | | -from superannotate_schemas.schemas.base import Tag |
8 | | -from superannotate_schemas.schemas.base import NotEmptyStr |
9 | | - |
10 | 13 | from superannotate_schemas.schemas.base import BaseModel |
11 | | -from pydantic import Field |
12 | | -from pydantic import StrictStr |
| 14 | +from superannotate_schemas.schemas.base import INVALID_DICT_MESSAGE |
| 15 | +from superannotate_schemas.schemas.base import NotEmptyStr |
| 16 | +from superannotate_schemas.schemas.base import Tag |
| 17 | +from superannotate_schemas.schemas.enums import DocumentAnnotationTypeEnum |
13 | 18 |
|
14 | 19 |
|
15 | 20 | class Attribute(BaseAttribute): |
16 | 21 | name: NotEmptyStr |
17 | 22 | group_name: NotEmptyStr = Field(alias="groupName") |
18 | 23 |
|
19 | 24 |
|
20 | | -class DocumentInstance(BaseInstance): |
| 25 | +class EntityInstance(BaseDocumentInstance): |
21 | 26 | start: int |
22 | 27 | end: int |
23 | 28 | attributes: Optional[List[Attribute]] = Field(list()) |
24 | 29 |
|
25 | 30 |
|
| 31 | +class TagInstance(BaseDocumentInstance): |
| 32 | + attributes: Optional[List[Attribute]] = Field(list()) |
| 33 | + class_name: NotEmptyStr = Field(alias="className") |
| 34 | + |
| 35 | + |
| 36 | +class DocumentInstance(BaseDocumentInstance): |
| 37 | + pass |
| 38 | + |
| 39 | + |
| 40 | +ANNOTATION_TYPES = { |
| 41 | + DocumentAnnotationTypeEnum.ENTITY: EntityInstance, |
| 42 | + DocumentAnnotationTypeEnum.TAG: TagInstance, |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +class AnnotationInstance(BaseModel): |
| 47 | + __root__: Union[TagInstance, EntityInstance] |
| 48 | + |
| 49 | + @classmethod |
| 50 | + def __get_validators__(cls): |
| 51 | + yield cls.return_action |
| 52 | + |
| 53 | + @classmethod |
| 54 | + def return_action(cls, values): |
| 55 | + try: |
| 56 | + try: |
| 57 | + instance_type = values["type"] |
| 58 | + except KeyError: |
| 59 | + raise ValidationError( |
| 60 | + [ErrorWrapper(ValueError("field required"), "type")], cls |
| 61 | + ) |
| 62 | + return ANNOTATION_TYPES[instance_type](**values) |
| 63 | + except KeyError: |
| 64 | + raise ValidationError( |
| 65 | + [ |
| 66 | + ErrorWrapper( |
| 67 | + ValueError( |
| 68 | + f"invalid type, valid types are {', '.join(ANNOTATION_TYPES.keys())}" |
| 69 | + ), |
| 70 | + "type", |
| 71 | + ) |
| 72 | + ], |
| 73 | + cls, |
| 74 | + ) |
| 75 | + except TypeError as e: |
| 76 | + raise TypeError(INVALID_DICT_MESSAGE) from e |
| 77 | + |
| 78 | + |
26 | 79 | class DocumentAnnotation(BaseModel): |
27 | 80 | metadata: Metadata |
28 | | - instances: Optional[List[DocumentInstance]] = Field(list()) |
| 81 | + instances: Optional[List[AnnotationInstance]] = Field(list()) |
29 | 82 | tags: Optional[List[Tag]] = Field(list()) |
30 | 83 | free_text: Optional[StrictStr] = Field(None, alias="freeText") |
0 commit comments