Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a GRPO training script for the GSM8K dataset using Ray, incorporating a brevity reward mechanism to promote concise reasoning. It also adds a patch to the Megatron PEFT implementation to disable merge checks. Key feedback includes correcting the log probability extraction logic for vLLM outputs, moving initialization calls into the main function to prevent side effects, and optimizing performance by instantiating reward functions outside the training loop. Additionally, the commented-out LoRA adapter configuration should be addressed to ensure the intended training behavior.
| for sample_response in sample_responses: | ||
| for sequence in sample_response.sequences: | ||
| all_input_data.append(sequence.new_input_feature) | ||
| all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) |
There was a problem hiding this comment.
The extraction logic [logprob[0][1] for logprob in sequence.logprobs] appears to be incorrect for the standard vLLM output format. In vLLM, sequence.logprobs is typically a list of dictionaries mapping token IDs to Logprob objects. Accessing logprob[0] would perform a key lookup for token ID 0, and [1] would fail on the resulting Logprob object. This should be updated to correctly extract the log probability of the sampled token.
| all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) | |
| all_old_logps.append([list(lp.values())[0].logprob for lp in sequence.logprobs if lp]) |
| swanlab.init( | ||
| project='twinkle', | ||
| ) |
There was a problem hiding this comment.
| accuracy_reward_fn = GSM8KAccuracyReward() | ||
| brevity_reward_fn = GSM8KBrevityReward() |
There was a problem hiding this comment.
Instantiating the reward functions GSM8KAccuracyReward and GSM8KBrevityReward inside compute_rewards is inefficient because this function is called in every training iteration. These should be instantiated once outside the training loop and passed in or reused to avoid unnecessary overhead.
| accuracy_reward_fn = GSM8KAccuracyReward() | |
| brevity_reward_fn = GSM8KBrevityReward() | |
| # These should ideally be instantiated once outside the loop | |
| accuracy_reward_fn = kwargs.get('accuracy_reward_fn', GSM8KAccuracyReward()) | |
| brevity_reward_fn = kwargs.get('brevity_reward_fn', GSM8KBrevityReward()) |
cookbook/rl/short_math_grpo.py
Outdated
| remote_group='model', | ||
| ) | ||
|
|
||
| # model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) |
There was a problem hiding this comment.
The call to add_adapter_to_model is currently commented out. Since the script defines LORA_RANK and lora_config, it seems intended for LoRA training. If this remains commented out, the model will perform full parameter fine-tuning, which might lead to Out-Of-Memory (OOM) issues or inconsistent behavior given the sampler configuration.
| # model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) | |
| model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) |
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).