1+ import torch
2+ import torch .nn .functional as F
3+ from torch .nn import CrossEntropyLoss
4+ from trainer .unlearn .grad_diff import GradDiff
5+
6+ class UNDIAL (GradDiff ):
7+ def __init__ (self , beta = 1.0 , * args , ** kwargs ):
8+ super ().__init__ (* args , ** kwargs )
9+ self .beta = beta
10+ if self .ref_model is None :
11+ self .ref_model = self ._prepare_ref_model (self .model )
12+
13+ def compute_loss (self , model , inputs , return_outputs = False ):
14+ forget_inputs = inputs ["forget" ]
15+ forget_loss , forget_outputs = self .compute_undial_loss (model , forget_inputs )
16+
17+ retain_inputs = inputs ["retain" ]
18+ retain_inputs = {
19+ "input_ids" : retain_inputs ["input_ids" ],
20+ "attention_mask" : retain_inputs ["attention_mask" ],
21+ "labels" : retain_inputs ["labels" ],
22+ }
23+ retain_loss = self .compute_retain_loss (model = model , retain_inputs = retain_inputs )
24+
25+ loss = self .gamma * forget_loss + self .alpha * retain_loss
26+ return (loss , forget_outputs ) if return_outputs else loss
27+
28+ def compute_undial_loss (self , model , inputs ):
29+ # Forward pass on the student (trainable) model
30+ outputs = model (** inputs )
31+ logits = outputs .logits
32+ labels = inputs ["labels" ]
33+
34+ shift_labels = labels [..., 1 :].contiguous ()
35+ shift_logits = logits [..., :- 1 , :].contiguous ()
36+
37+ # Forward pass on the teacher model (no grad)
38+ with torch .no_grad ():
39+ teacher_logits = self .ref_model (** inputs ).logits
40+ shift_teacher_logits = teacher_logits [..., :- 1 , :].contiguous ()
41+
42+ # Build the mask that identifies the tokens need to be unlearned
43+ mask = torch .zeros_like (shift_teacher_logits )
44+ batch_idx = torch .arange (mask .shape [0 ]).view (- 1 , 1 , 1 )
45+ seq_idx = torch .arange (mask .shape [1 ]).view (1 , - 1 , 1 )
46+ mask [batch_idx , seq_idx , shift_labels .unsqueeze (- 1 )] = 1.0
47+
48+ # Adjust teacher logits: subtract di_strength on the correct token
49+ pre_softmax = shift_teacher_logits - mask * self .beta
50+ soft_label = F .softmax (pre_softmax , dim = - 1 )
51+
52+ loss_fct = CrossEntropyLoss (reduction = 'none' )
53+ loss = loss_fct (
54+ shift_logits .view (- 1 , shift_logits .size (- 1 )),
55+ soft_label .view (- 1 , soft_label .size (- 1 )),
56+ )
57+ return loss .mean (), outputs
0 commit comments