diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 4caadea..1dfce87 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import torch import torch.nn as nn @@ -129,7 +129,7 @@ def forward( input_ids: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> dict[str, Optional[torch.Tensor]]: + ) -> Dict[str, Optional[torch.Tensor]]: """Converts input token IDs to their corresponding embeddings. Args: @@ -200,15 +200,18 @@ def _get_sentence_embedding( token_embeddings: torch.Tensor, attention_mask: torch.Tensor, return_label_attention_matrix: bool = False, - ) -> torch.Tensor: + ) -> Dict[str, Optional[torch.Tensor]]: """ Compute sentence embedding from embedded tokens - "remove" second dimension. Args (output from dataset collate_fn): token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens + return_label_attention_matrix (bool): Whether to compute and return the label attention matrix Returns: - torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim) + Dict[str, Optional[torch.Tensor]]: A dictionary containing: + - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled + - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None """ # average over non-pad token embeddings @@ -219,14 +222,20 @@ def _get_sentence_embedding( if self.attention_config is not None: if self.attention_config.aggregation_method is not None: # default is "mean" if self.attention_config.aggregation_method == "first": - return token_embeddings[:, 0, :] + return { + "sentence_embedding": token_embeddings[:, 0, :], + "label_attention_matrix": None, + } elif self.attention_config.aggregation_method == "last": lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1 - return token_embeddings[ - torch.arange(token_embeddings.size(0)), - lengths - 1, - :, - ] + return { + "sentence_embedding": token_embeddings[ + torch.arange(token_embeddings.size(0)), + lengths - 1, + :, + ], + "label_attention_matrix": None, + } else: if self.attention_config.aggregation_method != "mean": raise ValueError(