Skip to content

Commit 24c46cf

Browse files
committed
Update nll_dast.py
add config parameter
1 parent 303e624 commit 24c46cf

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pymic/net_run_nll/nll_dast.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,11 @@ def training(self):
207207
rank_n = self.noisy_rank.add_val(loss_n)
208208
rank_c = self.clean_rank.add_val(loss_c)
209209
if loss_n < loss_c:
210-
if rank_c >= rank_length * 0.8:
210+
select_ratio = nll_cfg.get('dast_select_ratio', 0.2)
211+
if rank_c >= rank_length * (1 - select_ratio):
211212
loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob)
212213
loss = loss + loss_dbc * w_dbc
213-
if rank_n <= 0.2 * rank_length:
214+
if rank_n <= rank_length * select_ratio:
214215
b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True)
215216
b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type)
216217
b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True)

0 commit comments

Comments
 (0)