diff --git a/rotary_embedding_torch/rotary_embedding_torch.py b/rotary_embedding_torch/rotary_embedding_torch.py index 14cbb60..a8b0c77 100644 --- a/rotary_embedding_torch/rotary_embedding_torch.py +++ b/rotary_embedding_torch/rotary_embedding_torch.py @@ -127,7 +127,7 @@ def get_scale(self, t, cache_key = None): scale = torch.cat((scale, scale), dim = -1) if exists(cache_key): - self.cache[cache_key] = freqs + self.cache[cache_key] = self.freqs return scale