Skip to content

Commit 3834d34

Browse files
committed
update learning rate scheduler
to keep consistent with pytorch, one step in learning rate scheduler corresponds to an epoch on the training set
1 parent 934816c commit 3834d34

File tree

17 files changed

+21
-74
lines changed

17 files changed

+21
-74
lines changed

pymic/net_run/agent_seg.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ def training(self):
162162
loss = self.get_loss_value(data, outputs, labels_prob)
163163
loss.backward()
164164
self.optimizer.step()
165-
if(self.scheduler is not None and \
166-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
167-
self.scheduler.step()
168-
169165
train_loss = train_loss + loss.item()
170166
# get dice evaluation for each class
171167
if(isinstance(outputs, tuple) or isinstance(outputs, list)):
@@ -219,10 +215,6 @@ def validation(self):
219215
valid_avg_loss = np.asarray(valid_loss_list).mean()
220216
valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0)
221217
valid_avg_dice = valid_cls_dice.mean()
222-
223-
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
224-
self.scheduler.step(valid_avg_dice)
225-
226218
valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\
227219
'class_dice': valid_cls_dice}
228220
return valid_scalers
@@ -300,9 +292,13 @@ def train_valid(self):
300292
t0 = time.time()
301293
train_scalars = self.training()
302294
t1 = time.time()
303-
304295
valid_scalars = self.validation()
305296
t2 = time.time()
297+
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
298+
self.scheduler.step(valid_scalars['avg_dice'])
299+
else:
300+
self.scheduler.step()
301+
306302
self.glob_it = it + iter_valid
307303
logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it))
308304
logging.info('learning rate {0:}'.format(lr_value))

pymic/net_run/get_optimizer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,28 @@ def get_optimizer(name, net_params, optim_params):
3838

3939
def get_lr_scheduler(optimizer, sched_params):
4040
name = sched_params["lr_scheduler"]
41+
val_it = sched_params["iter_valid"]
4142
if(name is None):
4243
return None
43-
lr_gamma = sched_params["lr_gamma"]
4444
if(keyword_match(name, "ReduceLROnPlateau")):
4545
patience_it = sched_params["ReduceLROnPlateau_patience".lower()]
46-
val_it = sched_params["iter_valid"]
47-
patience = patience_it / val_it
46+
patience = patience_it / val_it
47+
lr_gamma = sched_params["lr_gamma"]
4848
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
4949
mode = "max", factor=lr_gamma, patience = patience)
5050
elif(keyword_match(name, "MultiStepLR")):
5151
lr_milestones = sched_params["lr_milestones"]
52-
last_iter = sched_params["last_iter"]
52+
lr_milestones = [int(item / val_it) for item in lr_milestones]
53+
epoch_last = sched_params["last_iter"] / val_it
54+
lr_gamma = sched_params["lr_gamma"]
5355
scheduler = lr_scheduler.MultiStepLR(optimizer,
54-
lr_milestones, lr_gamma, last_iter)
56+
lr_milestones, lr_gamma, epoch_last)
57+
elif(keyword_match(name, "CosineAnnealingLR")):
58+
epoch_max = sched_params["iter_max"] / val_it
59+
epoch_last = sched_params["last_iter"] / val_it
60+
lr_min = sched_params.get("lr_min", 0)
61+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
62+
epoch_max, lr_min, epoch_last)
5563
else:
5664
raise ValueError("unsupported lr scheduler {0:}".format(name))
5765
return scheduler

pymic/net_run_nll/nll_co_teaching.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@ def training(self):
128128

