Skip to content

Commit 452ccef

Browse files
committed
update rotate and rescale
introduce a probability for rotate and rescale
1 parent 3834d34 commit 452ccef

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

pymic/net_run/get_optimizer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,30 @@ def get_optimizer(name, net_params, optim_params):
1010
lr = optim_params['learning_rate']
1111
momentum = optim_params['momentum']
1212
weight_decay = optim_params['weight_decay']
13+
# see https://www.codeleading.com/article/44815584159/
14+
param_group = [{'params': net_params, 'initial_lr': lr}]
1315
if(keyword_match(name, "SGD")):
14-
return optim.SGD(net_params, lr,
16+
return optim.SGD(param_group, lr,
1517
momentum = momentum, weight_decay = weight_decay)
1618
elif(keyword_match(name, "Adam")):
17-
return optim.Adam(net_params, lr, weight_decay = weight_decay)
19+
return optim.Adam(param_group, lr, weight_decay = weight_decay)
1820
elif(keyword_match(name, "SparseAdam")):
19-
return optim.SparseAdam(net_params, lr)
21+
return optim.SparseAdam(param_group, lr)
2022
elif(keyword_match(name, "Adadelta")):
21-
return optim.Adadelta(net_params, lr, weight_decay = weight_decay)
23+
return optim.Adadelta(param_group, lr, weight_decay = weight_decay)
2224
elif(keyword_match(name, "Adagrad")):
23-
return optim.Adagrad(net_params, lr, weight_decay = weight_decay)
25+
return optim.Adagrad(param_group, lr, weight_decay = weight_decay)
2426
elif(keyword_match(name, "Adamax")):
25-
return optim.Adamax(net_params, lr, weight_decay = weight_decay)
27+
return optim.Adamax(param_group, lr, weight_decay = weight_decay)
2628
elif(keyword_match(name, "ASGD")):
27-
return optim.ASGD(net_params, lr, weight_decay = weight_decay)
29+
return optim.ASGD(param_group, lr, weight_decay = weight_decay)
2830
elif(keyword_match(name, "LBFGS")):
29-
return optim.LBFGS(net_params, lr)
31+
return optim.LBFGS(param_group, lr)
3032
elif(keyword_match(name, "RMSprop")):
31-
return optim.RMSprop(net_params, lr, momentum = momentum,
33+
return optim.RMSprop(param_group, lr, momentum = momentum,
3234
weight_decay = weight_decay)
3335
elif(keyword_match(name, "Rprop")):
34-
return optim.Rprop(net_params, lr)
36+
return optim.Rprop(param_group, lr)
3537
else:
3638
raise ValueError("unsupported optimizer {0:}".format(name))
3739

@@ -57,7 +59,7 @@ def get_lr_scheduler(optimizer, sched_params):
5759
elif(keyword_match(name, "CosineAnnealingLR")):
5860
epoch_max = sched_params["iter_max"] / val_it
5961
epoch_last = sched_params["last_iter"] / val_it
60-
lr_min = sched_params.get("lr_min", 0)
62+
lr_min = sched_params.get("lr_min", 0)
6163
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
6264
epoch_max, lr_min, epoch_last)
6365
else:

pymic/transform/rescale.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,30 @@ class RandomRescale(AbstractTransform):
8585
The arguments should be written in the `params` dictionary, and it has the
8686
following fields:
8787
88-
:param `RandomRescale_lower_bound`: (list/tuple or int)
88+
:param `RandomRescale_lower_bound`: (list/tuple or float)
8989
Desired minimal rescale ratio. If tuple/list, the length should be 3 or 2.
90-
:param `RandomRescale_upper_bound`: (list/tuple or int)
90+
:param `RandomRescale_upper_bound`: (list/tuple or float)
9191
Desired maximal rescale ratio. If tuple/list, the length should be 3 or 2.
92+
:param `RandomRescale_probability`: (optional, float)
93+
The probability of applying RandomRescale. Default is 0.5.
9294
:param `RandomRescale_inverse`: (optional, bool)
9395
Is inverse transform needed for inference. Default is `True`.
9496
"""
9597
def __init__(self, params):
96-
"""
97-
ratio0 (tuple/list or int): Desired minimal rescale ratio.
98-
If tuple/list, the length should be 3 or 2.
99-
ratio1 (tuple/list or int): Desired maximal rescale ratio.
100-
If tuple/list, the length should be 3 or 2.
101-
"""
10298
super(RandomRescale, self).__init__(params)
10399
self.ratio0 = params["RandomRescale_lower_bound".lower()]
104100
self.ratio1 = params["RandomRescale_upper_bound".lower()]
101+
self.prob = params.get('RandomRescale_probability'.lower(), 0.5)
105102
self.inverse = params.get("RandomRescale_inverse".lower(), True)
106103
assert isinstance(self.ratio0, (float, list, tuple))
107104
assert isinstance(self.ratio1, (float, list, tuple))
108105

109106
def __call__(self, sample):
107+
if(np.random.uniform() > self.prob):
108+
sample['RandomRescale_triggered'] = False
109+
return sample
110+
else:
111+
sample['RandomRescale_triggered'] = True
110112
image = sample['image']
111113
input_shape = image.shape
112114
input_dim = len(input_shape) - 1
@@ -117,8 +119,8 @@ def __call__(self, sample):
117119
scale = [self.ratio0[i] + random.random()*(self.ratio1[i] - self.ratio0[i]) \
118120
for i in range(len(self.ratio0))]
119121
else:
120-
scale = [self.ratio0 + random.random()*(self.ratio1 - self.ratio0) \
121-
for i in range(input_dim)]
122+
scale = self.ratio0 + random.random()*(self.ratio1 - self.ratio0)
123+
scale = [scale] * input_dim
122124
scale = [1.0] + scale
123125
image_t = ndimage.interpolation.zoom(image, scale, order = 1)
124126

@@ -136,6 +138,8 @@ def __call__(self, sample):
136138
return sample
137139

138140
def inverse_transform_for_prediction(self, sample):
141+
if(not sample['RandomRescale_triggered']):
142+
return sample
139143
if(isinstance(sample['RandomRescale_origin_shape'], list) or \
140144
isinstance(sample['RandomRescale_origin_shape'], tuple)):
141145
origin_shape = json.loads(sample['RandomRescale_origin_shape'][0])

pymic/transform/rotate.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class RandomRotate(AbstractTransform):
2727
:param `RandomRotate_angle_range_w`: (list/tuple or None)
2828
Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90).
2929
If None, no rotation along this axis. Only used for 3D images.
30+
:param `RandomRotate_probability`: (optional, float)
31+
The probability of applying RandomRotate. Default is 0.5.
3032
:param `RandomRotate_inverse`: (optional, bool)
3133
Is inverse transform needed for inference. Default is `True`.
3234
"""
@@ -35,6 +37,7 @@ def __init__(self, params):
3537
self.angle_range_d = params['RandomRotate_angle_range_d'.lower()]
3638
self.angle_range_h = params['RandomRotate_angle_range_h'.lower()]
3739
self.angle_range_w = params['RandomRotate_angle_range_w'.lower()]
40+
self.prob = params.get('RandomRotate_probability'.lower(), 0.5)
3841
self.inverse = params.get('RandomRotate_inverse'.lower(), True)
3942

4043
def __apply_transformation(self, image, transform_param_list, order = 1):
@@ -50,6 +53,11 @@ def __apply_transformation(self, image, transform_param_list, order = 1):
5053
return image
5154

5255
def __call__(self, sample):
56+
if(np.random.uniform() > self.prob):
57+
sample['RandomRotate_triggered'] = False
58+
return sample
59+
else:
60+
sample['RandomRotate_triggered'] = True
5361
image = sample['image']
5462
input_shape = image.shape
5563
input_dim = len(input_shape) - 1
@@ -79,6 +87,8 @@ def __call__(self, sample):
7987
return sample
8088

8189
def inverse_transform_for_prediction(self, sample):
90+
if(not sample['RandomRotate_triggered']):
91+
return sample
8292
if(isinstance(sample['RandomRotate_Param'], list) or \
8393
isinstance(sample['RandomRotate_Param'], tuple)):
8494
transform_param_list = json.loads(sample['RandomRotate_Param'][0])

0 commit comments

Comments
 (0)