1+ # -*- coding: utf-8 -*-
2+ """
3+ Caculating the confidence map of labels of training samples,
4+ which is used in the method of SLSR.
5+ Minqing Zhang et al., Characterizing Label Errors: Confident Learning
6+ for Noisy-Labeled Image Segmentation, MICCAI 2020.
7+ """
8+
9+ from __future__ import print_function , division
10+ import cleanlab
11+ import logging
12+ import os
13+ import scipy
14+ import sys
15+ import torch
16+ import numpy as np
17+ import pandas as pd
18+ import torch .nn as nn
19+ import torchvision .transforms as transforms
20+ from PIL import Image
21+ from pymic .io .nifty_dataset import NiftyDataset
22+ from pymic .transform .trans_dict import TransformDict
23+ from pymic .util .parse_config import *
24+ from pymic .net_run .agent_seg import SegmentationAgent
25+ from pymic .net_run .infer_func import Inferer
26+
27+ def get_confident_map (gt , pred , CL_type = 'both' ):
28+ """
29+ gt: ground truth label (one-hot) with shape of NXC
30+ pred: digit prediction of network with shape of NXC
31+ """
32+ prob = scipy .special .softmax (pred , axis = 1 )
33+ if CL_type in ['both' , 'Qij' ]:
34+ noise = cleanlab .pruning .get_noise_indices (gt , prob , prune_method = 'both' , n_jobs = 1 )
35+ elif CL_type == 'Cij' :
36+ noise = cleanlab .pruning .get_noise_indices (gt , pred , prune_method = 'both' , n_jobs = 1 )
37+ elif CL_type == 'intersection' :
38+ noise_qij = cleanlab .pruning .get_noise_indices (gt , prob , prune_method = 'both' , n_jobs = 1 )
39+ noise_cij = cleanlab .pruning .get_noise_indices (gt , pred , prune_method = 'both' , n_jobs = 1 )
40+ noise = noise_qij & noise_cij
41+ elif CL_type == 'union' :
42+ noise_qij = cleanlab .pruning .get_noise_indices (gt , prob , prune_method = 'both' , n_jobs = 1 )
43+ noise_cij = cleanlab .pruning .get_noise_indices (gt , pred , prune_method = 'both' , n_jobs = 1 )
44+ noise = noise_qij | noise_cij
45+ elif CL_type in ['prune_by_class' , 'prune_by_noise_rate' ]:
46+ noise = cleanlab .pruning .get_noise_indices (gt , prob , prune_method = CL_type , n_jobs = 1 )
47+ return noise
48+
49+ class NLLCLSLSR (SegmentationAgent ):
50+ def __init__ (self , config , stage = 'test' ):
51+ super (NLLCLSLSR , self ).__init__ (config , stage )
52+
53+ def infer_with_cl (self ):
54+ device_ids = self .config ['testing' ]['gpus' ]
55+ device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
56+ self .net .to (device )
57+
58+ if (self .config ['testing' ].get ('evaluation_mode' , True )):
59+ self .net .eval ()
60+ if (self .config ['testing' ].get ('test_time_dropout' , False )):
61+ def test_time_dropout (m ):
62+ if (type (m ) == nn .Dropout ):
63+ logging .info ('dropout layer' )
64+ m .train ()
65+ self .net .apply (test_time_dropout )
66+
67+ ckpt_mode = self .config ['testing' ]['ckpt_mode' ]
68+ ckpt_name = self .get_checkpoint_name ()
69+ if (ckpt_mode == 3 ):
70+ assert (isinstance (ckpt_name , (tuple , list )))
71+ self .infer_with_multiple_checkpoints ()
72+ return
73+ else :
74+ if (isinstance (ckpt_name , (tuple , list ))):
75+ raise ValueError ("ckpt_mode should be 3 if ckpt_name is a list" )
76+
77+ # load network parameters and set the network as evaluation mode
78+ checkpoint = torch .load (ckpt_name , map_location = device )
79+ self .net .load_state_dict (checkpoint ['model_state_dict' ])
80+
81+ if (self .inferer is None ):
82+ infer_cfg = self .config ['testing' ]
83+ class_num = self .config ['network' ]['class_num' ]
84+ infer_cfg ['class_num' ] = class_num
85+ self .inferer = Inferer (infer_cfg )
86+ pred_list = []
87+ gt_list = []
88+ filename_list = []
89+ with torch .no_grad ():
90+ for data in self .test_loader :
91+ images = self .convert_tensor_type (data ['image' ])
92+ labels = self .convert_tensor_type (data ['label_prob' ])
93+ names = data ['names' ]
94+ filename_list .append (names )
95+ images = images .to (device )
96+
97+ pred = self .inferer .run (self .net , images )
98+ # convert tensor to numpy
99+ if (isinstance (pred , (tuple , list ))):
100+ pred = [item .cpu ().numpy () for item in pred ]
101+ else :
102+ pred = pred .cpu ().numpy ()
103+ data ['predict' ] = pred
104+ # inverse transform
105+ for transform in self .transform_list [::- 1 ]:
106+ if (transform .inverse ):
107+ data = transform .inverse_transform_for_prediction (data )
108+
109+ pred = data ['predict' ]
110+ # conver prediction from N, C, H, W to (N*H*W)*C
111+ print (names , pred .shape , labels .shape )
112+ pred_2d = np .swapaxes (pred , 1 , 2 )
113+ pred_2d = np .swapaxes (pred_2d , 2 , 3 )
114+ pred_2d = pred_2d .reshape (- 1 , class_num )
115+ lab = labels .cpu ().numpy ()
116+ lab_2d = np .swapaxes (lab , 1 , 2 )
117+ lab_2d = np .swapaxes (lab_2d , 2 , 3 )
118+ lab_2d = lab_2d .reshape (- 1 , class_num )
119+ pred_list .append (pred_2d )
120+ gt_list .append (lab_2d )
121+
122+ pred_cat = np .concatenate (pred_list )
123+ gt_cat = np .concatenate (gt_list )
124+ gt = np .argmax (gt_cat , axis = 1 )
125+ gt = gt .reshape (- 1 ).astype (np .uint8 )
126+ print (gt .shape , pred_cat .shape )
127+ conf = get_confident_map (gt , pred_cat )
128+ conf = conf .reshape (- 1 , 256 , 256 ).astype (np .uint8 ) * 255
129+ save_dir = self .config ['dataset' ]['root_dir' ] + "/slsr_conf"
130+ for idx in range (len (filename_list )):
131+ filename = filename_list [idx ][0 ].split ('/' )[- 1 ]
132+ conf_map = Image .fromarray (conf [idx ])
133+ dst_path = os .path .join (save_dir , filename )
134+ conf_map .save (dst_path )
135+
136+ def get_confidence_map ():
137+ if (len (sys .argv ) < 2 ):
138+ print ('Number of arguments should be 3. e.g.' )
139+ print (' python nll_cl.py config.cfg' )
140+ exit ()
141+ cfg_file = str (sys .argv [1 ])
142+ config = parse_config (cfg_file )
143+ config = synchronize_config (config )
144+
145+ # set dataset
146+ transform_names = config ['dataset' ]['valid_transform' ]
147+ transform_list = []
148+ transform_dict = TransformDict
149+ if (transform_names is None or len (transform_names ) == 0 ):
150+ data_transform = None
151+ else :
152+ transform_param = config ['dataset' ]
153+ transform_param ['task' ] = 'segmentation'
154+ for name in transform_names :
155+ if (name not in transform_dict ):
156+ raise (ValueError ("Undefined transform {0:}" .format (name )))
157+ one_transform = transform_dict [name ](transform_param )
158+ transform_list .append (one_transform )
159+ data_transform = transforms .Compose (transform_list )
160+ print ('transform list' , transform_list )
161+ csv_file = config ['dataset' ]['train_csv' ]
162+ modal_num = config ['dataset' ].get ('modal_num' , 1 )
163+ dataset = NiftyDataset (root_dir = config ['dataset' ]['root_dir' ],
164+ csv_file = csv_file ,
165+ modal_num = modal_num ,
166+ with_label = True ,
167+ transform = data_transform )
168+
169+ agent = NLLCLSLSR (config , 'test' )
170+ agent .set_datasets (None , None , dataset )
171+ agent .transform_list = transform_list
172+ agent .create_dataset ()
173+ agent .create_network ()
174+ agent .infer_with_cl ()
175+
176+ # create training csv for confidence learning
177+ df_train = pd .read_csv (csv_file )
178+ pixel_weight = []
179+ for i in range (len (df_train ["label" ])):
180+ lab_name = df_train ["label" ][i ].split ('/' )[- 1 ]
181+ weight_name = "slsr_conf/" + lab_name
182+ pixel_weight .append (weight_name )
183+ train_cl_dict = {"image" : df_train ["image" ],
184+ "pixel_weight" : pixel_weight ,
185+ "label" : df_train ["label" ]}
186+ train_cl_csv = csv_file .replace (".csv" , "_clslsr.csv" )
187+ df_cl = pd .DataFrame .from_dict (train_cl_dict )
188+ df_cl .to_csv (train_cl_csv , index = False )
189+
190+ if __name__ == "__main__" :
191+ get_confidence_map ()
0 commit comments