From 80db90fe20984341aeb05c2aef035f18dc5dbeb7 Mon Sep 17 00:00:00 2001 From: Joshua Chung Date: Sat, 7 Jun 2025 23:52:32 +0000 Subject: [PATCH 1/2] transformer.py: replace deprecated `.byte()` with `.bool()` pytorch deprecated byte mask input Closes #4 --- src/model/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/transformer.py b/src/model/transformer.py index 3f9cb92..c265ba1 100644 --- a/src/model/transformer.py +++ b/src/model/transformer.py @@ -711,7 +711,7 @@ def generate(self, src_enc, src_len, max_len=200, sample_temperature=None): # add to unfinished sentences if cur_len == max_len: - generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index) + generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index) # sanity check assert (generated == self.eos_index).sum() == 2 * bs From 48a499228cadf6e7e6634def69a2349c83eb8962 Mon Sep 17 00:00:00 2001 From: Joshua Chung Date: Sun, 8 Jun 2025 00:07:00 +0000 Subject: [PATCH 2/2] lstm.py: replace deprecated `.byte()` with `.bool()` --- src/model/lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/lstm.py b/src/model/lstm.py index 0071751..aefb642 100644 --- a/src/model/lstm.py +++ b/src/model/lstm.py @@ -229,7 +229,7 @@ def generate(self, src_enc, src_len, max_len=200, sample_temperature=None): # add to unfinished sentences if cur_len == max_len: - generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index) + generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index) # sanity check assert (generated == self.eos_index).sum() == 2 * bs