diff --git a/torchTextClassifiers/model/model.py b/torchTextClassifiers/model/model.py index 5e4cc66..3630e62 100644 --- a/torchTextClassifiers/model/model.py +++ b/torchTextClassifiers/model/model.py @@ -6,7 +6,7 @@ """ import logging -from typing import Annotated, Optional +from typing import Annotated, Optional, Union import torch from torch import nn @@ -120,7 +120,7 @@ def forward( categorical_vars: Annotated[torch.Tensor, "batch num_cats"], return_label_attention_matrix: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """ Memory-efficient forward pass implementation. @@ -128,15 +128,24 @@ def forward( input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features) + return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix Returns: - torch.Tensor: Model output scores for each class - shape (batch_size, num_classes) - Raw, not softmaxed. + Union[torch.Tensor, dict[str, torch.Tensor]]: + - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes) + containing raw logits (not softmaxed) + - If return_label_attention_matrix is True: dict with keys: + - "logits": torch.Tensor of shape (batch_size, num_classes) + - "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len) """ encoded_text = input_ids # clearer name label_attention_matrix = None if self.text_embedder is None: x_text = encoded_text.float() + if return_label_attention_matrix: + raise ValueError( + "return_label_attention_matrix=True requires a text_embedder with label attention enabled" + ) else: text_embed_output = self.text_embedder( input_ids=encoded_text,