Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
8fb9991
wip
tastelikefeet Mar 30, 2026
35efd81
wip
tastelikefeet Mar 30, 2026
01e7535
wip
tastelikefeet Mar 30, 2026
399960a
wip
tastelikefeet Mar 30, 2026
f2bd846
fix
tastelikefeet Mar 30, 2026
096c193
wip
tastelikefeet Mar 30, 2026
53c19a7
wip
tastelikefeet Mar 30, 2026
d951d47
wip
tastelikefeet Mar 31, 2026
ad82a77
wip
tastelikefeet Mar 31, 2026
48cbf13
wip
tastelikefeet Mar 31, 2026
43ef29e
wip
tastelikefeet Mar 31, 2026
c3c7620
fix
tastelikefeet Mar 31, 2026
dec91c9
lint code
tastelikefeet Mar 31, 2026
0a1c34c
fix
tastelikefeet Mar 31, 2026
13ec7f5
Merge commit 'a89ede55e3daa4fc36f0319c77847e0bf257fcce' into feat/mbr…
tastelikefeet Mar 31, 2026
ab0b161
fix
tastelikefeet Mar 31, 2026
d50465f
fix
tastelikefeet Mar 31, 2026
08d2daf
wip
tastelikefeet Apr 1, 2026
fa6b463
wip
tastelikefeet Apr 1, 2026
db71d2e
wip
tastelikefeet Apr 1, 2026
a218b9d
Merge commit 'a222914cae55cca628bf5154bf88ae037cebe7f7' into feat/mbr…
tastelikefeet Apr 1, 2026
6c23ca1
wip
tastelikefeet Apr 1, 2026
3862735
Merge branch 'feat/mbridge' of https://github.com/tastelikefeet/twink…
tastelikefeet Apr 1, 2026
ec2ac0b
wip
tastelikefeet Apr 1, 2026
dc1ee9c
wip
tastelikefeet Apr 1, 2026
f44fe4c
wip
tastelikefeet Apr 1, 2026
4555bfc
wip
tastelikefeet Apr 2, 2026
e9cf0da
wip
tastelikefeet Apr 2, 2026
30e1907
fix
tastelikefeet Apr 2, 2026
1e56c29
wip
tastelikefeet Apr 2, 2026
2a8895c
Merge branch 'feat/mbridge' of https://github.com/tastelikefeet/twink…
tastelikefeet Apr 2, 2026
682afab
wip
tastelikefeet Apr 2, 2026
9e1f823
wip
tastelikefeet Apr 3, 2026
d823f7e
lint code
tastelikefeet Apr 3, 2026
7c9d854
Merge branch 'main' into feat/mbridge
tastelikefeet Apr 3, 2026
9eb15b8
wip
tastelikefeet Apr 3, 2026
5f2d047
Merge branch 'feat/mbridge' of https://github.com/tastelikefeet/twink…
tastelikefeet Apr 3, 2026
ec64959
fix
tastelikefeet Apr 3, 2026
afa2bdf
wip
tastelikefeet Apr 3, 2026
8af8b82
wip
tastelikefeet Apr 3, 2026
9a819e7
wip
tastelikefeet Apr 3, 2026
29cb370
wip
tastelikefeet Apr 3, 2026
153afd8
fix
tastelikefeet Apr 3, 2026
347bd6b
wip
tastelikefeet Apr 3, 2026
ec6cf8c
Merge branch 'feat/mbridge' of https://github.com/tastelikefeet/twink…
tastelikefeet Apr 3, 2026
6488169
fix
tastelikefeet Apr 3, 2026
5752d45
fix
tastelikefeet Apr 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cookbook/megatron/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import MegatronModel
from twinkle.preprocessor import SelfCognitionProcessor
# Construct a device_mesh, tp=pp=cp=2, dp=1
# Construct a device_mesh, tp=pp=dp=2
device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
Expand All @@ -19,7 +19,7 @@
def eval(model):
# 100 Samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=16)
Expand All @@ -33,7 +33,7 @@ def train():
# 1000 samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
# Set template to prepare encoding
dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
# Preprocess the dataset to standard format
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
# Encode dataset
Expand Down
6 changes: 3 additions & 3 deletions cookbook/rl/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

