diff --git a/src/model/lstm.py b/src/model/lstm.py index 0071751..aefb642 100644 --- a/src/model/lstm.py +++ b/src/model/lstm.py @@ -229,7 +229,7 @@ def generate(self, src_enc, src_len, max_len=200, sample_temperature=None): # add to unfinished sentences if cur_len == max_len: - generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index) + generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index) # sanity check assert (generated == self.eos_index).sum() == 2 * bs diff --git a/src/model/transformer.py b/src/model/transformer.py index 3f9cb92..c265ba1 100644 --- a/src/model/transformer.py +++ b/src/model/transformer.py @@ -711,7 +711,7 @@ def generate(self, src_enc, src_len, max_len=200, sample_temperature=None): # add to unfinished sentences if cur_len == max_len: - generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index) + generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index) # sanity check assert (generated == self.eos_index).sum() == 2 * bs