Skip to content

Commit 75622c3

Browse files
committed
update comment and fix typo
update comment and fix typo
1 parent 24c46cf commit 75622c3

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

pymic/net_run/agent_cls.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def train_valid(self):
195195
ckpt_dir = self.config['training']['ckpt_save_dir']
196196
if(ckpt_dir[-1] == "/"):
197197
ckpt_dir = ckpt_dir[:-1]
198-
ckpt_prefx = ckpt_dir.split('/')[-1]
198+
ckpt_prefix = self.config['training'].get('ckpt_prefix', None)
199+
if(ckpt_prefix is None):
200+
ckpt_prefix = ckpt_dir.split('/')[-1]
199201
iter_start = self.config['training']['iter_start']
200202
iter_max = self.config['training']['iter_max']
201203
iter_valid = self.config['training']['iter_valid']
@@ -206,7 +208,7 @@ def train_valid(self):
206208
self.best_model_wts = None
207209
self.checkpoint = None
208210
if(iter_start > 0):
209-
checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, iter_start)
211+
checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start)
210212
self.checkpoint = torch.load(checkpoint_file, map_location = self.device)
211213
assert(self.checkpoint['iteration'] == iter_start)
212214
self.net.load_state_dict(self.checkpoint['model_state_dict'])
@@ -237,9 +239,9 @@ def train_valid(self):
237239
'valid_pred': valid_scalars[metrics],
238240
'model_state_dict': self.net.state_dict(),
239241
'optimizer_state_dict': self.optimizer.state_dict()}
240-
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, glob_it)
242+
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, glob_it)
241243
torch.save(save_dict, save_name)
242-
txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt')
244+
txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt')
243245
txt_file.write(str(glob_it))
244246
txt_file.close()
245247

@@ -248,9 +250,9 @@ def train_valid(self):
248250
'valid_pred': self.max_val_score,
249251
'model_state_dict': self.best_model_wts,
250252
'optimizer_state_dict': self.optimizer.state_dict()}
251-
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.max_val_it)
253+
save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it)
252254
torch.save(save_dict, save_name)
253-
txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt')
255+
txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt')
254256
txt_file.write(str(self.max_val_it))
255257
txt_file.close()
256258
logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\

pymic/net_run_nll/nll_clslsr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
Caculating the confidence map of labels of training samples,
44
which is used in the method of SLSR.
55
Minqing Zhang et al., Characterizing Label Errors: Confident Learning
6-
for Noisy-Labeled Image Segmentation, MICCAI 2020.
6+
for Noisy-Labeled Image Segmentation, MICCAI 2020.
7+
https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70
78
"""
89

910
from __future__ import print_function, division

pymic/net_run_nll/nll_trinet.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Implementation of Co-teaching for learning from noisy samples for
3+
Implementation of trinet for learning from noisy samples for
44
segmentation tasks according to the following paper:
5-
Bo Han et al., Co-teaching: Robust Training of Deep NeuralNetworks
6-
with Extremely Noisy Labels, NeurIPS, 2018
7-
The author's original implementation was:
8-
https://github.com/bhanML/Co-teaching
9-
10-
5+
Tianwei Zhang, Lequan Yu, Na Hu, Su Lv, Shi Gu:
6+
Robust Medical Image Segmentation from Non-expert Annotations with Tri-network.
7+
MICCAI 2020.
8+
https://link.springer.com/chapter/10.1007/978-3-030-59719-1_25
119
"""
1210
from __future__ import print_function, division
1311
import logging
@@ -48,11 +46,6 @@ def forward(self, x):
4846
return (out1 + out2 + out3) / 3
4947

5048
class NLLTriNet(SegmentationAgent):
51-
"""
52-
Co-teaching: Robust Training of Deep Neural Networks with Extremely
53-
Noisy Labels
54-
https://arxiv.org/abs/1804.06872
55-
"""
5649
def __init__(self, config, stage = 'train'):
5750
super(NLLTriNet, self).__init__(config, stage)
5851

0 commit comments

Comments
 (0)