We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 303e624 commit 24c46cfCopy full SHA for 24c46cf
pymic/net_run_nll/nll_dast.py
@@ -207,10 +207,11 @@ def training(self):
207
rank_n = self.noisy_rank.add_val(loss_n)
208
rank_c = self.clean_rank.add_val(loss_c)
209
if loss_n < loss_c:
210
- if rank_c >= rank_length * 0.8:
+ select_ratio = nll_cfg.get('dast_select_ratio', 0.2)
211
+ if rank_c >= rank_length * (1 - select_ratio):
212
loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob)
213
loss = loss + loss_dbc * w_dbc
- if rank_n <= 0.2 * rank_length:
214
+ if rank_n <= rank_length * select_ratio:
215
b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True)
216
b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type)
217
b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True)
0 commit comments