-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
30 lines (25 loc) · 906 Bytes
/
model.py
File metadata and controls
30 lines (25 loc) · 906 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch.nn as nn
class VAE(nn.Module):
def __init__(self, n_features=10, latent_dim=4):
super().__init__()
self.fc1 = nn.Linear(n_features, 64)
self.fc21 = nn.Linear(64, latent_dim)
self.fc22 = nn.Linear(64, latent_dim)
self.fc3 = nn.Linear(latent_dim, 64)
self.fc4 = nn.Linear(64, n_features)
self.relu = nn.ReLU()
def encode(self, x):
h1 = self.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = (0.5*logvar).exp()
eps = __import__('torch').randn_like(std)
return mu + eps*std
def decode(self, z):
h3 = self.relu(self.fc3(z))
return self.fc4(h3)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, z, mu, logvar