diff --git a/.gitignore b/.gitignore
index 3d509736..e9765e4a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,12 @@
/models/
/cache/
+/build/
+/executable/
.json
.pdf
.bson
+/evaluation/datasets/aviation/documents
+/evaluation/datasets/nobel/documents
+/evaluation/results/
+/logs/
diff --git a/main.py b/main.py
index de42b137..d34272cf 100644
--- a/main.py
+++ b/main.py
@@ -1,6 +1,7 @@
import logging
import sys
+from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import QApplication
from wannadb.resources import ResourceManager
@@ -14,6 +15,7 @@
with ResourceManager() as resource_manager:
# set up PyQt application
+ QApplication.setAttribute(Qt.ApplicationAttribute.AA_ShareOpenGLContexts)
app = QApplication(sys.argv)
window = MainWindow()
diff --git a/requirements.txt b/requirements.txt
index 70fa4762..ff7baffd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -109,6 +109,7 @@ language-data==1.2.0
# via langcodes
marisa-trie==1.2.0
# via language-data
+markdown==3.7
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
@@ -329,5 +330,11 @@ xxhash==3.5.0
yarl==1.12.1
# via aiohttp
+pyqtgraph==0.13.7
+
+PyOpenGL==3.1.9
+
+PyOpenGL_accelerate==3.1.9
+
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/scripts/preprocess.py b/scripts/preprocess.py
index 32fd4ab9..564bf924 100644
--- a/scripts/preprocess.py
+++ b/scripts/preprocess.py
@@ -6,11 +6,12 @@
from wannadb.configuration import Pipeline
from wannadb.data.data import Document, DocumentBase
from wannadb.interaction import EmptyInteractionCallback
+from wannadb.preprocessing.dimension_reduction import PCAReducer
from wannadb.preprocessing.embedding import BERTContextSentenceEmbedder, RelativePositionEmbedder, SBERTTextEmbedder, SBERTLabelEmbedder
from wannadb.preprocessing.extraction import StanzaNERExtractor, SpacyNERExtractor
from wannadb.preprocessing.label_paraphrasing import OntoNotesLabelParaphraser, SplitAttributeNameLabelParaphraser
from wannadb.preprocessing.normalization import CopyNormalizer
-from wannadb.preprocessing.other_processing import ContextSentenceCacher
+from wannadb.preprocessing.other_processing import ContextSentenceCacher, DuplicatedNuggetsCleaner
from wannadb.resources import ResourceManager
from wannadb.statistics import Statistics
from wannadb.status import EmptyStatusCallback
@@ -68,7 +69,9 @@ def main() -> None:
SBERTLabelEmbedder("SBERTBertLargeNliMeanTokensResource"),
SBERTTextEmbedder("SBERTBertLargeNliMeanTokensResource"),
BERTContextSentenceEmbedder("BertLargeCasedResource"),
- RelativePositionEmbedder()
+ RelativePositionEmbedder(),
+ DuplicatedNuggetsCleaner(),
+ PCAReducer()
])
document_base = DocumentBase(documents, [])
diff --git a/wannadb/change_captor.py b/wannadb/change_captor.py
new file mode 100644
index 00000000..71b060ce
--- /dev/null
+++ b/wannadb/change_captor.py
@@ -0,0 +1,220 @@
+"""
+Class providing model classes which can be utilized to capture the changes due a user feedback and propagate them to the
+UI.
+These changes are computed after every feedback of the user.
+"""
+
+from typing import Optional, Union, List
+
+from PyQt6.QtGui import QColor
+
+from wannadb.data.data import InformationNugget
+from wannadb_ui.common import ThresholdPosition, AddedReason
+
+
+class BestMatchUpdate:
+ """
+ Instances of this class represent an update of the best match of a document.
+
+ Each instance provide the old best match and the new best match of a document as well as the count specifying how
+ often similar changes of best guesses happened.
+ Another best match change is considered as similar if it happened in the same feedback round and the new best guess
+ is equal.
+
+ Methods
+ -------
+ old_best_match()
+ Returns the old best match of the related document.
+ new_best_match()
+ Returns the new best match of the related document.
+ count()
+ Returns the count of similar best match changes happened in the same feedback round.
+ """
+
+ def __init__(self, old_best_match: str, new_best_match: str, count: int):
+ """
+ Parameters
+ ----------
+ old_best_match: str
+ The old best match of the related document.
+ new_best_match: str
+ The new best match of the related document.
+ count: int
+ The count of similar best match changes happened in the same feedback round.
+ """
+
+ self._old_best_match: str = old_best_match
+ self._new_best_match: str = new_best_match
+ self._count: int = count
+
+ @property
+ def old_best_match(self) -> str:
+ return self._old_best_match
+
+ @property
+ def new_best_match(self) -> str:
+ return self._new_best_match
+
+ @property
+ def count(self) -> int:
+ return self._count
+
+
+class ThresholdPositionUpdate:
+ """
+ Instances of this class represent an update of the position of a nugget's distance relative to the threshold.
+
+ Each instance provide the text of the nugget whose position changed, the old position (above or below), the new
+ position (above or below), the old and new distance of the nugget as well as a count indicating how often similar
+ changes happened in the same feedback round.
+ A change is considered as similar if it happened in the same feedback round, the text represented by the nugget is
+ equal, and it has the same type of the update (above -> below or below -> above).
+
+ As mentioned, an instance of this class can cover multiple changes if the text of the nuggets with a change are
+ equal.
+ In this case the distance related properties are None as we don't refer to a single nugget.
+ """
+
+ def __init__(self,
+ nugget_text: str,
+ old_position: Optional[ThresholdPosition], new_position: ThresholdPosition,
+ old_distance: Optional[float], new_distance: Optional[float],
+ count: int):
+ """
+ Parameters
+ ----------
+ nugget_text: str
+ Text of the nuggets whose position relative to the threshold changed.
+ old_position: ThresholdPosition
+ Previous position of the covered nuggets relative to the threshold (above or below).
+ new_position: ThresholdPosition
+ New position of the covered nuggets relative to the threshold (above or below).
+ old_distance: float
+ Old distance associated with the nugget. If multiple nuggets are covered by this instance, this will be
+ None.
+ new_distance: float
+ New distance associated with the nugget. If multiple nuggets are covered by this instance, this will be
+ None.
+ count: int
+ Number of similar changes happened in the same feedback round.
+ """
+
+ self._best_guess: str = nugget_text
+ self._old_position: Optional[ThresholdPosition] = old_position
+ self._new_position: ThresholdPosition = new_position
+ self._old_distance: float = old_distance
+ self._new_distance: float = new_distance
+ self._count: int = count
+
+ @property
+ def nugget_text(self) -> str:
+ return self._best_guess
+
+ @property
+ def old_position(self) -> Optional[ThresholdPosition]:
+ return self._old_position
+
+ @property
+ def new_position(self) -> ThresholdPosition:
+ return self._new_position
+
+ @property
+ def old_distance(self) -> Optional[float]:
+ return self._old_distance
+
+ @property
+ def new_distance(self) -> Optional[float]:
+ return self._new_distance
+
+ @property
+ def count(self) -> int:
+ return self._count
+
+
+class NewlyAddedNuggetContext:
+ """
+ Instances of this class represent a newly added nugget to the document overview.
+ Each instance provide information about the old and new distance of the nugget as well as the reason why the system
+ newly added the nugget.
+ """
+
+ def __init__(self,
+ nugget: InformationNugget,
+ old_distance: Union[float, None],
+ new_distance: float,
+ added_reason: AddedReason):
+ """
+ Parameters
+ ----------
+ nugget: InformationNugget
+ Newly added nugget.
+ old_distance: float
+ Old distance associated with the nugget.
+ new_distance: float
+ New distance associated with the nugget.
+ added_reason: AddedReason
+ Reason for the nugget being newly added.
+ """
+
+ self._nugget = nugget
+ self._old_distance = old_distance
+ self._new_distance = new_distance
+ self._added_reason = added_reason
+
+ @property
+ def nugget(self):
+ return self._nugget
+
+ @property
+ def old_distance(self):
+ return self._old_distance
+
+ @property
+ def new_distance(self):
+ return self._new_distance
+
+ @property
+ def added_reason(self):
+ return self._added_reason
+
+
+class NuggetUpdatesContext:
+ """
+ Wrapper class wrapping multiple types of nugget related updates.
+ Nugget related updates refer to `NewlyAddedNuggetContext`, `ThresholdPositionUpdate` and `BestMatchUpdate`. Each
+ instance holds a list of updates for all of these 3 update types.
+ """
+
+ def __init__(self,
+ newly_added_nugget_contexts: List[NewlyAddedNuggetContext],
+ best_match_updates: List[BestMatchUpdate],
+ threshold_position_updates: List[ThresholdPositionUpdate]):
+ """
+ Parameters
+ ----------
+ newly_added_nugget_contexts: List[NewlyAddedNuggetContext]
+ List of all `NewlyAddedNuggetContext` instances wrapped by this instance.
+ best_match_updates: List[BestMatchUpdate]
+ List of all `BestMatchUpdate` instances wrapped by this instance.
+ threshold_position_updates: List[ThresholdPositionUpdate]
+ List of all `ThresholdPositionUpdate` instances wrapped by this instance.
+ """
+
+ self._newly_added_nugget_contexts: List[NewlyAddedNuggetContext] = newly_added_nugget_contexts
+ self._best_match_updates: List[BestMatchUpdate] = best_match_updates
+ self._threshold_position_updates: List[ThresholdPositionUpdate] = threshold_position_updates
+
+ @property
+ def newly_added_nugget_contexts(self) -> List[NewlyAddedNuggetContext]:
+ return self._newly_added_nugget_contexts
+
+ @property
+ def best_match_updates(self) -> List[BestMatchUpdate]:
+ return self._best_match_updates
+
+ @property
+ def threshold_position_updates(self) -> List[ThresholdPositionUpdate]:
+ return self._threshold_position_updates
+
+
+
diff --git a/wannadb/data/data.py b/wannadb/data/data.py
index 66c5e69a..c87e25c1 100644
--- a/wannadb/data/data.py
+++ b/wannadb/data/data.py
@@ -6,7 +6,8 @@
import bson
from wannadb.data import signals
-from wannadb.data.signals import BaseSignal, ValueSignal
+from wannadb.data.signals import BaseSignal, ValueSignal, TextEmbeddingSignal, CurrentMatchIndexSignal
+from wannadb.utils import embeddings_equal, get_possible_duplicate
logger: logging.Logger = logging.getLogger(__name__)
@@ -131,6 +132,19 @@ def __setitem__(self, key: Union[str, Type[BaseSignal]], value: Union[BaseSignal
else: # signal not already set and value is not a signal object ==> get signal class by id and create object
self._signals[signal_identifier] = signals.SIGNALS[signal_identifier](value)
+ def duplicates(self, other) -> bool:
+ if not isinstance(other, InformationNugget):
+ return False
+
+ if (TextEmbeddingSignal.identifier not in self._signals or
+ TextEmbeddingSignal.identifier not in other._signals):
+ return False
+
+ return (self.document.name == other.document.name and
+ self._start_char == other._start_char and
+ self._end_char == other._end_char and
+ embeddings_equal(self[TextEmbeddingSignal], other[TextEmbeddingSignal]))
+
class Attribute:
"""
@@ -150,6 +164,7 @@ def __init__(self, name: str) -> None:
:param name: name of the attribute (must be unique in the document base)
"""
self._name: str = name
+ self._confirmed_matches: List[InformationNugget] = []
self._signals: Dict[str, BaseSignal] = {}
@@ -170,6 +185,11 @@ def name(self) -> str:
"""Name of the attribute."""
return self._name
+ @property
+ def confirmed_matches(self) -> List[InformationNugget]:
+ """All nuggets that have been explicitly confirmed as match for this attribute by the user."""
+ return self._confirmed_matches
+
@property
def signals(self) -> Dict[str, BaseSignal]:
"""Signals associated with the attribute."""
@@ -269,6 +289,10 @@ def nuggets(self) -> List[InformationNugget]:
"""Nuggets obtained from the document."""
return self._nuggets
+ @nuggets.setter
+ def nuggets(self, nuggets: List[InformationNugget]) -> None:
+ self._nuggets = nuggets
+
@property
def attribute_mappings(self) -> Dict[str, List[InformationNugget]]:
"""Mappings between attribute names and lists of nuggets associated with them."""
@@ -292,7 +316,8 @@ def sentences(self) -> List[str]:
def sentence(self, idx: int) -> tuple[int, int, str]:
"""Sentence of the document at the given index."""
start_char = self['SentenceStartCharsSignal'][idx]
- end_char = self['SentenceStartCharsSignal'][idx + 1] if idx + 1 < len(self['SentenceStartCharsSignal']) else len(self.text)
+ end_char = self['SentenceStartCharsSignal'][idx + 1] if idx + 1 < len(
+ self['SentenceStartCharsSignal']) else len(self.text)
return start_char, end_char, self.text[start_char:end_char]
def __getitem__(self, item: Union[str, Type[BaseSignal]]) -> Any:
@@ -679,6 +704,8 @@ def from_bson(cls, bson_bytes: bytes) -> "DocumentBase":
# deserialize the document
document: Document = Document(name=serialized_document["name"], text=serialized_document["text"])
+ old_to_new_index: Dict[int, int] = dict()
+ old_index = 0
for serialized_nugget in serialized_document["nuggets"]:
# deserialize the nugget
nugget: InformationNugget = InformationNugget(
@@ -692,15 +719,28 @@ def from_bson(cls, bson_bytes: bytes) -> "DocumentBase":
signal: BaseSignal = BaseSignal.from_serializable(serialized_signal, signal_identifier)
nugget.signals[signal_identifier] = signal
- document.nuggets.append(nugget)
+ possible_duplicate, idx = get_possible_duplicate(nugget, document.nuggets)
+ if possible_duplicate is None:
+ old_to_new_index[old_index] = len(document.nuggets)
+ document.nuggets.append(nugget)
+ else:
+ old_to_new_index[old_index] = idx
+
+ old_index += 1
# deserialize the attribute mappings
for name, indices in serialized_document["attribute_mappings"].items():
- document.attribute_mappings[name] = [document.nuggets[idx] for idx in indices]
+ document.attribute_mappings[name] = [document.nuggets[old_to_new_index[idx]] for idx in indices]
# deserialize the signals
for signal_identifier, serialized_signal in serialized_document["signals"].items():
signal: BaseSignal = BaseSignal.from_serializable(serialized_signal, signal_identifier)
+ if signal_identifier == CurrentMatchIndexSignal.identifier:
+ # As this an index related signal and the original indices changed due to removing duplicated
+ # nuggets, we need to adapt this signals value to the new indices
+ old_index = signal.value
+ signal = CurrentMatchIndexSignal(old_to_new_index[old_index])
+
document.signals[signal_identifier] = signal
document_base.documents.append(document)
diff --git a/wannadb/data/signals.py b/wannadb/data/signals.py
index 4afdcedf..ea8607b1 100644
--- a/wannadb/data/signals.py
+++ b/wannadb/data/signals.py
@@ -356,6 +356,20 @@ class LabelEmbeddingSignal(BaseNumpyArraySignal):
do_serialize: bool = True
+@register_signal
+class PCADimensionReducedLabelEmbeddingSignal(BaseNumpyArraySignal):
+ """Embedding of the nugget's label or attribute's name reduced to 3 dimensions."""
+ identifier: str = "DimensionReducedLabelEmbeddingSignal"
+ do_serialize: bool = True
+
+
+@register_signal
+class TSNEDimensionReducedLabelEmbeddingSignal(BaseNumpyArraySignal):
+ """Embedding of the nugget's label or attribute's name reduced to 3 dimensions."""
+ identifier: str = "TSNEDimensionReducedLabelEmbeddingSignal"
+ do_serialize: bool = True
+
+
@register_signal
class TextEmbeddingSignal(BaseNumpyArraySignal):
"""Embedding of the nugget's text."""
@@ -363,6 +377,20 @@ class TextEmbeddingSignal(BaseNumpyArraySignal):
do_serialize: bool = True
+@register_signal
+class PCADimensionReducedTextEmbeddingSignal(BaseNumpyArraySignal):
+ """Embedding of the nugget's text reduced to 3 dimensions."""
+ identifier: str = "PCADimensionReducedTextEmbeddingSignal"
+ do_serialize: bool = True
+
+
+@register_signal
+class TSNEDimensionReducedTextEmbeddingSignal(BaseNumpyArraySignal):
+ """Embedding of the nugget's text reduced to 3 dimensions."""
+ identifier: str = "TSNEDimensionReducedTextEmbeddingSignal"
+ do_serialize: bool = True
+
+
@register_signal
class ContextSentenceEmbeddingSignal(BaseNumpyArraySignal):
"""Embedding of the nugget's textual context sentence."""
@@ -375,3 +403,10 @@ class DocumentSentenceEmbeddingSignal(BaseNumpyArraySignal):
"""Embedding of the sentences of a document."""
identifier: str = "DocumentSentenceEmbeddingSignal"
do_serialize: bool = True
+
+
+@register_signal
+class CurrentThresholdSignal(BaseFloatSignal):
+ """Current threshold associated with an attribute."""
+ identifier: str = "CurrentThresholdSignal"
+ do_serialize: bool = True
diff --git a/wannadb/matching/matching.py b/wannadb/matching/matching.py
index 289fdfed..0fe2df05 100644
--- a/wannadb/matching/matching.py
+++ b/wannadb/matching/matching.py
@@ -2,19 +2,21 @@
import logging
import random
import time
-from typing import Any, Dict, List, Callable, Tuple, Counter
+from typing import Any, Dict, List, Tuple, Counter, Optional
import numpy as np
from wannadb.configuration import BasePipelineElement, register_configurable_element, Pipeline
from wannadb.data.data import Document, DocumentBase, InformationNugget
from wannadb.data.signals import CachedContextSentenceSignal, CachedDistanceSignal, \
- SentenceStartCharsSignal, CurrentMatchIndexSignal, LabelSignal, ExtractorNameSignal
+ SentenceStartCharsSignal, CurrentMatchIndexSignal, LabelSignal, ExtractorNameSignal, CurrentThresholdSignal
from wannadb.interaction import BaseInteractionCallback
from wannadb.matching.custom_match_extraction import BaseCustomMatchExtractor
from wannadb.matching.distance import BaseDistance
+from wannadb.change_captor import NewlyAddedNuggetContext, NuggetUpdatesContext, BestMatchUpdate, ThresholdPositionUpdate
from wannadb.statistics import Statistics
from wannadb.status import BaseStatusCallback
+from wannadb_ui.common import AddedReason, ThresholdPosition
logger: logging.Logger = logging.getLogger(__name__)
@@ -130,7 +132,9 @@ def _call(
logger.info(f"Matching attribute '{attribute.name}'.")
start_matching: float = time.time()
- self._max_distance = self._default_max_distance
+ self._max_distance = self._default_max_distance # Current threshold
+ self._old_max_distance = -1 # Previous threshold
+ attribute[CurrentThresholdSignal] = CurrentThresholdSignal(self._max_distance)
statistics[attribute.name]["max_distances"] = [self._max_distance]
statistics[attribute.name]["feedback_durations"] = []
if self.store_best_guesses:
@@ -177,8 +181,13 @@ def _sort_remaining_documents():
# iterative user interactions
logger.info("Execute interactive matching.")
tik: float = time.time()
+ self._old_feedback_nuggets: List[InformationNugget] = [] # All nuggets displayed in the previous feedback round
+ self._new_nugget_contexts: List[NewlyAddedNuggetContext] = [] # All nuggets newly displayed in the current feedback round
num_feedback: int = 0
continue_matching: bool = True
+ new_best_matches: Counter[str] = Counter[str]() # New best matches due the user's latest feedback
+ new_to_old_match: Dict[str, str] = {} # All new best matches mapped to the corresponding previous best match of the same document
+ old_distances: Dict[InformationNugget, float] = {} # Values of CachedDistanceSignal for all nuggets in previous feedback round
while continue_matching and num_feedback < self._max_num_feedback and remaining_documents != []:
# sort remaining documents by distance
_sort_remaining_documents()
@@ -232,8 +241,11 @@ def _sort_remaining_documents():
# Add additional documents (most uncertain)...
if self.num_bad_docs > 0 and num_nuggets_above > 0:
k = min(self.num_bad_docs, num_nuggets_above)
- selected_documents.extend(random.choices(remaining_documents[:num_nuggets_above], k=k))
+ new_docs = random.choices(remaining_documents[:num_nuggets_above], k=k)
+ selected_documents.extend(new_docs)
num_nuggets_above -= k
+ # Mark best matches of newly added docs as newly added nuggets if they weren't present in previous feedback round
+ self._update_new_nugget_contexts(new_docs, AddedReason.MOST_UNCERTAIN, old_distances)
# ... and those that recently got interesting additional extractions to the list
if self.num_recent_docs > 0 and len(docs_with_added_nuggets) > 0:
# Create a list up to double the size wanted and then sample from that instead of only taking the same most promising documents potentially over and over again
@@ -241,10 +253,18 @@ def _sort_remaining_documents():
if len(selected_docs_with_added_nuggets) > self.num_recent_docs:
selected_docs_with_added_nuggets = random.choices(selected_docs_with_added_nuggets, k=self.num_recent_docs)
selected_documents.extend(selected_docs_with_added_nuggets)
+ # Mark best matches of newly added docs as newly added nuggets if they weren't present in previous feedback round
+ self._update_new_nugget_contexts(selected_docs_with_added_nuggets,
+ AddedReason.INTERESTING_ADDITIONAL_EXTRACTION,
+ old_distances)
selected_docs_with_added_nuggets = set(selected_docs_with_added_nuggets)
# Now fill the list with documents at threshold
- selected_documents.extend(doc for doc in remaining_documents[higher_left:lower_right] if doc not in selected_docs_with_added_nuggets)
+ docs_at_threshold_to_add = [doc for doc in remaining_documents[higher_left:lower_right] if
+ doc not in selected_docs_with_added_nuggets]
+ selected_documents.extend(docs_at_threshold_to_add)
+ # Mark best matches of selected docs as newly added if they weren't present in previous feedback round
+ self._update_new_nugget_contexts(docs_at_threshold_to_add, AddedReason.AT_THRESHOLD, old_distances)
# Sort to unify the order across the different three sources
selected_documents.sort(key=lambda x: x.nuggets[x[CurrentMatchIndexSignal]][CachedDistanceSignal], reverse=True)
@@ -260,23 +280,51 @@ def _sort_remaining_documents():
doc.nuggets[doc[CurrentMatchIndexSignal]] for doc in selected_documents)
)
)
+ all_guessed_nugget_matches = tuple([doc.nuggets[doc[CurrentMatchIndexSignal]] for doc in document_base.documents])
num_feedback += 1
statistics[attribute.name]["num_feedback"] += 1
+
t0 = time.time()
+
+ # Build all `BestMatchUpdate` instances based on `new_best_matches` dict
+ best_match_updates = [BestMatchUpdate(new_to_old_match[new_best_match],
+ new_best_match,
+ new_best_matches[new_best_match])
+ for new_best_match in new_best_matches.keys()]
+ # Build all `ThresholdPositionUpdate` instances based on old and new distances of all nuggets and the current and previous threshold
+ threshold_position_updates = self._compute_threshold_position_updates(document_base, old_distances)
+ # Gather all update types in `NuggetUpdatesContext` instance which is passed to UI
+ nugget_updates_context = NuggetUpdatesContext(newly_added_nugget_contexts=self._new_nugget_contexts,
+ best_match_updates=best_match_updates,
+ threshold_position_updates=threshold_position_updates)
+
feedback_result: Dict[str, Any] = interaction_callback(
self.identifier,
{
"max-distance": self._max_distance,
+ "max-distance-change": self._max_distance - self._old_max_distance if self._old_max_distance != -1 else 0,
"nuggets": feedback_nuggets,
+ "nugget-updates-context": nugget_updates_context,
+ "all-guessed-nugget-matches": all_guessed_nugget_matches,
"attribute": attribute,
"num-feedback": num_feedback,
"num-nuggets-above": num_nuggets_above,
- "num-nuggets-below": num_nuggets_below
+ "num-nuggets-below": num_nuggets_below,
+ "sampling-mode": self._sampling_mode
}
)
t1 = time.time()
statistics[attribute.name]["feedback_durations"].append(t1 - t0)
+ # Reinit all variables providing information related to previous feedback round
+ self._old_max_distance = self._max_distance
+ self._old_feedback_nuggets = feedback_nuggets
+ old_distances = {nugget: nugget[CachedDistanceSignal] for nugget in document_base.nuggets}
+ # Reset all variables providing information related to current feedback round
+ self._new_nugget_contexts.clear()
+ new_best_matches.clear()
+ new_to_old_match.clear()
+
if feedback_result["message"] == "stop-interactive-matching":
statistics[attribute.name]["stopped_matching_by_hand"] = True
continue_matching = False
@@ -309,7 +357,9 @@ def _sort_remaining_documents():
if feedback_nuggets_old_cached_distances[ix] < self._max_distance:
min_dist = min(min_dist, feedback_nuggets[ix][CachedDistanceSignal])
if min_dist < self._max_distance:
+ self._old_max_distance = self._max_distance
self._max_distance = min_dist
+ attribute[CurrentThresholdSignal] = CurrentThresholdSignal(min_dist)
statistics[attribute.name]["max_distances"].append(min_dist)
logger.info(f"NO MATCH IN DOCUMENT: Decreased the maximum distance to "
f"{self._max_distance}.")
@@ -357,6 +407,9 @@ def run_nugget_pipeline(nuggets):
feedback_result["document"].attribute_mappings[attribute.name] = [confirmed_nugget]
remaining_documents.remove(feedback_result["document"])
+ # add this nugget as a confirmed match to the corresponding attribute
+ attribute.confirmed_matches.append(confirmed_nugget)
+
# update the distances for the other documents
for document in remaining_documents:
new_distances: np.ndarray = self._distance.compute_distances(
@@ -367,10 +420,18 @@ def run_nugget_pipeline(nuggets):
for nugget, new_distance in zip(document.nuggets, new_distances):
if distances_based_on_label or new_distance < nugget[CachedDistanceSignal]:
nugget[CachedDistanceSignal] = new_distance
+
+ previous_best_match: InformationNugget = document.nuggets[document[CurrentMatchIndexSignal]] # Save previous best match
for ix, nugget in enumerate(document.nuggets):
current_guess: InformationNugget = document.nuggets[document[CurrentMatchIndexSignal]]
if nugget[CachedDistanceSignal] < current_guess[CachedDistanceSignal]:
document[CurrentMatchIndexSignal] = ix
+ new_best_match: InformationNugget = document.nuggets[document[CurrentMatchIndexSignal]]
+ # If there's new best match, save it and add mapping to previous best match for later use
+ if previous_best_match != new_best_match:
+ new_best_matches.update([new_best_match.text])
+ new_to_old_match[new_best_match.text] = previous_best_match.text
+
distances_based_on_label = False
# Find more nuggets that are similar to this match
@@ -422,6 +483,9 @@ def run_nugget_pipeline(nuggets):
nugget.document[CurrentMatchIndexSignal] = nugget.document.nuggets.index(nugget)
docs_with_added_nuggets[nugget.document] = distance_difference
logger.info(f"Found nugget better than current best guess for document {nugget.document.name} with distance difference {distance_difference}.")
+ old_distances[nugget] = nugget[CachedDistanceSignal]
+
+ old_distances[confirmed_nugget] = confirmed_nugget[CachedDistanceSignal]
elif feedback_result["message"] == "is-match":
statistics[attribute.name]["num_confirmed_match"] += 1
@@ -438,20 +502,30 @@ def run_nugget_pipeline(nuggets):
if doc in docs_with_added_nuggets:
docs_with_added_nuggets.pop(doc)
+ # add this nugget as a confirmed match to the corresponding attribute
+ attribute.confirmed_matches.append(feedback_result["nugget"])
+
# update the distances for the other documents
for document in remaining_documents:
new_distances: np.ndarray = self._distance.compute_distances(
[feedback_result["nugget"]],
document.nuggets,
statistics["distance"]
- )[0]
+ ) [0]
for nugget, new_distance in zip(document.nuggets, new_distances):
if distances_based_on_label or new_distance < nugget[CachedDistanceSignal]:
nugget[CachedDistanceSignal] = new_distance
+
+ previous_best_match: InformationNugget = document.nuggets[document[CurrentMatchIndexSignal]] # Save previous best match
for ix, nugget in enumerate(document.nuggets):
current_guess: InformationNugget = document.nuggets[document[CurrentMatchIndexSignal]]
if nugget[CachedDistanceSignal] < current_guess[CachedDistanceSignal]:
document[CurrentMatchIndexSignal] = ix
+ new_best_match: InformationNugget = document.nuggets[document[CurrentMatchIndexSignal]]
+ # If there's new best match, save it and add mapping to previous best match for later use
+ if previous_best_match != new_best_match:
+ new_best_matches.update([new_best_match.text])
+ new_to_old_match[new_best_match.text] = previous_best_match.text
distances_based_on_label = False
if self._adjust_threshold:
@@ -473,7 +547,9 @@ def run_nugget_pipeline(nuggets):
if feedback_nuggets_old_cached_distances[ix] > self._max_distance:
max_dist = max(max_dist, feedback_nuggets[ix][CachedDistanceSignal])
if max_dist > self._max_distance:
+ self._old_max_distance = self._max_distance
self._max_distance = max_dist
+ attribute[CurrentThresholdSignal] = CurrentThresholdSignal(max_dist)
statistics[attribute.name]["max_distances"].append(max_dist)
logger.info(f"CONFIRMED NUGGET FROM RANKED LIST: Increased the maximum distance"
f"to {self._max_distance}.")
@@ -527,6 +603,84 @@ def run_nugget_pipeline(nuggets):
statistics[attribute.name]["runtime"] = tak - start_matching
+ def _update_new_nugget_contexts(self, new_docs: List[Document], added_reason: AddedReason,
+ old_distances: Dict[InformationNugget, float]):
+ # Computes the newly added nuggets in this feedback round and creates the corresponding instances wrapping these updates
+ # To determine whether a nugget is newly added, the method considers the `_old_feedback_nuggets` list
+
+ best_matches: List[InformationNugget] = [new_doc.nuggets[new_doc[CurrentMatchIndexSignal]] for new_doc in
+ new_docs]
+
+ self._new_nugget_contexts.extend([NewlyAddedNuggetContext(nugget,
+ old_distances[nugget] if nugget in old_distances else None,
+ nugget[CachedDistanceSignal],
+ added_reason)
+ for nugget in best_matches if nugget not in self._old_feedback_nuggets])
+
+ def _compute_threshold_position_updates(self, document_base, old_distances):
+ # Computes all threshold position updates of the current feedback round based on the old and new distances of the nuggets as well as the old and new threshold
+ threshold_position_updates: Dict[str, Tuple[ThresholdPositionUpdate, Optional[ThresholdPositionUpdate]]] = dict()
+
+ for nugget in document_base.nuggets:
+ # We only care about nuggets representing a current best guesses
+ is_best_guess = nugget.document.nuggets[nugget.document[CurrentMatchIndexSignal]].text == nugget.text
+ if not is_best_guess:
+ continue
+
+ # Since we map the nuggets text to the corresponding update and there can be nuggets with equal texts, there can already be updates created for the current nugget's text
+ old_update = threshold_position_updates[nugget.text][0] if nugget.text in threshold_position_updates else None
+
+ # Compute old and new threshold position of the current nugget
+ if self._old_max_distance == -1:
+ old_threshold_position = None
+ else:
+ old_threshold_position = ThresholdPosition.ABOVE if old_distances[nugget] > self._old_max_distance \
+ else ThresholdPosition.BELOW
+ new_threshold_position = ThresholdPosition.ABOVE if nugget[CachedDistanceSignal] > self._max_distance \
+ else ThresholdPosition.BELOW
+
+ # Create update instances if old and new position differ
+ if old_threshold_position != new_threshold_position:
+ # If there's already a similar update created for the text of the current nugget, replace it by new one and increment its counter by one
+ if (old_update is not None and
+ old_update.old_position == old_threshold_position and
+ old_update.new_position == new_threshold_position):
+ threshold_position_updates[nugget.text] = (ThresholdPositionUpdate(nugget.text,
+ old_threshold_position,
+ new_threshold_position,
+ old_distances[nugget] if nugget in old_distances else None,
+ nugget[CachedDistanceSignal],
+ old_update.count + 1),
+ None)
+ # If there's already an update present whose type (above -> below / below -> above) is different, create new update and keep old one
+ elif old_update is not None:
+ threshold_position_updates[nugget.text] = (old_update,
+ ThresholdPositionUpdate(nugget.text,
+ old_threshold_position,
+ new_threshold_position,
+ old_distances[nugget] if nugget in old_distances else None,
+ nugget[CachedDistanceSignal],
+ 1))
+ # If there's no update present for the text of the current nugget, just create new one
+ else:
+ threshold_position_updates[nugget.text] = (ThresholdPositionUpdate(nugget.text,
+ old_threshold_position,
+ new_threshold_position,
+ old_distances[nugget] if nugget in old_distances else None,
+ nugget[CachedDistanceSignal],
+ 1),
+ None)
+
+ # Create final result by concatenating all created updates
+ result = []
+ for first_update, second_update in threshold_position_updates.values():
+ if second_update is None:
+ result.append(first_update)
+ else:
+ result.extend([first_update, second_update])
+
+ return result
+
def to_config(self) -> Dict[str, Any]:
return {
"identifier": self.identifier,
diff --git a/wannadb/preprocessing/dimension_reduction.py b/wannadb/preprocessing/dimension_reduction.py
new file mode 100644
index 00000000..b449a439
--- /dev/null
+++ b/wannadb/preprocessing/dimension_reduction.py
@@ -0,0 +1,117 @@
+import logging
+from abc import ABC
+from typing import Dict, Any
+
+from numpy import ndarray
+from sklearn.decomposition import PCA
+from sklearn.manifold import TSNE
+
+from wannadb.configuration import BasePipelineElement, register_configurable_element
+from wannadb.data.data import DocumentBase
+from wannadb.data.signals import LabelEmbeddingSignal, TextEmbeddingSignal, PCADimensionReducedLabelEmbeddingSignal, \
+ PCADimensionReducedTextEmbeddingSignal, TSNEDimensionReducedLabelEmbeddingSignal, \
+ TSNEDimensionReducedTextEmbeddingSignal
+from wannadb.interaction import BaseInteractionCallback
+from wannadb.statistics import Statistics
+from wannadb.status import BaseStatusCallback
+
+logger = logging.getLogger(__name__)
+
+
+class DimensionReducer(BasePipelineElement, ABC):
+ identifier: str = "DimensionReducer"
+
+ def __init__(self):
+ super(DimensionReducer, self).__init__()
+
+ def _call(self, document_base: DocumentBase, interaction_callback: BaseInteractionCallback,
+ status_callback: BaseStatusCallback, statistics: Statistics) -> None:
+ pass
+
+ def reduce_dimensions(self, data) -> ndarray:
+ pass
+
+
+@register_configurable_element
+class PCAReducer(DimensionReducer):
+ identifier: str = "PCAReducer"
+
+ def __init__(self):
+ super().__init__()
+ self.pca = PCA(n_components=3)
+
+ def __call__(
+ self,
+ document_base: DocumentBase,
+ interaction_callback: BaseInteractionCallback,
+ status_callback: BaseStatusCallback,
+ statistics: Statistics
+ ) -> None:
+ # Assume that all embeddings have same number of features
+ attribute_embeddings = [attribute[LabelEmbeddingSignal] for attribute in document_base.attributes]
+ nugget_embeddings = [nugget[TextEmbeddingSignal] for nugget in document_base.nuggets]
+ all_embeddings = attribute_embeddings + nugget_embeddings
+
+ if len(all_embeddings) < 3:
+ logger.warning("Not enough data to apply dimension reduction, will not compute them.")
+ return
+
+ dimension_reduced_embeddings = self.reduce_dimensions(all_embeddings)
+
+ for idx, embedding in enumerate(dimension_reduced_embeddings):
+ if idx < len(attribute_embeddings):
+ document_base.attributes[idx][PCADimensionReducedLabelEmbeddingSignal] = (
+ PCADimensionReducedLabelEmbeddingSignal(embedding))
+ else:
+ document_base.nuggets[idx - len(attribute_embeddings)][PCADimensionReducedTextEmbeddingSignal] = (
+ PCADimensionReducedTextEmbeddingSignal(embedding))
+
+ def reduce_dimensions(self, data) -> ndarray:
+ self.pca.fit(data)
+ return self.pca.transform(data)
+
+ def to_config(self) -> Dict[str, Any]:
+ return {
+ "identifier": self.identifier
+ }
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "DimensionReducer":
+ return cls()
+
+
+@register_configurable_element
+class TSNEReducer(DimensionReducer):
+ identifier: str = "TSNEReducer"
+
+ def __init__(self):
+ super().__init__()
+ self.tsne = TSNE(n_components=3, n_iter=300)
+
+ def _call(self, document_base: DocumentBase, interaction_callback: BaseInteractionCallback,
+ status_callback: BaseStatusCallback, statistics: Statistics) -> None:
+ attribute_embeddings = [attribute[LabelEmbeddingSignal] for attribute in document_base.attributes]
+ nugget_embeddings = [nugget[TextEmbeddingSignal] for nugget in document_base.nuggets]
+ all_embeddings = attribute_embeddings + nugget_embeddings
+
+ dimension_reduced_embeddings = self.reduce_dimensions(all_embeddings)
+
+ for idx, embedding in enumerate(dimension_reduced_embeddings):
+ if idx < len(attribute_embeddings):
+ document_base.attributes[idx][TSNEDimensionReducedLabelEmbeddingSignal] = (
+ TSNEDimensionReducedLabelEmbeddingSignal(embedding))
+ else:
+ document_base.nuggets[idx - len(attribute_embeddings)][TSNEDimensionReducedTextEmbeddingSignal] = (
+ TSNEDimensionReducedTextEmbeddingSignal(embedding))
+
+ def reduce_dimensions(self, data) -> ndarray:
+ return self.tsne.fit_transform(data)
+
+ def to_config(self) -> Dict[str, Any]:
+ return {
+ "identifier": self.identifier
+ }
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "DimensionReducer":
+ return cls()
diff --git a/wannadb/preprocessing/embedding.py b/wannadb/preprocessing/embedding.py
index 2f0bd6c7..ecef72b6 100644
--- a/wannadb/preprocessing/embedding.py
+++ b/wannadb/preprocessing/embedding.py
@@ -80,7 +80,8 @@ def _call(
self._embed_documents(document_base, interaction_callback, status_callback, statistics["documents"])
status_callback(f"Embedding documents with {self.identifier}...", 1)
tack: float = time.time()
- logger.info(f"Embedded {len(document_base.documents)} documents with {self.identifier} in {tack - tick} seconds.")
+ logger.info(
+ f"Embedded {len(document_base.documents)} documents with {self.identifier} in {tack - tick} seconds.")
statistics["documents"]["runtime"] = tack - tick
# compute embeddings for the nuggets
@@ -94,13 +95,18 @@ def _call(
if self.generated_signal_identifiers["nuggets"][0] in nuggets[0].signals.keys():
# Try to determine if the dimensions are correct (should match those of the embedding of the attributes)
if len(self.generated_signal_identifiers["attributes"]) > 0:
- if len(attributes) > 0 and attributes[0].signals[self.generated_signal_identifiers["attributes"][0]].value.shape == nuggets[0].signals[self.generated_signal_identifiers["attributes"][0]].value.shape:
- logger.info(f"No need to embedd nuggets again with {self.identifier}, existing embeddings with correct dimensions found.")
+ if len(attributes) > 0 and attributes[0].signals[
+ self.generated_signal_identifiers["attributes"][0]].value.shape == nuggets[0].signals[
+ self.generated_signal_identifiers["attributes"][0]].value.shape:
+ logger.info(
+ f"No need to embedd nuggets again with {self.identifier}, existing embeddings with correct dimensions found.")
return
- logger.info(f"Dimension missmatch, recomputing embeddings for {self.generated_signal_identifiers['nuggets'][0]} with {self.identifier}.")
+ logger.info(
+ f"Dimension missmatch, recomputing embeddings for {self.generated_signal_identifiers['nuggets'][0]} with {self.identifier}.")
else:
# Cannot check dimensions, but assuming they are correct do to lack of other evidence
- logger.info(f"Found existing embeddings for {self.generated_signal_identifiers['nuggets'][0]}, assuming they were created with {self.identifier} (even though dimension check is not possible.")
+ logger.info(
+ f"Found existing embeddings for {self.generated_signal_identifiers['nuggets'][0]}, assuming they were created with {self.identifier} (even though dimension check is not possible.")
return
# If no existing embeddings are found, or dimensions are not matching continue with embedding
@@ -165,6 +171,7 @@ def _embed_documents(
"""
pass # default behavior: do nothing
+
########################################################################################################################
# actual embedders
########################################################################################################################
@@ -513,13 +520,16 @@ def get_candidate_contexts(context_sentence, start_in_context, end_in_context):
"""
prev_candidate_context = None
for candidate_start, candidate_end in zip(map(lambda i: max(0, i), count(start_in_context, -1)),
- map(lambda i: min(i, len(context_sentence) - 1), count(end_in_context, 1))):
+ map(lambda i: min(i, len(context_sentence) - 1),
+ count(end_in_context, 1))):
candidate_context = context_sentence[candidate_start:candidate_end]
yield prev_candidate_context, candidate_context
prev_candidate_context = candidate_context
- for prev, candidate_context in get_candidate_contexts(context_sentence, start_in_context, end_in_context):
- input_ids, token_type_ids, attention_mask, char_to_token = get_encoding_data(candidate_context, device)
+ for prev, candidate_context in get_candidate_contexts(context_sentence, start_in_context,
+ end_in_context):
+ input_ids, token_type_ids, attention_mask, char_to_token = get_encoding_data(candidate_context,
+ device)
# The condition will be true at some point
# because token_type_ids[0] is monotonically increasing with longer context sentences
# and we know that the whole sentence is above the limit
@@ -533,16 +543,19 @@ def get_candidate_contexts(context_sentence, start_in_context, end_in_context):
logger.error(error)
raise RuntimeError(error)
context_sentence = prev
- input_ids, token_type_ids, attention_mask, char_to_token = get_encoding_data(context_sentence, device)
+ input_ids, token_type_ids, attention_mask, char_to_token = get_encoding_data(
+ context_sentence, device)
break
- logger.error(f"==> Using shorter context sentence '{context_sentence}' with {len(token_type_ids[0])} token indices "
- f"for nugget '{context_sentence[start_in_context:end_in_context]}'.")
+ logger.error(
+ f"==> Using shorter context sentence '{context_sentence}' with {len(token_type_ids[0])} token indices "
+ f"for nugget '{context_sentence[start_in_context:end_in_context]}'.")
return input_ids, token_type_ids, attention_mask, char_to_token, context_sentence
- input_ids, token_type_ids, attention_mask, char_to_token, context_sentence = get_encoding_data_with_limited_tokens_for_context(context_sentence,
- start_in_context,
- end_in_context)
+ input_ids, token_type_ids, attention_mask, char_to_token, context_sentence = get_encoding_data_with_limited_tokens_for_context(
+ context_sentence,
+ start_in_context,
+ end_in_context)
outputs = resources.MANAGER[self._bert_resource_identifier]["model"](
input_ids=input_ids,
diff --git a/wannadb/preprocessing/other_processing.py b/wannadb/preprocessing/other_processing.py
index cddb6ebc..53781347 100644
--- a/wannadb/preprocessing/other_processing.py
+++ b/wannadb/preprocessing/other_processing.py
@@ -4,10 +4,11 @@
from wannadb.configuration import BasePipelineElement, register_configurable_element
from wannadb.data.data import DocumentBase, InformationNugget
from wannadb.data.signals import CachedContextSentenceSignal, \
- SentenceStartCharsSignal
+ SentenceStartCharsSignal, TextEmbeddingSignal, CurrentMatchIndexSignal
from wannadb.interaction import BaseInteractionCallback
from wannadb.statistics import Statistics
from wannadb.status import BaseStatusCallback
+from wannadb.utils import get_possible_duplicate
logger: logging.Logger = logging.getLogger(__name__)
@@ -80,3 +81,63 @@ def to_config(self) -> Dict[str, Any]:
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ContextSentenceCacher":
return cls()
+
+
+@register_configurable_element
+class DuplicatedNuggetsCleaner(BasePipelineElement):
+ """
+ Removes duplicated nuggets.
+ We consider a nugget duplicating another nugget if they belong to the same document, are located at the same
+ position within the documents text and have "nearly" the same embedding. "Nearly" in this context refers to a
+ tolerance value which is required while comparing two nugget embeddings as embedding values are represented as
+ floats and therefore can't be compared for exact equality. For more details see
+ :func`~wannadb.utils.embeddings_equal`
+ """
+
+ identifier: str = "DuplicatedNuggetsCleaner"
+
+ required_signal_identifiers: Dict[str, List[str]] = {
+ "nuggets": [TextEmbeddingSignal.identifier],
+ "attributes": [],
+ "documents": []
+ }
+
+ def __init__(self):
+ """Initialize the DuplicatedNuggetsCleaner."""
+ super(DuplicatedNuggetsCleaner, self).__init__()
+ logger.debug(f"Initialized '{self.identifier}'.")
+
+ def _call(self, document_base: DocumentBase, interaction_callback: BaseInteractionCallback,
+ status_callback: BaseStatusCallback, statistics: Statistics) -> None:
+ for document in document_base.documents:
+
+ cleaned_nuggets: List[InformationNugget] = list()
+ old_to_new_index: Dict[int, int] = dict()
+ old_index = 0
+
+ for nugget in document_base.nuggets:
+ possible_duplicate, idx = get_possible_duplicate(nugget, document.nuggets)
+ if possible_duplicate is None:
+ old_to_new_index[old_index] = len(cleaned_nuggets)
+ cleaned_nuggets.append(nugget)
+ else:
+ old_to_new_index[old_index] = idx
+
+ old_index += 1
+
+ logger.info(f"Removed {len(document.nuggets) - len(cleaned_nuggets)} duplicated nuggets from document "
+ f"\"{document.name}\".")
+ document.nuggets = cleaned_nuggets
+
+ if CurrentMatchIndexSignal.identifier in document.signals:
+ old_index = document[CurrentMatchIndexSignal].value
+ document[CurrentMatchIndexSignal] = CurrentMatchIndexSignal(old_to_new_index[old_index])
+
+ def to_config(self) -> Dict[str, Any]:
+ return {
+ "identifier": self.identifier
+ }
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "DuplicatedNuggetsCleaner":
+ return cls()
diff --git a/wannadb/resources.py b/wannadb/resources.py
index c176621a..76779d64 100644
--- a/wannadb/resources.py
+++ b/wannadb/resources.py
@@ -16,6 +16,8 @@
from stanza import Pipeline
from transformers import BertModel, BertTokenizer, BertTokenizerFast
+from wannadb_ui.study import Tracker
+
logger: logging.Logger = logging.getLogger(__name__)
RESOURCES: Dict[str, Type["BaseResource"]] = {}
@@ -104,6 +106,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
for resource_identifier in list(self._resources.keys()):
self.unload(resource_identifier)
+ Tracker().dump_report()
tack: float = time.time()
logger.info(f"Unloaded all resources in {tack - tick} seconds.")
logger.info("Exited the resource manager.")
diff --git a/wannadb/utils.py b/wannadb/utils.py
new file mode 100644
index 00000000..2171e599
--- /dev/null
+++ b/wannadb/utils.py
@@ -0,0 +1,83 @@
+"""
+Utility class providing common functionality.
+"""
+
+import math
+
+import numpy as np
+from PyQt6.QtGui import QColor
+
+
+def get_possible_duplicate(nugget_to_check, nugget_list):
+ """
+ Checks the given list for duplicates of the given nugget and returns the first occurring duplicate if present and
+ its index in the list.
+
+ The check whether a nugget duplicates another is realized by the `duplicates(other) -> bool` function of the
+ `InformationNugget` class.
+ """
+
+ for idx, nugget in enumerate(nugget_list):
+ if nugget_to_check.duplicates(nugget):
+ return nugget, idx
+
+ return None, None
+
+
+def positions_equal(position1: np.ndarray, position2: np.ndarray) -> bool:
+ """
+ Checks if the given arrays are equal meaning that each element of the first array is close enough to the
+ corresponding value in the second array.
+ The check for closeness is realized by `math.isclose(...)` function.
+
+ Handles only (1, 3) shaped arrays as this function should only be used for arrays representing 3-dimensional
+ positions.
+ If one of the given arrays doesn't conform to this shape, the function returns `False`.
+
+ Returns
+ -------
+ Whether the given arrays are considered as equal according to the explanation above.
+ """
+
+ if position1.shape != (1, 3) or position2.shape != (1, 3):
+ return False
+
+ return (math.isclose(position1[0][0], position2[0][0], rel_tol=1e-05, abs_tol=1e-05) and
+ math.isclose(position1[0][1], position2[0][1], rel_tol=1e-05, abs_tol=1e-05) and
+ math.isclose(position1[0][2], position2[0][2], rel_tol=1e-05, abs_tol=1e-05))
+
+
+def embeddings_equal(embedding1: np.ndarray, embedding2: np.ndarray) -> bool:
+ if embedding1.shape != embedding2.shape:
+ return False
+
+ arrays_are_close = np.vectorize(math.isclose)
+ return arrays_are_close(embedding1, embedding2, rel_tol=1e-05, abs_tol=1e-05).all()
+
+
+class AccessibleColor:
+ """
+ Utility model class wrapping a color and its corresponding accessible color that is better understandable by users
+ suffering from color blindness.
+ """
+
+ def __init__(self, color: QColor, corresponding_accessible_color: QColor):
+ """
+ Parameters
+ ----------
+ color: QColor
+ Color represented by this instance.
+ corresponding_accessible_color: QColor
+ Accessible color corresponding to the given standard version of the color.
+ """
+
+ self._color = color
+ self._corresponding_accessible_color = corresponding_accessible_color
+
+ @property
+ def color(self):
+ return self._color
+
+ @property
+ def corresponding_accessible_color(self):
+ return self._corresponding_accessible_color
diff --git a/wannadb_ui/common.py b/wannadb_ui/common.py
index 622d69eb..3933360c 100644
--- a/wannadb_ui/common.py
+++ b/wannadb_ui/common.py
@@ -1,8 +1,16 @@
-import abc
+import os
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Union, List, Optional, Tuple
-from PyQt6.QtCore import Qt
-from PyQt6.QtGui import QFont
-from PyQt6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea, QFrame, QHBoxLayout, QDialog, QPushButton
+import markdown
+import pyqtgraph
+from PyQt6.QtCore import Qt, QPoint
+from PyQt6.QtGui import QFont, QPixmap, QPainter, QColor
+from PyQt6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea, QFrame, QHBoxLayout, QDialog, QPushButton, \
+ QMainWindow, QTextEdit
+
+from wannadb.data.data import InformationNugget
# fonts
HEADER_FONT = QFont("Segoe UI", pointSize=20, weight=QFont.Weight.Bold)
@@ -17,6 +25,7 @@
STATUS_BAR_FONT = QFont("Segoe UI", pointSize=11)
STATUS_BAR_FONT_BOLD = QFont("Segoe UI", pointSize=11, weight=QFont.Weight.Bold)
BUTTON_FONT = QFont("Segoe UI", pointSize=11)
+BUTTON_FONT_SMALL = QFont("Segoe UI", pointSize=9)
# colors
WHITE = "#FFFFFF"
@@ -36,6 +45,74 @@
INPUT_DOCS_COLUMN_NAME = "input_document"
+class ThresholdPosition(Enum):
+ ABOVE = 1
+ BELOW = 2
+
+
+class AvailableVisualizationsLevel(Enum):
+ DISABLED = 0
+ LEVEL_1 = 1
+ LEVEL_2 = 2
+
+
+class NuggetUpdateType(Enum):
+ NEWLY_ADDED = 1
+ THRESHOLD_POSITION_UPDATE = 2
+ BEST_MATCH_UPDATE = 3
+
+
+class AddedReason(Enum):
+ """
+ Corresponds to the reason why the framework decided to newly add a nugget to the overview list.
+ """
+
+ MOST_UNCERTAIN = "The documents match belongs to the considered most uncertain matches."
+ INTERESTING_ADDITIONAL_EXTRACTION = "The document recently got interesting additional extraction to the list."
+ AT_THRESHOLD = "The distance of the guessed match is within the considered range around the threshold."
+
+ def __init__(self, corresponding_tooltip_text: str):
+ self._corresponding_tooltip_text = corresponding_tooltip_text
+
+ @property
+ def corresponding_tooltip_text(self):
+ return self._corresponding_tooltip_text
+
+
+class VisualizationProvidingItem:
+ """
+ Abstract class identifying UI items which provide any kind of visualization and therefore requires adapting if the
+ enabled visualization level changes.
+ Forces classes inheriting from this class to implement a method `_adapt_to_visualizations_level(visualizations_level)`
+ adapting the UI element according to the currently enabled visualization level.
+
+ Methods
+ -------
+ update_shown_visualizations(visualization_level: AvailableVisualizationsLevel)
+ Adapts the corresponding UI element to the given visualization level as each level allows different
+ visualization components to be enabled.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def update_shown_visualizations(self, visualization_level: AvailableVisualizationsLevel):
+ """
+ Adapts the corresponding UI element to the given visualization level as each level allows different
+ visualization components to be enabled.
+
+ Parameters
+ ----------
+ visualization_level: AvailableVisualizationsLevel
+ Visualization level to which the UI component needs to be adapted.
+ """
+ self._adapt_to_visualizations_level(visualization_level)
+
+ @abstractmethod
+ def _adapt_to_visualizations_level(self, visualizations_level):
+ pass
+
+
class MainWindowContent(QWidget):
def __init__(self, main_window, header_text):
@@ -61,11 +138,11 @@ def __init__(self, main_window, header_text):
self.controls_widget_layout.setContentsMargins(0, 0, 0, 0)
self.top_widget_layout.addWidget(self.controls_widget, alignment=Qt.AlignmentFlag.AlignRight)
- @abc.abstractmethod
+ @abstractmethod
def enable_input(self):
raise NotImplementedError
- @abc.abstractmethod
+ @abstractmethod
def disable_input(self):
raise NotImplementedError
@@ -135,7 +212,7 @@ def update_item_list(self, item_list, params=None):
# make sure that there are enough item widgets
while len(item_list) > len(self.item_widgets):
- self.item_widgets.append(self.item_type(self.parent))
+ self.item_widgets.append(self._create_new_widget())
# make sure that the correct number of item widgets is shown
while len(item_list) > self.num_visible_item_widgets:
@@ -165,6 +242,31 @@ def disable_input(self):
for item_widget in self.item_widgets:
item_widget.disable_input()
+ def _create_new_widget(self):
+ return self.item_type(self.parent)
+
+
+class VisualizationProvidingCustomScrollableList(CustomScrollableList, VisualizationProvidingItem):
+ """
+ Class realizing a `CustomScrollableList` providing visualizations via the items in the list by inheriting from
+ `CustomScrollableList` and `VisualizationProvidingItem`.
+ """
+
+ def __init__(self, parent, item_type, visualizations_level, attach_visualization_level_observer,
+ floating_widget=None, orientation="vertical", above_widget=None):
+ super().__init__(parent, item_type, floating_widget, orientation, above_widget)
+
+ self.visualizations_level = visualizations_level
+ self.attach_visualization_level_observer = attach_visualization_level_observer
+
+ def _create_new_widget(self):
+ new_widget = self.item_type(self.parent, self.visualizations_level)
+ self.attach_visualization_level_observer(new_widget)
+ return new_widget
+
+ def _adapt_to_visualizations_level(self, visualizations_level):
+ self.visualizations_level = visualizations_level
+
class CustomScrollableListItem(QFrame):
@@ -172,15 +274,15 @@ def __init__(self, parent):
super(CustomScrollableListItem, self).__init__()
self.parent = parent
- @abc.abstractmethod
+ @abstractmethod
def update_item(self, item, params=None):
raise NotImplementedError
- @abc.abstractmethod
+ @abstractmethod
def enable_input(self):
raise NotImplementedError
- @abc.abstractmethod
+ @abstractmethod
def disable_input(self):
raise NotImplementedError
@@ -210,3 +312,192 @@ def show_confirmation_dialog(parent, title_text, explanation_text, accept_text,
no_button.setFocus()
return dialog.exec()
+
+
+class InformationPopup(QMainWindow):
+ """
+ Realizes an information popup as a separate window.
+
+ The content to be displayed within the window is determined by a markdown file.
+ """
+
+ def __init__(self, title: str, content_file_to_display: str):
+ """
+ Initializes an instance of this class by reading the given markdown file and render the corresponding content
+ in the window.
+
+ Parameters
+ ----------
+ title: str
+ The title of the popup.
+ content_file_to_display: str
+ The path to the markdown file determining the content to be displayed.
+ """
+
+ super().__init__()
+
+ # Init widget containing the HTML content defined in the given markdown file
+ self._text_widget = QTextEdit()
+
+ # Read markdown file
+ with open(content_file_to_display, "r") as file:
+ formatted_text = file.read()
+ markdown_result = markdown.markdown(formatted_text)
+
+ self._text_widget.setHtml(markdown_result)
+
+ self.setCentralWidget(self._text_widget)
+
+ self.setWindowTitle(title)
+ self.resize(1000, 700)
+
+
+class InfoDialog(QDialog):
+ """
+ Realizes an information dialog in form of pop-ups.
+
+ The user can click through the dialog via navigation buttons or skip the whole dialog directly.
+
+ Methods
+ -------
+ set_info_list(info_list)
+ Sets the information to display within this dialog.
+ set_image_list(image_list)
+ Sets the images to display within the dialogs.
+ exec()
+ Overwrite `exec()` to avoid multiple executions.
+ """
+
+ def __init__(self):
+ """
+ Initializes the required UI components.
+ The UI consists of the raw information text, an image serving as illustration as well as navigation buttons to
+ open next/previous screen or skip the dialog.
+ """
+ super().__init__()
+
+ self.dialog_shown: bool = False
+
+ self.info_list = None
+ self.image_list = None
+ self.current_index = 0
+ self.base_dir = "" # Base directory to prepend to image paths
+
+ # Set up the dialog layout
+ self.layout = QVBoxLayout()
+
+ # Set a fixed width for the dialog
+ self.setFixedWidth(600)
+
+ # Text edit to display the information text (supports HTML)
+ self.info_text = QTextEdit()
+ self.info_text.setReadOnly(True) # Make sure the text is not editable
+ self.layout.addWidget(self.info_text)
+
+ # Buttons for navigation (Previous, Next, Skip)
+ self.button_layout = QHBoxLayout()
+
+ self.prev_button = QPushButton("Previous")
+ self.prev_button.clicked.connect(self._show_previous)
+ self.button_layout.addWidget(self.prev_button)
+
+ self.next_button = QPushButton("Next")
+ self.next_button.clicked.connect(self._show_next)
+ self.button_layout.addWidget(self.next_button)
+
+ self.skip_button = QPushButton("Skip")
+ self.skip_button.clicked.connect(self._skip)
+ self.button_layout.addWidget(self.skip_button)
+
+ # Add button layout to the main layout
+ self.layout.addLayout(self.button_layout)
+
+ # Set the layout for the dialog
+ self.setLayout(self.layout)
+
+ def load_markdown_file(self, file_path: str):
+ """
+ Load the markdown file and convert it to HTML.
+ """
+ self.base_dir = os.path.dirname(os.path.abspath(file_path)) # Store base directory of the markdown file
+
+ with open(file_path, 'r', encoding='utf-8') as file:
+ markdown_content = file.read()
+ # Use custom delimiter to split sections instead of
+ sections = markdown_content.split('')
+ # Convert each section to HTML and store in info_list
+ self.info_list = [markdown.markdown(section) for section in sections]
+
+ self._update_info()
+
+ def _update_info(self):
+ """
+ Updates the displayed information in the QTextEdit widget.
+ Handles switching between sections.
+ """
+ if self.info_list is not None:
+ # Add base path to image sources in the HTML
+ html_with_images = self._add_base_path_to_images(self.info_list[self.current_index])
+ self.info_text.setHtml(html_with_images)
+ self._update_buttons()
+
+ def _add_base_path_to_images(self, html: str) -> str:
+ """
+ Modify HTML content to prepend base directory to image sources.
+
+ Parameters
+ ----------
+ html: str
+ The HTML content where image paths need to be modified.
+
+ Returns
+ -------
+ str
+ The HTML content with updated image paths.
+ """
+ if self.base_dir:
+ return html.replace('src="', f'src="{self.base_dir}/')
+ return html
+
+ def _update_buttons(self):
+ """
+ Update the state of navigation buttons (enabled/disabled).
+ Controls the "Previous" and "Next" buttons based on the current index.
+ """
+ self.prev_button.setEnabled(self.current_index > 0)
+ self.next_button.setEnabled(self.current_index < len(self.info_list) - 1)
+
+ def _show_previous(self):
+ """
+ Method to show the previous section of information in the dialog.
+ Decreases the current index by one and updates the displayed content.
+ """
+ if self.current_index > 0:
+ self.current_index -= 1
+ self._update_info()
+
+ def _show_next(self):
+ """
+ Method to show the next section of information in the dialog.
+ Increases the current index by one and updates the displayed content.
+ """
+ if self.current_index < len(self.info_list) - 1:
+ self.current_index += 1
+ self._update_info()
+
+ def _skip(self):
+ """
+ Method to skip the dialog and close it.
+ """
+ self.accept()
+
+ def exec(self):
+ """
+ Overwrite `exec()` to avoid multiple executions.
+ If dialog is not shown currently, call `exec()` of superclass, else do nothing.
+
+ For more information check documentation of `exec()` in `QtWidgets` module.
+ """
+ if not self.dialog_shown:
+ super().exec()
+ self.dialog_shown = True
\ No newline at end of file
diff --git a/wannadb_ui/data_insights.py b/wannadb_ui/data_insights.py
new file mode 100644
index 00000000..c7a242ed
--- /dev/null
+++ b/wannadb_ui/data_insights.py
@@ -0,0 +1,536 @@
+"""
+Module providing logic to realize Data Insights section visible in the document overview screen.
+"""
+
+import abc
+import random
+from typing import Generic, TypeVar, List, Tuple
+
+from PyQt6.QtCore import Qt
+from PyQt6.QtWidgets import QWidget, QHBoxLayout, QLabel, QVBoxLayout, QSpacerItem, QSizePolicy, QPushButton
+
+from wannadb.change_captor import BestMatchUpdate, ThresholdPositionUpdate
+from wannadb.utils import AccessibleColor
+from wannadb_ui import visualizations
+from wannadb_ui.common import ThresholdPosition, SUBHEADER_FONT, LABEL_FONT, \
+ BUTTON_FONT
+from wannadb_ui.study import track_button_click
+from wannadb_ui.visualizations import EmbeddingVisualizerWindow
+
+# Refers to the type of items displayed in a ChangesList
+UPDATE_TYPE = TypeVar("UPDATE_TYPE")
+
+
+class ChangesList(QWidget, Generic[UPDATE_TYPE]):
+ """
+ This class realizes a QWidget representing a list of updates.
+ These updates refer to changes induced by the latest user feedback.
+
+ Methods
+ -------
+ update_list(self, updates: List[UPDATE_TYPE]):
+ Updates the list with the given list of items.
+ """
+
+ def __init__(self, info_label_text, tooltip_text):
+ """
+ Initializes an empty UI ChangesList with the given name and tooltip.
+
+ Parameters
+ ----------
+ info_label_text : str
+ Name of this list displayed next to the list itself.
+ tooltip_text: QColor
+ Text further explaining the list's intention displayed if hovering over list's name
+ """
+
+ super(ChangesList, self).__init__()
+
+ # Setup layout
+ self._layout: QHBoxLayout = QHBoxLayout(self)
+ self._layout.setSpacing(0)
+ self._layout.setContentsMargins(0, 0, 0, 0)
+
+ # Init and add name label and tooltip
+ self._info_label: QLabel = QLabel(info_label_text)
+ self._info_label.setContentsMargins(0, 0, 8, 0)
+ self._list_labels: List[QWidget] = list()
+ self._layout.addWidget(self._info_label)
+ self._info_label.setToolTip(tooltip_text)
+
+ def update_list(self, updates: List[UPDATE_TYPE]):
+ """
+ Updates the list by the given list of items.
+
+ First it removes all items from the current list and then adds the new items represented by the given list.
+ In order to keep the UI clear, we only add the seven randomly sampled items of the given list to the UI list.
+ The existence of further - not displayed - items are indicated by a label displaying "... and
+ [NUMBER_OF_MISSING_ITEMS] more.".
+
+ Parameters
+ ----------
+ updates: List[UPDATE_TYPE]
+ Items which should be added to the list.
+ """
+
+ # Remove existing items from list
+ self._reset_list()
+
+ if len(updates) == 0:
+ # We don't want to have a list containing nothing but at least some symbol indicating that the list is empty
+ no_changes_label = QLabel("-")
+ no_changes_label.setContentsMargins(0, 0, 0, 0)
+ self._layout.addWidget(no_changes_label)
+ self._list_labels.append(no_changes_label)
+ return
+
+ # Select the 7 items to be displayed and add them to the UI
+ updates_to_add = random.sample(updates, k=min(7, len(updates)))
+ for update in updates_to_add:
+ label_text, tooltip_text = self._create_label_and_tooltip_text(update)
+ label = QLabel(label_text)
+ label.setContentsMargins(0, 0, 8, 0)
+ label.setToolTip(tooltip_text)
+ self._layout.addWidget(label)
+ self._list_labels.append(label)
+
+ if len(updates) > 7:
+ last_label = QLabel(f"... and {len(updates) - 7} more.")
+ last_label.setContentsMargins(0, 0, 0, 0)
+ self._layout.addWidget(last_label)
+ self._list_labels.append(last_label)
+
+ @abc.abstractmethod
+ def _create_label_and_tooltip_text(self, update: UPDATE_TYPE) -> Tuple[str, str]:
+ # Computes the label text and tooltip corresponding to an update depending on the actual type of the update
+ pass
+
+ def _reset_list(self):
+ # Removes all items from the UI list
+ for list_label in self._list_labels:
+ self._layout.removeWidget(list_label)
+
+ self._list_labels = []
+
+
+class ChangedBestMatchDocumentsList(ChangesList[BestMatchUpdate]):
+ """
+ Realizes a `ChangesList` displaying changed best matches after each user feedback which can be found within the
+ Data Insights section by inheriting from `ChangesList`.
+ """
+
+ def __init__(self):
+ """
+ Determines its tooltip text and name and initializes itself by calling super constructor.
+ """
+
+ tooltip_text = ("The distance associated with each nugget is recomputed after every feedback round.\n"
+ "Therefore the best guess of an document (nugget with lowest distance) might change "
+ "after a feedback round. Such best guesses are listed here.")
+
+ super(ChangedBestMatchDocumentsList, self).__init__("Changed best guesses:", tooltip_text)
+
+ def _create_label_and_tooltip_text(self, update: BestMatchUpdate) -> Tuple[str, str]:
+ # Computes the text and tooltip which should represent the given item in the UI list.
+ label_text = f"{update.new_best_match} {'(' + str(update.count) + ')' if update.count > 1 else ''}"
+ tooltip_text = (f"Previous best match was: {update.old_best_match}\n"
+ f"Changes to token \"{update.new_best_match}\": {update.count}")
+
+ return label_text, tooltip_text
+
+
+class ChangedThresholdPositionList(ChangesList[ThresholdPositionUpdate]):
+ """
+ Realizes an abstract ChangesList displaying nuggets whose threshold position changed (either above or below) due to
+ the latest user feedback which can be found in the Data Insights section by inheriting from `ChangesList`.
+
+ Methods
+ -------
+ update_list(updates: List[ThresholdPositionUpdate])
+ Extracts the relevant updates out of the given list matching and updates the list with the extracted, relevant
+ updates.
+ """
+
+ def __init__(self, info_label_text: str, tooltip_text: str, addressed_change: ThresholdPosition):
+ """
+ Initializes itself by calling super constructor.
+
+ Parameters:
+ -----------
+ info_label_text : str
+ Name of this list displayed next to the list itself.
+ tooltip_text: QColor
+ Text further explaining the list's intention displayed if hovering over list's name
+ addressed_change
+ Determines the change type addressed by this list, either from above to below or below to above.
+ The given position refers to the end position of the relevant updates (E.g. If it's 'below', then this list
+ only cares about 'above' -> 'below' updates).
+ """
+
+ self._addressed_change = addressed_change
+
+ super(ChangedThresholdPositionList, self).__init__(info_label_text, tooltip_text)
+
+ def update_list(self, threshold_updates: List[ThresholdPositionUpdate]):
+ """
+ Extracts the relevant updates out of the given list matching and updates the list with the extracted, relevant
+ updates.
+
+ The given list contains all updates covering changes from above to below as well as below to above the
+ threshold while this list should only display one of these type of changes.
+ Therefore, the mentioned extraction is required.
+
+ Parameters
+ ----------
+ threshold_updates: List[ThresholdPositionUpdate]
+ List of items in which the items to be added can be found. To extract the items to be added from the whole
+ list, filter it according to the change type addressed by the list.
+ """
+
+ # Extract relevant updates
+ relevant_updates = self._extract_relevant_updates(threshold_updates)
+
+ # Add extracted updates to list
+ super().update_list(relevant_updates)
+
+ def _create_label_and_tooltip_text(self, update: ThresholdPositionUpdate) -> Tuple[str, str]:
+ # Computes the label representing one change in the list and the corresponding tooltip
+
+ moving_direction = update.new_position.name.lower()
+
+ label_text = f"{update.nugget_text} {'(' + str(update.count) + ')' if update.count > 1 else ''}"
+ distance_change_text = f"Old distance: {round(update.old_distance, 4)} -> New distance: {round(update.new_distance, 4)}\n" if update.old_distance \
+ else f"Initial distance: {round(update.new_distance, 4)}\n"
+
+ tooltip_text = (f"Due to your last feedback {update.nugget_text} moved {moving_direction} the threshold.\n"
+ f"{distance_change_text}" if not update.count > 1 else "" # If update covers multiple nuggets, don't show distance text as the tooltip refers to multiple nuggets in this case
+ f"This happened for {update.count - 1} similar nuggets as well.")
+
+ return label_text, tooltip_text
+
+ def _extract_relevant_updates(self, threshold_updates: List[ThresholdPositionUpdate]) -> List[ThresholdPositionUpdate]:
+ # Extracts the updates relevant to this list from a list containing all updates by filtering according to the
+ # value of `_addressed_change`.
+
+ return list(filter(lambda update: (update.old_position != update.new_position and
+ update.new_position == self._addressed_change),
+ threshold_updates))
+
+
+class ChangedThresholdPositionToAboveList(ChangedThresholdPositionList):
+ """
+ Realizes a concrete `ChangedThresholdPositionList` displaying threshold updates where the position changed from
+ below to above.
+ """
+
+ def __init__(self):
+ """
+ Initializes itself by determining tooltip, name and calling super constructor
+ """
+
+ tooltip_text = ("The distance associated with each nugget as well as the threshold is recomputed after every "
+ "feedback round.\n"
+ "Therefore the best guess of an document might not be below the threshold anymore. Such best "
+ "guesses are listed here.")
+ super(ChangedThresholdPositionToAboveList, self).__init__("Moved above threshold:",
+ tooltip_text,
+ ThresholdPosition.ABOVE)
+
+
+class ChangedThresholdPositionToBelowList(ChangedThresholdPositionList):
+ """
+ Realizes a concrete `ChangedThresholdPositionList` displaying threshold updates where the position changed from
+ above to below.
+ """
+
+ def __init__(self):
+ """
+ Initializes itself by determining tooltip, name and calling super constructor
+ """
+
+ tooltip_text = ("The distance associated with each nugget as well as the threshold is recomputed after every "
+ "feedback round.\n"
+ "Therefore the best guess of an document might not be above the threshold anymore. Such best "
+ "guesses are listed here.")
+ super(ChangedThresholdPositionToBelowList, self).__init__("Moved below threshold:",
+ tooltip_text,
+ ThresholdPosition.BELOW)
+
+
+class DataInsightsArea:
+ """
+ Abstract superclass responsible for the common logic required for both, the simple and the extended version of the
+ Data Insights area.
+ It only handles the 3D-Grid as it's the only component present in both Data Insight area types.
+
+ Methods
+ -------
+ enable_accessible_color_palette()
+ Enables the accessible color palette in the grid.
+ disable_accessible_color_palette()
+ Disables the accessible color palette in the grid.
+ """
+
+ def __init__(self):
+ """
+ Initializes the Data Insight section by initializing the 3D Grid and setting up the corresponding buttons
+ responsible for opening the grid.
+ """
+
+ # Init 3D-Grid
+ self.suggestion_visualizer = EmbeddingVisualizerWindow([
+ (AccessibleColor(visualizations.WHITE, visualizations.WHITE), 'Below threshold'),
+ (AccessibleColor(visualizations.RED, visualizations.ACC_RED), 'Above threshold'),
+ (AccessibleColor(visualizations.GREEN, visualizations.ACC_GREEN), 'Confirmed match')
+ ])
+
+ # Init and setup button responsible for opening the 3D Grid
+ self.suggestion_visualizer_button = QPushButton("Show Suggestions In 3D-Grid")
+ self.suggestion_visualizer_button.setContentsMargins(0, 0, 0, 0)
+ self.suggestion_visualizer_button.setFont(BUTTON_FONT)
+ self.suggestion_visualizer_button.setMaximumWidth(240)
+ self.suggestion_visualizer_button.clicked.connect(self._show_suggestion_visualizer)
+
+ def enable_accessible_color_palette(self):
+ """
+ Enables the accessible color palette in the grid.
+
+ For further details, check the related method in `EmbeddingVisualizer`.
+ """
+
+ self.suggestion_visualizer.enable_accessible_color_palette()
+
+ def disable_accessible_color_palette(self):
+ """
+ Disables the accessible color palette in the grid.
+
+ For further details, check the related method in `EmbeddingVisualizer`.
+ """
+
+ self.suggestion_visualizer.disable_accessible_color_palette()
+
+ @track_button_click("Show Suggestions In 3D-Grid")
+ def _show_suggestion_visualizer(self):
+ # Opens the 3D-Grid and tracks the click on the corresponding button
+
+ self.suggestion_visualizer.setVisible(True)
+
+
+class SimpleDataInsightsArea(QWidget, DataInsightsArea):
+ """
+ Class realizing the simple version of the Data Insights Area which only contains the 3D-Grid with the best guesses
+ of all best guesses.
+
+ It can be found in the document overview screen if only Level 1 visualization are enabled via the menu.
+
+ Inherits from `QWidget` and `DataInsightsArea`.
+ """
+
+ def __init__(self):
+
+ # Call super constructors
+ QWidget.__init__(self)
+ DataInsightsArea.__init__(self)
+
+ # Set up layout
+ self.layout = QHBoxLayout(self)
+ self.layout.setContentsMargins(0, 0, 0, 0)
+ self.layout.setSpacing(0)
+
+ # Add button to widget
+ self.layout.addWidget(self.suggestion_visualizer_button, 0, Qt.AlignmentFlag.AlignRight)
+
+ # Make itself invisible initially
+ self.setVisible(False)
+
+
+class ExtendedDataInsightsArea(QWidget, DataInsightsArea):
+ """
+ Class realizing the extended Data Insights section providing
+ information about the effects of the user's latest feedback as well as a 3D grid displaying all best guesses of all
+ documents.
+
+ It contains a label indicating the current threshold, lists providing nugget related changes due to the user's last
+ feedback and 3D-Grid displaying the embeddings of all best guesses of all documents.
+ The lists providing information about nugget related changes cover two lists displaying nugget whose position
+ relative to the threshold changed. One list for all "above -> below" changes and one list for all "below -> above"
+ changes. The lists are realized by utilizing instances of `ChangedThresholdPositionList`.
+ Furthermore, there's a list displaying all nuggets who newly became the best guess due to the user's latest
+ feedback.
+
+ It can be found in the document overview screen if Level 2 visualization are enabled via the menu.
+
+ Methods
+ -------
+ update_threshold_value_label(new_threshold_value, threshold_value_change)
+ Updates the label indicating the current threshold with the given, new value and adds a label indicating the
+ change of the threshold by considering the given value change.
+ update_threshold_position_lists(threshold_position_updates: List[ThresholdPositionUpdate])
+ Updates the lists displaying nuggets whose position relative to the threshold changed due to the user's latest
+ feedback by the given list of changes.
+ update_best_match_list(new_best_matches: List[BestMatchUpdate])
+ Updates the list displaying changed best guesses by the given list of changes.
+ hide()
+ Hides itself as well as the possibly opened 3D-Grid.
+ """
+
+ def __init__(self):
+ """
+ Initializes an instance of this class by calling the related super constructors and setting up the required UI
+ components.
+ Setting up the required UI components covers the title displayed above the area, the label indicating the
+ current threshold and the lists showing changes due to the user's latest feedback.
+ """
+
+ # Call super constructors
+ QWidget.__init__(self)
+ DataInsightsArea.__init__(self)
+
+ # Set up layout
+ self.layout = QVBoxLayout(self)
+ self.layout.setSpacing(0)
+ self.layout.setContentsMargins(0, 0, 0, 0)
+
+ # Set up title
+ self.title_label = QLabel("Data Insights")
+ self.title_label.setFont(SUBHEADER_FONT)
+ self.title_label.setContentsMargins(0, 5, 0, 5)
+ self.layout.addWidget(self.title_label)
+
+ # Set up label indicating the current threshold and a possible change of the threshold's value
+ self.threshold_label = QLabel()
+ self.threshold_label.setFont(LABEL_FONT)
+ self.threshold_label.setText("Current Threshold: ")
+ self.threshold_value_label = QLabel()
+ self.threshold_value_label.setFont(LABEL_FONT)
+ self.threshold_change_label = QLabel()
+ self.threshold_change_label.setFont(LABEL_FONT)
+ self.threshold_hbox = QHBoxLayout()
+ self.threshold_hbox.setContentsMargins(0, 0, 0, 0)
+ self.threshold_hbox.setSpacing(0)
+ self.threshold_hbox.addWidget(self.threshold_label)
+ self.threshold_hbox.addWidget(self.threshold_value_label)
+ self.threshold_hbox.addWidget(self.threshold_change_label)
+ self.threshold_hbox.addItem(QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum))
+ self.layout.addLayout(self.threshold_hbox)
+
+ # Set up list displaying nuggets whose position relative to the threshold changed from above to below
+ self.changes_list1_hbox = QHBoxLayout()
+ self.changes_list1_hbox.setContentsMargins(0, 0, 0, 0)
+ self.changes_list1_hbox.setSpacing(0)
+ self.threshold_position_changes_below_list = ChangedThresholdPositionToBelowList()
+ self.changes_list1_hbox.addWidget(self.threshold_position_changes_below_list)
+ self.changes_list1_hbox.addItem(QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum))
+
+ # Set up list displaying nuggets whose position relative to the threshold changed from below to above
+ self.changes_list2_hbox = QHBoxLayout()
+ self.changes_list2_hbox.setContentsMargins(0, 0, 0, 0)
+ self.changes_list2_hbox.setSpacing(0)
+ self.threshold_position_changes_above_list = ChangedThresholdPositionToAboveList()
+ self.changes_list2_hbox.addWidget(self.threshold_position_changes_above_list)
+ self.changes_list2_hbox.addItem(QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum))
+
+ # Set up list displaying changed best guesses
+ self.changes_list3_hbox = QHBoxLayout()
+ self.changes_list3_hbox.setContentsMargins(0, 0, 0, 0)
+ self.changes_list3_hbox.setSpacing(0)
+ self.changes_best_matches_list = ChangedBestMatchDocumentsList()
+ self.changes_list3_hbox.addWidget(self.changes_best_matches_list)
+ self.changes_list3_hbox.addItem(QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum))
+ self.changes_list3_hbox.addWidget(self.suggestion_visualizer_button)
+
+ # Add lists to layout
+ self.layout.addLayout(self.changes_list1_hbox)
+ self.layout.addLayout(self.changes_list2_hbox)
+ self.layout.addLayout(self.changes_list3_hbox)
+
+ def _show_suggestion_visualizer(self):
+ self.suggestion_visualizer.setVisible(True)
+
+ def enable_accessible_color_palette(self):
+ self.accessible_color_palette = True
+ self._enable_accessible_color_palette()
+
+ def disable_accessible_color_palette(self):
+ self.accessible_color_palette = False
+ self._disable_accessible_color_palette()
+
+
+ def update_threshold_value_label(self, new_threshold_value, threshold_value_change):
+ """
+ Updates the label indicating the current threshold with the given, new value and adds a label indicating the
+ change of the threshold by considering the given value change.
+
+ The text of the label indicates the current threshold is set to the given new value.
+ If the given value change is non-zero, a label indicating this value change is added next to the label
+ displaying the actual threshold value.
+
+ Parameters
+ ----------
+ new_threshold_value: float
+ New threshold value used to update the label indicating the current threshold value.
+ threshold_value_change: float
+ Value that indicates how much the threshold has changed compared to the previous one. If non-zero, a label
+ containing this change is added.
+ """
+
+ # Add label indicating the value change if necessary
+ if round(threshold_value_change, 4) != 0:
+ self.threshold_value_label.setStyleSheet("color: orange;")
+ change_text = f'(+{round(threshold_value_change, 4)})' if threshold_value_change > 0 else f'{round(threshold_value_change, 4)})'
+ self.threshold_change_label.setText(change_text)
+ else:
+ self.threshold_value_label.setStyleSheet("")
+ self.threshold_change_label.setText("")
+
+ # Update the label displaying the current threshold
+ self.threshold_value_label.setText(f"{round(new_threshold_value, 4)} ")
+ self.threshold_label.setVisible(True)
+
+ def update_threshold_position_lists(self, threshold_position_updates: List[ThresholdPositionUpdate]):
+ """
+ Updates the lists displaying nuggets whose position relative to the threshold changed due to the user's latest
+ feedback by the given list of changes.
+
+ Each list will extract the relevant changes out of the given list and update itself according to extracted
+ changes.
+
+ Realized by calling `update_list(updates: List[ThresholdPositionUpdate])` method of
+ `ChangedThresholdPositionList` for both instances of the lists displaying the threshold position updates.
+ Further details can be found in the documentation of this method in the `ChangedThresholdPositionList` class.
+
+ Parameters
+ ----------
+ threshold_position_updates: List[ThresholdPositionUpdate]
+ List containing all nuggets whose position relative to the threshold changed due to the user's latest
+ feedback.
+ The list contains both types of changes 'above -> below' and 'below -> above'.
+ """
+
+ self.threshold_position_changes_below_list.update_list(threshold_position_updates)
+ self.threshold_position_changes_above_list.update_list(threshold_position_updates)
+
+ def update_best_match_list(self, new_best_matches: List[BestMatchUpdate]):
+ """
+ Updates the list displaying changed best guesses by the given list of changes.
+
+ Realized by calling `update_list(updates: List[BestMatchUpdate])` method of `ChangedBestMatchList` for the
+ instance representing the list.
+ Further details can be found in the documentation of this method in the `ChangedBestMatchList` class.
+
+ Parameters
+ ----------
+ new_best_matches: List[BestMatchUpdate]
+ List containing changed best guesses by the given list of changes. The `ChangedBestMatchList` instance will
+ update itself by this list.
+ """
+
+ self.changes_best_matches_list.update_list(new_best_matches)
+
+ def hide(self):
+ """
+ Hides itself as well as the possibly opened 3D-Grid.
+ """
+
+ super().hide()
+ self.suggestion_visualizer.hide()
diff --git a/wannadb_ui/interactive_matching.py b/wannadb_ui/interactive_matching.py
index 6f8c2f24..df19641e 100644
--- a/wannadb_ui/interactive_matching.py
+++ b/wannadb_ui/interactive_matching.py
@@ -1,13 +1,21 @@
import logging
+from typing import List
+
+import numpy as np
from PyQt6 import QtGui
-from PyQt6.QtCore import Qt
+from PyQt6.QtCore import Qt, QEvent
from PyQt6.QtGui import QIcon, QTextCursor
-from PyQt6.QtWidgets import QHBoxLayout, QLabel, QPushButton, QTextEdit, QVBoxLayout, QWidget
+from PyQt6.QtWidgets import QHBoxLayout, QLabel, QPushButton, QTextEdit, QVBoxLayout, QWidget, QGridLayout, QSizePolicy
from wannadb.data.signals import CachedContextSentenceSignal, CachedDistanceSignal
+from wannadb.change_captor import NewlyAddedNuggetContext
from wannadb_ui.common import BUTTON_FONT, CODE_FONT, CODE_FONT_BOLD, LABEL_FONT, MainWindowContent, \
- CustomScrollableList, CustomScrollableListItem, WHITE, LIGHT_YELLOW, YELLOW
+ CustomScrollableListItem, WHITE, LIGHT_YELLOW, YELLOW, \
+ VisualizationProvidingItem, AvailableVisualizationsLevel, VisualizationProvidingCustomScrollableList
+from wannadb_ui.data_insights import SimpleDataInsightsArea, ExtendedDataInsightsArea
+from wannadb_ui.visualizations import EmbeddingVisualizerWidget, BarChartVisualizerWidget
+from wannadb_ui.study import Tracker, track_button_click
logger = logging.getLogger(__name__)
@@ -26,8 +34,10 @@ def __init__(self, main_window):
self.stop_button.setMaximumWidth(240)
self.controls_widget_layout.addWidget(self.stop_button)
- self.nugget_list_widget = NuggetListWidget(self)
- self.document_widget = DocumentWidget(self)
+ self.nugget_list_widget = NuggetListWidget(self, main_window)
+ self.document_widget = DocumentWidget(self, main_window)
+
+ main_window.attach_visualization_level_observer(self.nugget_list_widget)
self.show_nugget_list_widget()
@@ -42,13 +52,15 @@ def disable_input(self):
self.document_widget.disable_input()
def handle_feedback_request(self, feedback_request):
- self.header.setText(f"Attribute: {feedback_request['attribute'].name}")
+ attribute = feedback_request['attribute']
+ self.header.setText(f"Attribute: {attribute.name}")
self.nugget_list_widget.update_nuggets(feedback_request)
+ self.document_widget.update_attribute(attribute)
self.enable_input()
self.show_nugget_list_widget()
- def get_document_feedback(self, nugget):
- self.document_widget.update_document(nugget)
+ def get_document_feedback(self, nugget, other_best_guesses):
+ self.document_widget.update_document(nugget, other_best_guesses)
self.show_document_widget()
def show_nugget_list_widget(self):
@@ -65,16 +77,34 @@ def show_document_widget(self):
self.layout.addWidget(self.document_widget)
self.stop_button.hide()
+ def enable_accessible_color_palette(self):
+ self.document_widget.enable_accessible_color_palette()
+ self.nugget_list_widget.enable_accessible_color_palette()
+
+ def disable_accessible_color_palette(self):
+ self.document_widget.disable_accessible_color_palette()
+ self.nugget_list_widget.disable_accessible_color_palette()
+
+ def enable_visualizations(self):
+ self.document_widget.show_visualizations()
+ self.nugget_list_widget.show_visualizations()
+
+ def disable_visualizations(self):
+ self.document_widget.hide_visualizations()
+ self.nugget_list_widget.hide_visualizations()
+
def _stop_button_clicked(self):
self.show_nugget_list_widget()
self.main_window.give_feedback_task({"message": "stop-interactive-matching"})
-class NuggetListWidget(QWidget):
- def __init__(self, interactive_matching_widget):
+class NuggetListWidget(QWidget, VisualizationProvidingItem):
+ def __init__(self, interactive_matching_widget, main_window):
super(NuggetListWidget, self).__init__(interactive_matching_widget)
self.interactive_matching_widget = interactive_matching_widget
+ self.visualization_level = main_window.visualizations_level
+
self.layout = QVBoxLayout(self)
self.layout.setContentsMargins(0, 0, 0, 0)
self.layout.setSpacing(10)
@@ -83,6 +113,14 @@ def __init__(self, interactive_matching_widget):
self.description.setFont(LABEL_FONT)
self.layout.addWidget(self.description)
+ # suggestion visualizer
+ self.simple_visualize_area = SimpleDataInsightsArea()
+ self.extended_visualize_area = ExtendedDataInsightsArea()
+ self.layout.addWidget(self.simple_visualize_area)
+ self.layout.addWidget(self.extended_visualize_area)
+ self.visualizations = True
+ self.accessible_color_palette = False
+
# nugget list
self.num_nuggets_above_label = QLabel("")
self.num_nuggets_above_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
@@ -94,66 +132,122 @@ def __init__(self, interactive_matching_widget):
self.num_nuggets_below_label.setFont(CODE_FONT_BOLD)
# self.num_nuggets_below_label.setStyleSheet(f"color: {YELLOW}")
- self.nugget_list = CustomScrollableList(self, NuggetListItemWidget,
- floating_widget=self.num_nuggets_below_label,
- above_widget=self.num_nuggets_above_label)
+ self.nugget_list = VisualizationProvidingCustomScrollableList(self, NuggetListItemWidget,
+ visualizations_level=main_window.visualizations_level,
+ attach_visualization_level_observer=main_window.attach_visualization_level_observer,
+ floating_widget=self.num_nuggets_below_label,
+ above_widget=self.num_nuggets_above_label)
self.layout.addWidget(self.nugget_list)
def update_nuggets(self, feedback_request):
- self.description.setText("Please confirm or edit the cell value guesses displayed below until you are satisfied with the guessed values, at which point you may continue with the next attribute."
- "\nWannaDB will use your feedback to continuously update its guesses. Note that the cells with low confidence (low confidence bar, light yellow highlights) will be left empty.")
- nuggets = feedback_request["nuggets"]
+ feedback_nuggets = feedback_request["nuggets"]
+ all_guessed_nugget_matches = feedback_request["all-guessed-nugget-matches"]
+ attribute = feedback_request["attribute"]
+ current_threshold = feedback_request["max-distance"]
+ threshold_change = feedback_request["max-distance-change"]
+ nugget_updates_context = feedback_request["nugget-updates-context"]
+
+ self.description.setText(
+ "Please confirm or edit the cell value guesses displayed below until you are satisfied with the guessed values, at which point you may continue with the next attribute."
+ "\nWannaDB will use your feedback to continuously update its guesses. Note that the cells with low confidence (low confidence bar, light yellow highlights) will be left empty.")
+ self.extended_visualize_area.update_threshold_value_label(current_threshold, threshold_change)
+ self.extended_visualize_area.update_best_match_list(nugget_updates_context.best_match_updates)
+ self.extended_visualize_area.update_threshold_position_lists(nugget_updates_context.threshold_position_updates)
+
params = {
- "max_start_chars": max([nugget[CachedContextSentenceSignal]["start_char"] for nugget in nuggets]),
- "max_distance": feedback_request["max-distance"]
+ "max_start_chars": max([nugget[CachedContextSentenceSignal]["start_char"] for nugget in feedback_nuggets]),
+ "max_distance": current_threshold,
+ "other_best_guesses": feedback_nuggets,
+ "new-nuggets": nugget_updates_context.newly_added_nugget_contexts,
+ "num-feedback": feedback_request["num-feedback"]
}
- self.nugget_list.update_item_list(nuggets, params)
+
+ if self.visualization_level == AvailableVisualizationsLevel.LEVEL_1:
+ self.simple_visualize_area.setVisible(True)
+ self.extended_visualize_area.setVisible(False)
+ elif self.visualization_level == AvailableVisualizationsLevel.LEVEL_2:
+ self.extended_visualize_area.setVisible(True)
+ self.simple_visualize_area.setVisible(False)
+
+ self.nugget_list.update_item_list(feedback_nuggets, params)
+ if len(feedback_nuggets) > 0:
+ self.extended_visualize_area.suggestion_visualizer.update_and_display_params(attribute=attribute,
+ nuggets=all_guessed_nugget_matches,
+ currently_highlighted_nugget=None,
+ best_guess=None,
+ other_best_guesses=[])
+ self.simple_visualize_area.suggestion_visualizer.update_and_display_params(attribute=attribute,
+ nuggets=all_guessed_nugget_matches,
+ currently_highlighted_nugget=None,
+ best_guess=None,
+ other_best_guesses=[])
+
if feedback_request["num-nuggets-above"] > 0:
- self.num_nuggets_above_label.setText(f"... and {feedback_request['num-nuggets-above']} more cells that will be left empty ...")
+ self.num_nuggets_above_label.setText(
+ f"... and {feedback_request['num-nuggets-above']} more cells that will be left empty ...")
else:
self.num_nuggets_above_label.setText("")
if feedback_request["num-nuggets-below"] > 0:
- self.num_nuggets_below_label.setText(f"... and {feedback_request['num-nuggets-below']} more cells that will be populated ...")
+ self.num_nuggets_below_label.setText(
+ f"... and {feedback_request['num-nuggets-below']} more cells that will be populated ...")
else:
self.num_nuggets_below_label.setText("")
def enable_input(self):
self.nugget_list.enable_input()
+
+ def enable_accessible_color_palette(self):
+ self.accessible_color_palette = True
+ self.simple_visualize_area.enable_accessible_color_palette()
+ self.extended_visualize_area.enable_accessible_color_palette()
+
+ def disable_accessible_color_palette(self):
+ self.accessible_color_palette = False
+ self.simple_visualize_area.disable_accessible_color_palette()
+ self.extended_visualize_area.disable_accessible_color_palette()
def disable_input(self):
self.nugget_list.disable_input()
+ def _adapt_to_visualizations_level(self, visualizations_level):
+ self.visualization_level = visualizations_level
+
+ if visualizations_level == AvailableVisualizationsLevel.LEVEL_1:
+ self.simple_visualize_area.setVisible(True)
+ self.extended_visualize_area.setVisible(False)
+ elif visualizations_level == AvailableVisualizationsLevel.LEVEL_2:
+ self.extended_visualize_area.setVisible(True)
+ self.simple_visualize_area.setVisible(False)
+ elif visualizations_level == AvailableVisualizationsLevel.DISABLED:
+ self.extended_visualize_area.setVisible(False)
+ self.simple_visualize_area.setVisible(False)
+
-class NuggetListItemWidget(CustomScrollableListItem):
- def __init__(self, nugget_list_widget):
+class NuggetListItemWidget(CustomScrollableListItem, VisualizationProvidingItem):
+ def __init__(self, nugget_list_widget, visualizations_level):
super(NuggetListItemWidget, self).__init__(nugget_list_widget)
self.nugget_list_widget = nugget_list_widget
self.nugget = None
+ self.other_best_guesses = None
+ self._default_stylesheet = "QWidget#nuggetListItemWidget { background-color: white}"
+ self._tooltip_text = ""
+ self._visualizations = visualizations_level == AvailableVisualizationsLevel.LEVEL_2
self.setFixedHeight(45)
self.setObjectName("nuggetListItemWidget")
- self.setStyleSheet("QWidget#nuggetListItemWidget { background-color: white}")
+ self.setStyleSheet(self._default_stylesheet)
self.layout = QHBoxLayout(self)
self.layout.setContentsMargins(20, 0, 20, 0)
self.layout.setSpacing(10)
self.confidence_button = QPushButton()
+ self.confidence_button.event = lambda e: self.handle_tooltip_event(self.confidence_button, e, "NuggetListItemWidget.confidence_button")
self.confidence_button.setFlat(True)
self.confidence_button.setIcon(ICON_LOW_CONFIDENCE)
self.confidence_button.setToolTip("Confidence in this match.")
self.layout.addWidget(self.confidence_button)
- # self.info_button = QPushButton()
- # self.info_button.setFlat(True)
- # self.info_button.setFont(CODE_FONT_BOLD)
- # self.info_button.clicked.connect(self._info_button_clicked)
- # self.layout.addWidget(self.info_button)
-
- # self.left_split_label = QLabel("|")
- # self.left_split_label.setFont(CODE_FONT_BOLD)
- # self.layout.addWidget(self.left_split_label)
-
self.text_edit = QTextEdit()
self.text_edit.setReadOnly(True)
self.text_edit.setFrameStyle(0)
@@ -166,27 +260,29 @@ def __init__(self, nugget_list_widget):
self.text_edit.setText("")
self.layout.addWidget(self.text_edit)
- # self.right_split_label = QLabel("|")
- # self.right_split_label.setFont(CODE_FONT_BOLD)
- # self.layout.addWidget(self.right_split_label)
-
self.match_button = QPushButton()
+ self.match_button.event = lambda e: self.handle_tooltip_event(self.match_button, e, "NuggetListItemWidget.match_button")
self.match_button.setIcon(QIcon("wannadb_ui/resources/correct.svg"))
self.match_button.setToolTip("Confirm this value.")
self.match_button.clicked.connect(self._match_button_clicked)
self.layout.addWidget(self.match_button)
self.fix_button = QPushButton()
+ self.fix_button.event = lambda e: self.handle_tooltip_event(self.fix_button, e, "NuggetListItemWidget.fix_button")
self.fix_button.setIcon(QIcon("wannadb_ui/resources/pencil.svg"))
self.fix_button.setToolTip("Edit this value.")
self.fix_button.clicked.connect(self._fix_button_clicked)
self.layout.addWidget(self.fix_button)
+ self.last_tooltip_text_passed = None
+
def update_item(self, item, params=None):
self.nugget = item
max_start_chars = params["max_start_chars"]
max_distance = params["max_distance"]
+ self.other_best_guesses = [other_best_guess for other_best_guess in params["other_best_guesses"]
+ if other_best_guess != self.nugget]
sentence = self.nugget[CachedContextSentenceSignal]["text"]
start_char = self.nugget[CachedContextSentenceSignal]["start_char"]
@@ -195,12 +291,25 @@ def update_item(self, item, params=None):
if max_distance < self.nugget[CachedDistanceSignal]:
color = LIGHT_YELLOW
self.confidence_button.setIcon(ICON_LOW_CONFIDENCE)
- self.confidence_button.setToolTip("Low confidence in this match, will not be included in result.")
+
+ self.confidence_button.setToolTip(
+ f"Low confidence in this match {self._build_distance_text() if self._visualizations else ''}, "
+ f"will not be included in result.")
else:
color = YELLOW
self.confidence_button.setIcon(ICON_HIGH_CONFIDENCE)
- self.confidence_button.setToolTip("High confidence in this match, will be included in result.")
- self.text_edit.setStyleSheet(f"color: black; background-color: {WHITE}")
+ self.confidence_button.setToolTip(
+ f"High confidence in this match {self._build_distance_text() if self._visualizations else ''}, "
+ f"will be included in result.")
+
+ new_nugget_contexts: List = params["new-nuggets"]
+
+ if self.nugget in map(lambda context: context.nugget, new_nugget_contexts):
+ self._handle_item_is_new(self._extract_matching_context(new_nugget_contexts))
+ else:
+ self._update_stylesheets(False)
+ self._tooltip_text = ""
+ self.setToolTip(self._tooltip_text)
self.text_edit.setText("")
formatted_text = (
@@ -218,6 +327,7 @@ def update_item(self, item, params=None):
# self.info_button.setText(f"{str(round(self.nugget[CachedDistanceSignal], 2)).ljust(4)}")
+ @track_button_click(button_name="nugget_list_match_button")
def _match_button_clicked(self):
self.nugget_list_widget.interactive_matching_widget.main_window.give_feedback_task({
"message": "is-match",
@@ -226,7 +336,7 @@ def _match_button_clicked(self):
})
def _fix_button_clicked(self):
- self.nugget_list_widget.interactive_matching_widget.get_document_feedback(self.nugget)
+ self.nugget_list_widget.interactive_matching_widget.get_document_feedback(self.nugget, self.other_best_guesses)
# def _info_button_clicked(self):
# lines = []
@@ -253,10 +363,78 @@ def disable_input(self):
self.match_button.setDisabled(True)
self.fix_button.setDisabled(True)
+ def _adapt_to_visualizations_level(self, visualizations_level):
+ if visualizations_level == AvailableVisualizationsLevel.LEVEL_2:
+ self._show_visualizations()
+ else:
+ self._hide_visualizations()
+
+ def _show_visualizations(self):
+ self._visualizations = True
+
+ if self._tooltip_text != "":
+ self.setToolTip(self._tooltip_text)
+ self.confidence_button.setToolTip(
+ f"Low confidence in this match {self._build_distance_text()}, will not be included in result.")
+ self._update_stylesheets(item_is_new=self._tooltip_text != "")
+
+ def _hide_visualizations(self):
+ self._visualizations = False
+
+ self.setToolTip("")
+ self.confidence_button.setToolTip(
+ f"Low confidence in this match, will not be included in result.")
+ self._update_stylesheets(False)
+
+ def _handle_item_is_new(self, newly_added_nugget_context):
+ distance_change_text = (f"Old distance: {round(newly_added_nugget_context.old_distance, 4)} -> "
+ f"New distance: {round(newly_added_nugget_context.new_distance, 4)}") \
+ if newly_added_nugget_context.old_distance is not None \
+ else f"Initial distance: {round(newly_added_nugget_context.new_distance, 4)}"
+ self._tooltip_text = (
+ f'Reason for the item to be newly added:\n'
+ f'{newly_added_nugget_context.added_reason.corresponding_tooltip_text}\n\n'
+ f'{distance_change_text}')
+
+ if not self._visualizations:
+ return
+
+ self.setToolTip(self._tooltip_text)
+ self._update_stylesheets(True)
+
+ def _extract_matching_context(self, contexts: List[NewlyAddedNuggetContext]):
+ for context in contexts:
+ if context.nugget == self.nugget:
+ return context
-class DocumentWidget(QWidget):
- def __init__(self, interactive_matching_widget):
- super(DocumentWidget, self).__init__(interactive_matching_widget)
+ raise ValueError(f"Own nugget ({self.nugget}) not in given list: {contexts}")
+
+ def _build_distance_text(self):
+ return f"(Distance: {round(self.nugget[CachedDistanceSignal], 4)})" if self._visualizations else ""
+
+ def _update_stylesheets(self, item_is_new):
+ if item_is_new:
+ self.setStyleSheet((f"QFrame {{ background-color: {'#e7ffe6'}; }}\n"
+ f"QToolTip {{ background-color: {WHITE}; }}"))
+ self.text_edit.setStyleSheet(f"color: black; background-color: {'#e7ffe6'}")
+ else:
+ self.setStyleSheet(self._default_stylesheet)
+ self.text_edit.setStyleSheet(f"color: black; background-color: {WHITE}")
+
+ def event(self, e):
+ return self.handle_tooltip_event(self, e, "NuggetListItemWidget")
+
+ def handle_tooltip_event(self, widget, e, identifier_name):
+ if e.type() == QEvent.Type.ToolTip:
+ tooltip_text = widget.toolTip() # Extract the tooltip text
+ if self.last_tooltip_text_passed != tooltip_text:
+ Tracker().track_tooltip_activation(identifier_name)
+ self.last_tooltip_text_passed = tooltip_text
+ return super(widget.__class__, widget).event(e)
+
+class DocumentWidget(QWidget, VisualizationProvidingItem):
+ def __init__(self, interactive_matching_widget, main_window):
+ super(DocumentWidget, self).__init__(parent=interactive_matching_widget)
self.interactive_matching_widget = interactive_matching_widget
self.layout = QVBoxLayout()
@@ -267,13 +445,18 @@ def __init__(self, interactive_matching_widget):
self.document = None
self.original_nugget = None
self.current_nugget = None
+ self.current_attribute = None
+ self.current_other_best_guesses = None
self.base_formatted_text = ""
self.idx_mapper = {}
self.nuggets_in_order = []
self.nuggets_sorted_by_distance = []
- self.description = QLabel("Please select the correct value by clicking on one of the highlighted snippets. You may also "
- "highlight a different span of text in case the required value is not highlighted already.")
+ main_window.attach_visualization_level_observer(self)
+
+ self.description = QLabel(
+ "Please select the correct value by clicking on one of the highlighted snippets. You may also "
+ "highlight a different span of text in case the required value is not highlighted already.")
self.description.setFont(LABEL_FONT)
self.layout.addWidget(self.description)
@@ -296,15 +479,30 @@ def __init__(self, interactive_matching_widget):
self.custom_selection_item_widget = CustomSelectionItemWidget(self)
self.custom_selection_item_widget.hide()
- self.suggestion_list = CustomScrollableList(self, SuggestionListItemWidget, orientation="horizontal",
- above_widget=self.custom_selection_item_widget)
+ self.suggestion_list = VisualizationProvidingCustomScrollableList(self, SuggestionListItemWidget,
+ main_window.visualizations_level,
+ main_window.attach_visualization_level_observer,
+ orientation="horizontal",
+ above_widget=self.custom_selection_item_widget)
+
self.suggestion_list.setFixedHeight(60)
self.layout.addWidget(self.suggestion_list)
+ self.upper_buttons_widget = QWidget()
+ self.upper_buttons_widget_layout = QHBoxLayout(self.upper_buttons_widget)
+ self.upper_buttons_widget_layout.setContentsMargins(0, 0, 0, 0)
+ self.layout.addWidget(self.upper_buttons_widget)
+ self.cosine_barchart = BarChartVisualizerWidget()
+ self.cosine_barchart.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred)
+ self.upper_buttons_widget_layout.addWidget(self.cosine_barchart)
+
+ self.visualizer = EmbeddingVisualizerWidget()
+ self.visualizer.setFixedHeight(355)
+ self.layout.addWidget(self.visualizer)
+
self.buttons_widget = QWidget()
self.buttons_widget_layout = QHBoxLayout(self.buttons_widget)
self.buttons_widget_layout.setContentsMargins(0, 0, 0, 0)
- # self.buttons_widget_layout.setAlignment(Qt.AlignmentFlag.AlignRight)
self.layout.addWidget(self.buttons_widget)
self.no_match_button = QPushButton("Value Not In Document")
@@ -317,6 +515,7 @@ def __init__(self, interactive_matching_widget):
self.match_button.clicked.connect(self._match_button_clicked)
self.buttons_widget_layout.addWidget(self.match_button)
+ @track_button_click(button_name="document_match_button")
def _match_button_clicked(self):
if self.current_nugget is None:
logger.info("Confirm custom nugget!")
@@ -334,6 +533,7 @@ def _match_button_clicked(self):
"not-a-match": None if self.current_nugget is self.original_nugget else self.original_nugget
})
+ @track_button_click(button_name="document_no_match_button")
def _no_match_button_clicked(self):
self.interactive_matching_widget.main_window.give_feedback_task({
"message": "no-match-in-document",
@@ -390,7 +590,8 @@ def _handle_selection_changed(self):
def _highlight_current_nugget(self):
if self.current_nugget:
mapped_start_char = self.idx_mapper[self.current_nugget.start_char]
- mapped_end_char = self.idx_mapper[self.current_nugget.end_char] if self.current_nugget.end_char < len(self.document.text) else len(self.base_formatted_text)
+ mapped_end_char = self.idx_mapper[self.current_nugget.end_char] if self.current_nugget.end_char < len(
+ self.document.text) else len(self.base_formatted_text)
formatted_text = (
f"{self.base_formatted_text[:mapped_start_char]}"
@@ -400,16 +601,19 @@ def _highlight_current_nugget(self):
)
self.text_edit.setText("")
self.text_edit.textCursor().insertHtml(formatted_text)
+
+ self.visualizer.highlight_selected_nugget(self.current_nugget)
else:
self.text_edit.setText("")
self.text_edit.textCursor().insertHtml(self.base_formatted_text)
self.suggestion_list.update_item_list(self.nuggets_sorted_by_distance, self.current_nugget)
- def update_document(self, nugget):
+ def update_document(self, nugget, other_best_guesses):
self.document = nugget.document
self.original_nugget = nugget
self.current_nugget = nugget
+ self.current_other_best_guesses = other_best_guesses
self.nuggets_sorted_by_distance = list(sorted(self.document.nuggets, key=lambda x: x[CachedDistanceSignal]))
self.nuggets_in_order = list(sorted(self.document.nuggets, key=lambda x: x.start_char))
self.custom_selection_item_widget.hide()
@@ -466,6 +670,14 @@ def update_document(self, nugget):
inside = False
self.base_formatted_text += char
self.idx_mapper[idx] = len(self.base_formatted_text) - 1
+
+ self.visualizer.update_and_display_params(attribute=self.current_attribute,
+ nuggets=self.document.nuggets,
+ currently_highlighted_nugget=nugget,
+ best_guess=self.nuggets_sorted_by_distance[0],
+ other_best_guesses=other_best_guesses)
+ self.cosine_barchart.update_data(self.nuggets_sorted_by_distance)
+
else:
self.idx_mapper = {}
for idx in range(len(self.document.text)):
@@ -473,6 +685,8 @@ def update_document(self, nugget):
self.base_formatted_text = ""
self._highlight_current_nugget()
+ self._highlight_best_guess(
+ self.nuggets_sorted_by_distance[0] if len(self.nuggets_sorted_by_distance) > 0 else None)
scroll_cursor = QTextCursor(self.text_edit.document())
scroll_cursor.setPosition(nugget.start_char)
@@ -489,23 +703,60 @@ def disable_input(self):
self.no_match_button.setDisabled(True)
self.suggestion_list.disable_input()
+ def update_attribute(self, attribute):
+ self.current_attribute = attribute
+
+ def enable_accessible_color_palette(self):
+ self.visualizer.enable_accessible_color_palette()
+
+ def disable_accessible_color_palette(self):
+ self.visualizer.disable_accessible_color_palette()
+
+ def show_visualizations(self):
+ self.upper_buttons_widget.show()
+ self.visualizer.show()
+
+ def _hide_visualizations(self):
+ self.upper_buttons_widget.hide()
+ self.visualizer.hide()
+
+ def _highlight_best_guess(self, best_guess):
+ if best_guess is None:
+ return
-class SuggestionListItemWidget(CustomScrollableListItem):
+ self.visualizer.highlight_best_guess(best_guess)
- def __init__(self, suggestion_list_widget):
+ def _adapt_to_visualizations_level(self, visualizations_level):
+ if (visualizations_level == AvailableVisualizationsLevel.LEVEL_2 or
+ visualizations_level == AvailableVisualizationsLevel.LEVEL_1):
+ self._show_visualizations()
+ elif visualizations_level == AvailableVisualizationsLevel.DISABLED:
+ self._hide_visualizations()
+
+
+class SuggestionListItemWidget(CustomScrollableListItem, VisualizationProvidingItem):
+
+ def __init__(self, suggestion_list_widget, visualizations_level):
super(SuggestionListItemWidget, self).__init__(suggestion_list_widget)
self.suggestion_list_widget = suggestion_list_widget
self.nugget = None
+ self.visualizations = visualizations_level == AvailableVisualizationsLevel.LEVEL_2
- self.setFixedHeight(30)
+ self.setFixedHeight(45)
self.setStyleSheet(f"background-color: {WHITE}")
- self.layout = QHBoxLayout(self)
+ self.layout = QGridLayout(self)
self.layout.setContentsMargins(10, 0, 10, 0)
self.text_label = QLabel()
self.text_label.setFont(CODE_FONT_BOLD)
- self.layout.addWidget(self.text_label)
+ self.layout.addWidget(self.text_label, 0, 0)
+
+ self.certainty_label = QLabel()
+ self.certainty_label.setFont(CODE_FONT)
+ self.layout.addWidget(self.certainty_label, 0, 1)
+ if not self.visualizations:
+ self.certainty_label.hide()
def mousePressEvent(self, a0: QtGui.QMouseEvent) -> None:
self.suggestion_list_widget.interactive_matching_widget.document_widget.current_nugget = self.nugget
@@ -514,9 +765,11 @@ def mousePressEvent(self, a0: QtGui.QMouseEvent) -> None:
def update_item(self, item, params=None):
self.nugget = item
- sanitized_text = self.nugget.text
- sanitized_text = sanitized_text.replace("\n", " ")
+ sanitized_text = self.nugget.text.replace("\n", " ")
+ certainty_value = np.round(1 - self.nugget[CachedDistanceSignal], 3)
self.text_label.setText(sanitized_text)
+ self.certainty_label.setText(str(certainty_value))
+
if self.nugget == params:
self.setStyleSheet(f"background-color: {YELLOW}")
self.suggestion_list_widget.interactive_matching_widget.document_widget.suggestion_list.scroll_area.horizontalScrollBar().setValue(
@@ -531,6 +784,14 @@ def enable_input(self):
def disable_input(self):
pass
+ def _adapt_to_visualizations_level(self, visualizations_level):
+ # Adapt UI element to enabled visualizations (show or hide certainty label)
+
+ if visualizations_level != AvailableVisualizationsLevel.LEVEL_2:
+ self.certainty_label.hide()
+ else:
+ self.certainty_label.show()
+
class CustomSelectionItemWidget(QWidget):
diff --git a/wannadb_ui/main_window.py b/wannadb_ui/main_window.py
index eb8150e7..cf9f909d 100644
--- a/wannadb_ui/main_window.py
+++ b/wannadb_ui/main_window.py
@@ -1,6 +1,7 @@
import enum
import logging
import re
+import wannadb_ui.visualizations as visualizations
from PyQt6.QtCore import QMutex, Qt, QThread, QWaitCondition, pyqtSignal, pyqtSlot
from PyQt6.QtGui import QAction, QIcon
@@ -9,10 +10,12 @@
from wannadb.data.data import DocumentBase
from wannadb.statistics import Statistics
from wannadb_parsql.cache_db import SQLiteCacheDB
-from wannadb_ui.common import MENU_FONT, STATUS_BAR_FONT, STATUS_BAR_FONT_BOLD, RED, BLACK, show_confirmation_dialog
+from wannadb_ui.common import MENU_FONT, STATUS_BAR_FONT, STATUS_BAR_FONT_BOLD, RED, BLACK, show_confirmation_dialog, \
+ AvailableVisualizationsLevel
from wannadb_ui.document_base import DocumentBaseCreatorWidget, DocumentBaseViewerWidget, DocumentBaseCreatingWidget
from wannadb_ui.interactive_matching import InteractiveMatchingWidget
from wannadb_ui.start_menu import StartMenuWidget
+from wannadb_ui.common import InformationPopup
from wannadb_ui.wannadb_api import WannaDBAPI
logger = logging.getLogger(__name__)
@@ -237,6 +240,30 @@ def save_statistics_to_json_task(self):
# noinspection PyUnresolvedReferences
self.save_statistics_to_json.emit(path, self.statistics)
+ def update_visualizations_level(self, visualizations_level):
+ logger.info("Execute task 'enable_visualizations_task'.")
+
+ self.visualizations_level = visualizations_level
+
+ self._set_available_visualization_actions()
+
+ for observer in self.visualizations_level_observers:
+ observer.update_shown_visualizations(visualizations_level)
+
+ def enable_accessible_color_palette_task(self):
+ logger.info("Execute task 'enable_accessible_color_palette_task'.")
+ self.accessible_color_palette = True
+ self.interactive_matching_widget.enable_accessible_color_palette()
+ self.enable_accessible_color_palette_action.setEnabled(False)
+ self.disable_accessible_color_palette_action.setEnabled(True)
+
+ def disable_accessible_color_palette_task(self):
+ logger.info("Execute task 'disable_accessible_color_palette_task'.")
+ self.accessible_color_palette = False
+ self.interactive_matching_widget.disable_accessible_color_palette()
+ self.enable_accessible_color_palette_action.setEnabled(True)
+ self.disable_accessible_color_palette_action.setEnabled(False)
+
def show_document_base_creator_widget_task(self):
logger.info("Execute task 'show_document_base_creator_widget_task'.")
@@ -281,6 +308,9 @@ def give_feedback_task(self, feedback):
self.api.feedback = feedback
self.feedback_cond.wakeAll()
+ self._set_available_visualization_actions()
+ self._enable_color_palette_settings()
+
def interactive_table_population_task(self):
logger.info("Execute task 'interactive_table_population_task'.")
@@ -305,6 +335,25 @@ def match_attribute_task(self, attribute_name):
# noinspection PyUnresolvedReferences
self.interactive_table_population.emit(self.document_base, self.statistics)
+ def open_usage_info_task(self):
+ logger.info("Execute task 'open_usage_info_task'.")
+
+ if self.usage_info_popup.isHidden():
+ self.usage_info_popup.show()
+
+ def open_visualization_info_task(self):
+ logger.info("Execute task 'open_visualization_info_task'.")
+
+ if self.visualization_info_popup.isHidden():
+ self.visualization_info_popup.show()
+
+ def open_general_info_task(self):
+ logger.info("Execute task 'open_general_info_task'.")
+
+ if self.general_info_popup.isHidden():
+ self.general_info_popup.show()
+
+
##################
# controller logic
##################
@@ -335,6 +384,8 @@ def to_start_state(self):
else:
self.enable_collect_statistics_action.setEnabled(True)
+ self._set_available_visualization_actions()
+ self._enable_color_palette_settings()
self.central_widget_layout.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.document_base_viewer_widget.hide()
self.document_base_creator_widget.hide()
@@ -366,6 +417,8 @@ def to_create_document_base_state(self):
else:
self.enable_collect_statistics_action.setEnabled(True)
+ self._set_available_visualization_actions()
+ self._enable_color_palette_settings()
self.central_widget_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
self.start_menu_widget.hide()
self.document_base_creation_widget.hide()
@@ -386,6 +439,8 @@ def to_creating_document_base_state(self):
self.disable_global_input()
+ self._set_available_visualization_actions()
+ self._enable_color_palette_settings()
self.central_widget_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
self.start_menu_widget.hide()
self.document_base_creator_widget.hide()
@@ -429,6 +484,8 @@ def to_view_document_base_state(self):
else:
self.enable_collect_statistics_action.setEnabled(True)
+ self._set_available_visualization_actions()
+ self._enable_color_palette_settings()
self.document_base_viewer_widget.enable_input()
self.central_widget_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
@@ -458,6 +515,8 @@ def to_interactive_matching_state(self):
else:
self.enable_collect_statistics_action.setEnabled(True)
+ self._set_available_visualization_actions()
+ self._enable_color_palette_settings()
self.central_widget_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
self.start_menu_widget.hide()
self.document_base_viewer_widget.hide()
@@ -472,6 +531,32 @@ def to_interactive_matching_state(self):
self.interactive_matching_widget.show()
self.central_widget_layout.update()
+ def _enable_visualization_settings(self):
+ self.enable_visualizations_action.setEnabled(not self.visualizations)
+ self.disable_visualizations_action.setEnabled(self.visualizations)
+
+ def attach_visualization_level_observer(self, observer):
+ self.visualizations_level_observers.append(observer)
+
+
+ def _set_available_visualization_actions(self):
+ if self.visualizations_level == AvailableVisualizationsLevel.DISABLED:
+ self.enable_lvl1_visualizations_action.setEnabled(True)
+ self.enable_lvl2_visualizations_action.setEnabled(True)
+ self.disable_visualizations_action.setEnabled(False)
+ elif self.visualizations_level == AvailableVisualizationsLevel.LEVEL_1:
+ self.enable_lvl1_visualizations_action.setEnabled(False)
+ self.enable_lvl2_visualizations_action.setEnabled(True)
+ self.disable_visualizations_action.setEnabled(True)
+ elif self.visualizations_level == AvailableVisualizationsLevel.LEVEL_2:
+ self.enable_lvl1_visualizations_action.setEnabled(True)
+ self.enable_lvl2_visualizations_action.setEnabled(False)
+ self.disable_visualizations_action.setEnabled(True)
+
+ def _enable_color_palette_settings(self):
+ self.enable_accessible_color_palette_action.setEnabled(not self.accessible_color_palette)
+ self.disable_accessible_color_palette_action.setEnabled(self.accessible_color_palette)
+
# noinspection PyUnresolvedReferences
def __init__(self) -> None:
super(MainWindow, self).__init__()
@@ -482,8 +567,14 @@ def __init__(self) -> None:
self.document_base = None
self.statistics = None
self.collect_statistics = True
+ self.visualizations_level_observers = list()
+ self.visualizations_level = AvailableVisualizationsLevel.LEVEL_2
+ self.accessible_color_palette = False
self.attributes_to_match = None
self.cache_db = None
+ self.usage_info_popup = InformationPopup("Usage Information", "wannadb_ui/resources/info_popups/usage_info.md")
+ self.visualization_info_popup = InformationPopup("Visualization Information", "wannadb_ui/resources/info_popups/visualization_info.md")
+ self.general_info_popup = InformationPopup("Underlying Ideas / Architecture", "wannadb_ui/resources/info_popups/ideas_and_architecture_info.md")
# set up the api_thread and api and connect slots and signals
self.feedback_mutex = QMutex()
@@ -610,6 +701,44 @@ def __init__(self) -> None:
self.save_statistics_to_json_action.triggered.connect(self.save_statistics_to_json_task)
self._all_actions.append(self.save_statistics_to_json_action)
+ self.enable_lvl1_visualizations_action = QAction("&Level 1", self)
+ self.enable_lvl1_visualizations_action.setStatusTip("Only grid related visualizations are available.")
+ self.enable_lvl1_visualizations_action.triggered.connect(lambda: self.update_visualizations_level(AvailableVisualizationsLevel.LEVEL_1))
+ self._all_actions.append(self.enable_lvl1_visualizations_action)
+
+ self.enable_lvl2_visualizations_action = QAction("&Level 2", self)
+ self.enable_lvl2_visualizations_action.setStatusTip("All visualizations are available.")
+ self.enable_lvl2_visualizations_action.triggered.connect(lambda: self.update_visualizations_level(AvailableVisualizationsLevel.LEVEL_2))
+ self._all_actions.append(self.enable_lvl2_visualizations_action)
+
+ self.disable_visualizations_action = QAction("&Disable", self)
+ self.disable_visualizations_action.setStatusTip("Disable visualization widgets.")
+ self.disable_visualizations_action.triggered.connect(lambda: self.update_visualizations_level(AvailableVisualizationsLevel.DISABLED))
+ self._all_actions.append(self.disable_visualizations_action)
+
+ self.enable_accessible_color_palette_action = QAction("&Enable accessible palette", self)
+ self.enable_accessible_color_palette_action.setStatusTip("Change the color palette to accessible.")
+ self.enable_accessible_color_palette_action.triggered.connect(self.enable_accessible_color_palette_task)
+ self._all_actions.append(self.enable_accessible_color_palette_action)
+
+ self.disable_accessible_color_palette_action = QAction("&Disable accessible palette", self)
+ self.disable_accessible_color_palette_action.setStatusTip("Change the color palette to rgb.")
+ self.disable_accessible_color_palette_action.triggered.connect(self.disable_accessible_color_palette_task)
+ self._all_actions.append(self.disable_accessible_color_palette_action)
+
+ self.open_usage_info = QAction("&Open usage info", self)
+ self.open_usage_info.setStatusTip("Open popup providing some information about the usage of the application.")
+ self.open_usage_info.triggered.connect(self.open_usage_info_task)
+
+ self.open_visualization_info = QAction("&Open visualization info", self)
+ self.open_visualization_info.setStatusTip("Open popup providing some information about the available visualizations.")
+ self.open_visualization_info.triggered.connect(self.open_visualization_info_task)
+
+ self.open_general_info = QAction("&Open general info", self)
+ self.open_general_info.setStatusTip("Open popup providing some general information about the application.")
+ self.open_general_info.triggered.connect(self.open_general_info_task)
+
+
# set up the menu bar
self.menubar = self.menuBar()
self.menubar.setFont(MENU_FONT)
@@ -635,13 +764,34 @@ def __init__(self) -> None:
self.population_menu.addAction(self.forget_matches_for_attribute_action)
self.population_menu.addAction(self.forget_matches_action)
- self.statistics_menu = self.menubar.addMenu("&Statistics")
+ self.settings_menu = self.menubar.addMenu("&Settings")
+ self.settings_menu.setFont(MENU_FONT)
+
+ self.statistics_menu = self.settings_menu.addMenu("&Statistics")
self.statistics_menu.setFont(MENU_FONT)
self.statistics_menu.addAction(self.enable_collect_statistics_action)
self.statistics_menu.addAction(self.disable_collect_statistics_action)
self.statistics_menu.addSeparator()
self.statistics_menu.addAction(self.save_statistics_to_json_action)
+ self.visualizations_menu = self.settings_menu.addMenu("&Visualizations")
+ self.visualizations_menu.setFont(MENU_FONT)
+ self.visualizations_menu.addAction(self.disable_visualizations_action)
+ self.visualizations_menu.addAction(self.enable_accessible_color_palette_action)
+ self.visualizations_menu.addAction(self.disable_accessible_color_palette_action)
+ self.visualizations_menu.addAction(self.enable_lvl1_visualizations_action)
+ self.visualizations_menu.addAction(self.enable_lvl2_visualizations_action)
+
+ self.help_menu = self.menubar.addMenu("&Help")
+ self.help_menu.setFont(MENU_FONT)
+
+ self.general_menu = self.help_menu.addMenu("&General")
+ self.general_menu.addAction(self.open_general_info)
+ self.usage_menu = self.help_menu.addMenu("&Usage")
+ self.usage_menu.addAction(self.open_usage_info)
+ self.visualization_menu = self.help_menu.addMenu("&Visualization")
+ self.visualization_menu.addAction(self.open_visualization_info)
+
# main UI
self.central_widget = QWidget(self)
self.central_widget_layout = QHBoxLayout(self.central_widget)
@@ -659,4 +809,8 @@ def __init__(self) -> None:
self.resize(1400, 800)
self.show()
+ # Information popup
+ self.information_popup = InformationPopup("Quick Start Guide", "wannadb_ui/resources/info_popups/splash_screen.md")
+ self.information_popup.show()
+
logger.info("Initialized MainWindow.")
diff --git a/wannadb_ui/resources/info_popups/barchart_tutorial.md b/wannadb_ui/resources/info_popups/barchart_tutorial.md
new file mode 100644
index 00000000..6efcc708
--- /dev/null
+++ b/wannadb_ui/resources/info_popups/barchart_tutorial.md
@@ -0,0 +1,50 @@
+## Hey there!
+Before you access the cosine-distance scale, take a moment to read the following tips.
+If you are familiar with the metrics used in WANNADB or have gone through this tutorial before,
+feel free to exit using the **skip** button.
+
+
+
+## Cosine Similarity in 2D Plane:
+Imagine that you and a friend are standing in the middle of a field, and both of you
+point in different directions. Each direction you point is like a piece of information.
+The closer your two arms are to pointing in the same direction, the more similar your
+thoughts or ideas are.
+
+### Same direction:
+If you both point in exactly the same direction, it means your ideas (or pieces of information) are exactly alike.
+This is like saying: "We’re thinking the same thing!"
+
+### Opposite direction:
+If you point in completely opposite directions, your ideas are as different as they can be. You’re thinking about completely different things.
+
+### Right angle:
+If your arms are at a 90-degree angle, you're pointing in different directions, but not as different as pointing in opposite directions. You’re thinking about different things, but there might still be a tiny bit of connection.
+
+
+
+
+
+## Multi Dimensionality of Vectors and Cosine Distance:
+Vectors may have more than 2 dimensions, as was the case of you and your friend on the field. The mathematical formula guarantees a value between -1 and 1 for each pair of vectors, for any number of dimensions.
+
+The cosine similarity is equal to 1 when the vectors point at the same direction, -1 when the vectors point in opposite directions, and 0 when the vectors are perpendicular to each other.
+
+As cosine similarity expresses how similar two vectors are, a higher value (in the range from -1 to 1) expresses a higher similarity. In **wanna-db** we use the dual concept of cosine distance. Contrary to cosine similarity, a higher value in the cosine distance metric, means a higher degree of dissimilarity.
+
+_cos-dist(**a**, **b**) = 1 - cos-sim(**a**, **b**)_
+
+
+
+
+
+## Cosine-Driven Choices: Ranking Database Values:
+The bar chart shows all nuggets found inside the documents, lined after each other along the x-axis. The y-axis shows the normalized cosine distance. As we mentioned, the lower the cosine distance is, the more certain we are that the corresponding word belongs to what we are looking for: a value in the database.
+
+### QUESTION:
+After you explore the bar chart, ask yourself - do the answers on the left tend to be more plausible?
+
+### PRO TIP:
+Click on each bar to show the exact value, as well as the full information nugget.
+
+
diff --git a/wannadb_ui/resources/info_popups/cosine_similarity.png b/wannadb_ui/resources/info_popups/cosine_similarity.png
new file mode 100644
index 00000000..e503596c
Binary files /dev/null and b/wannadb_ui/resources/info_popups/cosine_similarity.png differ
diff --git a/wannadb_ui/resources/info_popups/header_image.svg b/wannadb_ui/resources/info_popups/header_image.svg
new file mode 100644
index 00000000..2d0308e7
--- /dev/null
+++ b/wannadb_ui/resources/info_popups/header_image.svg
@@ -0,0 +1,1031 @@
+
\ No newline at end of file
diff --git a/wannadb_ui/resources/info_popups/ideas_and_architecture_info.md b/wannadb_ui/resources/info_popups/ideas_and_architecture_info.md
new file mode 100644
index 00000000..2f7a6aca
--- /dev/null
+++ b/wannadb_ui/resources/info_popups/ideas_and_architecture_info.md
@@ -0,0 +1,30 @@
+# WannaDB: Ad-hoc SQL Queries over Text Collections
+
+
+
+WannaDB allows users to explore unstructured text collections by automatically organizing the relevant information nuggets in a table. It supports ad-hoc SQL queries over text collections using a novel two-phased approach: First, a superset of information nuggets is extracted from the texts using existing extractors such as named entity recognizers. The extractions are then interactively matched to a structured table definition as requested by the user.
+
+Watch our [demo video](https://link.tuda.systems/aset-video) or [read our paper](https://doi.org/10.18420/BTW2023-08) to learn more about the usage and underlying concepts.
+
+
+# Underlying architecture and ideas
+
+This section gives a brief insight into the ideas facilitating the possibilities provided by WannaDB.
+
+## Cosine distance
+
+The cosine distance is used to measure the similarity between text snippets (nuggets) or text snippets and attributes. It's calculated as the cosine distance between the embedding vectors of two nuggets or a nugget and an attribute.
+In order to decide whether a nugget matches an attribute, the system calculates the cosine distance between the nugget and either the attribute to match or the closest nugget identified as a confirmed match for this attribute. The initial distance of a nugget is always its distance to the corresponding attribute. During the feedback process, the feedback process keeps track of the confirmed matches for this attribute and might update the nuggets distance to the distance to a confirmed match if this lowers the nuggets distance. The shorter the computed distance, the greater the confidence that this nugget matches this attribute.
+This distance is calculated for each extracted nugget within a document and the nugget with the lowest distance is considered as the best match of this document.
+As the user gives feedback the cosine distance of a nugget might change as this feedback might result in a close confirmed match.
+
+## Threshold
+
+The threshold refers to the maximum cosine distance from which a nugget isn't be considered as a match for an attribute and therefore not added to the resulting table. The best match of a document is only added to the table if its distance is below the current threshold.
+During the user's feedback process, the threshold might change its value multiple times to utilize the user's feedback as much as possible.
+
+## Interactive matching process
+
+The interactive matching process is the workflow in which the cells of the table are populated based on feedback given by the user.
+In order to populate cells related to some attribute, the user gives feedback to the corresponding nugget matches determined by the system. The feedback provided by the user (confirms/rejects a match) for a specific document influences the computed distance of other nuggets and the threshold as there might be new confirmed matches. Therefore, each feedback round might lead to changing best matches in other documents.
+In this way, the system kind of learns from the user's feedback and tries to improve the resulting table with each feedback round without requiring user feedback for each document.
diff --git a/wannadb_ui/resources/info_popups/overview.svg b/wannadb_ui/resources/info_popups/overview.svg
new file mode 100644
index 00000000..5425b7bc
--- /dev/null
+++ b/wannadb_ui/resources/info_popups/overview.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/wannadb_ui/resources/info_popups/screenshot_bar_chart.png b/wannadb_ui/resources/info_popups/screenshot_bar_chart.png
new file mode 100644
index 00000000..fb0168a7
Binary files /dev/null and b/wannadb_ui/resources/info_popups/screenshot_bar_chart.png differ
diff --git a/wannadb_ui/resources/info_popups/screenshot_grid.png b/wannadb_ui/resources/info_popups/screenshot_grid.png
new file mode 100644
index 00000000..902b5417
Binary files /dev/null and b/wannadb_ui/resources/info_popups/screenshot_grid.png differ
diff --git a/wannadb_ui/resources/info_popups/splash_screen.md b/wannadb_ui/resources/info_popups/splash_screen.md
new file mode 100644
index 00000000..a490fe72
--- /dev/null
+++ b/wannadb_ui/resources/info_popups/splash_screen.md
@@ -0,0 +1,16 @@
+
+
+WannaDB helps you to turn large collections of text into a concise table:
+
+1) Load the text files you want to turn into a table. The system will then automatically prepare them.
+
+2) Specify which columns your table should have. Then start the automatic table filling process.
+
+3) You are requested to give some feedback/examples for possible entries for the first column. The system will learn from them to automatically find similar entries. By giving more feedback, you can improve the quality. Once the preview you see seems of sufficient quality, you can continue this process for the next column.
+
+4) The system will use all the information gathered to fill the final table, without you having to read the other documents. Each row will represent one input document.
+
+To get more detailed instructions or to learn how WannaDB works internally, check the *Help* section in the menu.
+
+
+
diff --git a/wannadb_ui/resources/info_popups/usage_info.md b/wannadb_ui/resources/info_popups/usage_info.md
new file mode 100644
index 00000000..e4f0178a
--- /dev/null
+++ b/wannadb_ui/resources/info_popups/usage_info.md
@@ -0,0 +1,19 @@
+# Usage
+
+### 1. Create or load a document base
+Put all .txt files from which the table is to be created later into a new directory and start the *Create Document Base* dialog.
+Apart from providing the path of the .txt files, the attribute names - columns of the resulting table - need to be specified to create a new document base.
+Alternatively you can load an already existing document base.
+
+### 2. Check document base overview
+After loading or creating a document base, the *Document Base Viewer* will be displayed concluding the current result of the table population.
+In the beginning, there won't be any resulting table as you need to populate the cells of each attribute in the interactive matching process first.
+
+### 3. Populate cells
+From the *Document Base Viewer* screen, you can start populating the cells for each attribute.
+Clicking the corresponding start button will invoke the interactive matching process for an attribute in which text snippets matching an attribute are determined.
+For more information check the *Underlying architecture and ideas* help section.
+
+### 4. Export results
+After matching all attributes, the resulting table can be exported in the menu via *Document Base* -> *Export table to CSV*.
+Furthermore, the resulting table can be saved for later use and modification via *Document base* -> *Save document base*.
\ No newline at end of file
diff --git a/wannadb_ui/resources/info_popups/visualization_info.md b/wannadb_ui/resources/info_popups/visualization_info.md
new file mode 100644
index 00000000..6a4aa46e
--- /dev/null
+++ b/wannadb_ui/resources/info_popups/visualization_info.md
@@ -0,0 +1,12 @@
+# Visualizations
+
+## 3D-Grid
+The 3D-Grid aims to visualize how text snippets are interpreted by the system.
+Each text snippet extracted from one of the provided documents is represented by a vector of numbers within the system. This vector contains information about the meaning of the text snippet and is called *embedding vector* or just *embedding*.
+These embedding vectors are displayed in the 3D-Grid which makes it possible to recognize that words with similar meaning are also mapped to similar embedding vectors.
+Furthermore, the grid shows where a text snippet lies relative to the current threshold. The threshold basically determines whether a found text snippet should be considered as a match or not. To learn more about it, check [this section](#threshold).
+
+## Cosine-Distance Bar-Chart
+This bar-chart attempts to display the system-calculated confidence with which a text snippet from a document matches an attribute.
+In order to determine this certainty, the system uses the similarity between the embedding of a text snippet and the embedding of either the attribute to match or an already confirmed match for this attribute. What exactly is used the comparative value (attribute or already confirmed match) depends on which value results in a higher similarity score.
+To learn more about how this similarity is calculated, check the *Underlying architecture and ideas* help section.
\ No newline at end of file
diff --git a/wannadb_ui/study.py b/wannadb_ui/study.py
new file mode 100644
index 00000000..bb050706
--- /dev/null
+++ b/wannadb_ui/study.py
@@ -0,0 +1,126 @@
+import json
+import logging
+import os
+import time
+from collections import defaultdict
+from functools import wraps
+
+from PyQt6.QtCore import QObject, QTimer, QDateTime, pyqtSignal
+from typing import Dict, Callable
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+# Singleton class for tracking user interaction with a GUI
+class Tracker(QObject):
+ _instance = None # Class-level attribute to store the singleton instance
+ time_spent_signal = pyqtSignal(str, float) # Define the signal with window name and time spent
+
+ def __new__(cls, *args, **kwargs):
+ """Singleton pattern ensures one instance of the class"""
+ if not cls._instance:
+ cls._instance = super(Tracker, cls).__new__(cls, *args, **kwargs)
+ cls._instance._initialized = False
+
+ return cls._instance
+
+ def __init__(self):
+ """Initialize tracking properties if not already initialized"""
+ if not self._initialized:
+ super().__init__() # Call the QObject initializer
+ self.window_open_times = {}
+ self.timer = QTimer()
+ self.button_click_counts = defaultdict(int)
+ self.tooltips_hovered_counts = defaultdict(int)
+ self.total_window_open_times = {}
+ self._initialized = True
+ self.log = ''
+ self.sequence_number = 1
+ self.json_data = []
+
+ def dump_report(self):
+ """Dumps the interaction data to two report files.
+ One of them contains a json representations of the user activiy, the other
+ contains natural text."""
+ log_directory = './logs'
+ log_file = os.path.join(log_directory, 'user_report.txt')
+ os.makedirs(log_directory, exist_ok=True)
+
+ tick: float = time.time()
+ with open(log_file, 'w') as file:
+ file.write(self.log)
+ file.write("\nTotal Statistics:\n")
+ file.write(f"\nButton information:\n")
+ for button_name, number_of_clicks in self.button_click_counts.items():
+ file.write(f"\t'{button_name}' button has been clicked {number_of_clicks} times\n")
+ file.write(f"Window Information:\n")
+ for window_name, time_open_in_sec in self.total_window_open_times.items():
+ file.write(f"\t{window_name} was open for a total of {time_open_in_sec} seconds\n")
+ tack: float = time.time()
+ logger.info(f"Wrote the report in {round(tick - tack, 2)} seconds")
+
+ tick = time.time()
+ json_string = json.dumps(self.json_data, indent=4)
+ with open(os.path.join(log_directory, 'json_report.txt'), 'w') as file:
+ file.write(json_string)
+ tack = time.time()
+ logger.info(f"Dumped the json report file in {round(tick - tack, 2)} seconds")
+
+ def start_timer(self, window_name: str):
+ """Starts the timer for tracking window open time"""
+ self.window_open_times[window_name] = QDateTime.currentDateTime()
+ self.timer.start(1000)
+ self.log += f"{self.sequence_number}. {window_name} was opened\n"
+ self.json_data.append({'type': 'window', 'action': 'open', 'identifier': window_name})
+ self.sequence_number += 1
+
+ def stop_timer(self, window_name: str):
+ """Stops the timer for a window and calculates the time spent"""
+ self.timer.stop()
+ logger.debug(f"window_name = {window_name}")
+ self.calculate_time_spent(window_name)
+
+ def calculate_time_spent(self, window_name: str):
+ """Calculates the time spent in a window and logs the result"""
+ if self.window_open_times[window_name]:
+ current_time = QDateTime.currentDateTime()
+ time_spent = self.window_open_times[window_name].msecsTo(current_time) / 1000.0 # Convert to seconds
+ self.time_spent_signal.emit(window_name, time_spent)
+ self.window_open_times[window_name] = None
+ if window_name in self.total_window_open_times:
+ self.total_window_open_times[window_name] += time_spent
+ else:
+ self.total_window_open_times[window_name] = time_spent
+ self.log += f'{self.sequence_number}. {window_name} was closed. Time spent in {window_name} : {round(time_spent, 2)} seconds.\n'
+ self.sequence_number += 1
+ self.json_data.append(
+ {'type': 'window', 'action': 'close', 'identifier': window_name, 'time_open': time_spent})
+
+ def track_button_click(self, button_name: str):
+ """Tracks button clicks and logs them. Helper method for the decorator below"""
+ self.button_click_counts[button_name] += 1
+ self.log += f'{self.sequence_number}. {button_name} was clicked.\n'
+ self.sequence_number += 1
+ self.json_data.append({'type': 'button', 'identifier': button_name})
+
+ def track_tooltip_activation(self, tooltip_object: str):
+ """Tracks tooltip activations and logs them. Must be manually wired to every added tooltip"""
+ self.tooltips_hovered_counts[tooltip_object] += 1
+ self.log += f'{self.sequence_number}. The following tooltip was activated:\n {tooltip_object} \n'
+ self.sequence_number += 1
+ self.json_data.append({'type': 'tooltip', 'identifier': tooltip_object})
+
+
+def track_button_click(button_name: str):
+ """Decorator to track button clicks. Add to function signature behind a button to start logging"""
+ def decorator(func: Callable):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ print(f"Arguments passed to {func.__name__}: args={args}, kwargs={kwargs}")
+ args = tuple() # empty args, because .connect() implicit arguments are added, which result in an erroneous call of the decorated method
+ Tracker().track_button_click(button_name)
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
diff --git a/wannadb_ui/visualizations.py b/wannadb_ui/visualizations.py
new file mode 100644
index 00000000..739b3052
--- /dev/null
+++ b/wannadb_ui/visualizations.py
@@ -0,0 +1,1310 @@
+"""
+This class provides several classes related to visualization widgets.
+ 1. PointLegend
+ Label serving as a legend for a dyed point.
+ 2. EmbeddingVisualizerLegend
+ Widget serving as a legend for the EmbeddingVisualizer.
+ 3. EmbeddingVisualizer
+ Provides logic for handling a grid displaying dimension reduced nuggets.
+ 4. EmbeddingVisualizerWindow
+ Realizes an EmbeddingVisualizer in a separate window.
+ 5. EmbeddingVisualizerWidget
+ Realizes an EmbeddingVisualizer in a widget.
+ 6. BarChartVisualizerWidget
+ Widget realizing a bar chart displaying nuggets with their certainty with which they match an attribute.
+"""
+
+import logging
+from typing import List, Dict, Tuple, Union
+
+import numpy as np
+import pyqtgraph as pg
+import pyqtgraph.opengl as gl
+from PyQt6.QtCore import Qt, QPoint
+from PyQt6.QtGui import QFont, QColor, QPixmap, QPainter
+from PyQt6.QtWidgets import QWidget, QVBoxLayout, QPushButton, QMainWindow, QHBoxLayout, QFrame, QScrollArea, \
+ QApplication, QLabel
+from matplotlib import pyplot as plt
+from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
+from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
+from matplotlib.colors import LinearSegmentedColormap
+from matplotlib.figure import Figure
+from matplotlib.patches import Rectangle
+from pyqtgraph import Color
+from pyqtgraph.opengl import GLViewWidget, GLScatterPlotItem, GLTextItem
+
+from wannadb.data.data import InformationNugget, Attribute
+from wannadb.data.signals import PCADimensionReducedTextEmbeddingSignal, PCADimensionReducedLabelEmbeddingSignal, \
+ CachedDistanceSignal, CurrentThresholdSignal
+from wannadb.utils import AccessibleColor
+from wannadb_ui.common import BUTTON_FONT_SMALL, InfoDialog
+from wannadb_ui.study import Tracker, track_button_click
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+RED = pg.mkColor('red')
+ACC_RED = pg.mkColor(220, 38, 127)
+BLUE = pg.mkColor('blue')
+ACC_BLUE = pg.mkColor(100, 143, 255)
+GREEN = pg.mkColor('green')
+ACC_GREEN = pg.mkColor(255, 176, 0)
+WHITE = pg.mkColor('white')
+YELLOW = pg.mkColor('yellow')
+ACC_YELLOW = pg.mkColor(254, 97, 0)
+PURPLE = pg.mkColor('purple')
+ACC_PURPLE = pg.mkColor(120, 94, 240)
+EMBEDDING_ANNOTATION_FONT = QFont('Helvetica', 10)
+DEFAULT_NUGGET_SIZE = 10
+HIGHLIGHT_SIZE = 17
+
+def initialize_app():
+ """
+ Initializes the PyQt application and sets up the main window.
+ This function is typically called at the start of the application.
+ """
+ app = QApplication.getInstance()
+ if app is None:
+ app = QApplication([])
+ screen = app.primaryScreen()
+ screen_geometry = screen.geometry()
+ return app, screen_geometry
+
+app, screen_geometry = initialize_app()
+WINDOW_WIDTH = int(screen_geometry.width() * 0.7)
+WINDOW_HEIGHT = int(screen_geometry.height() * 0.7)
+
+
+def _get_colors(distances, color_start='green', color_end='red'):
+ cmap = LinearSegmentedColormap.from_list("CustomMap", [color_start, color_end])
+ norm = plt.Normalize(min(distances), max(distances))
+ colors = [cmap(norm(value)) for value in distances]
+ return colors
+
+
+def _build_nuggets_annotation_text(nugget) -> str:
+ return f"{nugget.text}: {round(nugget[CachedDistanceSignal], 3)}"
+
+
+def _create_sanitized_text(nugget):
+ return nugget.text.replace("\n", " ")
+
+
+class PointLegend(QLabel):
+ """
+ Class realizing a legend for a dyed point by displaying a dyed point next to the meaning of this point within a label.
+
+ In the application, this class is employed to create a legend for the 3D-Grids.
+ The 3D-Grid contains points with different colors. Each color is explained using a label created by this class.
+ """
+ def __init__(self, point_meaning: str, point_color: QColor):
+ """
+ Parameters
+ ----------
+ point_meaning : str
+ the meaning of points with the given color
+ point_color: QColor
+ the color of the points whose meaning is explained by this label
+ """
+
+ super().__init__()
+
+ # Set fixed sizes
+ self._height = 30
+ self._width = 300
+ self._circle_diameter = 10
+
+ # Init pixmap on which all contents will be painted
+ self._pixmap = QPixmap(self._width, self._height)
+ self._pixmap.fill(Qt.GlobalColor.transparent)
+
+ # Init painter used to paint on pixmap
+ self._painter = QPainter(self._pixmap)
+
+ # Init point displayed on pixmap serving as a reference to which points the meaning refers to
+ circle_center = QPoint(self._circle_diameter, round(self._height / 2))
+
+ # Paint point and text on pixmap
+ self._painter.setPen(Qt.PenStyle.NoPen)
+ self._painter.setBrush(point_color)
+ self._painter.drawEllipse(circle_center, self._circle_diameter, self._circle_diameter)
+ self._painter.setFont(BUTTON_FONT_SMALL)
+ self._painter.setPen(pg.mkColor('black'))
+ text_height = self._painter.fontMetrics().height()
+ self._painter.drawText(circle_center.x() + self._circle_diameter + 5,
+ circle_center.y() + round(text_height / 4),
+ f': {point_meaning}')
+
+ self._painter.end()
+
+ # Add pixmap to label represented by this instance
+ self.setPixmap(self._pixmap)
+
+
+class EmbeddingVisualizerLegend(QWidget):
+ """
+ Class realizing a legend for a 3D-Grid realized by the EmbeddingVisualizer class which explains the meaning of all
+ point colors occurring within the grid.
+ Utilizes instances of `PointLegend` to explain the meaning of a specific color.
+
+ Methods
+ -------
+ reset():
+ Removes all widgets - realized as instances of `PointLegend` - contained within this widget.
+ update_colors_and_meanings(colors_with_meanings: List[Tuple[QColor, str]]):
+ Fills this instance with an actual legend explaining the given colors with the given meanings.
+ """
+
+ def __init__(self):
+ """
+ Initializes an instance of this class by creating and setting up the corresponding layout.
+ Initially the widget represented by this instance is empty and doesn't contain anything except an empty layout.
+ """
+
+ super().__init__()
+
+ # Set up the layout
+ self.layout = QHBoxLayout()
+ self.layout.setContentsMargins(0, 0, 0, 0)
+ self.layout.setSpacing(0)
+ self.setLayout(self.layout)
+
+ # Init the list of PointLegends contained by this instance
+ self._point_legends = []
+
+ def reset(self):
+ """
+ Removes all widgets - realized as instances of `PointLegend`.
+
+ After calling this method, the widget represented by this instance is empty and doesn't contain anything except
+ an empty layout.
+ """
+
+ for widget in self._point_legends:
+ self.layout.removeWidget(widget)
+ self._point_legends = []
+
+ def update_colors_and_meanings(self, colors_with_meanings: List[Tuple[QColor, str]]):
+ """
+ Fills this instance with an actual legend explaining the given colors with the given meanings.
+
+ First, this instance is cleared by calling the `reset()` method.
+ Then an explanation for each of the given colors is created by creating `PointLegend` instances for each color
+ with its associated meaning and added to the widget represented by this instance.
+ """
+
+ # Clear this widget
+ self.reset()
+
+ # Add new explanations
+ for color, meaning in colors_with_meanings:
+ point_legend = PointLegend(meaning, color)
+ self.layout.addWidget(point_legend)
+ self._point_legends.append(point_legend)
+
+
+class EmbeddingVisualizer:
+ """
+ Class providing the required logic to handle a 3D-Grid displaying dimension-reduced embedding vectors.
+
+ Methods
+ -------
+ enable_accessible_color_palette_():
+ Replaces the colors of points displayed within the grid by accessible colors allowing people with color
+ blindness to better differentiate the colors.
+ disable_accessible_color_palette_():
+ Replaces the colors of points displayed within the grid by the originally used colors and therefore disables the
+ usage of accessible colors.
+ update_and_display_params(attribute: Attribute,
+ nuggets: List[InformationNugget],
+ currently_highlighted_nugget: Union[InformationNugget, None],
+ best_guess: Union[InformationNugget, None],
+ other_best_guesses: List[InformationNugget]):
+ Removes all currently displayed nuggets and adds the given attribute as well as the nuggets to the grid.
+ highlight_best_guess(best_guess: InformationNugget):
+ Highlights the point representing the given nugget by increasing its size and dying it white.
+ highlight_selected_nugget(newly_selected_nugget: InformationNugget):
+ Highlights the point representing the given nugget by increasing its size and dying it blue.
+ display_other_best_guesses(other_best_guesses: List[InformationNugget]):
+ Adds the given nuggets - corresponding to the best guesses of other documents - to the grid and highlight them
+ by dying them yellow.
+ remove_other_best_guesses(other_best_guesses: List[InformationNugget]):
+ Removes the nuggets corresponding to the best guesses of other documents from the grid.
+ reset():
+ Removes all points and their corresponding annotation text from the grid.
+ """
+
+ def __init__(self,
+ legend: EmbeddingVisualizerLegend,
+ colors_with_meanings: List[Tuple[AccessibleColor, str]],
+ attribute: Attribute = None,
+ nuggets: List[InformationNugget] = None,
+ currently_highlighted_nugget: InformationNugget = None,
+ best_guess: InformationNugget = None,
+ other_best_guesses: List[InformationNugget] = None,
+ accessible_color_palette: bool = False):
+ """
+ Parameters:
+ -----------
+ legend: EmbeddingVisualizerLegend
+ Instance of the legend displayed below the grid and explaining the meaning of the colors occurring in the
+ grid.
+ colors_with_meanings: List[Tuple[QColor, str]]
+ List of colors occurring in the grid associated with their meaning used to fill the given legend.
+ attribute: Attribute = None
+ `Attribute` instance representing the attribute to which the nuggets displayed within the grid belong to as
+ its embedding is displayed in the grid as well.
+ nuggets: List[InformationNugget] = None
+ Nuggets whose dimension-reduced embedding vectors should be displayed within the grid.
+ currently_highlighted_nugget: InformationNugget = None
+ Refers to the nugget which is currently selected and therefore should be highlighted. If none, nothing is
+ highlighted.
+ best_guess: InformationNugget = None
+ Refers to the best guess of the document represented by this grid and therefore should be highlighted. If
+ none, nothing is highlighted. Applicable only in case the grid belongs to the document view and not to the
+ document overview screen.
+ other_best_guesses: List[InformationNugget] = None
+ Nuggets representing best guesses from other documents which should be displayed in this grid as well
+ initially. Applicable only in case the grid belongs to the document view and not to the document overview
+ screen.
+ accessible_color_palette: bool
+ Specifies whether the colors used by the points displayed in the grid are accessible - usable for people
+ with color blindness - or not.
+ """
+
+ self._attribute: Attribute = attribute
+ self._nuggets: List[InformationNugget] = nuggets
+ self._currently_highlighted_nugget: InformationNugget = currently_highlighted_nugget
+ self._best_guess: InformationNugget = best_guess
+ self._other_best_guesses: List[InformationNugget] = other_best_guesses
+ self._nugget_to_displayed_items: Dict[InformationNugget, Tuple[GLScatterPlotItem, GLTextItem]] = dict()
+ self._gl_widget = GLViewWidget()
+ self._accessible_color_palette = accessible_color_palette
+ self._legend = legend
+ self._colors_with_meanings = colors_with_meanings
+
+ # Add the given colors with their meanings to the given legend
+ self._update_legend()
+
+ def enable_accessible_color_palette(self):
+ """
+ Replaces the colors of points displayed within the grid by accessible colors allowing people with color
+ blindness to better differentiate the colors.
+ """
+
+ self._accessible_color_palette = True
+ self._update_legend()
+ self.update_and_display_params(self._attribute,
+ self._nuggets,
+ self._currently_highlighted_nugget,
+ self._best_guess,
+ self._other_best_guesses)
+
+ def disable_accessible_color_palette(self):
+ """
+ Replaces the colors of points displayed within the grid by the originally used colors and therefore disables the
+ usage of accessible colors.
+ """
+
+ self._accessible_color_palette = False
+ self._update_legend()
+ self.update_and_display_params(self._attribute,
+ self._nuggets,
+ self._currently_highlighted_nugget,
+ self._best_guess,
+ self._other_best_guesses)
+
+ def update_and_display_params(self,
+ attribute: Attribute,
+ nuggets: List[InformationNugget],
+ currently_highlighted_nugget: Union[InformationNugget, None],
+ best_guess: Union[InformationNugget, None],
+ other_best_guesses: List[InformationNugget]):
+ """
+ Removes all currently displayed nuggets and adds the given attribute as well as the nuggets to the grid.
+
+ First, removes all currently displayed points.
+ Then adds the dimension-reduced embedding vector of the given attribute and the given nuggets to the grid.
+ Next, the given best guess and `currently_highlighted_nugget` and - if this grid belongs to the document
+ overview - already confirmed matches are highlighted.
+
+ Parameters:
+ -----------
+ attribute: Attribute
+ `Attribute` instance representing the attribute to which the nuggets displayed within the grid as its
+ embedding is displayed in the grid as well.
+ nuggets: List[InformationNugget]
+ Nuggets whose dimension-reduced embedding vectors should be displayed within the grid.
+ currently_highlighted_nugget: InformationNugget
+ Special nugget which should be highlighted. If none, nothing is highlighted.
+ best_guess: InformationNugget
+ Best guess of the document corresponding to the grid which is highlighted. If none, nothing is highlighted.
+ Applicable only in case the grid belongs to the document view and not to the document overview screen.
+ other_best_guesses: List[InformationNugget]
+ Best guesses of other documents which should be displayed in the grid as well. Applicable only in case the
+ grid belongs to the document view and not to the document overview screen.
+ """
+
+ self.reset()
+
+ # Add attribute to grid
+ if attribute is not None:
+ self._display_attribute_embedding(attribute)
+ else:
+ logger.warning("Given attribute is null, can not display.")
+
+ # Add nuggets to the grid
+ if nuggets:
+ self._nuggets = nuggets
+ self._display_nugget_embeddings(nuggets)
+ else:
+ logger.warning("Given nugget list is null or empty, can not display.")
+
+ # Highlight best guess if present
+ if best_guess is not None:
+ self.highlight_best_guess(best_guess)
+ else:
+ logger.info("Given best_guess is null, can not highlight.")
+
+ # Highlight confirmed matches if possible
+ self._highlight_confirmed_matches()
+
+ # Highlight currently selected nugget if possible
+ if currently_highlighted_nugget is not None:
+ self.highlight_selected_nugget(currently_highlighted_nugget)
+ else:
+ logger.info("Given nugget to highlight is null, can not highlight.")
+
+ self._other_best_guesses = other_best_guesses
+
+ def highlight_best_guess(self, best_guess: InformationNugget):
+ """
+ Highlights the point representing the given nugget by increasing its size and dying it white.
+
+ If the best guess is equal to the currently selected nugget, it's highlighted in blue.
+ """
+
+ # Update internal attribute
+ self._best_guess = best_guess
+
+ # Highlight in blue if equal to currently selected nugget
+ if self._best_guess == self._currently_highlighted_nugget:
+ self._highlight_nugget(self._best_guess, ACC_BLUE if self._accessible_color_palette else BLUE, 15)
+ return
+
+ # Highlight given nugget in white and increase size
+ self._highlight_nugget(self._best_guess, WHITE, 15)
+
+ def highlight_selected_nugget(self, newly_selected_nugget: InformationNugget):
+ """
+ Highlights the point representing the given nugget by increasing its size and dying it blue.
+
+ If present, the previously selected nugget is reset to original color and size. Exact reset color and size
+ depend on type of previously selected nugget (best guess, confirmed match, normal nugget)
+ """
+
+ # Determine highlight color and size as well as reset color and size. Highlight values are always blue and 15
+ # while reset values depend on type of previously selected nugget (see above).
+ (highlight_color, highlight_size), (reset_color, reset_size) = self._determine_update_values(
+ previously_selected_nugget=self._currently_highlighted_nugget)
+
+ # Reset currently highlighted nugget to determined color and size
+ if self._currently_highlighted_nugget is not None:
+ currently_highlighted_scatter, _ = self._nugget_to_displayed_items[self._currently_highlighted_nugget]
+ currently_highlighted_scatter.setData(color=reset_color, size=reset_size)
+
+ # Highlight newly selected nugget
+ self._highlight_nugget(nugget_to_highlight=newly_selected_nugget,
+ new_color=highlight_color,
+ new_size=highlight_size)
+
+ # Update internal variable
+ self._currently_highlighted_nugget = newly_selected_nugget
+
+ def display_other_best_guesses(self, other_best_guesses: List[InformationNugget]):
+ """
+ Adds the given nuggets - corresponding to the best guesses of other documents - to the grid and highlight them
+ by dying them yellow.
+ """
+
+ for other_best_guess in other_best_guesses:
+ self._add_other_best_guess(other_best_guess)
+
+ def remove_other_best_guesses(self, other_best_guesses: List[InformationNugget]):
+ """
+ Removes the nuggets corresponding to the best guesses of other documents from the grid.
+ """
+
+ self._remove_nuggets_from_widget(other_best_guesses)
+
+ def reset(self):
+ """
+ Removes all points and their corresponding annotation text from the grid.
+ """
+
+ # Remove widgets
+ for nugget, (scatter, annotation) in self._nugget_to_displayed_items.items():
+ self._gl_widget.removeItem(scatter)
+ self._gl_widget.removeItem(annotation)
+
+ # Reset internal state variables
+ self._nugget_to_displayed_items = {}
+ self._currently_highlighted_nugget = None
+ self._best_guess = None
+
+ def _add_item_to_grid(self,
+ nugget_to_display_context: Tuple[Union[InformationNugget, Attribute], Color],
+ annotation_text: str,
+ size: int = DEFAULT_NUGGET_SIZE):
+ # Determine position of point to display and its color
+ item_to_display, color = nugget_to_display_context
+ position = np.array([item_to_display[PCADimensionReducedTextEmbeddingSignal]]) if isinstance(item_to_display,
+ InformationNugget) \
+ else np.array([item_to_display[PCADimensionReducedLabelEmbeddingSignal]])
+
+ # Create grid items representing the given nugget and annotation text at the computed position
+ scatter = GLScatterPlotItem(pos=position, color=color, size=size, pxMode=True)
+ annotation = GLTextItem(pos=[position[0][0], position[0][1], position[0][2]],
+ color=WHITE,
+ text=annotation_text,
+ font=EMBEDDING_ANNOTATION_FONT)
+
+ # Add created items to grid
+ self._gl_widget.addItem(scatter)
+ self._gl_widget.addItem(annotation)
+
+ # Add created items to internal variable to keep track about the added items
+ if isinstance(item_to_display, InformationNugget):
+ self._nugget_to_displayed_items[item_to_display] = (scatter, annotation)
+
+ def _display_nugget_embeddings(self, nuggets):
+ for nugget in nuggets:
+ nugget_to_display_context = (nugget, self._determine_nuggets_color(nugget))
+
+ self._add_item_to_grid(nugget_to_display_context=nugget_to_display_context,
+ annotation_text=_build_nuggets_annotation_text(nugget))
+
+ def _display_attribute_embedding(self, attribute):
+ self._add_item_to_grid(nugget_to_display_context=(attribute, WHITE),
+ annotation_text=f'Attribute: {attribute.name}')
+ self._attribute = attribute
+
+ def _remove_nuggets_from_widget(self, nuggets_to_remove):
+ # Removes all items associated with the given nuggets from the grid
+ for nugget in nuggets_to_remove:
+ scatter, annotation = self._nugget_to_displayed_items.pop(nugget)
+
+ self._gl_widget.removeItem(scatter)
+ self._gl_widget.removeItem(annotation)
+
+ def _highlight_confirmed_matches(self):
+ # Only relevant if the grid belongs to the document overview view as it highlights the nuggets which are already
+ # confirmed by the user in the feedback process.
+ if self._attribute is None:
+ logger.warning("Attribute has not been initialized yet, can not highlight confirmed matches.")
+ return
+
+ for confirmed_match in self._attribute.confirmed_matches:
+ if confirmed_match in self._nugget_to_displayed_items:
+ self._highlight_nugget(confirmed_match, ACC_GREEN if self._accessible_color_palette else GREEN,
+ DEFAULT_NUGGET_SIZE)
+
+ def _determine_update_values(self, previously_selected_nugget) -> ((int, Color), (int, Color)):
+ # Computes the size and color of a newly selected nugget as well as the size and color of the nugget
+ # which was selected previously
+
+ # Highlight values are always same
+ highlight_color = ACC_BLUE if self._accessible_color_palette else BLUE
+ highlight_size = 15
+
+ # Reset values depend on the type of the nugget whose size and color should be reset
+ if previously_selected_nugget is None:
+ reset_color = WHITE
+ reset_size = DEFAULT_NUGGET_SIZE
+ elif previously_selected_nugget in self._attribute.confirmed_matches:
+ reset_color = ACC_GREEN if self._accessible_color_palette else GREEN
+ reset_size = DEFAULT_NUGGET_SIZE
+ elif previously_selected_nugget == self._best_guess:
+ reset_color = WHITE
+ reset_size = HIGHLIGHT_SIZE
+ else:
+ reset_color = self._determine_nuggets_color(previously_selected_nugget)
+ reset_size = DEFAULT_NUGGET_SIZE
+
+ return (highlight_color, highlight_size), (reset_color, reset_size)
+
+ def _determine_nuggets_color(self, nugget: InformationNugget) -> Color:
+ # Computes the nuggets color based on its type:
+ # Purple -> Failure during computation
+ # White -> Below threshold
+ # Red -> Above Threshold
+
+ if (self._attribute is None or
+ CurrentThresholdSignal.identifier not in self._attribute.signals):
+ logger.warning(f"Could not determine nuggets color from given attribute: {self._attribute}. "
+ f"Will return purple as color highlighting nuggets with this issue.")
+ return ACC_PURPLE if self._accessible_color_palette else PURPLE
+
+ return WHITE if nugget[CachedDistanceSignal] < self._attribute[
+ CurrentThresholdSignal] else ACC_RED if self._accessible_color_palette else RED
+
+ def _add_grids(self):
+ # Adds the UI items realizing the grid
+
+ grid_xy = gl.GLGridItem()
+ self._gl_widget.addItem(grid_xy)
+
+ grid_xz = gl.GLGridItem()
+ grid_xz.rotate(90, 1, 0, 0)
+ self._gl_widget.addItem(grid_xz)
+
+ grid_yz = gl.GLGridItem()
+ grid_yz.rotate(90, 0, 1, 0)
+ self._gl_widget.addItem(grid_yz)
+
+ def _highlight_nugget(self, nugget_to_highlight, new_color, new_size):
+ scatter_to_highlight, _ = self._nugget_to_displayed_items[nugget_to_highlight]
+
+ if scatter_to_highlight is None:
+ logger.warning("Couldn't find nugget to highlight.")
+ return
+
+ scatter_to_highlight.setData(color=new_color, size=new_size)
+
+ def _add_other_best_guess(self, other_best_guess):
+ self._add_item_to_grid(
+ nugget_to_display_context=(other_best_guess, ACC_YELLOW if self._accessible_color_palette else YELLOW),
+ annotation_text=_build_nuggets_annotation_text(other_best_guess),
+ size=HIGHLIGHT_SIZE)
+
+ def _update_legend(self):
+ # Updates the legend associated with this grid according to the current value of the internal variables
+ # `_color_with_meanings` and `_accessible_color_palette`
+
+ def map_to_correct_color(accessible_color):
+ # Maps each color to its standard or accessible version depending on the value of
+ # `_accessible_color_palette`
+ return accessible_color.corresponding_accessible_color if self._accessible_color_palette \
+ else accessible_color.color
+
+ colors_with_meanings = list(map(lambda color_with_meaning: (map_to_correct_color(color_with_meaning[0]),
+ color_with_meaning[1]),
+ self._colors_with_meanings))
+ self._legend.update_colors_and_meanings(colors_with_meanings)
+
+
+class EmbeddingVisualizerWindow(EmbeddingVisualizer, QMainWindow):
+ """
+ Class realizing an `EmbeddingVisualizer` in a separate window by inheriting from `EmbeddingVisualizer` and
+ `QMainWindow`.
+
+ Methods
+ -------
+ showEvent():
+ Shows the associated window.
+ closeEvent():
+ Closes the associated window.
+ """
+
+ def __init__(self,
+ colors_with_meanings: List[Tuple[AccessibleColor, str]],
+ attribute: Attribute = None,
+ nuggets: List[InformationNugget] = None,
+ currently_highlighted_nugget: InformationNugget = None,
+ best_guess: InformationNugget = None,
+ other_best_guesses: List[InformationNugget] = None,
+ accessible_color_palette: bool = False):
+ """
+ Initializes an instance of this class by calling constructor of `EmbeddingVisualizer` and `QMainWindow` and sets
+ up the required UI components.
+ The parameters are propagated to the `EmbeddingVisualizer` constructor in order to add content to the grid
+ initially.
+
+ Parameters
+ ----------
+ colors_with_meanings: List[Tuple[QColor, str]]
+ List of colors occurring in the grid associated with their meaning used to fill the given legend.
+ attribute: Attribute = None
+ `Attribute` instance representing the attribute to which the nuggets displayed within the grid belong to as
+ its embedding is displayed in the grid as well.
+ nuggets: List[InformationNugget] = None
+ Nuggets whose dimension-reduced embedding vectors should be displayed within the grid.
+ currently_highlighted_nugget: InformationNugget = None
+ Refers to the nugget which is currently selected and therefore should be highlighted. If none, nothing is
+ highlighted.
+ best_guess: InformationNugget = None
+ Refers to the best guess of the document represented by this grid and therefore should be highlighted. If
+ none, nothing is highlighted. Applicable only in case the grid belongs to the document view and not to the
+ document overview screen.
+ other_best_guesses: List[InformationNugget] = None
+ Nuggets representing best guesses from other documents which should be displayed in this grid as well
+ initially. Applicable only in case the grid belongs to the document view and not to the document overview
+ screen.
+ accessible_color_palette: bool
+ Specifies whether the colors used by the points displayed in the grid are accessible - usable for people
+ with color blindness - or not.
+ """
+
+ # Call super constructors
+ EmbeddingVisualizer.__init__(self,
+ legend=EmbeddingVisualizerLegend(),
+ colors_with_meanings=colors_with_meanings,
+ attribute=attribute,
+ nuggets=nuggets,
+ currently_highlighted_nugget=currently_highlighted_nugget,
+ best_guess=best_guess,
+ accessible_color_palette=accessible_color_palette)
+ QMainWindow.__init__(self)
+
+ # Set up window
+ self.setWindowTitle("3D Grid Visualizer")
+ self.setGeometry(100, 100, WINDOW_WIDTH, WINDOW_HEIGHT)
+
+ # Set up layout
+ central_widget = QWidget()
+ self.setCentralWidget(central_widget)
+ self.fullscreen_layout = QVBoxLayout()
+ central_widget.setLayout(self.fullscreen_layout)
+
+ # Add grid and legend item to the UI
+ self.fullscreen_layout.addWidget(self._gl_widget, stretch=7)
+ self.fullscreen_layout.addWidget(self._legend, stretch=1)
+
+ self._add_grids()
+
+ # If values which should be displayed in the grid are present, add them the grid, else make itself invisible
+ if (attribute is not None and
+ nuggets is not None and
+ currently_highlighted_nugget is not None and
+ best_guess is not None):
+ self.update_and_display_params(attribute, nuggets, currently_highlighted_nugget, best_guess,
+ other_best_guesses)
+ else:
+ self.setVisible(False)
+
+ def showEvent(self, event):
+ """
+ Shows the associated window and start timer tracking the time, the window is opened.
+ """
+
+ super().showEvent(event)
+ Tracker().start_timer(str(self.__class__))
+
+ def closeEvent(self, event):
+ """
+ Closes the associated window and stop timer tracking the time, the window is opened.
+ """
+
+ Tracker().stop_timer(str(self.__class__))
+ event.accept()
+
+
+class EmbeddingVisualizerWidget(EmbeddingVisualizer, QWidget):
+ """
+ Class realizing an `EmbeddingVisualizer` within a widget by inheriting from `EmbeddingVisualizer` and `QWidget`.
+
+ Each instance of this visualizer is associated with a fullscreen version which displays the same content and can
+ be opened and closed with buttons.
+
+ Methods
+ -------
+ enable_accessible_color_palette():
+ Enables accessible color palette in this visualizer as well as in fullscreen version if opened.
+ disable_accessible_color_palette():
+ Disables accessible color palette in this visualizer as well as in fullscreen version if opened.
+ return_from_embedding_visualizer_window(self):
+ Close fullscreen version of this visualizer.
+ update_other_best_guesses():
+ Update variable holding best guesses from other documents.
+ highlight_selected_nugget(nugget):
+ Highlights selected nugget in this visualizer as well in fullscreen version if opened.
+ highlight_best_guess(best_guess: InformationNugget):
+ Highlights the best guess of the corresponding document in this visualizer as well in fullscreen version if
+ opened.
+ reset():
+ Resets this widget by calling superclass implementation and resetting internal variables.
+ """
+
+ tracker: Tracker = Tracker()
+
+ def __init__(self):
+ """
+ Initializes an instance of this class by determining the colors with their associated meanings used by the
+ corresponding grid, calling the super constructors and setting up the required UI components.
+
+ Required UI components cover the 3D grid, as well as buttons to show grid in separate window as well as adding /
+ removing best guesses from other documents to / from the grid.
+
+ The `EmbeddingVisualizer` is initialized without any nuggets leading to an initially empty grid.
+ """
+
+ # Determine colors with their associated meanings used by the corresponding grid
+ colors_with_meanings = [
+ (AccessibleColor(WHITE, WHITE), 'Below threshold'),
+ (AccessibleColor(RED, ACC_RED), 'Above threshold'),
+ (AccessibleColor(BLUE, ACC_BLUE), 'Documents best match'),
+ (AccessibleColor(YELLOW, ACC_YELLOW), 'Other documents best matches'),
+ (AccessibleColor(PURPLE, ACC_PURPLE), 'Could not determine correct color')
+ ]
+
+ # Call super constructors
+ EmbeddingVisualizer.__init__(self, EmbeddingVisualizerLegend(), colors_with_meanings)
+ QWidget.__init__(self)
+
+ # Set up layout
+ self.layout = QVBoxLayout()
+ self.layout.setContentsMargins(0, 0, 0, 0)
+ self.layout.setSpacing(0)
+ self.setLayout(self.layout)
+
+ # Set up grid widget and add to layout
+ self._gl_widget.setMinimumHeight(300) # Set the initial height of the grid to 200
+ self.layout.addWidget(self._gl_widget)
+
+ self.layout.addWidget(self._legend)
+
+ # Set up buttons and add to layout
+ self.best_guesses_widget = QWidget()
+ self.best_guesses_widget_layout = QHBoxLayout(self.best_guesses_widget)
+ self.best_guesses_widget_layout.setContentsMargins(0, 0, 0, 0)
+ self.best_guesses_widget_layout.setSpacing(0)
+ self.fullscreen_button = QPushButton("Show 3D Grid in windowed fullscreen mode")
+ self.fullscreen_button.clicked.connect(self._show_embedding_visualizer_window)
+ self.best_guesses_widget_layout.addWidget(self.fullscreen_button)
+ self.show_other_best_guesses_button = QPushButton("Show best guesses from other documents")
+ self.show_other_best_guesses_button.clicked.connect(self._handle_show_other_best_guesses_clicked)
+ self.best_guesses_widget_layout.addWidget(self.show_other_best_guesses_button)
+ self.remove_other_best_guesses_button = QPushButton("Stop showing best guesses from other documents")
+ self.remove_other_best_guesses_button.setEnabled(False)
+ self.remove_other_best_guesses_button.clicked.connect(self._handle_remove_other_best_guesses_clicked)
+ self.best_guesses_widget_layout.addWidget(self.remove_other_best_guesses_button)
+ self.layout.addWidget(self.best_guesses_widget)
+
+ # Add items representing the grid itself to the grid widget
+ self._add_grids()
+
+ # Init internal variables
+ self._fullscreen_window = None
+ self._other_best_guesses = None
+
+ def enable_accessible_color_palette(self):
+ """
+ Invokes `enable_accessible_color_palette()` method of the `EmbeddingVisualizer` superclass for this instance. If
+ present, invokes the same method on the `EmbeddingVisualizer` instance realizing the fullscreen version of
+ this visualizer to enable accessible color palette there as well.
+
+ More detailed information about the `enable_accessible_color_palette()` method are elaborated in implementation
+ of superclass.
+ """
+
+ # Call superclass implementation to enable accessible color palette on this grid
+ super().enable_accessible_color_palette()
+
+ # Enable accessible color palette in fullscreen window if present
+ if self._fullscreen_window is not None:
+ self._fullscreen_window.enable_accessible_color_palette()
+
+ def disable_accessible_color_palette(self):
+ """
+ Invokes `disable_accessible_color_palette()` method of the `EmbeddingVisualizer` superclass for this instance.
+ If present, invokes the same method on the `EmbeddingVisualizer` instance realizing the fullscreen version of
+ this visualizer to disable accessible color palette there as well.
+
+ More detailed information about the `disable_accessible_color_palette()` method are elaborated in implementation
+ of superclass.
+ """
+
+ # Call superclass implementation to disable accessible color palette on this grid
+ super().disable_accessible_color_palette()
+
+ # Disable accessible color palette in fullscreen window if present
+ if self._fullscreen_window is not None:
+ self._fullscreen_window.disable_accessible_color_palette()
+
+ def return_from_embedding_visualizer_window(self):
+ """
+ Close fullscreen version of this visualizer.
+ """
+
+ self._fullscreen_window.close()
+ self._fullscreen_window = None
+
+ def update_other_best_guesses(self, other_best_guesses: List[InformationNugget]):
+ """
+ Update variable holding best guesses from other documents.
+
+ Parameters
+ ----------
+ other_best_guesses: List[InformationNugget]
+ List of other best guesses from other documents to which the internal variable should be updated.
+ """
+
+ self._other_best_guesses = other_best_guesses
+
+ def highlight_selected_nugget(self, selected_nugget: InformationNugget):
+ """
+ Highlights selected nugget in this visualizer as well in fullscreen version if present.
+ More details are provided in documentation of implementation in `EmbeddingVisualizer`.
+
+ Realized by calling implementation in `EmbeddingVisualizer` of this method and same method on fullscreen version
+ of this visualizer.
+
+ Parameters
+ ----------
+ selected_nugget: InformationNugget
+ Nugget whose representation in the grid should be highlighted.
+ """
+
+ # Highlight selected nugget in this visualizer
+ super().highlight_selected_nugget(selected_nugget)
+
+ # Highlight selected nugget in fullscreen version
+ if self._fullscreen_window is not None:
+ self._fullscreen_window.highlight_selected_nugget(selected_nugget)
+
+ def highlight_best_guess(self, best_guess: InformationNugget):
+ """
+ Highlights the best guess of the corresponding document in this visualizer as well in fullscreen version if
+ present.
+ More details are provided in documentation of implementation in `EmbeddingVisualizer`.
+
+ Realized by calling implementation in `EmbeddingVisualizer` of this method and same method on fullscreen
+ version of this visualizer.
+
+ Applicable only if this visualizer belongs to the document view as only in this case the visualizer covers one
+ document providing only one best guess.
+
+ Parameters
+ ----------
+ best_guess: InformationNugget
+ Nugget whose representation in the grid should be highlighted.
+ """
+
+ # Highlight selected nugget in this visualizer
+ super().highlight_best_guess(best_guess)
+
+ # Highlight selected nugget in fullscreen version
+ if self._fullscreen_window is not None:
+ self._fullscreen_window.highlight_best_guess(best_guess)
+
+ def reset(self):
+ """
+ Resets this widget by calling superclass implementation and resetting internal variables.
+ More details are provided in documentation of superclass implementation.
+ """
+
+ # Call superclass implementation
+ super().reset()
+
+ # Reset internal variables
+ self._fullscreen_window = None
+ self._other_best_guesses = None
+
+ self.show_other_best_guesses_button.setEnabled(True)
+ self.remove_other_best_guesses_button.setEnabled(False)
+
+ def hide(self):
+ """
+ Hide this widget and close fullscreen version if present.
+ """
+
+ super().hide()
+ if self._fullscreen_window is not None:
+ self._fullscreen_window.close()
+
+ @track_button_click("fullscreen embedding visualizer")
+ def _show_embedding_visualizer_window(self):
+ # Opens the fullscreen version of this visualizer and track that the corresponding has been clicked.
+
+ if self._fullscreen_window is None:
+ self._fullscreen_window = EmbeddingVisualizerWindow(colors_with_meanings=self._colors_with_meanings,
+ attribute=self._attribute,
+ nuggets=list(self._nugget_to_displayed_items.keys()),
+ currently_highlighted_nugget=self._currently_highlighted_nugget,
+ best_guess=self._best_guess)
+ self._fullscreen_window.show()
+
+ @track_button_click(button_name="show other best guesses from other documents")
+ def _handle_show_other_best_guesses_clicked(self):
+ # Adds the best guesses from other documents - contained in the internal variable `_other_best_guesses` - to
+ # this visualizer and the fullscreen version if opened
+ # Track that the corresponding button has been clicked.
+
+ # Log warning if no other best guesses are available
+ if self._other_best_guesses is None:
+ logger.warning("Can not display best guesses from other documents as these best guesses have not been "
+ "initialized yet.")
+ return
+
+ # Only the currently applicable button of the buttons to add and remove other best guesses should be enabled
+ self.show_other_best_guesses_button.setEnabled(False)
+ self.remove_other_best_guesses_button.setEnabled(True)
+
+ # Add other best guesses to this visualizer and fullscreen version if opened
+ self.display_other_best_guesses(self._other_best_guesses)
+ if self._fullscreen_window is not None:
+ self._fullscreen_window.display_other_best_guesses(self._other_best_guesses)
+
+ @track_button_click(button_name="stop showing other best guesses from other documents")
+ def _handle_remove_other_best_guesses_clicked(self):
+ # Removes the best guesses from other documents - contained in the internal variable `_other_best_guesses` -
+ # from this visualizer and the fullscreen version if opened.
+ # Track that the corresponding button has been clicked.
+
+ self.show_other_best_guesses_button.setEnabled(True)
+ self.remove_other_best_guesses_button.setEnabled(False)
+
+ self._remove_nuggets_from_widget(self._other_best_guesses)
+ if self._fullscreen_window is not None:
+ self._fullscreen_window._remove_nuggets_from_widget(self._other_best_guesses)
+
+
+dialog = InfoDialog()
+
+
+class BarChartVisualizerWidget(QWidget):
+ """
+ A QWidget-based class that provides a UI widget for visualizing cosine values in a bar chart.
+ It allows users to update the data, display a bar chart with certainty values, and interact with
+ the chart (e.g., displaying annotations on click).
+ """
+ def __init__(self, parent=None):
+ """
+ Initializes the BarChartVisualizerWidget, sets up the layout and button,
+ and prepares attributes to store data, the chart window, and interactive state.
+ """
+ super(BarChartVisualizerWidget, self).__init__(parent)
+ self.layout = QVBoxLayout(self)
+ self.layout.setContentsMargins(0, 0, 0, 0)
+ self.button = QPushButton("Show Bar Chart with cosine values")
+ self.layout.addWidget(self.button)
+ self.data = []
+ self.button.clicked.connect(self.show_bar_chart)
+ self.window: QMainWindow = None
+ self.current_annotation_index = None
+ self.bar = None
+
+ def update_data(self, nuggets):
+ """
+ Updates the widget's data based on the provided nuggets. Resets any previous state
+ and processes the nuggets to extract text and cosine values.
+
+ :param nuggets: List of information nuggets with cosine similarity values.
+ """
+ self.reset()
+
+ self.data = [(_create_sanitized_text(nugget),
+ np.round(nugget[CachedDistanceSignal], 3))
+ for nugget in nuggets]
+
+ @track_button_click("show bar chart")
+ def show_bar_chart(self):
+ """
+ Displays the bar chart using the current data. If no data is available, the method returns early. Represents a button
+ """
+ if not self.data:
+ return
+ self.plot_bar_chart()
+
+ def _unique_nuggets(self):
+ """
+ Ensures that only the most relevant (i.e., minimal cosine distance) nuggets are included in the data.
+ Filters out duplicates based on text, keeping only the lowest cosine distance for each unique nugget.
+ """
+ min_dict = {}
+ for item in self.data:
+ key, value = item
+ if key not in min_dict or value < min_dict[key]:
+ min_dict[key] = value
+ self.data = [(key, min_dict[key]) for key in min_dict]
+
+ def plot_bar_chart(self):
+ """
+ Generates and displays the bar chart with cosine-based certainty values.
+ Includes interactive functionality for annotations and customizable axes.
+ """
+ self._unique_nuggets()
+ if self.window is not None:
+ self.window.close()
+
+ fig = Figure()
+ ax = fig.add_subplot(111)
+ texts, distances = zip(*self.data)
+
+ rounded_certainties = np.round(np.ones(len(distances)) - distances, 3)
+ x_positions = [0]
+ for i, y_val in enumerate(rounded_certainties):
+ if i == 0:
+ continue
+ if rounded_certainties[i - 1] != y_val:
+ x_positions.append(x_positions[i - 1] + 2)
+ else:
+ x_positions.append(x_positions[i - 1] + 1)
+
+ self.bar = ax.bar(x_positions, rounded_certainties, alpha=0.75, picker=True, color=_get_colors(distances))
+ ax.set_xticks([])
+ ax.set_ylabel('Certainty', fontsize=15)
+ ax.set_xlabel('Information Nuggets', fontsize=15)
+ fig.subplots_adjust(left=0.115, right=0.920, top=0.945, bottom=0.065)
+ for idx, rect in enumerate(self.bar):
+ height = rect.get_height()
+ ax.text(
+ rect.get_x() + rect.get_width() / 2,
+ height / 2,
+ f'{texts[idx]}',
+ ha='center',
+ va='center',
+ rotation=90, # Rotate text by 90 degrees
+ fontsize=12,
+ color='white' # fontcolors[idx]# Optional: Adjust font size
+ )
+
+ self.bar_chart_canvas = FigureCanvas(fig)
+ self.bar_chart_canvas.setMinimumWidth(
+ max(0.9 * WINDOW_WIDTH, len(texts) * 50)) # Set a minimum width based on number of bars
+
+ scroll_area = QScrollArea()
+ scroll_area.setWidget(self.bar_chart_canvas)
+ scroll_area.setWidgetResizable(True)
+ scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
+ scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
+ scroll_area.setFrameShape(QFrame.Shape.NoFrame)
+
+ self.window = QMainWindow()
+ self.window.closeEvent = self.closeWindowEvent
+ self.window.showEvent = self.showWindowEvent
+
+ self.window.setWindowTitle("Bar Chart")
+ self.window.setGeometry(100, 100, WINDOW_WIDTH, WINDOW_HEIGHT)
+ self.window.setCentralWidget(scroll_area)
+
+ self.bar_chart_toolbar = NavigationToolbar(self.bar_chart_canvas, self.window)
+ self.window.addToolBar(self.bar_chart_toolbar)
+
+ self.window.show()
+ self.bar_chart_canvas.draw()
+
+ self.annotation = ax.annotate(
+ "", xy=(0, 0), xytext=(20, 20),
+ textcoords="offset points", bbox=dict(boxstyle="round", fc="w"),
+ arrowprops=dict(arrowstyle="->")
+ )
+ self.annotation.set_visible(False)
+ self.bar_chart_canvas.mpl_connect('pick_event', self.on_pick)
+
+ self.texts = texts
+ self.distances = rounded_certainties
+
+ global dialog
+ dialog.load_markdown_file('wannadb_ui/resources/info_popups/barchart_tutorial.md') # Path to your .md file
+ # dialog.set_image_list([None, 'image1.png', 'image2.png', 'image3.png'])
+ dialog.exec()
+
+ def on_pick(self, event):
+ """
+ Handles click events on the bar chart. When a bar is clicked, displays an annotation
+ with detailed information about the clicked nugget.
+
+ :param event: The pick event triggered by clicking a bar.
+ """
+ if isinstance(event.artist, Rectangle):
+ patch = event.artist
+ index = self.bar.get_children().index(patch)
+ if self.current_annotation_index == index and self.annotation.get_visible():
+ # If the same bar is clicked again, hide the annotation
+ self.annotation.set_visible(False)
+ self.current_annotation_index = None
+ else:
+ # Show annotation for the clicked bar
+ text = f"Information Nugget: \n{self.texts[index]} \n\n Value: {self.distances[index]}"
+ self.annotation.set_text(text)
+ annotation_x = patch.get_x() + patch.get_width() / 2
+ annotation_y = patch.get_height() / 2
+ self.annotation.xy = (annotation_x, annotation_y)
+ self.annotation.set_visible(True)
+ self.current_annotation_index = index
+ self.bar_chart_canvas.draw_idle()
+
+ def reset(self):
+ """
+ Resets the state of the bar chart widget, clearing any previously stored data and bars.
+ """
+ self.data = []
+ self.bar = None
+
+ def showWindowEvent(self, event):
+ """
+ These and method below needed for tracking how much time user spent on the bar chart
+ """
+ super().showEvent(event)
+ Tracker().start_timer(str(self.__class__))
+
+ def closeWindowEvent(self, event):
+ event.accept()
+ Tracker().stop_timer(str(self.__class__))
+
+
+class ScatterPlotVisualizerWidget(QWidget):
+ def __init__(self, parent=None):
+ super(ScatterPlotVisualizerWidget, self).__init__(parent)
+ self.layout = QVBoxLayout(self)
+ self.layout.setContentsMargins(0, 0, 0, 0)
+ self.button = QPushButton("Show Scatter Plot with Cosine Distances")
+ self.layout.addWidget(self.button)
+ self.data = [] # Store data as a list of tuples
+ self.button.clicked.connect(self.show_scatter_plot)
+ self.scatter_plot_canvas = None
+ self.scatter_plot_toolbar = None
+ self.window = None
+ self.annotation = None
+ self.texts = None
+ self.distances = None
+ self.y = None
+ self.scatter = None
+ self.accessible_color_palette = False
+
+ def enable_accessible_color_palette(self):
+ self.accessible_color_palette = True
+
+ def disable_accessible_color_palette(self):
+ self.accessible_color_palette = False
+
+ def update_data(self, nuggets):
+ self.reset()
+
+ self.data = [(_create_sanitized_text(nugget),
+ np.round(nugget[CachedDistanceSignal], 3))
+ for nugget in nuggets]
+
+ def reset(self):
+ self.data = []
+ self.texts = None
+ self.distances = None
+ self.y = None
+ self.scatter = None
+ if self.window is not None:
+ self.window.close()
+ self.scatter_plot_canvas = None
+ self.scatter_plot_toolbar = None
+ self.window = None
+ self.annotation = None
+
+ @track_button_click("show scatter plot")
+ def show_scatter_plot(self):
+ if not self.data:
+ return
+
+ # Clear data to prevent duplication
+ self.data = list(set(self.data))
+
+ # Close existing scatter plot
+ if self.window is not None:
+ self.window.close()
+
+ fig = Figure()
+ ax = fig.add_subplot(111)
+ texts, distances = zip(*self.data)
+
+ # Round the distances to a fixed number of decimal places
+ rounded_distances = np.round(distances, 3)
+
+ # Ensure consistent x-values for the same rounded distance
+ distance_map = {}
+ for original, rounded in zip(distances, rounded_distances):
+ if rounded not in distance_map:
+ distance_map[rounded] = original
+
+ consistent_distances = [distance_map[rd] for rd in rounded_distances]
+
+ # Generate jittered y-values for points with the same x-value
+ unique_distances = {}
+ for i, distance in enumerate(consistent_distances):
+ if distance not in unique_distances:
+ unique_distances[distance] = []
+ unique_distances[distance].append(i)
+
+ y = np.zeros(len(distances))
+ for distance, indices in unique_distances.items():
+ jitter = np.linspace(-0.4, 0.4, len(indices))
+ for j, index in enumerate(indices):
+ y[index] = jitter[j]
+
+ # Generating a list of colors for each point
+ num_points = len(distances)
+ colormap = plt.cm.jet
+ norm = plt.Normalize(min(rounded_distances), max(rounded_distances))
+ colors = colormap(norm(rounded_distances))
+
+ # Plot the points
+ scatter = ax.scatter(rounded_distances, y, c=colors, alpha=0.75, picker=True) # Enable picking
+
+ ax.set_xlabel("Cosine Distance")
+ ax.set_xlim(min(rounded_distances) - 0.05,
+ max(rounded_distances) + 0.05) # Adjust x-axis limits for better visibility
+ ax.set_yticks([]) # Remove y-axis labels to avoid confusion
+ fig.subplots_adjust(left=0.020, right=0.980, top=0.940, bottom=0.075)
+ # fig.tight_layout()
+
+ # Create canvas
+ self.scatter_plot_canvas = FigureCanvas(fig)
+
+ # Create a new window for the plot
+ self.window = QMainWindow()
+ self.window.closeEvent = self.closeWindowEvent
+ self.window.showEvent = self.showWindowEvent
+ self.window.setWindowTitle("Scatter Plot")
+
+ self.window.setGeometry(100, 100, WINDOW_WIDTH, WINDOW_HEIGHT)
+
+ # Set the central widget of the window to the canvas
+ self.window.setCentralWidget(self.scatter_plot_canvas)
+
+ # Add NavigationToolbar to the window
+ self.scatter_plot_toolbar = NavigationToolbar(self.scatter_plot_canvas, self.window)
+ self.window.addToolBar(self.scatter_plot_toolbar)
+
+ # Show the window
+ self.window.show()
+ self.scatter_plot_canvas.draw()
+
+ # Create an annotation box
+ self.annotation = ax.annotate(
+ "", xy=(0, 0), xytext=(20, 20),
+ textcoords="offset points", bbox=dict(boxstyle="round", fc="w"),
+ arrowprops=dict(arrowstyle="->")
+ )
+ self.annotation.set_visible(False)
+
+ # Connect the pick event
+ self.scatter_plot_canvas.mpl_connect("pick_event", self.on_pick)
+
+ # Store the data for use in the event handler
+ self.texts = texts
+ self.distances = rounded_distances
+ self.y = y
+ self.scatter = scatter
+
+ def on_pick(self, event):
+ if event.artist != self.scatter:
+ return
+ # Get index of the picked point
+ ind = event.ind[0]
+
+ # Update annotation text and position
+ self.annotation.xy = (self.distances[ind], self.y[ind])
+ text = f"Text: {self.texts[ind]}\nValue: {self.distances[ind]:.3f}"
+ self.annotation.set_text(text)
+ self.annotation.set_visible(True)
+ self.scatter_plot_canvas.draw_idle()
+
+ def reset(self):
+ self.data = []
+ self.bar = None
+
+ def showWindowEvent(self, event):
+ super().showEvent(event)
+ Tracker().start_timer(str(self.__class__))
+
+ def closeWindowEvent(self, event):
+ event.accept()
+ Tracker().stop_timer(str(self.__class__))
+
diff --git a/wannadb_ui/wannadb_api.py b/wannadb_ui/wannadb_api.py
index 7292422f..b94bcbcd 100644
--- a/wannadb_ui/wannadb_api.py
+++ b/wannadb_ui/wannadb_api.py
@@ -14,13 +14,14 @@
from wannadb.matching.custom_match_extraction import FaissSentenceSimilarityExtractor
from wannadb.matching.distance import SignalsMeanDistance
from wannadb.matching.matching import RankingBasedMatcher
+from wannadb.preprocessing.dimension_reduction import PCAReducer
from wannadb.preprocessing.embedding import BERTContextSentenceEmbedder, RelativePositionEmbedder, \
SBERTTextEmbedder, SBERTLabelEmbedder, SBERTDocumentSentenceEmbedder
from wannadb.preprocessing.extraction import StanzaNERExtractor, SpacyNERExtractor
from wannadb.preprocessing.label_paraphrasing import OntoNotesLabelParaphraser, \
SplitAttributeNameLabelParaphraser
from wannadb.preprocessing.normalization import CopyNormalizer
-from wannadb.preprocessing.other_processing import ContextSentenceCacher
+from wannadb.preprocessing.other_processing import ContextSentenceCacher, DuplicatedNuggetsCleaner
from wannadb.statistics import Statistics
from wannadb.status import StatusCallback
from wannadb_parsql.cache_db import SQLiteCacheDB
@@ -125,7 +126,10 @@ def create_document_base(self, path, attribute_names, statistics):
SBERTTextEmbedder("SBERTBertLargeNliMeanTokensResource"),
BERTContextSentenceEmbedder("BertLargeCasedResource"),
SBERTDocumentSentenceEmbedder("SBERTBertLargeNliMeanTokensResource"),
- RelativePositionEmbedder()
+ RelativePositionEmbedder(),
+ DuplicatedNuggetsCleaner(),
+ PCAReducer(),
+ #TSNEReducer()
])
# run preprocessing phase
@@ -351,6 +355,8 @@ def interactive_table_population(self, document_base, statistics):
ContextSentenceCacher(),
SBERTLabelEmbedder("SBERTBertLargeNliMeanTokensResource"),
SBERTDocumentSentenceEmbedder("SBERTBertLargeNliMeanTokensResource"),
+ PCAReducer(),
+ #TSNEReducer(),
RankingBasedMatcher(
distance=SignalsMeanDistance(
signal_identifiers=[
@@ -375,7 +381,9 @@ def interactive_table_population(self, document_base, statistics):
SBERTLabelEmbedder("SBERTBertLargeNliMeanTokensResource"),
SBERTTextEmbedder("SBERTBertLargeNliMeanTokensResource"),
BERTContextSentenceEmbedder("BertLargeCasedResource"),
- RelativePositionEmbedder()
+ RelativePositionEmbedder(),
+ PCAReducer(),
+ #TSNEReducer()
]
),
find_additional_nuggets=FaissSentenceSimilarityExtractor(num_similar_sentences=20, num_phrases_per_sentence=3),