129129
loss.backward()
130130
self.optimizer.step()
131-
if(self.scheduler is not None and \
132-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
133-
self.scheduler.step()
134131

135132
train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item()
136133
train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item()

pymic/net_run_nll/nll_dast.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,6 @@ def training(self):
239239

240240
loss.backward()
241241
self.optimizer.step()
242-
if(self.scheduler is not None and \
243-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
244-
self.scheduler.step()
245242

246243
train_loss = train_loss + loss.item()
247244
train_loss_sup = train_loss_sup + loss_sup.item()

pymic/net_run_nll/nll_trinet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
import torch.nn as nn
1010
import torch.optim as optim
11-
from torch.optim import lr_scheduler
1211
from pymic.loss.seg.util import get_soft_label
1312
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
1413
from pymic.loss.seg.util import get_classwise_dice
@@ -125,9 +124,6 @@ def training(self):
125124

126125
loss.backward()
127126
self.optimizer.step()
128-
if(self.scheduler is not None and \
129-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
130-
self.scheduler.step()
131127

132128
train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item()
133129
train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item()

pymic/net_run_ssl/ssl_cct.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77
import numpy as np
8-
from torch.optim import lr_scheduler
98
from pymic.loss.seg.util import get_soft_label
109
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
1110
from pymic.loss.seg.util import get_classwise_dice
@@ -139,9 +138,7 @@ def training(self):
139138

140139
loss.backward()
141140
self.optimizer.step()
142-
if(self.scheduler is not None and \
143-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
144-
self.scheduler.step()
141+
145142
train_loss = train_loss + loss.item()
146143
train_loss_sup = train_loss_sup + loss_sup.item()
147144
train_loss_reg = train_loss_reg + loss_reg.item()

pymic/net_run_ssl/ssl_cps.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import torch
66
import torch.nn as nn
7-
from torch.optim import lr_scheduler
87
from pymic.loss.seg.util import get_soft_label
98
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
109
from pymic.loss.seg.util import get_classwise_dice
@@ -26,7 +25,7 @@ def forward(self, x):
2625
if(self.training):
2726
return out1, out2
2827
else:
29-
return (out1 + out2) / 3
28+
return (out1 + out2) / 2
3029

3130
class SSLCPS(SSLSegAgent):
3231
"""
@@ -117,9 +116,6 @@ def training(self):
117116

118117
loss.backward()
119118
self.optimizer.step()
120-
if(self.scheduler is not None and \
121-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
122-
self.scheduler.step()
123119

124120
train_loss = train_loss + loss.item()
125121
train_loss_sup1 = train_loss_sup1 + loss_sup1.item()

pymic/net_run_ssl/ssl_em.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import numpy as np
55
import torch
6-
from torch.optim import lr_scheduler
76
from pymic.loss.seg.util import get_soft_label
87
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
98
from pymic.loss.seg.util import get_classwise_dice
@@ -83,9 +82,6 @@ def training(self):
8382
# if (self.config['training']['use'])
8483
loss.backward()
8584
self.optimizer.step()
86-
if(self.scheduler is not None and \
87-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
88-
self.scheduler.step()
8985

9086
train_loss = train_loss + loss.item()
9187
train_loss_sup = train_loss_sup + loss_sup.item()

pymic/net_run_ssl/ssl_mt.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import torch
55
import numpy as np
6-
from torch.optim import lr_scheduler
76
from pymic.loss.seg.util import get_soft_label
87
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
98
from pymic.loss.seg.util import get_classwise_dice
@@ -102,9 +101,6 @@ def training(self):
102101

103102
loss.backward()
104103
self.optimizer.step()
105-
if(self.scheduler is not None and \
106-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
107-
self.scheduler.step()
108104

109105
# update EMA
110106
alpha = ssl_cfg.get('ema_decay', 0.99)

pymic/net_run_ssl/ssl_uamt.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import torch
55
import numpy as np
6-
from torch.optim import lr_scheduler
76
from pymic.loss.seg.util import get_soft_label
87
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
98
from pymic.loss.seg.util import get_classwise_dice
@@ -104,10 +103,6 @@ def training(self):
104103

105104
loss.backward()
106105
self.optimizer.step()
107-
if(self.scheduler is not None and \
108-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
109-
self.scheduler.step()
110-
111106

112107
# update EMA
113108
alpha = ssl_cfg.get('ema_decay', 0.99)

0 commit comments

Comments
 (0)