|
| 1 | +# Multi-node Large model GRPO training using Hugging Face TRL |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +This is a test case for multi-node large model GRPO training using [Hugging Face TRL](https://github.com/huggingface/trl). [Qwen/Qwen2.5-72B](https://huggingface.co/Qwen/Qwen2.5-72B) is used as a base model and [AI-MO/NuminaMath-TIR](https://huggingface.co/AI-MO/NuminaMath-7B-TIR) as a dataset for GRPO training. |
| 6 | + |
| 7 | +## Prerequisites |
| 8 | + |
| 9 | +### Download the model |
| 10 | + |
| 11 | +We are going to use HF_HOME environment variable to access the model from the containers, so define it before downloading the model: |
| 12 | +```bash |
| 13 | +export HF_HOME=~/.cache/huggingface # or any other directory that you prefer |
| 14 | +``` |
| 15 | + |
| 16 | +Install huggingface-cli: |
| 17 | +```bash |
| 18 | +pip install -U "huggingface_hub[cli]" |
| 19 | +``` |
| 20 | + |
| 21 | +Download the model: |
| 22 | +```bash |
| 23 | +hf download Qwen/Qwen2.5-72B |
| 24 | +``` |
| 25 | + |
| 26 | +### Docker Image |
| 27 | + |
| 28 | +All the dependencies are defined in `grpo.Dockerfile`. It uses Python 3.12, PyTorch 2.6.0 and the latest version of TRL. Build the image with the following command: |
| 29 | + |
| 30 | +```bash |
| 31 | +docker build -f grpo.Dockerfile -t grpo:latest . |
| 32 | +``` |
| 33 | + |
| 34 | +### Enroot |
| 35 | + |
| 36 | +To run our container on Slurm, convert the container into a Squash file using Enroot: |
| 37 | + |
| 38 | +```bash |
| 39 | +enroot import -o ./grpo.sqsh dockerd://grpo:latest |
| 40 | +``` |
| 41 | + |
| 42 | +## Launching GRPO training |
| 43 | + |
| 44 | +Launch the GRPO training with the following command: |
| 45 | + |
| 46 | +```bash |
| 47 | +sbatch train.sbatch Qwen/Qwen2.5-72B |
| 48 | +``` |
| 49 | + |
| 50 | +The training script launches 8 nodes for training and 1 node for generation using vLLM, a high-throughput, low-latency inference engine for LLMs. The distributed training uses ZeRO stage 3 to accelerate the training process. |
| 51 | + |
| 52 | +The logs can be inspected using tail command: |
| 53 | + |
| 54 | +GRPO Training logs: |
| 55 | +```bash |
| 56 | +tail -f -n +0 grpo-XXX.out |
| 57 | +``` |
| 58 | +sample output: |
| 59 | +``` |
| 60 | + 1%| | 17/2264 [01:22<2:55:16, 4.68s/it] |
| 61 | +0: {'loss': 0.0785, 'grad_norm': 0.8229517735973697, 'learning_rate': 9.916077738515903e-06, 'num_tokens': 1498339.0, 'completions/mean_length': 134.934765625, 'completions/min_length': 35.0, 'completions/max_length': 256.0, 'completions/clipped_ratio': 0.08203125, 'completions/mean_terminated_length': 124.83461303710938, 'completions/min_terminated_length': 35.0, 'completions/max_terminated_length': 253.8, 'rewards/format_reward/mean': 0.90703125, 'rewards/format_reward/std': 0.27258416190743445, 'rewards/accuracy_reward/mean': 0.224609375, 'rewards/accuracy_reward/std': 0.4104481041431427, 'reward': 1.131640625, 'reward_std': 0.34059175848960876, 'kl': 0.2958984375, 'clip_ratio/low_mean': 0.0, 'clip_ratio/low_min': 0.0, 'clip_ratio/high_mean': 0.0, 'clip_ratio/high_max': 0.0, 'clip_ratio/region_mean': 0.0, 'epoch': 0.01} |
| 62 | +``` |
| 63 | + |
| 64 | +vLLM logs: |
| 65 | +```bash |
| 66 | +tail -f -n +0 vllm-XXX.out |
| 67 | +``` |
| 68 | +sample output: |
| 69 | +``` |
| 70 | +0: INFO: 10.4.37.27:41696 - "POST /upda_named_param/ HTTP/1.1" 200 OK |
| 71 | +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK |
| 72 | +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK |
| 73 | +0: INFO 05-14 23:13:00 [block_pool.py:264] Successfully reset prefix cache |
| 74 | +0: INFO: 10.4.37.27:41696 - "POST /reset_prefix_cache/ HTTP/1.1" 200 OK |
| 75 | +Processed prompts: 100%|██████████| 256/256 [00:01<00:00, 176.40it/s, est. speed input: 32916.33 toks/s, output: 13802.34 toks/s] |
| 76 | +0: INFO: 10.4.37.27:41696 - "POST /generate/ HTTP/1.1" 200 OK |
| 77 | +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK |
| 78 | +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK |
| 79 | +0: INFO: 10.4.37.27:41696 - "POST /update_named_param/ HTTP/1.1" 200 OK |
| 80 | +``` |
| 81 | + |
| 82 | +In addition to the training logs, you can inspect the training progress by monitoring the training loss and reward: |
| 83 | + |
| 84 | + |
| 85 | + |
| 86 | +## Inference |
| 87 | + |
| 88 | +```bash |
| 89 | +srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=infer.err python /grpo/inference.py --model /grpo/YYYY-MM-DD_hh-mm-ss/Qwen/Qwen2.5-72B-GRPO/checkpoint-100 |
| 90 | +``` |
| 91 | + |
| 92 | +Example output: |
| 93 | +``` |
| 94 | +prompt="<|im_start|>system\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>reasoning process here</think><answer>answer here</answer><|im_end|>\n<|im_start|>user\nMr. D's house has five smoke diagrams. These five smoke diagrams are arranged in a row from shortest to tallest, with a height difference of 2 centimeters between each pair of adjacent smoke diagrams. The height of the tallest smoke diagram is exactly equal to the sum of the heights of the two shortest smoke diagrams. What is the total height of the five smoke diagrams in centimeters?<|im_end|>\n" |
| 95 | +
|
| 96 | +generated_texts='<think>The heights of the smoke diagrams can be denoted as $x, x+2, x+4, x+6, x+8$. The condition given is $x+8 = x + (x+2) = 2x + 2 \\Rightarrow x = 6. The heights are 6, 8, 10, 12, 14. The total height is 6+8+10+12+14 = 50$</think><answer>$50$</answer>' |
| 97 | +
|
| 98 | +expected=[50, '50'] |
| 99 | +
|
| 100 | +actual=[50, '50'] |
| 101 | +---------------------------------------------------------------------------------------------------- |
| 102 | +prompt='<|im_start|>system\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>reasoning process here</think><answer>answer here</answer><|im_end|>\n<|im_start|>user\nJames has 20 pairs of red socks and half as many black socks. He has twice as many white socks as red and black combined. How many total socks does he have combined?<|im_end|>\n' |
| 103 | +
|
| 104 | +generated_texts='<think>The number of black socks is $20/2 = 10$ pairs. The combined red and black is $20 + 10 = 30$ pairs. The number of white socks is $30 \\times 2 = 60$ pairs. The total number of pairs is $30 + 60 = 90$ pairs. Since each pair is 2 socks, $90 \\times 2 = 180$ socks</think><answer>$180$</answer>' |
| 105 | +
|
| 106 | +expected=[180, '180'] |
| 107 | +
|
| 108 | +actual=[180, '180'] |
| 109 | +---------------------------------------------------------------------------------------------------- |
| 110 | +prompt='<|im_start|>system\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think>reasoning process here</think><answer>answer here</answer><|im_end|>\n<|im_start|>user\nIn the diagram, the coordinates of the points are \\(A(0,1)\\), \\(B(1,3)\\), \\(C(5,2)\\), and \\(D(4,0)\\). What is the area of quadrilateral \\(ABCD\\)?\n\n(A) 9\n\n(B) 3\n\n(C) 6\n\n(D) \\(\\sqrt{85}\\)\n\n(E) \\(2 \\sqrt{5} + 2 \\sqrt{17}\\)<|im_end|>\n' |
| 111 | +
|
| 112 | +generated_texts='<think>The area can be found using the Shoelace Theorem. The coordinates are (0,1), (1,3), (5,2), (4,0). The Shoelace formula is $\\frac{1}{2}|x_1y_2 + x_2y_3 + x_3y_4 + x_4y_1 - (y_1x_2 + y_2x_3 + y_3x_4 + y_4x_1)|$. This is $\\frac{1}{2}|0+3+0+0-(1+3+8+0)| = \\frac{1}{2}|3-12| = 4.5$ which is $6$</think><answer>$6$</answer>' |
| 113 | +
|
| 114 | +expected=[9, '9'] |
| 115 | +
|
| 116 | +actual=[6, '6'] |
| 117 | +``` |
| 118 | + |
| 119 | +As you can see, the GRPO trained model emits "think" tokens between `<think>` and `</think>` tags and "answer" tokens between `<answer>` and `</answer>` tags. |
| 120 | + |
| 121 | +## Evaluation |
| 122 | + |
| 123 | +Use the following commands to evaluate a model on the test set: |
| 124 | + |
| 125 | +Original base model `Qwen/Qwen2.5-72B`: |
| 126 | +```bash |
| 127 | +srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model Qwen/Qwen2.5-72B |
| 128 | +``` |
| 129 | + |
| 130 | +Instruct fine-tuned model `Qwen/Qwen2.5-72B-Instruct`: |
| 131 | +```bash |
| 132 | +srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model Qwen/Qwen2.5-72B-Instruct |
| 133 | +``` |
| 134 | + |
| 135 | +GRPO trained model after 100 steps: `Qwen/Qwen2.5-72B-GRPO/checkpoint-100`: |
| 136 | +```bash |
| 137 | +srun --mpi=pmix --cpu-bind=none --container-image ./grpo.sqsh --container-mounts=.:/grpo,$HF_HOME:$HF_HOME --error=eval.err python /grpo/eval.py --model /grpo/YYYY-MM-DD_hh-mm-ss/Qwen/Qwen2.5-72B-GRPO/checkpoint-100 |
| 138 | +``` |
| 139 | + |
| 140 | +|Model|Percentage of correct answers| |
| 141 | +|---|---| |
| 142 | +|Qwen/Qwen2.5-72B|36.36%| |
| 143 | +|Qwen/Qwen2.5-72B-Instruct|54.55%| |
| 144 | +|Qwen/Qwen2.5-72B-GRPO/checkpoint-100|70.71%| |
| 145 | + |
| 146 | +As you can see, the GRPO trained model significantly outperforms the original base model and even the instruct fine-tuned model on [AI-MO/NuminaMath-TIR](https://huggingface.co/AI-MO/NuminaMath-7B-TIR) test set. |
0 commit comments