Skip to content

Commit 16d4f90

Browse files
committed
Add pytorch model in models
1 parent 655376b commit 16d4f90

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

mplc/models.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from sklearn.metrics import log_loss
66
from tensorflow.keras.backend import dot
77
from tensorflow.keras.layers import Dense
8+
import torch, torchvision
9+
import torch.nn as nn
10+
import torch.optim as optim
11+
import torch.utils.data as data
12+
import torchvision.transforms as transforms
813

914

1015
class LogisticRegression(skLR):
@@ -88,6 +93,152 @@ def load_model(path):
8893
path.replace('.h5', '.joblib')
8994
return load(path)
9095

96+
class cifar100_dataset(torch.utils.data.Dataset):
97+
98+
def __init__(self, x, y, transform=[]):
99+
self.x = x
100+
self.y = y
101+
self.transform = transform
102+
103+
def __len__(self):
104+
return len(self.x)
105+
106+
def __getitem__(self, index):
107+
108+
x = self.x[index]
109+
y = torch.tensor(int(self.y[index][0]))
110+
111+
if self.transform:
112+
x = self.transform(x)
113+
114+
return x, y
115+
116+
class ModelPytorch(nn.Module):
117+
def __init__(self):
118+
super(ModelPytorch, self).__init__()
119+
model = torchvision.models.vgg16()
120+
self.features = nn.Sequential(model.features)
121+
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7))
122+
self.classifier = nn.Sequential(
123+
nn.Linear(25088, 4096),
124+
nn.ReLU(inplace=True),
125+
nn.Dropout(p=0.5, inplace=False),
126+
nn.Linear(4096, 4096),
127+
nn.ReLU(inplace=True),
128+
nn.Dropout(p=0.5, inplace=False),
129+
nn.Linear(4096, 1000)
130+
)
131+
self.optimizer = optim.Adam(model.parameters(), lr=1e-3)
132+
133+
134+
def forward(self, x):
135+
x = self.features(x)
136+
x = self.avgpool(x)
137+
x = x.view(x.size(0), -1)
138+
return self.classifier(x)
139+
140+
141+
def fit(self, x_train, y_train, batch_size, validation_data, epochs=1, verbose=False, callbacks=None):
142+
criterion = nn.CrossEntropyLoss()
143+
transform = transforms.Compose([transforms.ToTensor()])
144+
145+
train_data = cifar100_dataset(x_train, y_train, transform)
146+
train_loader = data.DataLoader(train_data, batch_size=int(batch_size), shuffle=True)
147+
148+
history = super(ModelPytorch, self).train()
149+
150+
for batch_idx, (image, label) in enumerate(train_loader):
151+
images, labels = torch.autograd.Variable(image), torch.autograd.Variable(label)
152+
153+
outputs = self.forward(images)
154+
loss = criterion(outputs, labels)
155+
156+
self.optimizer.zero_grad()
157+
loss.backward()
158+
self.optimizer.step()
159+
160+
[loss, acc] = self.evaluate(x_train, y_train)
161+
[val_loss, val_acc] = self.evaluate(*validation_data)
162+
# Mimic Keras' history
163+
history.history = {
164+
'loss': [loss],
165+
'accuracy': [acc],
166+
'val_loss': [val_loss],
167+
'val_accuracy': [val_acc]
168+
}
169+
170+
return history
171+
172+
def evaluate(self, x_eval, y_eval, **kwargs):
173+
criterion = nn.CrossEntropyLoss()
174+
transform = transforms.Compose([transforms.ToTensor()])
175+
176+
test_data = cifar100_dataset(x_eval, y_eval, transform)
177+
test_loader = data.DataLoader(test_data, shuffle=True)
178+
179+
self.eval()
180+
181+
with torch.no_grad():
182+
183+
y_true_np = []
184+
y_pred_np = []
185+
count=0
186+
for i, (images, labels) in enumerate(test_loader):
187+
count+= 1
188+
N = images.size(0)
189+
190+
images = torch.autograd.Variable(images)
191+
labels = torch.autograd.Variable(labels)
192+
193+
outputs = self(images)
194+
predictions = outputs.max(1, keepdim=True)[1]
195+
196+
val_loss =+ criterion(outputs, labels).item()
197+
val_acc =+ (predictions.eq(labels.view_as(predictions)).sum().item() / N)
198+
199+
model_evaluation = [val_loss/count, val_acc/count]
200+
201+
return model_evaluation
202+
203+
204+
def save_weights(self, path):
205+
if '.h5' in path:
206+
logger.debug('Automatically switch file format from .h5 to .pth')
207+
path.replace('.h5', '.pth')
208+
torch.save(self.state_dict(), path)
209+
210+
211+
def load_weights(self, path):
212+
if '.h5' in path:
213+
logger.debug('Automatically switch file format from .h5 to .pth')
214+
path.replace('.h5', '.pth')
215+
weights = torch.load(path)
216+
self.set_weights(weights)
217+
218+
219+
def get_weights(self):
220+
return self.state_dict()
221+
222+
223+
def set_weights(self, weights):
224+
self.load_state_dict(weights)
225+
226+
227+
def save_model(self, path):
228+
if '.h5' in path:
229+
logger.debug('Automatically switch file format from .h5 to .pth')
230+
path.replace('.h5', '.pth')
231+
torch.save(self, path)
232+
233+
234+
@staticmethod
235+
def load_model(path):
236+
if '.h5' in path:
237+
logger.debug('Automatically switch file format from .h5 to .pth')
238+
path.replace('.h5', '.pth')
239+
model = torch.load(path)
240+
return model.eval()
241+
91242

92243
class NoiseAdaptationChannel(Dense):
93244
"""

0 commit comments

Comments
 (0)