|
| 1 | +import copy |
1 | 2 | import enum |
2 | 3 | import functools |
3 | 4 | import logging |
|
7 | 8 | Iterable, |
8 | 9 | List, |
9 | 10 | NamedTuple, |
| 11 | + Optional, |
10 | 12 | Sequence, |
11 | 13 | Set, |
12 | 14 | Tuple, |
13 | 15 | Union, |
14 | 16 | ) |
15 | 17 |
|
| 18 | +from rich.console import Console |
| 19 | +from rich.table import Table |
| 20 | + |
16 | 21 | from nucleus.annotation import ( |
| 22 | + AnnotationList, |
17 | 23 | BoxAnnotation, |
18 | 24 | CategoryAnnotation, |
19 | 25 | CuboidAnnotation, |
|
29 | 35 | CuboidPrediction, |
30 | 36 | LinePrediction, |
31 | 37 | PolygonPrediction, |
| 38 | + PredictionList, |
32 | 39 | SegmentationPrediction, |
33 | 40 | ) |
34 | 41 |
|
@@ -568,3 +575,147 @@ def ensureDNFFilters(filters) -> OrAndDNFFilters: |
568 | 575 | formatted_filter.append(and_chain) |
569 | 576 | filters = formatted_filter |
570 | 577 | return filters |
| 578 | + |
| 579 | + |
| 580 | +def pretty_format_filters_with_or_and( |
| 581 | + filters: Optional[Union[ListOfOrAndFilters, ListOfAndFilters]] |
| 582 | +): |
| 583 | + if filters is None: |
| 584 | + return "No filters applied!" |
| 585 | + dnf_filters = ensureDNFFilters(filters) |
| 586 | + or_branches = [] |
| 587 | + for or_branch in dnf_filters: |
| 588 | + and_statements = [] |
| 589 | + for and_branch in or_branch: |
| 590 | + if and_branch.type == FilterType.FIELD: |
| 591 | + class_name = "FieldFilter" |
| 592 | + elif and_branch.type == FilterType.METADATA: |
| 593 | + class_name = "MetadataFilter" |
| 594 | + elif and_branch.type == FilterType.SEGMENT_FIELD: |
| 595 | + class_name = "SegmentFieldFilter" |
| 596 | + elif and_branch.type == FilterType.SEGMENT_METADATA: |
| 597 | + class_name = "SegmentMetadataFilter" |
| 598 | + else: |
| 599 | + raise RuntimeError( |
| 600 | + f"Un-handled filter type: {and_branch.type}" |
| 601 | + ) |
| 602 | + op = ( |
| 603 | + and_branch.op.value |
| 604 | + if isinstance(and_branch.op, FilterOp) |
| 605 | + else and_branch.op |
| 606 | + ) |
| 607 | + value_formatted = ( |
| 608 | + f'"{and_branch.value}"' |
| 609 | + if isinstance(and_branch.value, str) |
| 610 | + else f"{and_branch.value}".replace("'", '"') |
| 611 | + ) |
| 612 | + statement = ( |
| 613 | + f'{class_name}("{and_branch.key}", "{op}", {value_formatted})' |
| 614 | + ) |
| 615 | + and_statements.append(statement) |
| 616 | + |
| 617 | + or_branches.append(and_statements) |
| 618 | + |
| 619 | + and_to_join = [] |
| 620 | + for and_statements in or_branches: |
| 621 | + joined_and = " and ".join(and_statements) |
| 622 | + if len(or_branches) > 1 and len(and_statements) > 1: |
| 623 | + joined_and = "(" + joined_and + ")" |
| 624 | + and_to_join.append(joined_and) |
| 625 | + |
| 626 | + full_statement = " or ".join(and_to_join) |
| 627 | + return full_statement |
| 628 | + |
| 629 | + |
| 630 | +def compose_helpful_filtering_error( |
| 631 | + ann_or_pred_list: Union[AnnotationList, PredictionList], filters |
| 632 | +) -> List[str]: |
| 633 | + prefix = ( |
| 634 | + "Annotations" |
| 635 | + if isinstance(ann_or_pred_list, AnnotationList) |
| 636 | + else "Predictions" |
| 637 | + ) |
| 638 | + msg = [] |
| 639 | + msg.append(f"{prefix}: All items filtered out by:") |
| 640 | + msg.append(f" {pretty_format_filters_with_or_and(filters)}") |
| 641 | + msg.append("") |
| 642 | + console = Console() |
| 643 | + table = Table( |
| 644 | + "Type", |
| 645 | + "Count", |
| 646 | + "Labels", |
| 647 | + title=f"Original {prefix}", |
| 648 | + title_justify="left", |
| 649 | + ) |
| 650 | + for ann_or_pred_type, items in ann_or_pred_list.items(): |
| 651 | + if items and isinstance( |
| 652 | + items[-1], (SegmentationAnnotation, SegmentationPrediction) |
| 653 | + ): |
| 654 | + labels = set() |
| 655 | + for seg in items: |
| 656 | + labels.update(set(s.label for s in seg.annotations)) |
| 657 | + else: |
| 658 | + labels = set(a.label for a in items) |
| 659 | + if items: |
| 660 | + table.add_row(ann_or_pred_type, str(len(items)), str(list(labels))) |
| 661 | + with console.capture() as capture: |
| 662 | + console.print(table) |
| 663 | + msg.append(capture.get()) |
| 664 | + return msg |
| 665 | + |
| 666 | + |
| 667 | +def filter_annotation_list( |
| 668 | + annotations: AnnotationList, annotation_filters |
| 669 | +) -> AnnotationList: |
| 670 | + annotations = copy.deepcopy(annotations) |
| 671 | + if annotation_filters is None or len(annotation_filters) == 0: |
| 672 | + return annotations |
| 673 | + annotations.box_annotations = apply_filters( |
| 674 | + annotations.box_annotations, annotation_filters |
| 675 | + ) |
| 676 | + annotations.line_annotations = apply_filters( |
| 677 | + annotations.line_annotations, annotation_filters |
| 678 | + ) |
| 679 | + annotations.polygon_annotations = apply_filters( |
| 680 | + annotations.polygon_annotations, annotation_filters |
| 681 | + ) |
| 682 | + annotations.cuboid_annotations = apply_filters( |
| 683 | + annotations.cuboid_annotations, annotation_filters |
| 684 | + ) |
| 685 | + annotations.category_annotations = apply_filters( |
| 686 | + annotations.category_annotations, annotation_filters |
| 687 | + ) |
| 688 | + annotations.multi_category_annotations = apply_filters( |
| 689 | + annotations.multi_category_annotations, annotation_filters |
| 690 | + ) |
| 691 | + annotations.segmentation_annotations = apply_filters( |
| 692 | + annotations.segmentation_annotations, annotation_filters |
| 693 | + ) |
| 694 | + return annotations |
| 695 | + |
| 696 | + |
| 697 | +def filter_prediction_list( |
| 698 | + predictions: PredictionList, prediction_filters |
| 699 | +) -> PredictionList: |
| 700 | + predictions = copy.deepcopy(predictions) |
| 701 | + if prediction_filters is None or len(prediction_filters) == 0: |
| 702 | + return predictions |
| 703 | + predictions.box_predictions = apply_filters( |
| 704 | + predictions.box_predictions, prediction_filters |
| 705 | + ) |
| 706 | + predictions.line_predictions = apply_filters( |
| 707 | + predictions.line_predictions, prediction_filters |
| 708 | + ) |
| 709 | + predictions.polygon_predictions = apply_filters( |
| 710 | + predictions.polygon_predictions, prediction_filters |
| 711 | + ) |
| 712 | + predictions.cuboid_predictions = apply_filters( |
| 713 | + predictions.cuboid_predictions, prediction_filters |
| 714 | + ) |
| 715 | + predictions.category_predictions = apply_filters( |
| 716 | + predictions.category_predictions, prediction_filters |
| 717 | + ) |
| 718 | + predictions.segmentation_predictions = apply_filters( |
| 719 | + predictions.segmentation_predictions, prediction_filters |
| 720 | + ) |
| 721 | + return predictions |
0 commit comments