From bd3734f036bf1c409a61236f91ea4626fadbbdef Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:34:01 +0000 Subject: [PATCH 1/2] Initial plan From af56bbec7b7046bb4aa70c9c4813701c41622e42 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:40:04 +0000 Subject: [PATCH 2/2] fix: ensure label_indices uses correct device and dtype Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com> --- .gitignore | 1 + torchTextClassifiers/model/components/text_embedder.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e65366e..586ae23 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,4 @@ example_files/ _site/ .quarto/ **/*.quarto_ipynb +my_ttc/ diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 7bd030c..b5ce92e 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -324,7 +324,9 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F compute_attention_matrix = bool(compute_attention_matrix) # 1. Create label indices [0, 1, ..., C-1] for the whole batch - label_indices = torch.arange(self.num_classes).expand(B, -1) + label_indices = torch.arange( + self.num_classes, dtype=torch.long, device=token_embeddings.device + ).expand(B, -1) all_label_embeddings = self.label_embeds( label_indices