Skip to content

Commit 303e624

Browse files
committed
add dast for nll
add dast for nll set the output model of dual branch network
1 parent 137c7eb commit 303e624

File tree

4 files changed

+271
-46
lines changed

4 files changed

+271
-46
lines changed

pymic/loss/seg/ce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pymic.loss.seg.util import reshape_tensor_to_2D
77

88
class CrossEntropyLoss(nn.Module):
9-
def __init__(self, params):
9+
def __init__(self, params = None):
1010
super(CrossEntropyLoss, self).__init__()
1111
if(params is None):
1212
self.softmax = True

pymic/net/net2d/unet2d_dual_branch.py

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
class UNet2D_DualBranch(nn.Module):
1717
def __init__(self, params):
1818
super(UNet2D_DualBranch, self).__init__()
19+
self.output_mode = params.get("output_mode", "average")
1920
self.encoder = Encoder(params)
2021
self.decoder1 = Decoder(params)
2122
self.decoder2 = Decoder(params)
@@ -41,47 +42,9 @@ def forward(self, x):
4142
if(self.training):
4243
return output1, output2
4344
else:
44-
return (output1 + output2)/2
45-
# for backup
46-
class DualBranchUNet2D(UNet2D):
47-
def __init__(self, params):
48-
params['deep_supervise'] = False
49-
super(DualBranchUNet2D, self).__init__(params)
50-
if(len(self.ft_chns) == 5):
51-
self.up1_aux = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear)
52-
self.up2_aux = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear)
53-
self.up3_aux = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear)
54-
self.up4_aux = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear)
55-
56-
self.out_conv_aux = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
57-
58-
def forward(self, x):
59-
x_shape = list(x.shape)
60-
if(len(x_shape) == 5):
61-
[N, C, D, H, W] = x_shape
62-
new_shape = [N*D, C, H, W]
63-
x = torch.transpose(x, 1, 2)
64-
x = torch.reshape(x, new_shape)
65-
66-
x0 = self.in_conv(x)
67-
x1 = self.down1(x0)
68-
x2 = self.down2(x1)
69-
x3 = self.down3(x2)
70-
if(len(self.ft_chns) == 5):
71-
x4 = self.down4(x3)
72-
x_d3, x_d3_aux = self.up1(x4, x3), self.up1_aux(x4, x3)
73-
else:
74-
x_d3, x_d3_aux = x3, x3
75-
76-
x_d2, x_d2_aux = self.up2(x_d3, x2), self.up2_aux(x_d3_aux, x2)
77-
x_d1, x_d1_aux = self.up3(x_d2, x1), self.up3_aux(x_d2_aux, x1)
78-
x_d0, x_d0_aux = self.up4(x_d1, x0), self.up4_aux(x_d1_aux, x0)
79-
output, output_aux = self.out_conv(x_d0), self.out_conv_aux(x_d0_aux)
80-
81-
if(len(x_shape) == 5):
82-
new_shape = [N, D] + list(output.shape)[1:]
83-
output = torch.reshape(output, new_shape)
84-
output = torch.transpose(output, 1, 2)
85-
output_aux = torch.reshape(output_aux, new_shape)
86-
output_aux = torch.transpose(output_aux, 1, 2)
87-
return output, output_aux
45+
if(self.output_mode == "average"):
46+
return (output1 + output2)/2
47+
elif(self.output_mode == "first"):
48+
return output1
49+
else:
50+
return output2

