-
Notifications
You must be signed in to change notification settings - Fork 21
Fix bugs #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix bugs #139
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,264 @@ | ||||||||||||
| """GRPO training script for GSM8K dataset. | ||||||||||||
|
|
||||||||||||
| Converted from the Tinker client version to Ray-based training. | ||||||||||||
| Uses short reasoning format: shorter thinking gets higher format reward. | ||||||||||||
| Answer extracted from \\boxed{} or #### format. | ||||||||||||
| """ | ||||||||||||
| import os | ||||||||||||
| import re | ||||||||||||
| from typing import List, Tuple, Dict, Any | ||||||||||||
|
|
||||||||||||
| from peft import LoraConfig | ||||||||||||
|
|
||||||||||||
| import twinkle | ||||||||||||
| from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger | ||||||||||||
| from twinkle.advantage import GRPOAdvantage | ||||||||||||
| from twinkle.checkpoint_engine import CheckpointEngineManager | ||||||||||||
| from twinkle.data_format import SamplingParams | ||||||||||||
| from twinkle.dataloader import DataLoader | ||||||||||||
| from twinkle.dataset import Dataset, DatasetMeta | ||||||||||||
| from twinkle.metric import CompletionRewardMetric | ||||||||||||
| from twinkle.model import TransformersModel | ||||||||||||
| from twinkle.processor import InputProcessor | ||||||||||||
| from twinkle.reward import GSM8KAccuracyReward | ||||||||||||
| from twinkle.reward.base import Reward | ||||||||||||
| from twinkle.sampler import vLLMSampler | ||||||||||||
| from twinkle.preprocessor.llm import GSM8KProcessor | ||||||||||||
|
|
||||||||||||
| logger = get_logger() | ||||||||||||
|
|
||||||||||||
| # ========== Configuration ========== | ||||||||||||
| MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') | ||||||||||||
| USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) | ||||||||||||
|
|
||||||||||||
| MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) | ||||||||||||
| SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) | ||||||||||||
| NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS | ||||||||||||
|
|
||||||||||||
| NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) | ||||||||||||
| MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) | ||||||||||||
| LEARNING_RATE = float(os.environ.get('LR', 1e-6)) | ||||||||||||
| MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) | ||||||||||||
| BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) | ||||||||||||
| MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) | ||||||||||||
| MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) | ||||||||||||
| GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) | ||||||||||||
| ADAPTER_NAME = 'default' | ||||||||||||
| SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50)) | ||||||||||||
| LORA_RANK = int(os.environ.get('LORA_RANK', 16)) | ||||||||||||
|
|
||||||||||||
| SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' | ||||||||||||
| 'and put your final answer within \\boxed{}.') | ||||||||||||
|
|
||||||||||||
| import swanlab | ||||||||||||
| swanlab.init( | ||||||||||||
| project='twinkle', | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ========== Reward Functions ========== | ||||||||||||
| class GSM8KBrevityReward(Reward): | ||||||||||||
| """Brevity reward: rewards shorter completions that contain a valid answer. | ||||||||||||
|
|
||||||||||||
| Returns 0.0 if no valid answer format (\\boxed{} or ####). | ||||||||||||
| Otherwise returns higher score for shorter completions (1.0 at <=200 chars). | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: | ||||||||||||
| rewards = [] | ||||||||||||
| for traj in trajectories: | ||||||||||||
| messages = traj.get('messages', []) | ||||||||||||
| completion = '' | ||||||||||||
| for msg in reversed(messages): | ||||||||||||
| if msg.get('role') == 'assistant': | ||||||||||||
| completion = msg.get('content', '') | ||||||||||||
| break | ||||||||||||
|
|
||||||||||||
| has_answer = bool( | ||||||||||||
| re.search(r'\\boxed\{[^}]+\}', completion) | ||||||||||||
| or re.search(r'####\s*[\-\d,\.]+', completion) | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| if not has_answer: | ||||||||||||
| rewards.append(0.0) | ||||||||||||
| else: | ||||||||||||
| length = len(completion) | ||||||||||||
| if length <= 200: | ||||||||||||
| rewards.append(1.0) | ||||||||||||
| else: | ||||||||||||
| rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) | ||||||||||||
| return rewards | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ========== Dataset ========== | ||||||||||||
| def create_gsm8k_dataset(): | ||||||||||||
| dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) | ||||||||||||
| dataset.set_template('Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete') | ||||||||||||
| dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) | ||||||||||||
| dataset.encode(add_generation_prompt=True) | ||||||||||||
| return dataset | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def compute_rewards( | ||||||||||||
| trajectories: List[Dict[str, Any]], | ||||||||||||
| ) -> Tuple[List[float], List[float], List[float]]: | ||||||||||||
| accuracy_reward_fn = GSM8KAccuracyReward() | ||||||||||||
| brevity_reward_fn = GSM8KBrevityReward() | ||||||||||||
|
Comment on lines
+105
to
+106
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instantiating the reward functions
Suggested change
|
||||||||||||
|
|
||||||||||||
| accuracy_rewards = accuracy_reward_fn(trajectories) | ||||||||||||
| brevity_rewards = brevity_reward_fn(trajectories) | ||||||||||||
| total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] | ||||||||||||
| return total_rewards, brevity_rewards, accuracy_rewards | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ========== Main ========== | ||||||||||||
| def main(): | ||||||||||||
| device_groups = [ | ||||||||||||
| DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), | ||||||||||||
| DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), | ||||||||||||
| ] | ||||||||||||
|
|
||||||||||||
| model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) | ||||||||||||
| sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) | ||||||||||||
| twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) | ||||||||||||
|
|
||||||||||||
| lora_config = LoraConfig( | ||||||||||||
| target_modules='all-linear', | ||||||||||||
| r=LORA_RANK, | ||||||||||||
| lora_alpha=LORA_RANK * 2, | ||||||||||||
| lora_dropout=0.05, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| if USE_MEGATRON: | ||||||||||||
| from twinkle.model.megatron import MegatronModel | ||||||||||||
| model = MegatronModel( | ||||||||||||
| model_id=MODEL_ID, | ||||||||||||
| device_mesh=model_mesh, | ||||||||||||
| remote_group='model', | ||||||||||||
| mixed_precision='bf16', | ||||||||||||
| ) | ||||||||||||
| else: | ||||||||||||
| model = TransformersModel( | ||||||||||||
| model_id=MODEL_ID, | ||||||||||||
| device_mesh=model_mesh, | ||||||||||||
| remote_group='model', | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) | ||||||||||||
| if USE_MEGATRON: | ||||||||||||
| model.set_optimizer('default', lr=LEARNING_RATE) | ||||||||||||
| model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) | ||||||||||||
| else: | ||||||||||||
| model.set_optimizer('AdamW', lr=LEARNING_RATE) | ||||||||||||
| model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) | ||||||||||||
|
|
||||||||||||
| model.set_loss('GRPOLoss', epsilon=0.2) | ||||||||||||
| model.set_processor(InputProcessor) | ||||||||||||
| model.set_template('Template', model_id=MODEL_ID) | ||||||||||||
|
|
||||||||||||
| sampler = vLLMSampler( | ||||||||||||
| model_id=MODEL_ID, | ||||||||||||
| engine_args={ | ||||||||||||
| 'gpu_memory_utilization': 0.8, | ||||||||||||
| 'max_model_len': 8192, | ||||||||||||
| 'max_lora_rank': 32, # save as lora_config | ||||||||||||
| # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 | ||||||||||||
| 'enable_lora': True, | ||||||||||||
| }, | ||||||||||||
| device_mesh=sampler_mesh, | ||||||||||||
| remote_group='sampler', | ||||||||||||
| ) | ||||||||||||
| sampler.set_template('Template', model_id=MODEL_ID) | ||||||||||||
|
|
||||||||||||
| ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) | ||||||||||||
|
|
||||||||||||
| GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS | ||||||||||||
| dataloader = DataLoader( | ||||||||||||
| dataset=create_gsm8k_dataset, | ||||||||||||
| batch_size=GLOBAL_BATCH_SIZE, | ||||||||||||
| min_batch_size=GLOBAL_BATCH_SIZE, | ||||||||||||
| device_mesh=model_mesh, | ||||||||||||
| remote_group='model', | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| advantage_fn = GRPOAdvantage() | ||||||||||||
| metrics = CompletionRewardMetric() | ||||||||||||
| sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95) | ||||||||||||
|
|
||||||||||||
| optim_step = 0 | ||||||||||||
| logger.info('Starting GSM8K GRPO training (short reasoning)') | ||||||||||||
| logger.info(get_device_placement()) | ||||||||||||
|
|
||||||||||||
| for batch in dataloader: | ||||||||||||
| if optim_step >= MAX_STEPS: | ||||||||||||
| break | ||||||||||||
|
|
||||||||||||
| metrics.reset() | ||||||||||||
| expand_prompts = [] | ||||||||||||
| for prompt in batch: | ||||||||||||
| expand_prompts.extend([prompt] * NUM_GENERATIONS) | ||||||||||||
|
|
||||||||||||
| ckpt_manager.sync_weights(merge_and_sync=False) | ||||||||||||
| sampler.reset_prefix_cache() | ||||||||||||
|
|
||||||||||||
| sample_responses = sampler.sample( | ||||||||||||
| expand_prompts, | ||||||||||||
| sampling_params, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| all_input_data: List[Dict[str, Any]] = [] | ||||||||||||
| all_old_logps: List[List[float]] = [] | ||||||||||||
| all_completion_lengths: List[int] = [] | ||||||||||||
|
|
||||||||||||
| for sample_response in sample_responses: | ||||||||||||
| for sequence in sample_response.sequences: | ||||||||||||
| all_input_data.append(sequence.new_input_feature) | ||||||||||||
| all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The extraction logic
Suggested change
|
||||||||||||
| all_completion_lengths.append(len(sequence.tokens)) | ||||||||||||
|
|
||||||||||||
| total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data) | ||||||||||||
|
|
||||||||||||
| metrics.accumulate( | ||||||||||||
| completion_lengths=all_completion_lengths, | ||||||||||||
| rewards={ | ||||||||||||
| 'total': total_rewards, | ||||||||||||
| 'brevity': brevity_rewards, | ||||||||||||
| 'accuracy': accuracy_rewards, | ||||||||||||
| }, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() | ||||||||||||
|
|
||||||||||||
| total_completions = len(all_input_data) | ||||||||||||
| for mb_start in range(0, total_completions, MINI_BATCH_SIZE): | ||||||||||||
| mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) | ||||||||||||
| mb_inputs = all_input_data[mb_start:mb_end] | ||||||||||||
| mb_old_logps = all_old_logps[mb_start:mb_end] | ||||||||||||
| mb_advantages = advantages[mb_start:mb_end] | ||||||||||||
|
|
||||||||||||
| model.forward_backward( | ||||||||||||
| inputs=mb_inputs, | ||||||||||||
| old_logps=mb_old_logps, | ||||||||||||
| advantages=mb_advantages, | ||||||||||||
| micro_batch_size=MICRO_BATCH_SIZE, | ||||||||||||
| ) | ||||||||||||
| model.clip_grad_and_step() | ||||||||||||
| optim_step += 1 | ||||||||||||
|
|
||||||||||||
| if optim_step >= MAX_STEPS: | ||||||||||||
| break | ||||||||||||
| if optim_step % SAVE_STEPS == 0: | ||||||||||||
| model.save(f'math-grpo-checkpoint-{optim_step}') | ||||||||||||
|
|
||||||||||||
| log_dict = metrics.calculate() | ||||||||||||
| log_dict.update(model.calculate_metric(is_training=True)) | ||||||||||||
| swanlab.log(log_dict) | ||||||||||||
| metrics.reset() | ||||||||||||
| logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') | ||||||||||||
|
|
||||||||||||
| logger.info(f'Training completed. optim_steps={optim_step}') | ||||||||||||
| model.save('math-grpo-final') | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| if __name__ == '__main__': | ||||||||||||
| main() | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
swanlab.initcall is placed at the module level. This can lead to unexpected side effects when the script is imported or used in different contexts. It is generally better to move initialization logic inside themain()function or a dedicated setup function.