Skip to content

Commit 137c7eb

Browse files
committed
fix typo
1 parent 04567f1 commit 137c7eb

File tree

6 files changed

+197
-7
lines changed

6 files changed

+197
-7
lines changed

pymic/net/net3d/unet3d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def __init__(self, params):
9696
self.n_class = self.params['class_num']
9797
self.trilinear = self.params['trilinear']
9898
self.deep_sup = self.params['deep_supervise']
99-
self.stage = self.params['stage']
10099
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
101100

102101
self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
@@ -134,7 +133,7 @@ def forward(self, x):
134133
x_d1 = self.up3(x_d2, x1)
135134
x_d0 = self.up4(x_d1, x0)
136135
output = self.out_conv(x_d0)
137-
if(self.deep_sup and self.stage == "train"):
136+
if(self.deep_sup):
138137
out_shape = list(output.shape)[2:]
139138
output1 = self.out_conv1(x_d1)
140139
output1 = interpolate(output1, out_shape, mode = 'trilinear')

pymic/net_run/agent_seg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def train_valid(self):
307307
elif(isinstance(iter_save, (tuple, list))):
308308
iter_save_list = iter_save
309309
else:
310-
iter_save_list = range(iter_start, iter_max + 1, iter_save)
310+
iter_save_list = range(0, iter_max + 1, iter_save)
311311

312312
self.max_val_dice = 0.0
313313
self.max_val_it = 0
@@ -519,7 +519,7 @@ def save_ouputs(self, data):
519519
filename_replace_source = self.config['testing'].get('filename_replace_source', None)
520520
filename_replace_target = self.config['testing'].get('filename_replace_target', None)
521521
if(not os.path.exists(output_dir)):
522-
os.mkdir(output_dir)
522+
os.makedirs(output_dir, exist_ok=True)
523523

524524
names, pred = data['names'], data['predict']
525525
if(isinstance(pred, (list, tuple))):

pymic/net_run/net_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def main():
1111
if(len(sys.argv) < 3):
1212
print('Number of arguments should be 3. e.g.')
13-
print(' pymic_net_run train config.cfg')
13+
print(' pymic_run train config.cfg')
1414
exit()
1515
stage = str(sys.argv[1])
1616
cfg_file = str(sys.argv[2])