pymic/net_run_nll/nll_dast.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Implementation of DAST for noise robust learning according to the following paper.
4+
Shuojue Yang, Guotai Wang, Hui Sun, Xiangde Luo, Peng Sun, Kang Li, Qijun Wang,
5+
Shaoting Zhang: Learning COVID-19 Pneumonia Lesion Segmentation from Imperfect
6+
Annotations via Divergence-Aware Selective Training.
7+
JBHI 2022. https://ieeexplore.ieee.org/document/9770406
8+
"""
9+
10+
from __future__ import print_function, division
11+
import random
12+
import torch
13+
import numpy as np
14+
import torch.nn as nn
15+
import torchvision.transforms as transforms
16+
from torch.optim import lr_scheduler
17+
from pymic.io.nifty_dataset import NiftyDataset
18+
from pymic.loss.seg.util import get_soft_label
19+
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
20+
from pymic.loss.seg.util import get_classwise_dice
21+
from pymic.net_run.agent_seg import SegmentationAgent
22+
from pymic.util.parse_config import *
23+
from pymic.util.ramps import get_rampup_ratio
24+
25+
class Rank(object):
26+
"""
27+
Dynamically rank the current training sample with specific metrics
28+
"""
29+
def __init__(self, quene_length = 100):
30+
self.vals = []
31+
self.quene_length = quene_length
32+
33+
def add_val(self, val):
34+
"""
35+
Update the quene and calculate the order of the input value.
36+
37+
Return
38+
---------
39+
rank: rank of the input value with a range of (0, self.quenen_length)
40+
"""
41+
if len(self.vals) < self.quene_length:
42+
self.vals.append(val)
43+
rank = -1
44+
else:
45+
self.vals.pop(0)
46+
self.vals.append(val)
47+
assert len(self.vals) == self.quene_length
48+
idxes = np.argsort(self.vals)
49+
rank = np.where(idxes == self.quene_length-1)[0][0]
50+
return rank
51+
52+
class ConsistLoss(nn.Module):
53+
def __init__(self):
54+
super(ConsistLoss, self).__init__()
55+
56+
def kl_div_map(self, input, label):
57+
kl_map = torch.sum(label * (torch.log(label + 1e-16) - torch.log(input + 1e-16)), dim = 1)
58+
return kl_map
59+
60+
def kl_loss(self,input, target, size_average=True):
61+
kl_div = self.kl_div_map(input, target)
62+
if size_average:
63+
return torch.mean(kl_div)
64+
else:
65+
return kl_div
66+
67+
def forward(self, input1, input2, size_average = True):
68+
kl1 = self.kl_loss(input1, input2.detach(), size_average=size_average)
69+
kl2 = self.kl_loss(input2, input1.detach(), size_average=size_average)
70+
return (kl1 + kl2) / 2
71+
72+
def get_ce(prob, soft_y, size_avg = True):
73+
prob = prob * 0.999 + 5e-4
74+
ce = - soft_y* torch.log(prob)
75+
ce = torch.sum(ce, dim = 1) # shape is [N]
76+
if(size_avg):
77+
ce = torch.mean(ce)
78+
return ce
79+
80+
@torch.no_grad()
81+
def select_criterion(no_noisy_sample, cl_noisy_sample, label):
82+
"""
83+
no_noisy_sample: noisy branch's output probability for noisy sample
84+
cl_noisy_sample: clean branch's output probability for noisy sample
85+
label: noisy label
86+
"""
87+
l_n = get_ce(no_noisy_sample, label, size_avg = False)
88+
l_c = get_ce(cl_noisy_sample, label, size_avg = False)
89+
js_distance = ConsistLoss()
90+
variance = js_distance(no_noisy_sample, cl_noisy_sample, size_average=False)
91+
exp_variance = torch.exp(-16 * variance)
92+
loss_n = torch.mean(l_c * exp_variance).item()
93+
loss_c = torch.mean(l_n * exp_variance).item()
94+
return loss_n, loss_c
95+
96+
class NLLDAST(SegmentationAgent):
97+
def __init__(self, config, stage = 'train'):
98+
super(NLLDAST, self).__init__(config, stage)
99+
self.train_set_noise = None
100+
self.train_loader_noise = None
101+
self.trainIter_noise = None
102+
self.noisy_rank = None
103+
self.clean_rank = None
104+
105+
def get_noisy_dataset_from_config(self):
106+
root_dir = self.config['dataset']['root_dir']
107+
modal_num = self.config['dataset'].get('modal_num', 1)
108+
transform_names = self.config['dataset']['train_transform']
109+
110+
self.transform_list = []
111+
if(transform_names is None or len(transform_names) == 0):
112+
data_transform = None
113+
else:
114+
transform_param = self.config['dataset']
115+
transform_param['task'] = 'segmentation'
116+
for name in transform_names:
117+
if(name not in self.transform_dict):
118+
raise(ValueError("Undefined transform {0:}".format(name)))
119+
one_transform = self.transform_dict[name](transform_param)
120+
self.transform_list.append(one_transform)
121+
data_transform = transforms.Compose(self.transform_list)
122+
123+
csv_file = self.config['dataset'].get('train_csv_noise', None)
124+
dataset = NiftyDataset(root_dir=root_dir,
125+
csv_file = csv_file,
126+
modal_num = modal_num,
127+
with_label= True,
128+
transform = data_transform )
129+
return dataset
130+
131+
def create_dataset(self):
132+
super(NLLDAST, self).create_dataset()
133+
if(self.stage == 'train'):
134+
if(self.train_set_noise is None):
135+
self.train_set_noise = self.get_noisy_dataset_from_config()
136+
if(self.deterministic):
137+
def worker_init_fn(worker_id):
138+
random.seed(self.random_seed + worker_id)
139+
worker_init = worker_init_fn
140+
else:
141+
worker_init = None
142+
143+
bn_train_noise = self.config['dataset']['train_batch_size_noise']
144+
num_worker = self.config['dataset'].get('num_workder', 16)
145+
self.train_loader_noise = torch.utils.data.DataLoader(self.train_set_noise,
146+
batch_size = bn_train_noise, shuffle=True, num_workers= num_worker,
147+
worker_init_fn=worker_init)
148+
149+
def training(self):
150+
class_num = self.config['network']['class_num']
151+
iter_valid = self.config['training']['iter_valid']
152+
nll_cfg = self.config['noisy_label_learning']
153+
iter_max = self.config['training']['iter_max']
154+
rampup_start = nll_cfg.get('rampup_start', 0)
155+
rampup_end = nll_cfg.get('rampup_end', iter_max)
156+
train_loss = 0
157+
train_loss_sup = 0
158+
train_loss_reg = 0
159+
train_dice_list = []
160+
self.net.train()
161+
162+
rank_length = nll_cfg.get("dast_rank_length", 20)
163+
consist_loss = ConsistLoss()
164+
for it in range(iter_valid):
165+
try:
166+
data_cl = next(self.trainIter)
167+
except StopIteration:
168+
self.trainIter = iter(self.train_loader)
169+
data_cl = next(self.trainIter)
170+
try:
171+
data_no = next(self.trainIter_noise)
172+
except StopIteration:
173+
self.trainIter_noise = iter(self.train_loader_noise)
174+
data_no = next(self.trainIter_noise)
175+
176+
# get the inputs
177+
x0 = self.convert_tensor_type(data_cl['image']) # clean sample
178+
y0 = self.convert_tensor_type(data_cl['label_prob'])
179+
x1 = self.convert_tensor_type(data_no['image']) # noisy sample
180+
y1 = self.convert_tensor_type(data_no['label_prob'])
181+
inputs = torch.cat([x0, x1], dim = 0).to(self.device)
182+
y0, y1 = y0.to(self.device), y1.to(self.device)
183+
184+
# zero the parameter gradients
185+
self.optimizer.zero_grad()
186+
187+
# forward + backward + optimize
188+
b0_pred, b1_pred = self.net(inputs)
189+
n0 = list(x0.shape)[0] # number of clean samples
190+
b0_x0_pred = b0_pred[:n0] # predication of clean samples from clean branch
191+
b0_x1_pred = b0_pred[n0:] # predication of noisy samples from clean branch
192+
b1_x1_pred = b1_pred[n0:] # predication of noisy samples from noisy branch
193+
194+
# supervised loss for the clean and noisy branches, respectively
195+
loss_sup_cl = self.get_loss_value(data_cl, b0_x0_pred, y0)
196+
loss_sup_no = self.get_loss_value(data_no, b1_x1_pred, y1)
197+
loss_sup = (loss_sup_cl + loss_sup_no) / 2
198+
loss = loss_sup
199+
200+
# Severe Noise supression & Supplementary Training
201+
rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid")
202+
w_dbc = nll_cfg.get('dast_dbc_w', 0.1) * rampup_ratio
203+
w_st = nll_cfg.get('dast_st_w', 0.1) * rampup_ratio
204+
b1_x1_prob = nn.Softmax(dim = 1)(b1_x1_pred)
205+
b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred)
206+
loss_n, loss_c = select_criterion(b1_x1_prob, b0_x1_prob, y1)
207+
rank_n = self.noisy_rank.add_val(loss_n)
208+
rank_c = self.clean_rank.add_val(loss_c)
209+
if loss_n < loss_c:
210+
if rank_c >= rank_length * 0.8:
211+
loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob)
212+
loss = loss + loss_dbc * w_dbc
213+
if rank_n <= 0.2 * rank_length:
214+
b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True)
215+
b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type)
216+
b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True)
217+
b1_x1_lab = get_soft_label(b1_x1_argmax, class_num, self.tensor_type)
218+
pseudo_label = (b0_x1_lab + b1_x1_lab + y1) / 3
219+
sharpen = lambda p,T: p**(1.0/T)/(p**(1.0/T) + (1-p)**(1.0/T))
220+
b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred)
221+
loss_st = torch.mean(torch.abs(b0_x1_prob - sharpen(pseudo_label, 0.5)))
222+
loss = loss + loss_st * w_st
223+
224+
loss.backward()
225+
self.optimizer.step()
226+
if(self.scheduler is not None and \
227+
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
228+
self.scheduler.step()
229+
230+
train_loss = train_loss + loss.item()
231+
train_loss_sup = train_loss_sup + loss_sup.item()
232+
# train_loss_reg = train_loss_reg + loss_reg.item()
233+
# get dice evaluation for each class in annotated images
234+
if(isinstance(b0_x0_pred, tuple) or isinstance(b0_x0_pred, list)):
235+
p0 = b0_x0_pred[0]
236+
else:
237+
p0 = b0_x0_pred
238+
p0_argmax = torch.argmax(p0, dim = 1, keepdim = True)
239+
p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type)
240+
p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0)
241+
dice_list = get_classwise_dice(p0_soft, y0)
242+
train_dice_list.append(dice_list.cpu().numpy())
243+
train_avg_loss = train_loss / iter_valid
244+
train_avg_loss_sup = train_loss_sup / iter_valid
245+
train_avg_loss_reg = train_loss_reg / iter_valid
246+
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
247+
train_avg_dice = train_cls_dice.mean()
248+
249+
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
250+
'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc,
251+
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
252+
return train_scalers
253+
254+
def train_valid(self):
255+
self.trainIter_noise = iter(self.train_loader_noise)
256+
nll_cfg = self.config['noisy_label_learning']
257+
rank_length = nll_cfg.get("dast_rank_length", 20)
258+
self.noisy_rank = Rank(rank_length)
259+
self.clean_rank = Rank(rank_length)
260+
super(NLLDAST, self).train_valid()

pymic/net_run_nll/nll_main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from pymic.util.parse_config import *
88
from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching
99
from pymic.net_run_nll.nll_trinet import NLLTriNet
10+
from pymic.net_run_nll.nll_dast import NLLDAST
1011

1112
NLLMethodDict = {'CoTeaching': NLLCoTeaching,
12-
"TriNet": NLLTriNet}
13+
"TriNet": NLLTriNet,
14+
"DAST": NLLDAST}
1315

1416
def main():
1517
if(len(sys.argv) < 3):

0 commit comments

Comments
 (0)