Skip to content

Commit e29f16c

Browse files
sayakpaulSayak Paul
andauthored
[Research Projects] ORPO diffusion for alignment (#7423)
* barebones orpo * remove reference model. * full implementation * change default of beta_orpo * add a training command. * fix: dataloading issues. * interpreting the formulation. * revert styling * add: wds full blown version * fix: per_gpu_batch_siz * start debuggin * debugging * remove print * fix * remove filter keys. * turn on non-blocking calls. * device_placement * let's see. * add bigger training run command * reinitialize generator for fair repro * add: detailed readme and requirements --------- Co-authored-by: Sayak Paul <sayakpaul@Sayaks-MacBook-Pro-2.local>
1 parent f7dfcfd commit e29f16c

File tree

4 files changed

+2307
-0
lines changed

4 files changed

+2307
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/.
2+
3+
> [!WARNING]
4+
> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental.
5+
6+
## Training
7+
8+
Here's training command you can use on a 40GB A100 to validate things on a [small preference
9+
dataset](https://hf.co/datasets/kashif/pickascore):
10+
11+
```bash
12+
accelerate launch train_diffusion_orpo_sdxl_lora.py \
13+
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
14+
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
15+
--output_dir="diffusion-sdxl-orpo" \
16+
--mixed_precision="fp16" \
17+
--dataset_name=kashif/pickascore \
18+
--train_batch_size=8 \
19+
--gradient_accumulation_steps=2 \
20+
--gradient_checkpointing \
21+
--use_8bit_adam \
22+
--rank=8 \
23+
--learning_rate=1e-5 \
24+
--report_to="wandb" \
25+
--lr_scheduler="constant" \
26+
--lr_warmup_steps=0 \
27+
--max_train_steps=2000 \
28+
--checkpointing_steps=500 \
29+
--run_validation --validation_steps=50 \
30+
--seed="0" \
31+
--report_to="wandb" \
32+
--push_to_hub
33+
```
34+
35+
We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset:
36+
37+
```bash
38+
accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \
39+
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
40+
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
41+
--dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \
42+
--output_dir="diffusion-sdxl-orpo-wds" \
43+
--mixed_precision="fp16" \
44+
--gradient_accumulation_steps=1 \
45+
--gradient_checkpointing \
46+
--use_8bit_adam \
47+
--rank=8 \
48+
--dataloader_num_workers=8 \
49+
--learning_rate=3e-5 \
50+
--report_to="wandb" \
51+
--lr_scheduler="constant" \
52+
--lr_warmup_steps=0 \
53+
--max_train_steps=50000 \
54+
--checkpointing_steps=2000 \
55+
--run_validation --validation_steps=500 \
56+
--seed="0" \
57+
--report_to="wandb" \
58+
--push_to_hub
59+
```
60+
61+
We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally.
62+
63+
You can use the code below to convert the original dataset into `webdataset` shards:
64+
65+
```python
66+
import os
67+
import io
68+
import ray
69+
import webdataset as wds
70+
from datasets import Dataset
71+
from PIL import Image
72+
73+
ray.init(num_cpus=8)
74+
75+
76+
def convert_to_image(im_bytes):
77+
return Image.open(io.BytesIO(im_bytes)).convert("RGB")
78+
79+
def main():
80+
dataset_path = "/pickapic_v2/data"
81+
wds_shards_path = "/pickapic_v2_webdataset"
82+
# get all .parquet files in the dataset path
83+
dataset_files = [
84+
os.path.join(dataset_path, f)
85+
for f in os.listdir(dataset_path)
86+
if f.endswith(".parquet")
87+
]
88+
89+
@ray.remote
90+
def create_shard(path):
91+
# get basename of the file
92+
basename = os.path.basename(path)
93+
# get the shard number data-00123-of-01034.parquet -> 00123
94+
shard_num = basename.split("-")[1]
95+
dataset = Dataset.from_parquet(path)
96+
# create a webdataset shard
97+
shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar"))
98+
99+
for i, example in enumerate(dataset):
100+
wds_example = {
101+
"__key__": str(i),
102+
"original_prompt.txt": example["caption"],
103+
"jpg_0.jpg": convert_to_image(example["jpg_0"]),
104+
"jpg_1.jpg": convert_to_image(example["jpg_1"]),
105+
"label_0.txt": str(example["label_0"]),
106+
"label_1.txt": str(example["label_1"])
107+
}
108+
shard.write(wds_example)
109+
shard.close()
110+
111+
futures = [create_shard.remote(path) for path in dataset_files]
112+
ray.get(futures)
113+
114+
115+
if __name__ == "__main__":
116+
main()
117+
```
118+
119+
## Inference
120+
121+
Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
datasets
2+
accelerate
3+
transformers
4+
torchvision
5+
wandb
6+
peft
7+
webdataset

0 commit comments

Comments
 (0)