1+ # -*- coding: utf-8 -*-
2+ """
3+ Implementation of DAST for noise robust learning according to the following paper.
4+ Shuojue Yang, Guotai Wang, Hui Sun, Xiangde Luo, Peng Sun, Kang Li, Qijun Wang,
5+ Shaoting Zhang: Learning COVID-19 Pneumonia Lesion Segmentation from Imperfect
6+ Annotations via Divergence-Aware Selective Training.
7+ JBHI 2022. https://ieeexplore.ieee.org/document/9770406
8+ """
9+
10+ from __future__ import print_function , division
11+ import random
12+ import torch
13+ import numpy as np
14+ import torch .nn as nn
15+ import torchvision .transforms as transforms
16+ from torch .optim import lr_scheduler
17+ from pymic .io .nifty_dataset import NiftyDataset
18+ from pymic .loss .seg .util import get_soft_label
19+ from pymic .loss .seg .util import reshape_prediction_and_ground_truth
20+ from pymic .loss .seg .util import get_classwise_dice
21+ from pymic .net_run .agent_seg import SegmentationAgent
22+ from pymic .util .parse_config import *
23+ from pymic .util .ramps import get_rampup_ratio
24+
25+ class Rank (object ):
26+ """
27+ Dynamically rank the current training sample with specific metrics
28+ """
29+ def __init__ (self , quene_length = 100 ):
30+ self .vals = []
31+ self .quene_length = quene_length
32+
33+ def add_val (self , val ):
34+ """
35+ Update the quene and calculate the order of the input value.
36+
37+ Return
38+ ---------
39+ rank: rank of the input value with a range of (0, self.quenen_length)
40+ """
41+ if len (self .vals ) < self .quene_length :
42+ self .vals .append (val )
43+ rank = - 1
44+ else :
45+ self .vals .pop (0 )
46+ self .vals .append (val )
47+ assert len (self .vals ) == self .quene_length
48+ idxes = np .argsort (self .vals )
49+ rank = np .where (idxes == self .quene_length - 1 )[0 ][0 ]
50+ return rank
51+
52+ class ConsistLoss (nn .Module ):
53+ def __init__ (self ):
54+ super (ConsistLoss , self ).__init__ ()
55+
56+ def kl_div_map (self , input , label ):
57+ kl_map = torch .sum (label * (torch .log (label + 1e-16 ) - torch .log (input + 1e-16 )), dim = 1 )
58+ return kl_map
59+
60+ def kl_loss (self ,input , target , size_average = True ):
61+ kl_div = self .kl_div_map (input , target )
62+ if size_average :
63+ return torch .mean (kl_div )
64+ else :
65+ return kl_div
66+
67+ def forward (self , input1 , input2 , size_average = True ):
68+ kl1 = self .kl_loss (input1 , input2 .detach (), size_average = size_average )
69+ kl2 = self .kl_loss (input2 , input1 .detach (), size_average = size_average )
70+ return (kl1 + kl2 ) / 2
71+
72+ def get_ce (prob , soft_y , size_avg = True ):
73+ prob = prob * 0.999 + 5e-4
74+ ce = - soft_y * torch .log (prob )
75+ ce = torch .sum (ce , dim = 1 ) # shape is [N]
76+ if (size_avg ):
77+ ce = torch .mean (ce )
78+ return ce
79+
80+ @torch .no_grad ()
81+ def select_criterion (no_noisy_sample , cl_noisy_sample , label ):
82+ """
83+ no_noisy_sample: noisy branch's output probability for noisy sample
84+ cl_noisy_sample: clean branch's output probability for noisy sample
85+ label: noisy label
86+ """
87+ l_n = get_ce (no_noisy_sample , label , size_avg = False )
88+ l_c = get_ce (cl_noisy_sample , label , size_avg = False )
89+ js_distance = ConsistLoss ()
90+ variance = js_distance (no_noisy_sample , cl_noisy_sample , size_average = False )
91+ exp_variance = torch .exp (- 16 * variance )
92+ loss_n = torch .mean (l_c * exp_variance ).item ()
93+ loss_c = torch .mean (l_n * exp_variance ).item ()
94+ return loss_n , loss_c
95+
96+ class NLLDAST (SegmentationAgent ):
97+ def __init__ (self , config , stage = 'train' ):
98+ super (NLLDAST , self ).__init__ (config , stage )
99+ self .train_set_noise = None
100+ self .train_loader_noise = None
101+ self .trainIter_noise = None
102+ self .noisy_rank = None
103+ self .clean_rank = None
104+
105+ def get_noisy_dataset_from_config (self ):
106+ root_dir = self .config ['dataset' ]['root_dir' ]
107+ modal_num = self .config ['dataset' ].get ('modal_num' , 1 )
108+ transform_names = self .config ['dataset' ]['train_transform' ]
109+
110+ self .transform_list = []
111+ if (transform_names is None or len (transform_names ) == 0 ):
112+ data_transform = None
113+ else :
114+ transform_param = self .config ['dataset' ]
115+ transform_param ['task' ] = 'segmentation'
116+ for name in transform_names :
117+ if (name not in self .transform_dict ):
118+ raise (ValueError ("Undefined transform {0:}" .format (name )))
119+ one_transform = self .transform_dict [name ](transform_param )
120+ self .transform_list .append (one_transform )
121+ data_transform = transforms .Compose (self .transform_list )
122+
123+ csv_file = self .config ['dataset' ].get ('train_csv_noise' , None )
124+ dataset = NiftyDataset (root_dir = root_dir ,
125+ csv_file = csv_file ,
126+ modal_num = modal_num ,
127+ with_label = True ,
128+ transform = data_transform )
129+ return dataset
130+
131+ def create_dataset (self ):
132+ super (NLLDAST , self ).create_dataset ()
133+ if (self .stage == 'train' ):
134+ if (self .train_set_noise is None ):
135+ self .train_set_noise = self .get_noisy_dataset_from_config ()
136+ if (self .deterministic ):
137+ def worker_init_fn (worker_id ):
138+ random .seed (self .random_seed + worker_id )
139+ worker_init = worker_init_fn
140+ else :
141+ worker_init = None
142+
143+ bn_train_noise = self .config ['dataset' ]['train_batch_size_noise' ]
144+ num_worker = self .config ['dataset' ].get ('num_workder' , 16 )
145+ self .train_loader_noise = torch .utils .data .DataLoader (self .train_set_noise ,
146+ batch_size = bn_train_noise , shuffle = True , num_workers = num_worker ,
147+ worker_init_fn = worker_init )
148+
149+ def training (self ):
150+ class_num = self .config ['network' ]['class_num' ]
151+ iter_valid = self .config ['training' ]['iter_valid' ]
152+ nll_cfg = self .config ['noisy_label_learning' ]
153+ iter_max = self .config ['training' ]['iter_max' ]
154+ rampup_start = nll_cfg .get ('rampup_start' , 0 )
155+ rampup_end = nll_cfg .get ('rampup_end' , iter_max )
156+ train_loss = 0
157+ train_loss_sup = 0
158+ train_loss_reg = 0
159+ train_dice_list = []
160+ self .net .train ()
161+
162+ rank_length = nll_cfg .get ("dast_rank_length" , 20 )
163+ consist_loss = ConsistLoss ()
164+ for it in range (iter_valid ):
165+ try :
166+ data_cl = next (self .trainIter )
167+ except StopIteration :
168+ self .trainIter = iter (self .train_loader )
169+ data_cl = next (self .trainIter )
170+ try :
171+ data_no = next (self .trainIter_noise )
172+ except StopIteration :
173+ self .trainIter_noise = iter (self .train_loader_noise )
174+ data_no = next (self .trainIter_noise )
175+
176+ # get the inputs
177+ x0 = self .convert_tensor_type (data_cl ['image' ]) # clean sample
178+ y0 = self .convert_tensor_type (data_cl ['label_prob' ])
179+ x1 = self .convert_tensor_type (data_no ['image' ]) # noisy sample
180+ y1 = self .convert_tensor_type (data_no ['label_prob' ])
181+ inputs = torch .cat ([x0 , x1 ], dim = 0 ).to (self .device )
182+ y0 , y1 = y0 .to (self .device ), y1 .to (self .device )
183+
184+ # zero the parameter gradients
185+ self .optimizer .zero_grad ()
186+
187+ # forward + backward + optimize
188+ b0_pred , b1_pred = self .net (inputs )
189+ n0 = list (x0 .shape )[0 ] # number of clean samples
190+ b0_x0_pred = b0_pred [:n0 ] # predication of clean samples from clean branch
191+ b0_x1_pred = b0_pred [n0 :] # predication of noisy samples from clean branch
192+ b1_x1_pred = b1_pred [n0 :] # predication of noisy samples from noisy branch
193+
194+ # supervised loss for the clean and noisy branches, respectively
195+ loss_sup_cl = self .get_loss_value (data_cl , b0_x0_pred , y0 )
196+ loss_sup_no = self .get_loss_value (data_no , b1_x1_pred , y1 )
197+ loss_sup = (loss_sup_cl + loss_sup_no ) / 2
198+ loss = loss_sup
199+
200+ # Severe Noise supression & Supplementary Training
201+ rampup_ratio = get_rampup_ratio (self .glob_it , rampup_start , rampup_end , "sigmoid" )
202+ w_dbc = nll_cfg .get ('dast_dbc_w' , 0.1 ) * rampup_ratio
203+ w_st = nll_cfg .get ('dast_st_w' , 0.1 ) * rampup_ratio
204+ b1_x1_prob = nn .Softmax (dim = 1 )(b1_x1_pred )
205+ b0_x1_prob = nn .Softmax (dim = 1 )(b0_x1_pred )
206+ loss_n , loss_c = select_criterion (b1_x1_prob , b0_x1_prob , y1 )
207+ rank_n = self .noisy_rank .add_val (loss_n )
208+ rank_c = self .clean_rank .add_val (loss_c )
209+ if loss_n < loss_c :
210+ if rank_c >= rank_length * 0.8 :
211+ loss_dbc = consist_loss (b1_x1_prob , b0_x1_prob )
212+ loss = loss + loss_dbc * w_dbc
213+ if rank_n <= 0.2 * rank_length :
214+ b0_x1_argmax = torch .argmax (b0_x1_pred , dim = 1 , keepdim = True )
215+ b0_x1_lab = get_soft_label (b0_x1_argmax , class_num , self .tensor_type )
216+ b1_x1_argmax = torch .argmax (b1_x1_pred , dim = 1 , keepdim = True )
217+ b1_x1_lab = get_soft_label (b1_x1_argmax , class_num , self .tensor_type )
218+ pseudo_label = (b0_x1_lab + b1_x1_lab + y1 ) / 3
219+ sharpen = lambda p ,T : p ** (1.0 / T )/ (p ** (1.0 / T ) + (1 - p )** (1.0 / T ))
220+ b0_x1_prob = nn .Softmax (dim = 1 )(b0_x1_pred )
221+ loss_st = torch .mean (torch .abs (b0_x1_prob - sharpen (pseudo_label , 0.5 )))
222+ loss = loss + loss_st * w_st
223+
224+ loss .backward ()
225+ self .optimizer .step ()
226+ if (self .scheduler is not None and \
227+ not isinstance (self .scheduler , lr_scheduler .ReduceLROnPlateau )):
228+ self .scheduler .step ()
229+
230+ train_loss = train_loss + loss .item ()
231+ train_loss_sup = train_loss_sup + loss_sup .item ()
232+ # train_loss_reg = train_loss_reg + loss_reg.item()
233+ # get dice evaluation for each class in annotated images
234+ if (isinstance (b0_x0_pred , tuple ) or isinstance (b0_x0_pred , list )):
235+ p0 = b0_x0_pred [0 ]
236+ else :
237+ p0 = b0_x0_pred
238+ p0_argmax = torch .argmax (p0 , dim = 1 , keepdim = True )
239+ p0_soft = get_soft_label (p0_argmax , class_num , self .tensor_type )
240+ p0_soft , y0 = reshape_prediction_and_ground_truth (p0_soft , y0 )
241+ dice_list = get_classwise_dice (p0_soft , y0 )
242+ train_dice_list .append (dice_list .cpu ().numpy ())
243+ train_avg_loss = train_loss / iter_valid
244+ train_avg_loss_sup = train_loss_sup / iter_valid
245+ train_avg_loss_reg = train_loss_reg / iter_valid
246+ train_cls_dice = np .asarray (train_dice_list ).mean (axis = 0 )
247+ train_avg_dice = train_cls_dice .mean ()
248+
249+ train_scalers = {'loss' : train_avg_loss , 'loss_sup' :train_avg_loss_sup ,
250+ 'loss_reg' :train_avg_loss_reg , 'regular_w' :w_dbc ,
251+ 'avg_dice' :train_avg_dice , 'class_dice' : train_cls_dice }
252+ return train_scalers
253+
254+ def train_valid (self ):
255+ self .trainIter_noise = iter (self .train_loader_noise )
256+ nll_cfg = self .config ['noisy_label_learning' ]
257+ rank_length = nll_cfg .get ("dast_rank_length" , 20 )
258+ self .noisy_rank = Rank (rank_length )
259+ self .clean_rank = Rank (rank_length )
260+ super (NLLDAST , self ).train_valid ()
0 commit comments