From 8e7239dfae6565b2f1b06e1867b707c8dbcfb64e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:45:49 +0000 Subject: [PATCH 1/3] Initial plan From 48cc6ac21edabc3b4c46d58267665044f1cf5005 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:50:02 +0000 Subject: [PATCH 2/3] Apply attention mask in LabelAttentionClassifier cross-attention Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .../model/components/text_embedder.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index ddad80a..da21c30 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -235,7 +235,9 @@ def _get_sentence_embedding( if self.enable_label_attention: label_attention_result = self.label_attention_module( - token_embeddings, compute_attention_matrix=return_label_attention_matrix + token_embeddings, + attention_mask=attention_mask, + compute_attention_matrix=return_label_attention_matrix, ) sentence_embedding = label_attention_result[ "sentence_embedding" @@ -320,10 +322,11 @@ def __init__(self, config: TextEmbedderConfig): self.c_v = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False) - def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = False): + def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = None, compute_attention_matrix: Optional[bool] = False): """ Args: token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input. + attention_mask (torch.Tensor, optional), shape (batch, seq_len): Attention mask indicating non-pad tokens (1 for real tokens, 0 for padding). compute_attention_matrix (bool): Whether to compute and return the attention matrix. Returns: dict: { @@ -358,7 +361,18 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F v.transpose(1, 2), ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D) - y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa) + # Prepare attention mask for scaled_dot_product_attention + # attention_mask: (B, T) with 1 for real tokens, 0 for padding + # scaled_dot_product_attention expects attn_mask: (B, H, Q, K) or broadcastable shape + # where True means "mask out" (ignore), False means "attend to" + attn_mask = None + if attention_mask is not None: + # Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to) + attn_mask = (attention_mask == 0) # (B, T) + # Expand to (B, 1, 1, T) for broadcasting across heads and queries + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=self.enable_gqa) # Re-assemble the heads side by side and project back to residual stream y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model) @@ -366,9 +380,17 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F attention_matrix = None if compute_attention_matrix: - # size (B, n_head, n_labels, seq_len) - we let the user handle aggregation over heads if desired - attention_matrix = torch.softmax( - torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5), dim=-1 - ) + # Compute attention scores + # size (B, n_head, n_labels, seq_len) + attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) + + # Apply mask to attention scores before softmax + if attention_mask is not None: + # attn_mask is already in the right shape: (B, 1, 1, T) + # We need to apply it to scores of shape (B, n_head, n_labels, T) + # Set masked positions to -inf so they become 0 after softmax + attention_scores = attention_scores.masked_fill(attn_mask, float('-inf')) + + attention_matrix = torch.softmax(attention_scores, dim=-1) return {"sentence_embedding": y, "attention_matrix": attention_matrix} From 95b39690250d41f463b4c3ace13068cf7eeea2fc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:51:20 +0000 Subject: [PATCH 3/3] Fix trailing whitespace in attention matrix computation Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- torchTextClassifiers/model/components/text_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index da21c30..627bbf0 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -383,14 +383,14 @@ def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = Non # Compute attention scores # size (B, n_head, n_labels, seq_len) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) - + # Apply mask to attention scores before softmax if attention_mask is not None: # attn_mask is already in the right shape: (B, 1, 1, T) # We need to apply it to scores of shape (B, n_head, n_labels, T) # Set masked positions to -inf so they become 0 after softmax attention_scores = attention_scores.masked_fill(attn_mask, float('-inf')) - + attention_matrix = torch.softmax(attention_scores, dim=-1) return {"sentence_embedding": y, "attention_matrix": attention_matrix}