diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 72c63fb86d43..6d449493343c 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -433,8 +433,8 @@ def forward(self, hidden_states): pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long) pos_seq = pos_seq[:, None] - pos_seq[None, :] - pos_seq[pos_seq < -self.max_length] = -self.max_length - pos_seq[pos_seq >= self.max_length] = self.max_length - 1 + pos_seq = torch.where(pos_seq < -self.max_length, -self.max_length, pos_seq) + pos_seq = torch.where(pos_seq >= self.max_length, self.max_length - 1, pos_seq) pos_seq = pos_seq + self.max_length return self.pe_k(pos_seq)