Skip to content

Commit 8ab8a5a

Browse files
committed
Add new unlearning method UNDIAL
1 parent a897704 commit 8ab8a5a

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

configs/trainer/UNDIAL.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
handler: UNDIAL # corresponds to the class defined in src/trainer/unlearn/grad_diff.py
2+
args: # HuggingFace TrainingArguments
3+
per_device_train_batch_size: 2
4+
per_device_eval_batch_size: 16
5+
gradient_accumulation_steps: 4
6+
learning_rate: 1e-5
7+
num_train_epochs: 10
8+
method_args: # Your own method-specific arguments
9+
gamma: 1.0
10+
alpha: 1.0
11+
beta: 10.0 # the strength of penalty for memorized tokens
12+
retain_loss_type: NLL

src/trainer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from trainer.unlearn.dpo import DPO
1111
from trainer.unlearn.simnpo import SimNPO
1212
from trainer.unlearn.rmu import RMU
13+
from trainer.unlearn.undial import UNDIAL
1314

1415
TRAINER_REGISTRY: Dict[str, Any] = {}
1516

@@ -20,6 +21,7 @@ def _register_trainer(trainer_class):
2021

2122
def load_trainer_args(trainer_args: DictConfig, dataset):
2223
trainer_args = dict(trainer_args)
24+
trainer_args["output_dir"] = trainer_args.pop("output_dir", "./output")
2325
warmup_epochs = trainer_args.pop("warmup_epochs", None)
2426
if warmup_epochs:
2527
batch_size = trainer_args["per_device_train_batch_size"]
@@ -81,3 +83,4 @@ def load_trainer(
8183
_register_trainer(DPO)
8284
_register_trainer(SimNPO)
8385
_register_trainer(RMU)
86+
_register_trainer(UNDIAL)

src/trainer/unlearn/undial.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torch.nn import CrossEntropyLoss
4+
from trainer.unlearn.grad_diff import GradDiff
5+
6+
class UNDIAL(GradDiff):
7+
def __init__(self, beta=1.0, *args, **kwargs):
8+
super().__init__(*args, **kwargs)
9+
self.beta = beta
10+
if self.ref_model is None:
11+
self.ref_model = self._prepare_ref_model(self.model)
12+
13+
def compute_loss(self, model, inputs, return_outputs=False):
14+
forget_inputs = inputs["forget"]
15+
forget_loss, forget_outputs = self.compute_undial_loss(model, forget_inputs)
16+
17+
retain_inputs = inputs["retain"]
18+
retain_inputs = {
19+
"input_ids": retain_inputs["input_ids"],
20+
"attention_mask": retain_inputs["attention_mask"],
21+
"labels": retain_inputs["labels"],
22+
}
23+
retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)
24+
25+
loss = self.gamma * forget_loss + self.alpha * retain_loss
26+
return (loss, forget_outputs) if return_outputs else loss
27+
28+
def compute_undial_loss(self, model, inputs):
29+
# Forward pass on the student (trainable) model
30+
outputs = model(**inputs)
31+
logits = outputs.logits
32+
labels = inputs["labels"]
33+
34+
shift_labels = labels[..., 1:].contiguous()
35+
shift_logits = logits[..., :-1, :].contiguous()
36+
37+
# Forward pass on the teacher model (no grad)
38+
with torch.no_grad():
39+
teacher_logits = self.ref_model(**inputs).logits
40+
shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
41+
42+
# Build the mask that identifies the tokens need to be unlearned
43+
mask = torch.zeros_like(shift_teacher_logits)
44+
batch_idx = torch.arange(mask.shape[0]).view(-1, 1, 1)
45+
seq_idx = torch.arange(mask.shape[1]).view(1, -1, 1)
46+
mask[batch_idx, seq_idx, shift_labels.unsqueeze(-1)] = 1.0
47+
48+
# Adjust teacher logits: subtract di_strength on the correct token
49+
pre_softmax = shift_teacher_logits - mask * self.beta
50+
soft_label = F.softmax(pre_softmax, dim=-1)
51+
52+
loss_fct = CrossEntropyLoss(reduction='none')
53+
loss = loss_fct(
54+
shift_logits.view(-1, shift_logits.size(-1)),
55+
soft_label.view(-1, soft_label.size(-1)),
56+
)
57+
return loss.mean(), outputs

0 commit comments

Comments
 (0)