Skip to content

Commit 183e636

Browse files
pbelevichKeitaW
andauthored
Multi-node LLM post-training with GRPO (#684)
* GRPO example * Update 3.test_cases/pytorch/trl/grpo/train.sbatch Co-authored-by: Keita Watanabe <mlkeita@amazon.com> * Add HF_HOME * Update 3.test_cases/pytorch/trl/grpo/README.md Co-authored-by: Keita Watanabe <mlkeita@amazon.com> * Update 3.test_cases/pytorch/trl/grpo/README.md Co-authored-by: Keita Watanabe <mlkeita@amazon.com> * Update 3.test_cases/pytorch/trl/grpo/README.md Co-authored-by: Keita Watanabe <mlkeita@amazon.com> --------- Co-authored-by: Keita Watanabe <mlkeita@amazon.com>
1 parent 070517f commit 183e636

File tree

8 files changed

+569
-0
lines changed

8 files changed

+569
-0
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Multi-node Large model GRPO training using Hugging Face TRL
2+
3+
## Overview
4+
5+
This is a test case for multi-node large model GRPO training using [Hugging Face TRL](https://github.com/huggingface/trl). [Qwen/Qwen2.5-72B](https://huggingface.co/Qwen/Qwen2.5-72B) is used as a base model and [AI-MO/NuminaMath-TIR](https://huggingface.co/AI-MO/NuminaMath-7B-TIR) as a dataset for GRPO training.
6+
7+
## Prerequisites
8+
9+
### Download the model
10+
11+
We are going to use HF_HOME environment variable to access the model from the containers, so define it before downloading the model:
12+
```bash
13+
export HF_HOME=~/.cache/huggingface # or any other directory that you prefer
14+
```
15+
16+
Install huggingface-cli:
17+
```bash
18+
pip install -U "huggingface_hub[cli]"
19+
```
20+
21+
Download the model:
22+
```bash
23+
hf download Qwen/Qwen2.5-72B
24+
```
25+
26+
### Docker Image
27+
28+
All the dependencies are defined in `grpo.Dockerfile`. It uses Python 3.12, PyTorch 2.6.0 and the latest version of TRL. Build the image with the following command:
29+
30+
```bash
31+
docker build -f grpo.Dockerfile -t grpo:latest .
32+
```
33+
34+
### Enroot
35+
36+
To run our container on Slurm, convert the container into a Squash file using Enroot:
37+
38+
```bash
39+
enroot import -o ./grpo.sqsh dockerd://grpo:latest
40+
```
41+
42+
## Launching GRPO training
43+
44+
Launch the GRPO training with the following command:
45+
46+
```bash
47+
sbatch train.sbatch Qwen/Qwen2.5-72B
48+
```
49+
50+
The training script launches 8 nodes for training and 1 node for generation using vLLM, a high-throughput, low-latency inference engine for LLMs. The distributed training uses ZeRO stage 3 to accelerate the training process.
51+
52+
The logs can be inspected using tail command:
53+
54+
GRPO Training logs:
55+
```bash
56+
tail -f -n +0 grpo-XXX.out
57+
```
58+
sample output:
59+
```
60+
1%| | 17/2264 [01:22<2:55:16, 4.68s/it]
61+
0: {'loss': 0.0785, 'grad_norm': 0.8229517735973697, 'learning_rate': 9.916077738515903e-06, 'num_tokens': 1498339.0, 'completions/mean_length': 134.934765625, 'completions/min_length': 35.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.08203125, 'completions/mean_terminated_length': 124.83461303710938, 'completions/min_terminated_length': 35.0, 'completions/max_terminated_length': 253.8, 'rewards/format_reward/mean': 0.90703125, 'rewards/format_reward/std': 0.27258416190743445, 'rewards/accuracy_reward/mean': 0.224609375, 'rewards/accuracy_reward/std': 0.4104481041431427, 'reward': 1.131640625, 'reward_std': 0.34059175848960876, 'kl': 0.2958984375, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.01}
62+
```
63+
64+
vLLM logs:
65+
```bash
66+
tail -f -n +0 vllm-XXX.out
67+
```
68+
sample output:
69+
```
70+
0: INFO: 10.4.37.27:41696 - "POST /upda_named_param/ HTTP/1.1" 200 OK
71+
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
72+
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
73+
0: INFO 05-14 23:13:00 [block_pool.py:264] Successfully reset prefix cache
74+
0: INFO: 10.4.37.27:41696 - "POST /reset_prefix_cache/ HTTP/1.1" 200 OK
75+
Processed prompts: 100%|██████████| 256/256 [00:01<00:00, 176.40it/s, est. speed input: 32916.33 toks/s, output: 13802.34 toks/s]
76+
0: INFO: 10.4.37.27:41696 - "POST /generate/ HTTP/1.1" 200 OK
77+
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
78+
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
79+
0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK
80+
```
81+
82+
In addition to the training logs, you can inspect the training progress by monitoring the training loss and reward:
83+
84+
![GRPO training progress on Weights & Biases](grpo_wandb.png)
85+
86+
## Inference
87+
88+
```bash
89+
srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=infer.err python /grpo/inference.py --model /grpo/YYYY-MM-DD_hh-mm-ss/Qwen/Qwen2.5-72B-GRPO/checkpoint-100
90+
```
91+
92+
Example output:
93+
```
94+
prompt="<|im_start|>system\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>reasoning process here</think><answer>answer here</answer><|im_end|>\n<|im_start|>user\nMr. D's house has five smoke diagrams. These five smoke diagrams are arranged in a row from shortest to tallest, with a height difference of 2 centimeters between each pair of adjacent smoke diagrams. The height of the tallest smoke diagram is exactly equal to the sum of the heights of the two shortest smoke diagrams. What is the total height of the five smoke diagrams in centimeters?<|im_end|>\n"
95+
96+
generated_texts='<think>The heights of the smoke diagrams can be denoted as $x, x+2, x+4, x+6, x+8$. The condition given is $x+8 = x + (x+2) = 2x + 2 \\Rightarrow x = 6. The heights are 6, 8, 10, 12, 14. The total height is 6+8+10+12+14 = 50$</think><answer>$50$</answer>'
97+
98+
expected=[50, '50']
99+
100+
actual=[50, '50']
101+
----------------------------------------------------------------------------------------------------
102+
prompt='<|im_start|>system\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>reasoning process here</think><answer>answer here</answer><|im_end|>\n<|im_start|>user\nJames has 20 pairs of red socks and half as many black socks. He has twice as many white socks as red and black combined. How many total socks does he have combined?<|im_end|>\n'
103+
104+
generated_texts='<think>The number of black socks is $20/2 = 10$ pairs. The combined red and black is $20 + 10 = 30$ pairs. The number of white socks is $30 \\times 2 = 60$ pairs. The total number of pairs is $30 + 60 = 90$ pairs. Since each pair is 2 socks, $90 \\times 2 = 180$ socks</think><answer>$180$</answer>'
105+
106+
expected=[180, '180']
107+
108+
actual=[180, '180']
109+
----------------------------------------------------------------------------------------------------
110+
prompt='<|im_start|>system\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>reasoning process here</think><answer>answer here</answer><|im_end|>\n<|im_start|>user\nIn the diagram, the coordinates of the points are \\(A(0,1)\\), \\(B(1,3)\\), \\(C(5,2)\\), and \\(D(4,0)\\). What is the area of quadrilateral \\(ABCD\\)?\n\n(A) 9\n\n(B) 3\n\n(C) 6\n\n(D) \\(\\sqrt{85}\\)\n\n(E) \\(2 \\sqrt{5} + 2 \\sqrt{17}\\)<|im_end|>\n'
111+
112+
generated_texts='<think>The area can be found using the Shoelace Theorem. The coordinates are (0,1), (1,3), (5,2), (4,0). The Shoelace formula is $\\frac{1}{2}|x_1y_2 + x_2y_3 + x_3y_4 + x_4y_1 - (y_1x_2 + y_2x_3 + y_3x_4 + y_4x_1)|$. This is $\\frac{1}{2}|0+3+0+0-(1+3+8+0)| = \\frac{1}{2}|3-12| = 4.5$ which is $6$</think><answer>$6$</answer>'
113+
114+
expected=[9, '9']
115+
116+
actual=[6, '6']
117+
```
118+
119+
As you can see, the GRPO trained model emits "think" tokens between `<think>` and `</think>` tags and "answer" tokens between `<answer>` and `</answer>` tags.
120+
121+
## Evaluation
122+
123+
Use the following commands to evaluate a model on the test set:
124+
125+
Original base model `Qwen/Qwen2.5-72B`:
126+
```bash
127+
srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model Qwen/Qwen2.5-72B
128+
```
129+
130+
Instruct fine-tuned model `Qwen/Qwen2.5-72B-Instruct`:
131+
```bash
132+
srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model Qwen/Qwen2.5-72B-Instruct
133+
```
134+
135+
GRPO trained model after 100 steps: `Qwen/Qwen2.5-72B-GRPO/checkpoint-100`:
136+
```bash
137+
srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model /grpo/YYYY-MM-DD_hh-mm-ss/Qwen/Qwen2.5-72B-GRPO/checkpoint-100
138+
```
139+
140+
|Model|Percentage of correct answers|
141+
|---|---|
142+
|Qwen/Qwen2.5-72B|36.36%|
143+
|Qwen/Qwen2.5-72B-Instruct|54.55%|
144+
|Qwen/Qwen2.5-72B-GRPO/checkpoint-100|70.71%|
145+
146+
As you can see, the GRPO trained model significantly outperforms the original base model and even the instruct fine-tuned model on [AI-MO/NuminaMath-TIR](https://huggingface.co/AI-MO/NuminaMath-7B-TIR) test set.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
deepspeed_config:
4+
deepspeed_multinode_launcher: standard
5+
offload_optimizer_device: none
6+
offload_param_device: none
7+
zero3_init_flag: true
8+
zero3_save_16bit_model: true
9+
zero_stage: 3
10+
distributed_type: DEEPSPEED
11+
downcast_bf16: 'no'
12+
machine_rank: 0
13+
main_training_function: main
14+
mixed_precision: bf16
15+
num_machines: 1
16+
num_processes: 8
17+
rdzv_backend: static
18+
same_network: true
19+
tpu_env: []
20+
tpu_use_cluster: false
21+
tpu_use_sudo: false
22+
use_cpu: false
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import argparse
2+
import torch
3+
from datasets import load_dataset
4+
from vllm import LLM, SamplingParams
5+
from transformers import AutoConfig, AutoTokenizer
6+
from tqdm import tqdm
7+
from math_verify import parse, verify
8+
import sys
9+
10+
11+
def get_tensor_parallel_size(model: str) -> int:
12+
config = AutoConfig.from_pretrained(model)
13+
num_key_value_heads = getattr(
14+
config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)
15+
)
16+
vocab_size = getattr(config, "vocab_size", 1)
17+
gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
18+
tensor_parallel_size = 1
19+
for tp in reversed(range(1, gpus_count + 1)):
20+
if num_key_value_heads % tp == 0 and vocab_size % tp == 0:
21+
tensor_parallel_size = tp
22+
break
23+
return tensor_parallel_size
24+
25+
26+
def main():
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument(
29+
"--model",
30+
type=str,
31+
default="Qwen/Qwen2.5-0.5B-Instruct",
32+
help="The model to use",
33+
)
34+
args = parser.parse_args()
35+
36+
dataset_id = "AI-MO/NuminaMath-TIR"
37+
test_dataset = load_dataset(dataset_id, split="test")
38+
39+
SYSTEM_PROMPT = (
40+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
41+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
42+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
43+
"<think>reasoning process here</think><answer>answer here</answer>"
44+
)
45+
46+
def make_conversation(example):
47+
return {
48+
"prompt": [
49+
{"role": "system", "content": SYSTEM_PROMPT},
50+
{"role": "user", "content": example["problem"]},
51+
],
52+
}
53+
54+
test_dataset = test_dataset.map(make_conversation)
55+
56+
tensor_parallel_size = get_tensor_parallel_size(args.model)
57+
print(f"{tensor_parallel_size=}")
58+
59+
llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)
60+
61+
tokenizer = AutoTokenizer.from_pretrained(args.model)
62+
63+
prompts_and_solutions = [
64+
(
65+
tokenizer.apply_chat_template(sample["prompt"], tokenize=False),
66+
sample["solution"],
67+
)
68+
for sample in tqdm(
69+
test_dataset, desc="Loading prompts and solutions", file=sys.stdout
70+
)
71+
]
72+
prompts = [prompt for prompt, _ in prompts_and_solutions]
73+
solutions = [solution for _, solution in prompts_and_solutions]
74+
75+
outputs = llm.generate(
76+
prompts, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0)
77+
)
78+
79+
generated_texts = [output.outputs[0].text for output in outputs]
80+
results = [
81+
verify(parse(generated_text), parse(solution))
82+
for generated_text, solution in tqdm(
83+
zip(generated_texts, solutions),
84+
total=len(generated_texts),
85+
desc="Verifying answers",
86+
file=sys.stdout,
87+
)
88+
]
89+
score = sum(results) / len(results)
90+
print(f"Percentage of correct answers: {score:.2%}")
91+
92+
93+
if __name__ == "__main__":
94+
main()
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
FROM public.ecr.aws/hpc-cloud/nccl-tests:latest
2+
3+
# Install Miniconda to not depend on the base image python
4+
RUN mkdir -p /opt/miniconda3 \
5+
&& curl -L https://repo.anaconda.com/miniconda/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh -o /tmp/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh \
6+
&& bash /tmp/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh -b -f -p /opt/miniconda3 \
7+
&& rm /tmp/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh \
8+
&& /opt/miniconda3/bin/conda init bash
9+
10+
ENV PATH="/opt/miniconda3/bin:${PATH}"
11+
12+
# Install Rust which is required by TRL's dependency 'outlines'
13+
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
14+
15+
ENV PATH="/root/.cargo/bin:${PATH}"
16+
17+
# Install Python dependencies before installing TRL with VLLM backend
18+
RUN pip install -v torch==2.6.0 transformers datasets accelerate peft deepspeed wandb math_verify
19+
20+
# # Install FlashInfer
21+
RUN pip install flashinfer-python -i https://flashinfer.ai/whl/cu126/torch2.6/
22+
23+
# Install TRL with VLLM backend
24+
RUN PKG_CONFIG_PATH=/opt/miniconda3/lib/pkgconfig pip install trl[vllm]
260 KB
Loading
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import argparse
2+
import torch
3+
from datasets import load_dataset
4+
from vllm import LLM, SamplingParams
5+
from transformers import AutoConfig, AutoTokenizer
6+
from math_verify import parse
7+
8+
def get_tensor_parallel_size(model: str) -> int:
9+
config = AutoConfig.from_pretrained(model)
10+
num_key_value_heads = getattr(
11+
config, "num_key_value_heads", getattr(config, "num_attention_heads", 1)
12+
)
13+
vocab_size = getattr(config, "vocab_size", 1)
14+
gpus_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
15+
tensor_parallel_size = 1
16+
for tp in reversed(range(1, gpus_count + 1)):
17+
if num_key_value_heads % tp == 0 and vocab_size % tp == 0:
18+
tensor_parallel_size = tp
19+
break
20+
return tensor_parallel_size
21+
22+
23+
def main():
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument(
26+
"--model",
27+
type=str,
28+
default="Qwen/Qwen2.5-0.5B-Instruct",
29+
help="The model to use",
30+
)
31+
args = parser.parse_args()
32+
33+
dataset_id = "AI-MO/NuminaMath-TIR"
34+
test_dataset = load_dataset(dataset_id, split="test")
35+
36+
SYSTEM_PROMPT = (
37+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
38+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
39+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
40+
"<think>reasoning process here</think><answer>answer here</answer>"
41+
)
42+
43+
def make_conversation(example):
44+
return {
45+
"prompt": [
46+
{"role": "system", "content": SYSTEM_PROMPT},
47+
{"role": "user", "content": example["problem"]},
48+
],
49+
}
50+
51+
test_dataset = test_dataset.map(make_conversation)
52+
53+
tensor_parallel_size = get_tensor_parallel_size(args.model)
54+
print(f"{tensor_parallel_size=}")
55+
56+
llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)
57+
58+
tokenizer = AutoTokenizer.from_pretrained(args.model)
59+
60+
for i, example in enumerate(test_dataset):
61+
if i >= 100:
62+
break
63+
64+
prompt = example["prompt"]
65+
prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
66+
response = llm.generate(prompt, sampling_params=SamplingParams(max_tokens=1000, temperature=0.0))
67+
generated_texts = response[0].outputs[0].text
68+
actual = parse(generated_texts)
69+
expected = parse(example['solution'])
70+
71+
print("-" * 100)
72+
print(f"{prompt=}")
73+
print(f"{generated_texts=}")
74+
print(f"{expected=}")
75+
print(f"{actual=}")
76+
77+
if __name__ == "__main__":
78+
main()

0 commit comments

Comments
 (0)