|
1 | 1 | import datetime |
2 | 2 | import warnings |
| 3 | +from dataclasses import dataclass |
| 4 | +from enum import Enum |
3 | 5 | from typing import Dict, Iterable, List, Optional, Set, Tuple, Union |
4 | 6 |
|
5 | 7 | import requests |
|
17 | 19 | ) |
18 | 20 |
|
19 | 21 |
|
| 22 | +class SliceBuilderMethods(str, Enum): |
| 23 | + """ |
| 24 | + Which method to use for sampling the dataset items. |
| 25 | + - Random: randomly select items |
| 26 | + - Uniqueness: Prioritizes more unique images based on model embedding distance, so that the final sample has fewer similar images. |
| 27 | + """ |
| 28 | + |
| 29 | + RANDOM = "Random" |
| 30 | + UNIQUENESS = "Uniqueness" |
| 31 | + |
| 32 | + def __contains__(self, item): |
| 33 | + try: |
| 34 | + self(item) |
| 35 | + except ValueError: |
| 36 | + return False |
| 37 | + return True |
| 38 | + |
| 39 | + @staticmethod |
| 40 | + def options(): |
| 41 | + return list(map(lambda c: c.value, SliceBuilderMethods)) |
| 42 | + |
| 43 | + |
| 44 | +@dataclass |
| 45 | +class SliceBuilderFilterAutotag: |
| 46 | + """ |
| 47 | + Helper class for specifying an autotag filter for building a slice. |
| 48 | +
|
| 49 | + Args: |
| 50 | + autotag_id: Filter items that belong to this autotag |
| 51 | + score_range: Specify the range of the autotag items' score that should be considered, between [-1, 1]. |
| 52 | + For example, [-0.3, 0.7]. |
| 53 | + """ |
| 54 | + |
| 55 | + autotag_id: str |
| 56 | + score_range: List[int] |
| 57 | + |
| 58 | + def __post_init__(self): |
| 59 | + warn_msg = f"Autotag score range must be within [-1, 1]. But got {self.score_range}." |
| 60 | + assert len(self.score_range) == 2, warn_msg |
| 61 | + assert ( |
| 62 | + min(self.score_range) >= -1 and max(self.score_range) <= 1 |
| 63 | + ), warn_msg |
| 64 | + |
| 65 | + |
| 66 | +@dataclass |
| 67 | +class SliceBuilderFilters: |
| 68 | + """ |
| 69 | + Optionally apply filters to the collection of dataset items when building the slice. |
| 70 | + Items can be filtered by an existing slice and/or an autotag. |
| 71 | +
|
| 72 | + Args: |
| 73 | + slice_id: Build the slice from items pertaining to this slice |
| 74 | + autotag: Build the slice from items pertaining to an autotag (see SliceBuilderFilterAutotag) |
| 75 | + """ |
| 76 | + |
| 77 | + slice_id: Optional[str] = None |
| 78 | + autotag: Optional[SliceBuilderFilterAutotag] = None |
| 79 | + |
| 80 | + |
20 | 81 | class Slice: |
21 | 82 | """A Slice represents a subset of DatasetItems in your Dataset. |
22 | 83 |
|
@@ -502,3 +563,50 @@ def check_annotations_are_in_slice( |
502 | 563 | annotations_are_in_slice, |
503 | 564 | reference_ids_not_found_in_slice, |
504 | 565 | ) |
| 566 | + |
| 567 | + |
| 568 | +def create_slice_builder_payload( |
| 569 | + name: str, |
| 570 | + sample_size: int, |
| 571 | + sample_method: Union[str, "SliceBuilderMethods"], |
| 572 | + filters: Optional["SliceBuilderFilters"], |
| 573 | +): |
| 574 | + """ |
| 575 | + Format the slice builder payload request from the dataclasses |
| 576 | + Args: |
| 577 | + name: Name for the slice being created |
| 578 | + sample_size: Number of items to sample |
| 579 | + sample_method: Method to use for sample the dataset items |
| 580 | + filters: Optional set of filters to apply when collecting the dataset items |
| 581 | +
|
| 582 | + Returns: |
| 583 | + A request friendly payload |
| 584 | + """ |
| 585 | + |
| 586 | + assert ( |
| 587 | + sample_method in SliceBuilderMethods |
| 588 | + ), f"Method ${sample_method} not available. Must be one of: {SliceBuilderMethods.options()}" |
| 589 | + |
| 590 | + # enum or string |
| 591 | + sampleMethod = ( |
| 592 | + sample_method.value |
| 593 | + if isinstance(sample_method, SliceBuilderMethods) |
| 594 | + else sample_method |
| 595 | + ) |
| 596 | + |
| 597 | + filter_payload: Dict[str, Union[str, dict]] = {} |
| 598 | + if filters is not None: |
| 599 | + if filters.slice_id is not None: |
| 600 | + filter_payload["sliceId"] = filters.slice_id |
| 601 | + if filters.autotag is not None: |
| 602 | + filter_payload["autotag"] = { |
| 603 | + "autotagId": filters.autotag.autotag_id, |
| 604 | + "range": filters.autotag.score_range, |
| 605 | + } |
| 606 | + |
| 607 | + return { |
| 608 | + "name": name, |
| 609 | + "sampleSize": sample_size, |
| 610 | + "sampleMethod": sampleMethod, |
| 611 | + "filters": filter_payload, |
| 612 | + } |
0 commit comments