diff --git a/vqvae.py b/vqvae.py index 099fa1915e..87f3286cc2 100755 --- a/vqvae.py +++ b/vqvae.py @@ -35,7 +35,7 @@ def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): embed = torch.randn(dim, n_embed) self.register_buffer("embed", embed) - self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("cluster_size", torch.ones(n_embed)) self.register_buffer("embed_avg", embed.clone()) def forward(self, input):