Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cookbook/rl/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
LEARNING_RATE = float(os.environ.get('LR', 1e-6))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 8)) # global completion-level mini-batch-size
Expand Down
264 changes: 264 additions & 0 deletions cookbook/rl/short_math_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""GRPO training script for GSM8K dataset.

Converted from the Tinker client version to Ray-based training.
Uses short reasoning format: shorter thinking gets higher format reward.
Answer extracted from \\boxed{} or #### format.
"""
import os
import re
from typing import List, Tuple, Dict, Any

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.metric import CompletionRewardMetric
from twinkle.model import TransformersModel
from twinkle.processor import InputProcessor
from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.sampler import vLLMSampler
from twinkle.preprocessor.llm import GSM8KProcessor

logger = get_logger()

# ========== Configuration ==========
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3-4B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-6))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))
LORA_RANK = int(os.environ.get('LORA_RANK', 16))

SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
'and put your final answer within \\boxed{}.')

import swanlab
swanlab.init(
project='twinkle',
)
Comment on lines +54 to +56
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The swanlab.init call is placed at the module level. This can lead to unexpected side effects when the script is imported or used in different contexts. It is generally better to move initialization logic inside the main() function or a dedicated setup function.



# ========== Reward Functions ==========
class GSM8KBrevityReward(Reward):
"""Brevity reward: rewards shorter completions that contain a valid answer.

Returns 0.0 if no valid answer format (\\boxed{} or ####).
Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
"""

def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
rewards = []
for traj in trajectories:
messages = traj.get('messages', [])
completion = ''
for msg in reversed(messages):
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break

has_answer = bool(
re.search(r'\\boxed\{[^}]+\}', completion)
or re.search(r'####\s*[\-\d,\.]+', completion)
)

if not has_answer:
rewards.append(0.0)
else:
length = len(completion)
if length <= 200:
rewards.append(1.0)
else:
rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
return rewards


# ========== Dataset ==========
def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Template', model_id=MODEL_ID, max_length=4096, truncation_strategy='delete')
dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset


def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = GSM8KAccuracyReward()
brevity_reward_fn = GSM8KBrevityReward()
Comment on lines +105 to +106
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instantiating the reward functions GSM8KAccuracyReward and GSM8KBrevityReward inside compute_rewards is inefficient because this function is called in every training iteration. These should be instantiated once outside the training loop and passed in or reused to avoid unnecessary overhead.

Suggested change
accuracy_reward_fn = GSM8KAccuracyReward()
brevity_reward_fn = GSM8KBrevityReward()
# These should ideally be instantiated once outside the loop
accuracy_reward_fn = kwargs.get('accuracy_reward_fn', GSM8KAccuracyReward())
brevity_reward_fn = kwargs.get('brevity_reward_fn', GSM8KBrevityReward())


accuracy_rewards = accuracy_reward_fn(trajectories)
brevity_rewards = brevity_reward_fn(trajectories)
total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
return total_rewards, brevity_rewards, accuracy_rewards


# ========== Main ==========
def main():
device_groups = [
DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU'),
]

model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

lora_config = LoraConfig(
target_modules='all-linear',
r=LORA_RANK,
lora_alpha=LORA_RANK * 2,
lora_dropout=0.05,
)

if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
mixed_precision='bf16',
)
else:
model = TransformersModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
if USE_MEGATRON:
model.set_optimizer('default', lr=LEARNING_RATE)
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
else:
model.set_optimizer('AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)

model.set_loss('GRPOLoss', epsilon=0.2)
model.set_processor(InputProcessor)
model.set_template('Template', model_id=MODEL_ID)

sampler = vLLMSampler(
model_id=MODEL_ID,
engine_args={
'gpu_memory_utilization': 0.8,
'max_model_len': 8192,
'max_lora_rank': 32, # save as lora_config
# NOTE: To use enable_lora with qwen3.5, ensure vLLM includes PR https://github.com/vllm-project/vllm/pull/36976
'enable_lora': True,
},
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template('Template', model_id=MODEL_ID)

ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
dataloader = DataLoader(
dataset=create_gsm8k_dataset,
batch_size=GLOBAL_BATCH_SIZE,
min_batch_size=GLOBAL_BATCH_SIZE,
device_mesh=model_mesh,
remote_group='model',
)

advantage_fn = GRPOAdvantage()
metrics = CompletionRewardMetric()
sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1, temperature=1.0, top_p=0.95)

optim_step = 0
logger.info('Starting GSM8K GRPO training (short reasoning)')
logger.info(get_device_placement())

for batch in dataloader:
if optim_step >= MAX_STEPS:
break

metrics.reset()
expand_prompts = []
for prompt in batch:
expand_prompts.extend([prompt] * NUM_GENERATIONS)

ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()

sample_responses = sampler.sample(
expand_prompts,
sampling_params,
)

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
all_completion_lengths: List[int] = []

for sample_response in sample_responses:
for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The extraction logic [logprob[0][1] for logprob in sequence.logprobs] appears to be incorrect for the standard vLLM output format. In vLLM, sequence.logprobs is typically a list of dictionaries mapping token IDs to Logprob objects. Accessing logprob[0] would perform a key lookup for token ID 0, and [1] would fail on the resulting Logprob object. This should be updated to correctly extract the log probability of the sampled token.

Suggested change
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
all_old_logps.append([list(lp.values())[0].logprob for lp in sequence.logprobs if lp])

all_completion_lengths.append(len(sequence.tokens))

total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(all_input_data)

metrics.accumulate(
completion_lengths=all_completion_lengths,
rewards={
'total': total_rewards,
'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
},
)

advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()

total_completions = len(all_input_data)
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
mb_inputs = all_input_data[mb_start:mb_end]
mb_old_logps = all_old_logps[mb_start:mb_end]
mb_advantages = advantages[mb_start:mb_end]

model.forward_backward(
inputs=mb_inputs,
old_logps=mb_old_logps,
advantages=mb_advantages,
micro_batch_size=MICRO_BATCH_SIZE,
)
model.clip_grad_and_step()
optim_step += 1

if optim_step >= MAX_STEPS:
break
if optim_step % SAVE_STEPS == 0:
model.save(f'math-grpo-checkpoint-{optim_step}')

log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
swanlab.log(log_dict)
metrics.reset()
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')

logger.info(f'Training completed. optim_steps={optim_step}')
model.save('math-grpo-final')


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,7 @@ def weight_generator():
# Skip LoRA-specific weights for base model sync
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
continue
name = name.replace('base_model.model.', '')
yield name, tensor

else:
Expand All @@ -1446,6 +1447,7 @@ def _raw_weights():
# Skip LoRA-specific weights for base model sync
if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
continue
name = name.replace('base_model.model.', '')
yield name, tensor

def weight_generator():
Expand Down
5 changes: 5 additions & 0 deletions src/twinkle/patch/megatron_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def __call__(self, *args, **kwargs):

if MegatronPeft._peft_patched:
return

def _check_merge_allowed(*args, **kwargs):
pass

BaseTuner._check_merge_allowed = _check_merge_allowed

_origin_get_tied_target_modules = BaseTuner._get_tied_target_modules

Expand Down
Loading