pymic/net_run_nll/nll_clslsr.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Caculating the confidence map of labels of training samples,
4+
which is used in the method of SLSR.
5+
Minqing Zhang et al., Characterizing Label Errors: Confident Learning
6+
for Noisy-Labeled Image Segmentation, MICCAI 2020.
7+
"""
8+
9+
from __future__ import print_function, division
10+
import cleanlab
11+
import logging
12+
import os
13+
import scipy
14+
import sys
15+
import torch
16+
import numpy as np
17+
import pandas as pd
18+
import torch.nn as nn
19+
import torchvision.transforms as transforms
20+
from PIL import Image
21+
from pymic.io.nifty_dataset import NiftyDataset
22+
from pymic.transform.trans_dict import TransformDict
23+
from pymic.util.parse_config import *
24+
from pymic.net_run.agent_seg import SegmentationAgent
25+
from pymic.net_run.infer_func import Inferer
26+
27+
def get_confident_map(gt, pred, CL_type = 'both'):
28+
"""
29+
gt: ground truth label (one-hot) with shape of NXC
30+
pred: digit prediction of network with shape of NXC
31+
"""
32+
prob = scipy.special.softmax(pred, axis = 1)
33+
if CL_type in ['both', 'Qij']:
34+
noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1)
35+
elif CL_type == 'Cij':
36+
noise = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1)
37+
elif CL_type == 'intersection':
38+
noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1)
39+
noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1)
40+
noise = noise_qij & noise_cij
41+
elif CL_type == 'union':
42+
noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1)
43+
noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1)
44+
noise = noise_qij | noise_cij
45+
elif CL_type in ['prune_by_class', 'prune_by_noise_rate']:
46+
noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1)
47+
return noise
48+
49+
class NLLCLSLSR(SegmentationAgent):
50+
def __init__(self, config, stage = 'test'):
51+
super(NLLCLSLSR, self).__init__(config, stage)
52+
53+
def infer_with_cl(self):
54+
device_ids = self.config['testing']['gpus']
55+
device = torch.device("cuda:{0:}".format(device_ids[0]))
56+
self.net.to(device)
57+
58+
if(self.config['testing'].get('evaluation_mode', True)):
59+
self.net.eval()
60+
if(self.config['testing'].get('test_time_dropout', False)):
61+
def test_time_dropout(m):
62+
if(type(m) == nn.Dropout):
63+
logging.info('dropout layer')
64+
m.train()
65+
self.net.apply(test_time_dropout)
66+
67+
ckpt_mode = self.config['testing']['ckpt_mode']
68+
ckpt_name = self.get_checkpoint_name()
69+
if(ckpt_mode == 3):
70+
assert(isinstance(ckpt_name, (tuple, list)))
71+
self.infer_with_multiple_checkpoints()
72+
return
73+
else:
74+
if(isinstance(ckpt_name, (tuple, list))):
75+
raise ValueError("ckpt_mode should be 3 if ckpt_name is a list")
76+
77+
# load network parameters and set the network as evaluation mode
78+
checkpoint = torch.load(ckpt_name, map_location = device)
79+
self.net.load_state_dict(checkpoint['model_state_dict'])
80+
81+
if(self.inferer is None):
82+
infer_cfg = self.config['testing']
83+
class_num = self.config['network']['class_num']
84+
infer_cfg['class_num'] = class_num
85+
self.inferer = Inferer(infer_cfg)
86+
pred_list = []
87+
gt_list = []
88+
filename_list = []
89+
with torch.no_grad():
90+
for data in self.test_loader:
91+
images = self.convert_tensor_type(data['image'])
92+
labels = self.convert_tensor_type(data['label_prob'])
93+
names = data['names']
94+
filename_list.append(names)
95+
images = images.to(device)
96+
97+
pred = self.inferer.run(self.net, images)
98+
# convert tensor to numpy
99+
if(isinstance(pred, (tuple, list))):
100+
pred = [item.cpu().numpy() for item in pred]
101+
else:
102+
pred = pred.cpu().numpy()
103+
data['predict'] = pred
104+
# inverse transform
105+
for transform in self.transform_list[::-1]:
106+
if (transform.inverse):
107+
data = transform.inverse_transform_for_prediction(data)
108+
109+
pred = data['predict']
110+
# conver prediction from N, C, H, W to (N*H*W)*C
111+
print(names, pred.shape, labels.shape)
112+
pred_2d = np.swapaxes(pred, 1, 2)
113+
pred_2d = np.swapaxes(pred_2d, 2, 3)
114+
pred_2d = pred_2d.reshape(-1, class_num)
115+
lab = labels.cpu().numpy()
116+
lab_2d = np.swapaxes(lab, 1, 2)
117+
lab_2d = np.swapaxes(lab_2d, 2, 3)
118+
lab_2d = lab_2d.reshape(-1, class_num)
119+
pred_list.append(pred_2d)
120+
gt_list.append(lab_2d)
121+
122+
pred_cat = np.concatenate(pred_list)
123+
gt_cat = np.concatenate(gt_list)
124+
gt = np.argmax(gt_cat, axis = 1)
125+
gt = gt.reshape(-1).astype(np.uint8)
126+
print(gt.shape, pred_cat.shape)
127+
conf = get_confident_map(gt, pred_cat)
128+
conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255
129+
save_dir = self.config['dataset']['root_dir'] + "/slsr_conf"
130+
for idx in range(len(filename_list)):
131+
filename = filename_list[idx][0].split('/')[-1]
132+
conf_map = Image.fromarray(conf[idx])
133+
dst_path = os.path.join(save_dir, filename)
134+
conf_map.save(dst_path)
135+
136+
def get_confidence_map():
137+
if(len(sys.argv) < 2):
138+
print('Number of arguments should be 3. e.g.')
139+
print(' python nll_cl.py config.cfg')
140+
exit()
141+
cfg_file = str(sys.argv[1])
142+
config = parse_config(cfg_file)
143+
config = synchronize_config(config)
144+
145+
# set dataset
146+
transform_names = config['dataset']['valid_transform']
147+
transform_list = []
148+
transform_dict = TransformDict
149+
if(transform_names is None or len(transform_names) == 0):
150+
data_transform = None
151+
else:
152+
transform_param = config['dataset']
153+
transform_param['task'] = 'segmentation'
154+
for name in transform_names:
155+
if(name not in transform_dict):
156+
raise(ValueError("Undefined transform {0:}".format(name)))
157+
one_transform = transform_dict[name](transform_param)
158+
transform_list.append(one_transform)
159+
data_transform = transforms.Compose(transform_list)
160+
print('transform list', transform_list)
161+
csv_file = config['dataset']['train_csv']
162+
modal_num = config['dataset'].get('modal_num', 1)
163+
dataset = NiftyDataset(root_dir = config['dataset']['root_dir'],
164+
csv_file = csv_file,
165+
modal_num = modal_num,
166+
with_label= True,
167+
transform = data_transform )
168+
169+
agent = NLLCLSLSR(config, 'test')
170+
agent.set_datasets(None, None, dataset)
171+
agent.transform_list = transform_list
172+
agent.create_dataset()
173+
agent.create_network()
174+
agent.infer_with_cl()
175+
176+
# create training csv for confidence learning
177+
df_train = pd.read_csv(csv_file)
178+
pixel_weight = []
179+
for i in range(len(df_train["label"])):
180+
lab_name = df_train["label"][i].split('/')[-1]
181+
weight_name = "slsr_conf/" + lab_name
182+
pixel_weight.append(weight_name)
183+
train_cl_dict = {"image": df_train["image"],
184+
"pixel_weight": pixel_weight,
185+
"label": df_train["label"]}
186+
train_cl_csv = csv_file.replace(".csv", "_clslsr.csv")
187+
df_cl = pd.DataFrame.from_dict(train_cl_dict)
188+
df_cl.to_csv(train_cl_csv, index = False)
189+
190+
if __name__ == "__main__":
191+
get_confidence_map()

pymic/net_run_ssl/ssl_em.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import numpy as np
55
import torch
6-
from torhc.optim import lr_scheduler
6+
from torch.optim import lr_scheduler
77
from pymic.loss.seg.util import get_soft_label
88
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
99
from pymic.loss.seg.util import get_classwise_dice

pymic/net_run_ssl/ssl_urpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torch.nn as nn
66
import numpy as np
7-
from torhc.optim import lr_scheduler
7+
from torch.optim import lr_scheduler
88
from pymic.loss.seg.util import get_soft_label
99
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
1010
from pymic.loss.seg.util import get_classwise_dice

0 commit comments

Comments
 (0)