def create_gsm8k_dataset():
dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
dataset.set_template('Template', model_id=MODEL_ID, max_length=400)
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=400)
dataset.map(GSM8KProcessor())
dataset.encode(add_generation_prompt=True)
return dataset
Expand Down Expand Up @@ -94,7 +94,7 @@ def main():
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
model.set_loss('GRPOLoss', epsilon=0.2)
model.set_processor(InputProcessor)
model.set_template('Template', model_id=MODEL_ID)
model.set_template('Qwen3_5Template', model_id=MODEL_ID)

sampler = vLLMSampler(
model_id=MODEL_ID,
Expand All @@ -108,7 +108,7 @@ def main():
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template(Template, model_id=MODEL_ID)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID)

ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

Expand Down
288 changes: 288 additions & 0 deletions cookbook/rl/grpo_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""GRPO training script for OlympiadBench multimodal math/physics dataset.

Supports three subsets:
- OE_MM_maths_zh_CEE: Multimodal math problems (Chinese CEE)
- OE_MM_physics_zh_CEE: Multimodal physics problems (Chinese CEE)
- OE_TO_maths_zh_CEE: Text-only math problems (Chinese CEE)
"""
import os
from typing import List, Tuple, Dict, Any

from peft import LoraConfig

import twinkle
from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
from twinkle.advantage import GRPOAdvantage
from twinkle.checkpoint_engine import CheckpointEngineManager
from twinkle.data_format import SamplingParams
from twinkle.dataloader import DataLoader
from twinkle.dataset import DatasetMeta, LazyDataset
from twinkle.metric import CompletionRewardMetric
from twinkle.model import TransformersModel
from twinkle.preprocessor.olympiad_bench import OlympiadBenchProcessor
from twinkle.reward.olympiad_bench import (
OlympiadBenchAccuracyReward,
OlympiadBenchFormatReward,
OlympiadBenchQualityReward,
)
from twinkle.sampler import vLLMSampler

import swanlab
swanlab.init(
project='twinkle',
)
logger = get_logger()

# Model configuration
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))

# GPU configuration
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 4))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS

# Training hyperparameters
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
LEARNING_RATE = float(os.environ.get('LR', 1e-5))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 4))
MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4))
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
ADAPTER_NAME = 'default'
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50))

# Dataset configuration
SUBSETS = [
'OE_MM_maths_zh_CEE',
'OE_MM_physics_zh_CEE',
'OE_TO_maths_zh_CEE',
]


def create_olympiad_dataset():
"""Create OlympiadBench dataset with all three subsets mixed."""
# Create dataset with first subset
ds = DatasetMeta(
'ms://AI-ModelScope/OlympiadBench',
subset_name=SUBSETS[0],
split='train',
)
dataset = LazyDataset(ds)
dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds)

# Add remaining subsets
for subset in SUBSETS[1:]:
ds = DatasetMeta(
'ms://AI-ModelScope/OlympiadBench',
subset_name=subset,
split='train',
)
dataset.add_dataset(ds)
dataset.map(OlympiadBenchProcessor(language='zh'), dataset_meta=ds)

# Set template and preprocess
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=2048, enable_thinking=False)
# Mix all datasets (interleave)
dataset.mix_dataset(interleave=True)
return dataset


def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], Dict[str, List[float]]]:
"""Compute rewards for trajectories.

Three core rewards, all normalized to [0, 1]:
- Accuracy: Answer correctness (weight: 2.0)
- Format: Answer formatting and consistency (weight: 1.0)
- Quality: Reasoning, length, repetition (weight: 1.0)

