From 1e4f53e932491037b604b59150d5748f3984a91b Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 4 Apr 2026 16:15:40 +0800 Subject: [PATCH 1/5] wip --- cookbook/rl/short_math_grpo.py | 263 +++++++++++++++++++++++++++++ src/twinkle/patch/megatron_peft.py | 5 + 2 files changed, 268 insertions(+) create mode 100644 cookbook/rl/short_math_grpo.py diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py new file mode 100644 index 00000000..f2cfbe71 --- /dev/null +++ b/cookbook/rl/short_math_grpo.py @@ -0,0 +1,263 @@ +"""GRPO training script for GSM8K dataset. + +Converted from the Tinker client version to Ray-based training. +Uses short reasoning format: shorter thinking gets higher format reward. +Answer extracted from \\boxed{} or #### format. +""" +import os +import re +from typing import List, Tuple, Dict, Any + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.processor import InputProcessor +from twinkle.reward import GSM8KAccuracyReward +from twinkle.reward.base import Reward +from twinkle.sampler import vLLMSampler +from twinkle.preprocessor.llm import GSM8KProcessor + +logger = get_logger() + +# ========== Configuration ========== +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50)) +LORA_RANK = int(os.environ.get('LORA_RANK', 16)) + +SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' + 'and put your final answer within \\boxed{}.') + +import swanlab +swanlab.init( + project='twinkle', +) + + +# ========== Reward Functions ========== +class GSM8KBrevityReward(Reward): + """Brevity reward: rewards shorter completions that contain a valid answer. + + Returns 0.0 if no valid answer format (\\boxed{} or ####). + Otherwise returns higher score for shorter completions (1.0 at <=200 chars). + """ + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards = [] + for traj in trajectories: + messages = traj.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + completion = msg.get('content', '') + break + + has_answer = bool( + re.search(r'\\boxed\{[^}]+\}', completion) + or re.search(r'####\s*[\-\d,\.]+', completion) + ) + + if not has_answer: + rewards.append(0.0) + else: + length = len(completion) + if length <= 200: + rewards.append(1.0) + else: + rewards.append(max(0.0, 1.0 - (length - 200) / 3000)) + return rewards + + +# ========== Dataset ========== +def create_gsm8k_dataset(): + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete') + dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) + dataset.encode(add_generation_prompt=True) + return dataset + + +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + accuracy_reward_fn = GSM8KAccuracyReward() + brevity_reward_fn = GSM8KBrevityReward() + + accuracy_rewards = accuracy_reward_fn(trajectories) + brevity_rewards = brevity_reward_fn(trajectories) + total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)] + return total_rewards, brevity_rewards, accuracy_rewards + + +# ========== Main ========== +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'), + ] + + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + lora_config = LoraConfig( + target_modules='all-linear', + r=LORA_RANK, + lora_alpha=LORA_RANK * 2, + lora_dropout=0.05, + ) + + if USE_MEGATRON: + from twinkle.model.megatron import MegatronModel + model = MegatronModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + mixed_precision='bf16', + ) + else: + model = TransformersModel( + model_id=MODEL_ID, + device_mesh=model_mesh, + remote_group='model', + ) + + # model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) + if USE_MEGATRON: + model.set_optimizer('default', lr=LEARNING_RATE) + model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) + else: + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + + model.set_loss('GRPOLoss', epsilon=0.2) + model.set_processor(InputProcessor) + model.set_template('Template', model_id=MODEL_ID) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, + 'max_model_len': 8192, + 'max_lora_rank': max(32, LORA_RANK), + 'enable_lora': False, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template('Template', model_id=MODEL_ID) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + dataloader = DataLoader( + dataset=create_gsm8k_dataset, + batch_size=GLOBAL_BATCH_SIZE, + min_batch_size=GLOBAL_BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95) + + optim_step = 0 + logger.info('Starting GSM8K GRPO training (short reasoning)') + logger.info(get_device_placement()) + + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + + metrics.reset() + expand_prompts = [] + for prompt in batch: + expand_prompts.extend([prompt] * NUM_GENERATIONS) + + ckpt_manager.sync_weights(merge_and_sync=True) + sampler.reset_prefix_cache() + + sample_responses = sampler.sample( + expand_prompts, + sampling_params, + ) + + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) + + total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data) + + metrics.accumulate( + completion_lengths=all_completion_lengths, + rewards={ + 'total': total_rewards, + 'brevity': brevity_rewards, + 'accuracy': accuracy_rewards, + }, + ) + + advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + total_completions = len(all_input_data) + for mb_start in range(0, total_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) + mb_inputs = all_input_data[mb_start:mb_end] + mb_old_logps = all_old_logps[mb_start:mb_end] + mb_advantages = advantages[mb_start:mb_end] + + model.forward_backward( + inputs=mb_inputs, + old_logps=mb_old_logps, + advantages=mb_advantages, + micro_batch_size=MICRO_BATCH_SIZE, + ) + model.clip_grad_and_step() + optim_step += 1 + + if optim_step >= MAX_STEPS: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'math-grpo-checkpoint-{optim_step}') + + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + swanlab.log(log_dict) + metrics.reset() + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('math-grpo-final') + + +if __name__ == '__main__': + main() diff --git a/src/twinkle/patch/megatron_peft.py b/src/twinkle/patch/megatron_peft.py index 435947f9..75727b93 100644 --- a/src/twinkle/patch/megatron_peft.py +++ b/src/twinkle/patch/megatron_peft.py @@ -14,6 +14,11 @@ def __call__(self, *args, **kwargs): if MegatronPeft._peft_patched: return + + def _check_merge_allowed(*args, **kwargs): + pass + + BaseTuner._check_merge_allowed = _check_merge_allowed _origin_get_tied_target_modules = BaseTuner._get_tied_target_modules From d07d01bf27c1ca5f061ea722d64bebdd63d63219 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sat, 4 Apr 2026 23:54:07 +0800 Subject: [PATCH 2/5] wip --- cookbook/rl/grpo.py | 2 +- cookbook/rl/short_math_grpo.py | 13 +++++++------ src/twinkle/model/megatron/megatron.py | 2 ++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index bc864309..0df8dd73 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -29,7 +29,7 @@ NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +LEARNING_RATE = float(os.environ.get('LR', 1e-6)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py index f2cfbe71..3508b9a0 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/short_math_grpo.py @@ -29,7 +29,7 @@ # ========== Configuration ========== MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') -USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '0'))) +USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4)) @@ -37,7 +37,7 @@ NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 1e-5)) +LEARNING_RATE = float(os.environ.get('LR', 1e-6)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) @@ -144,7 +144,7 @@ def main(): remote_group='model', ) - # model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1) if USE_MEGATRON: model.set_optimizer('default', lr=LEARNING_RATE) model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE) @@ -161,8 +161,9 @@ def main(): engine_args={ 'gpu_memory_utilization': 0.8, 'max_model_len': 8192, - 'max_lora_rank': max(32, LORA_RANK), - 'enable_lora': False, + 'max_lora_rank': 32, # save as lora_config + # NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976 + 'enable_lora': True, }, device_mesh=sampler_mesh, remote_group='sampler', @@ -197,7 +198,7 @@ def main(): for prompt in batch: expand_prompts.extend([prompt] * NUM_GENERATIONS) - ckpt_manager.sync_weights(merge_and_sync=True) + ckpt_manager.sync_weights(merge_and_sync=False) sampler.reset_prefix_cache() sample_responses = sampler.sample( diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 01604375..1b21f408 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1424,6 +1424,7 @@ def weight_generator(): # Skip LoRA-specific weights for base model sync if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: continue + name = name.replace('base_model.model.', '') yield name, tensor else: @@ -1446,6 +1447,7 @@ def _raw_weights(): # Skip LoRA-specific weights for base model sync if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: continue + name = name.replace('base_model.model.', '') yield name, tensor def weight_generator(): From 936a8e3357a704ac2d0339f6a131c3685c4ad0c0 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 5 Apr 2026 11:10:29 +0800 Subject: [PATCH 3/5] wip --- src/twinkle/model/megatron/megatron.py | 58 +++++++++++++------ .../model/transformers/transformers.py | 23 +++++++- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 1b21f408..ef63c898 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -36,6 +36,9 @@ from twinkle.template import Template from twinkle.utils import construct_class, selective_log_softmax from .strategy import MegatronStrategy +from twinkle.utils import get_logger + +logger = get_logger() @dataclass @@ -1399,17 +1402,26 @@ def merge_lora(): if isinstance(_model, PeftModel): _model.unmerge_adapter() - def _add_base_layer_suffix(params): - for name, param in params: - if name.endswith('.weight'): - base_layer_name = f'{name[:-7]}.base_layer.weight' - if base_layer_name in model_keys or not model_keys: - name = base_layer_name - elif name.endswith('.bias'): - base_layer_name = f'{name[:-5]}.base_layer.bias' - if base_layer_name in model_keys or not model_keys: - name = base_layer_name - yield name, param + def _normalize(name: str, keep_base_layer: bool) -> str: + name = name.replace('base_model.model.', '') + if not keep_base_layer: + name = name.replace('.base_layer', '') + return name + + def _print_weight_example(names): + for name in names[:3]: + logger.info(f'Sync weight: {name}') + + def _add_base_layer_suffix(name): + if name.endswith('.weight'): + base_layer_name = f'{name[:-7]}.base_layer.weight' + if base_layer_name in model_keys or not model_keys: + name = base_layer_name + elif name.endswith('.bias'): + base_layer_name = f'{name[:-5]}.base_layer.bias' + if base_layer_name in model_keys or not model_keys: + name = base_layer_name + return name is_peft_format = (adapter_name != _default_adapter_name) if base_sync_done and adapter_name: @@ -1418,43 +1430,55 @@ def _add_base_layer_suffix(params): def weight_generator(): with merge_lora(): + names = [] for name, tensor in self.get_hf_state_dict(adapter_name=''): if name is None or tensor is None: continue # Skip LoRA-specific weights for base model sync if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: continue - name = name.replace('base_model.model.', '') + name = _normalize(name, keep_base_layer=False) + names.append(name) yield name, tensor + _print_weight_example(names) else: def weight_generator(): + names = [] for name, tensor in self.get_hf_state_dict(adapter_name=adapter_name): if name is None or tensor is None: continue if 'lora' not in name: continue - name = name.replace('base_model.model.', '') + name = _normalize(name, keep_base_layer=True) + logger.info(f'Sync weight: {name}') + names.append(name) yield name, tensor + _print_weight_example(names) else: # Need to synchronize the base model # First full base-model sync. - def _raw_weights(): + def _raw_weights(add_base_layer_suffix=False): + names = [] for name, tensor in self.get_hf_state_dict(adapter_name=''): if name is None or tensor is None: continue # Skip LoRA-specific weights for base model sync if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name: continue - name = name.replace('base_model.model.', '') + name = _normalize(name, keep_base_layer=False) + if add_base_layer_suffix: + name = _add_base_layer_suffix(name) + names.append(name) yield name, tensor + _print_weight_example(names) def weight_generator(): if is_peft_format and (not merge_and_sync): - yield from _add_base_layer_suffix(_raw_weights()) + yield from _raw_weights(True) else: - yield from _raw_weights() + yield from _raw_weights(False) is_sender = (engine.rank is not None and engine.rank == 0) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 88e08a2e..a7dc7a0b 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -40,6 +40,9 @@ from twinkle.utils import construct_class, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +from twinkle.utils import get_logger + +logger = get_logger() @dataclass @@ -1165,6 +1168,10 @@ def _normalize(name: str, keep_base_layer: bool) -> str: name = name.replace('.base_layer', '') return name + def _print_weight_example(names): + for name in names[:3]: + logger.info(f'Sync weight: {name}') + def _is_lora_key(name: str) -> bool: return 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name @@ -1176,11 +1183,15 @@ def _is_lora_key(name: str) -> bool: def weight_generator(): if isinstance(model, PeftModel): model.merge_adapter() + names = [] for name, tensor in model.state_dict().items(): if _is_lora_key(name): continue tensor = Torch.to_local_tensor(tensor) - yield _normalize(name, keep_base_layer=False), tensor + name = _normalize(name, keep_base_layer=False) + names.append(name) + yield name, tensor + _print_weight_example(names) if isinstance(model, PeftModel): model.unmerge_adapter() else: @@ -1190,9 +1201,13 @@ def weight_generator(): lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) def weight_generator(): + names = [] for name, tensor in lora_state_dict.items(): tensor = Torch.to_local_tensor(tensor) + name = _normalize(name, keep_base_layer=True) + names.append(name) yield name, tensor + _print_weight_example(names) else: # First full base-model sync. Whether to keep ``.base_layer.`` @@ -1203,11 +1218,15 @@ def weight_generator(): state_dict = model.state_dict() def weight_generator(): + names = [] for name, tensor in state_dict.items(): if _is_lora_key(name): continue tensor = Torch.to_local_tensor(tensor) - yield _normalize(name, keep_base_layer=keep_base_layer), tensor + name = _normalize(name, keep_base_layer=keep_base_layer) + names.append(name) + yield name, tensor + _print_weight_example(names) # Run async send_weights in a dedicated event loop thread. # We cannot use the Ray worker's event loop because it may already From d4bcbdd2e15e3ce121ef83176a69e462dc6217a9 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 5 Apr 2026 12:21:47 +0800 Subject: [PATCH 4/5] lint --- cookbook/rl/grpo_mm.py | 6 +++++- cookbook/rl/short_math_grpo.py | 18 +++++++++++------- src/twinkle/model/megatron/megatron.py | 6 ++---- src/twinkle/model/transformers/transformers.py | 3 +-- src/twinkle/patch/megatron_peft.py | 2 +- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/cookbook/rl/grpo_mm.py b/cookbook/rl/grpo_mm.py index 0dcef2c2..d6f934d5 100644 --- a/cookbook/rl/grpo_mm.py +++ b/cookbook/rl/grpo_mm.py @@ -138,7 +138,11 @@ def main(): # LoRA configuration lora_config = LoraConfig( - target_modules=['all-linear'], + target_modules=[ + 'q_proj', 'k_proj', 'v_proj', 'o_proj', + 'gate_proj', 'up_proj', 'down_proj', + 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj', + ], r=16, lora_alpha=32, lora_dropout=0.05, diff --git a/cookbook/rl/short_math_grpo.py b/cookbook/rl/short_math_grpo.py index 3508b9a0..55939cbd 100644 --- a/cookbook/rl/short_math_grpo.py +++ b/cookbook/rl/short_math_grpo.py @@ -28,7 +28,7 @@ logger = get_logger() # ========== Configuration ========== -MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B') +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B') USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1'))) MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4)) @@ -37,14 +37,14 @@ NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 1e-6)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4)) MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) ADAPTER_NAME = 'default' -SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50)) +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 1000)) LORA_RANK = int(os.environ.get('LORA_RANK', 16)) SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning ' @@ -93,7 +93,7 @@ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: # ========== Dataset ========== def create_gsm8k_dataset(): dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) - dataset.set_template('Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete') + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete', enable_thinking=False) dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT)) dataset.encode(add_generation_prompt=True) return dataset @@ -123,7 +123,11 @@ def main(): twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) lora_config = LoraConfig( - target_modules='all-linear', + target_modules=[ + 'q_proj', 'k_proj', 'v_proj', 'o_proj', + 'gate_proj', 'up_proj', 'down_proj', + 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj', + ], r=LORA_RANK, lora_alpha=LORA_RANK * 2, lora_dropout=0.05, @@ -154,7 +158,7 @@ def main(): model.set_loss('GRPOLoss', epsilon=0.2) model.set_processor(InputProcessor) - model.set_template('Template', model_id=MODEL_ID) + model.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) sampler = vLLMSampler( model_id=MODEL_ID, @@ -168,7 +172,7 @@ def main(): device_mesh=sampler_mesh, remote_group='sampler', ) - sampler.set_template('Template', model_id=MODEL_ID) + sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False) ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index ef63c898..9b485f55 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -34,9 +34,8 @@ from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template -from twinkle.utils import construct_class, selective_log_softmax +from twinkle.utils import construct_class, get_logger, selective_log_softmax from .strategy import MegatronStrategy -from twinkle.utils import get_logger logger = get_logger() @@ -1407,7 +1406,7 @@ def _normalize(name: str, keep_base_layer: bool) -> str: if not keep_base_layer: name = name.replace('.base_layer', '') return name - + def _print_weight_example(names): for name in names[:3]: logger.info(f'Sync weight: {name}') @@ -1452,7 +1451,6 @@ def weight_generator(): if 'lora' not in name: continue name = _normalize(name, keep_base_layer=True) - logger.info(f'Sync weight: {name}') names.append(name) yield name, tensor _print_weight_example(names) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index a7dc7a0b..e8f9bdda 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -37,10 +37,9 @@ from twinkle.patch import Patch, apply_patch from twinkle.processor import InputProcessor from twinkle.template import Template -from twinkle.utils import construct_class, selective_log_softmax, torch_util +from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm -from twinkle.utils import get_logger logger = get_logger() diff --git a/src/twinkle/patch/megatron_peft.py b/src/twinkle/patch/megatron_peft.py index 75727b93..7663b010 100644 --- a/src/twinkle/patch/megatron_peft.py +++ b/src/twinkle/patch/megatron_peft.py @@ -14,7 +14,7 @@ def __call__(self, *args, **kwargs): if MegatronPeft._peft_patched: return - + def _check_merge_allowed(*args, **kwargs): pass From 4158ccf65f97a35bee84fe51da229597623e5167 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Sun, 5 Apr 2026 12:22:22 +0800 Subject: [PATCH 5/5] fix --- cookbook/rl/grpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py index 0df8dd73..bc864309 100644 --- a/cookbook/rl/grpo.py +++ b/cookbook/rl/grpo.py @@ -29,7 +29,7 @@ NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8)) MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) -LEARNING_RATE = float(os.environ.get('LR', 1e-6)) +LEARNING_RATE = float(os.environ.get('LR', 1e-5)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size