Returns:
total_rewards: Weighted sum normalized to [0, 1]
reward_dict: Individual reward components for logging
"""
accuracy_fn = OlympiadBenchAccuracyReward()
format_fn = OlympiadBenchFormatReward()
quality_fn = OlympiadBenchQualityReward()

accuracy = accuracy_fn(trajectories)
format_r = format_fn(trajectories)
quality = quality_fn(trajectories)

# Weights: accuracy most important, format and quality equal
total_rewards = [
(2.0 * a + 1.0 * f + 1.0 * q) / 4.0
for a, f, q in zip(accuracy, format_r, quality)
]

return total_rewards, {
'accuracy': accuracy,
'format': format_r,
'quality': quality,
}


def main():
# Device groups: model and sampler on separate GPUs
device_groups = [
DeviceGroup(name='model', ranks=MODEL_GPUS, device_type='GPU'),
DeviceGroup(name='sampler', ranks=SAMPLER_GPUS, device_type='GPU'),
]

model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)

# LoRA configuration
lora_config = LoraConfig(
target_modules=['all-linear'],
r=16,
lora_alpha=32,
lora_dropout=0.05,
)

# Model setup
if USE_MEGATRON:
from twinkle.model.megatron import MegatronModel
model = MegatronModel(
model_id=MODEL_ID,
device_mesh=model_mesh,
remote_group='model',
)
else:
from transformers import Qwen3_5ForConditionalGeneration
model = TransformersModel(
model_id=MODEL_ID,
model_cls=Qwen3_5ForConditionalGeneration,
device_mesh=model_mesh,
remote_group='model',
)

model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)

if USE_MEGATRON:
model.set_optimizer('default', lr=LEARNING_RATE, adapter_name=ADAPTER_NAME)
model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE, adapter_name=ADAPTER_NAME)
else:
model.set_optimizer('AdamW', lr=LEARNING_RATE)
model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)

model.set_loss('GRPOLoss', epsilon=0.2, adapter_name=ADAPTER_NAME)
model.set_template('Qwen3_5Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME, enable_thinking=False)

# Sampler setup
sampler = vLLMSampler(
model_id=MODEL_ID,
engine_args={
'gpu_memory_utilization': 0.8,
'max_model_len': 32000,
'max_lora_rank': 32,
'enable_lora': True,
'limit_mm_per_prompt': {'image': 9}, # OlympiadBench has up to 9 images
},
device_mesh=sampler_mesh,
remote_group='sampler',
)
sampler.set_template('Qwen3_5Template', model_id=MODEL_ID, enable_thinking=False)

# Checkpoint manager
ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)

# DataLoader
GLOBAL_BATCH_SIZE = BATCH_SIZE
dataloader = DataLoader(
dataset=create_olympiad_dataset,
batch_size=GLOBAL_BATCH_SIZE,
min_batch_size=GLOBAL_BATCH_SIZE,
device_mesh=model_mesh,
)

# RL components
advantage_fn = GRPOAdvantage()
metrics = CompletionRewardMetric()

sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS, num_samples=1, logprobs=1)

optim_step = 0
logger.info(f'Starting OlympiadBench GRPO training on subsets: {SUBSETS}')
logger.info(get_device_placement())

for batch in dataloader:
if optim_step >= MAX_STEPS:
break

metrics.reset()

# Sync weights to sampler
ckpt_manager.sync_weights(merge_and_sync=False)
sampler.reset_prefix_cache()

# Sample multiple completions per prompt
sample_responses = sampler.sample(
batch * NUM_GENERATIONS,
sampling_params,
)

all_input_data: List[Dict[str, Any]] = []
all_old_logps: List[List[float]] = []
all_completion_lengths: List[int] = []

for sample_response in sample_responses:
for sequence in sample_response.sequences:
all_input_data.append(sequence.new_input_feature)
all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs])
all_completion_lengths.append(len(sequence.tokens))

# Compute rewards
total_rewards, reward_dict = compute_rewards(all_input_data)

metrics.accumulate(
completion_lengths=all_completion_lengths,
rewards={
'total': total_rewards,
**{k: v for k, v in reward_dict.items()},
},
)

# Compute advantages
advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()

# Mini-batch training
total_completions = len(all_input_data)
for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
mb_inputs = all_input_data[mb_start:mb_end]
mb_old_logps = all_old_logps[mb_start:mb_end]
mb_advantages = advantages[mb_start:mb_end]

model.forward_backward(
inputs=mb_inputs,
old_logps=mb_old_logps,
advantages=mb_advantages,
micro_batch_size=MICRO_BATCH_SIZE,
adapter_name=ADAPTER_NAME,
)
model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
optim_step += 1

if optim_step >= MAX_STEPS:
break

if optim_step % SAVE_STEPS == 0:
model.save(f'olympiad-grpo-mixed-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)

log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME))
metrics.reset()
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
swanlab.log(log_dict)

logger.info(f'Training completed. optim_steps={optim_step}')
model.save('olympiad-grpo-mixed-final', adapter_name=ADAPTER_NAME)


if __name__ == '__main__':
main()
Loading
Loading