diff --git a/cookbook/client/server/megatron/run.sh b/cookbook/client/server/megatron/run.sh
index c7db36d1..14966ce9 100644
--- a/cookbook/client/server/megatron/run.sh
+++ b/cookbook/client/server/megatron/run.sh
@@ -1,6 +1,341 @@
+#!/bin/bash
+
+# ============================================
+# Twinkle Megatron 服务启动脚本
+# ============================================
+# 功能:启动 Ray 集群(支持多 GPU/CPU 节点)、Prometheus 监控和 Twinkle 服务器
+#
+# 用法:./run.sh [选项]
+#
+# 选项:
+# --head NODE Head 节点 GPU 配置,格式 "设备列表:数量" (默认: 0,1,2,3:4)
+# --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7:4)
+# --cpu-workers N CPU Worker 数量 (默认: 1)
+# --temp-dir DIR Ray 临时目录 (默认: /dashscope/caches/application/ray_logs)
+# --help 显示帮助信息
+#
+# 示例:
+# ./run.sh # 使用默认配置
+# ./run.sh --head "0,1,2,3" --gpu-workers "4,5,6,7" --cpu-workers 1
+# ./run.sh --head "0,1,2,3" --gpu-workers "" --cpu-workers 0
+# ./run.sh --head "" --cpu-workers 4 # 纯 CPU 模式
+# ./run.sh --temp-dir /tmp/my_ray_logs # 自定义临时目录
+# ============================================
+
+set -e # 遇到错误立即退出
+
+# ============================================
+# 配置区(根据你的环境修改)
+# ============================================
+
+# --- Ray 集群配置 ---
+# Head 节点(必须是第一个启动)
+# 格式:"GPU设备列表:GPU数量",如 "0,1,2,3:4"
+# 如果不需要 GPU,设为空字符串 ""
+# 可通过命令行参数 $1 传入
+
+# GPU Worker 节点列表(可以有多个)
+# 格式:用分号分隔的 "GPU设备列表:GPU数量"
+# 示例:"4,5,6,7:4" 或 "4,5,6,7:4;8,9,10,11:4"
+# 可通过命令行参数 $2 传入
+
+# CPU Worker 数量
+# 可通过命令行参数 $3 传入
+
+# --- 网络配置 ---
+RAY_PORT=6379
+RAY_ADDRESS="127.0.0.1:$RAY_PORT"
+
+# --- 路径配置 ---
+DEFAULT_TEMP_DIR="/dashscope/caches/application/ray_logs"
+LOG_FILE="run.log"
+
+# --- Prometheus 监控配置 ---
+PROMETHEUS_BIN="/dashscope/caches/application/monitor/prometheus-3.10.0.linux-amd64/prometheus"
+PROMETHEUS_CONFIG_SUFFIX="session_latest/metrics/prometheus/prometheus.yml"
+
+# --- Ray 日志轮转配置 ---
export RAY_ROTATION_MAX_BYTES=1024
export RAY_ROTATION_BACKUP_COUNT=1
-CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --port=6379 --num-gpus=4 --disable-usage-stats --include-dashboard=false
-CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=127.0.0.1:6379 --num-gpus=4
-CUDA_VISIBLE_DEVICES="" ray start --address=127.0.0.1:6379 --num-gpus=0
-python "$(dirname "$0")/server.py"
+
+# ============================================
+# 参数解析(支持 --key=value 或 --key value 格式)
+# ============================================
+
+# 默认值
+HEAD_NODE="0,1,2,3"
+GPU_WORKERS_INPUT="4,5,6,7"
+CPU_WORKER_COUNT="1"
+TEMP_DIR="$DEFAULT_TEMP_DIR"
+
+# 解析命名参数
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --head)
+ HEAD_NODE="$2"
+ shift 2
+ ;;
+ --head=*)
+ HEAD_NODE="${1#*=}"
+ shift
+ ;;
+ --gpu-workers)
+ GPU_WORKERS_INPUT="$2"
+ shift 2
+ ;;
+ --gpu-workers=*)
+ GPU_WORKERS_INPUT="${1#*=}"
+ shift
+ ;;
+ --cpu-workers)
+ CPU_WORKER_COUNT="$2"
+ shift 2
+ ;;
+ --cpu-workers=*)
+ CPU_WORKER_COUNT="${1#*=}"
+ shift
+ ;;
+ --temp-dir)
+ TEMP_DIR="$2"
+ shift 2
+ ;;
+ --temp-dir=*)
+ TEMP_DIR="${1#*=}"
+ shift
+ ;;
+ --help|-h)
+ echo "用法: ./run.sh [选项]"
+ echo ""
+ echo "选项:"
+ echo " --head NODE Head 节点 GPU 设备列表,逗号分隔 (默认: 0,1,2,3)"
+ echo " --gpu-workers LIST GPU Worker 列表,分号分隔多个节点 (默认: 4,5,6,7)"
+ echo " --cpu-workers N CPU Worker 数量 (默认: 1)"
+ echo " --temp-dir DIR Ray 临时目录"
+ echo " --help, -h 显示帮助信息"
+ echo ""
+ echo "示例:"
+ echo " ./run.sh # 默认配置"
+ echo " ./run.sh --head '0,1,2,3' --gpu-workers '4,5,6,7'"
+ echo " ./run.sh --head '0,1,2,3,4,5,6,7' # 单机 8 卡"
+ echo " ./run.sh --gpu-workers '4,5,6,7;8,9,10,11' # 多 GPU Worker"
+ echo " ./run.sh --cpu-workers 4 --head '' # 纯 CPU 模式"
+ exit 0
+ ;;
+ *)
+ print_error "未知参数: $1"
+ echo "使用 --help 查看帮助"
+ exit 1
+ ;;
+ esac
+done
+
+# 将分号分隔的字符串转为数组
+if [ -z "$GPU_WORKERS_INPUT" ]; then
+ GPU_WORKERS=()
+else
+ IFS=';' read -ra GPU_WORKERS <<< "$GPU_WORKERS_INPUT"
+fi
+
+PROMETHEUS_CONFIG="${TEMP_DIR}/${PROMETHEUS_CONFIG_SUFFIX}"
+
+# ============================================
+# 辅助函数
+# ============================================
+print_info() {
+ echo -e "\033[36m[INFO]\033[0m $1"
+}
+
+print_success() {
+ echo -e "\033[32m[SUCCESS]\033[0m $1"
+}
+
+print_warning() {
+ echo -e "\033[33m[WARNING]\033[0m $1"
+}
+
+print_error() {
+ echo -e "\033[31m[ERROR]\033[0m $1"
+}
+
+print_separator() {
+ echo "============================================"
+}
+
+print_header() {
+ echo ""
+ print_separator
+ echo -e "\033[1;34m $1 \033[0m"
+ print_separator
+}
+
+# 解析节点配置 "devices" -> 返回 devices 和自动计算 _gpu_count
+# 示例: "0,1,2,3" -> devices="0,1,2,3", count=4
+parse_node_config() {
+ local config="$1"
+ if [ -z "$config" ]; then
+ _gpu_devices=""
+ _gpu_count=0
+ return
+ fi
+ _gpu_devices="$config"
+ # 通过逗号数量+1计算 GPU 数量
+ local comma_count=$(echo "$config" | tr -cd ',' | wc -c)
+ _gpu_count=$((comma_count + 1))
+}
+
+# ============================================
+# 开始启动
+# ============================================
+print_header "Twinkle Megatron 服务启动脚本"
+
+# 打印配置信息
+print_info "集群配置:"
+echo ""
+
+# 解析并显示 Head 节点
+parse_node_config "$HEAD_NODE"
+if [ -n "$_gpu_devices" ]; then
+ echo " [Head 节点]"
+ echo " - GPU 设备: $_gpu_devices"
+ echo " - GPU 数量: $_gpu_count"
+else
+ echo " [Head 节点] CPU only"
+fi
+
+# 显示 GPU Worker 节点
+if [ ${#GPU_WORKERS[@]} -gt 0 ]; then
+ echo ""
+ echo " [GPU Worker 节点] 共 ${#GPU_WORKERS[@]} 个"
+ for i in "${!GPU_WORKERS[@]}"; do
+ parse_node_config "${GPU_WORKERS[$i]}"
+ echo " Worker $((i+1)): GPU=$_gpu_devices, Count=$_gpu_count"
+ done
+fi
+
+# 显示 CPU Worker
+if [ "$CPU_WORKER_COUNT" -gt 0 ]; then
+ echo ""
+ echo " [CPU Worker 节点] $CPU_WORKER_COUNT 个"
+fi
+
+echo ""
+print_info "运行参数:"
+echo " - Ray 地址: $RAY_ADDRESS"
+echo " - 临时目录: $TEMP_DIR"
+echo " - 日志文件: $LOG_FILE"
+echo ""
+
+# 检查临时目录
+if [ ! -d "$TEMP_DIR" ]; then
+ print_info "创建临时目录: $TEMP_DIR"
+ mkdir -p "$TEMP_DIR"
+fi
+
+# ============================================
+# 停止已有 Ray 集群和 Prometheus
+# ============================================
+print_header "清理环境"
+print_info "停止已有的 Ray 集群..."
+ray stop --force 2>/dev/null || true
+
+print_info "停止已有的 Prometheus..."
+pkill prometheus 2>/dev/null || true
+
+# ============================================
+# 启动 Ray Head 节点
+# ============================================
+print_header "启动 Ray 集群"
+
+parse_node_config "$HEAD_NODE"
+if [ -n "$_gpu_devices" ]; then
+ print_info "启动 Head 节点 (GPU: $_gpu_devices)..."
+ CUDA_VISIBLE_DEVICES="$_gpu_devices" ray start --head \
+ --port=$RAY_PORT \
+ --num-gpus=$_gpu_count \
+ --disable-usage-stats \
+ --include-dashboard=true \
+ --temp-dir="$TEMP_DIR"
+else
+ print_info "启动 Head 节点 (CPU only)..."
+ CUDA_VISIBLE_DEVICES="" ray start --head \
+ --port=$RAY_PORT \
+ --num-gpus=0 \
+ --disable-usage-stats \
+ --include-dashboard=true \
+ --temp-dir="$TEMP_DIR"
+fi
+print_success "Head 节点启动成功!"
+
+# ============================================
+# 启动 GPU Worker 节点
+# ============================================
+for i in "${!GPU_WORKERS[@]}"; do
+ parse_node_config "${GPU_WORKERS[$i]}"
+ print_info "启动 GPU Worker $((i+1)) (GPU: $_gpu_devices)..."
+ CUDA_VISIBLE_DEVICES="$_gpu_devices" ray start \
+ --address=$RAY_ADDRESS \
+ --num-gpus=$_gpu_count
+ print_success "GPU Worker $((i+1)) 启动成功!"
+done
+
+# ============================================
+# 启动 CPU Worker 节点
+# ============================================
+if [ "$CPU_WORKER_COUNT" -gt 0 ]; then
+ print_info "启动 $CPU_WORKER_COUNT 个 CPU Worker..."
+ for ((i=1; i<=CPU_WORKER_COUNT; i++)); do
+ CUDA_VISIBLE_DEVICES="" ray start \
+ --address=$RAY_ADDRESS \
+ --num-gpus=0
+ done
+ print_success "CPU Worker 启动成功!"
+fi
+
+# ============================================
+# 显示集群状态
+# ============================================
+echo ""
+print_info "集群状态:"
+ray status 2>/dev/null || true
+
+# ============================================
+# 启动 Prometheus 监控(可选)
+# ============================================
+print_header "启动监控(可选)"
+
+PROMETHEUS_PID=""
+if [ -f "$PROMETHEUS_BIN" ]; then
+ print_info "检测到 Prometheus,正在启动监控服务..."
+
+ # 等待 Ray 生成 Prometheus 配置
+ sleep 2
+
+ if [ -f "$PROMETHEUS_CONFIG" ]; then
+ nohup "$PROMETHEUS_BIN" --config.file="$PROMETHEUS_CONFIG" > prometheus.log 2>&1 &
+ PROMETHEUS_PID=$!
+ print_success "Prometheus 监控已启动 (PID: $PROMETHEUS_PID)"
+ echo " - 监控日志: prometheus.log"
+ echo " - 配置文件: $PROMETHEUS_CONFIG"
+ else
+ print_warning "Prometheus 配置文件不存在,跳过监控启动"
+ echo " - 预期路径: $PROMETHEUS_CONFIG"
+ fi
+else
+ print_warning "未检测到 Prometheus,跳过监控启动"
+ echo " - 预期路径: $PROMETHEUS_BIN"
+fi
+
+# ============================================
+# 启动 Twinkle 服务器
+# ============================================
+print_header "启动 Twinkle 服务器"
+
+print_info "日志输出到: $LOG_FILE"
+echo ""
+
+# 启动服务器并实时显示日志
+nohup python server.py > "$LOG_FILE" 2>&1 &
+SERVER_PID=$!
+
+# 实时显示日志
+tail -f "$LOG_FILE"
diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml
index 0f66dd57..6d584455 100644
--- a/cookbook/client/server/megatron/server_config.yaml
+++ b/cookbook/client/server/megatron/server_config.yaml
@@ -87,7 +87,7 @@ applications:
nproc_per_node: 4 # Number of GPU processes per node
device_group:
name: model
- ranks: 4 # GPU rank indices
+ ranks: 4
device_type: cuda
device_mesh:
device_type: cuda
diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml
index e191b981..5dd8a696 100644
--- a/cookbook/client/server/megatron/server_config_4b.yaml
+++ b/cookbook/client/server/megatron/server_config_4b.yaml
@@ -7,7 +7,7 @@ proxy_location: EveryNode
# HTTP listener settings
http_options:
host: 0.0.0.0 # Listen on all network interfaces
- port: 8000 # Port number for the server
+ port: 9000 # Port number for the server
# Applications: each entry defines a service component deployed on the server
applications:
@@ -39,25 +39,24 @@ applications:
import_path: model
args:
use_megatron: true
- model_cls: Qwen3_5ForConditionalGeneration
model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier
max_length: 10240
nproc_per_node: 2 # Number of GPU processes per node
device_group:
name: model
- ranks: 2 # GPU rank indices
+ ranks: 2
device_type: cuda
device_mesh:
device_type: cuda
dp_size: 2
queue_config:
rps_limit: 100 # Max requests per second
- tps_limit: 10000 # Max tokens per second for a single user
- max_input_tokens: 10000 # Maximum input tokens per request
+ tps_limit: 100000 # Max tokens per second for a single user
+ max_input_tokens: 60000 # Maximum input tokens per request
adapter_config:
adapter_timeout: 30 # Seconds before idle adapter unload
adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours)
- max_loras: 1 # Maximum number of LoRA adapters per model
+ max_loras: 5 # Maximum number of LoRA adapters per model
deployments:
- name: ModelManagement
autoscaling_config:
@@ -80,8 +79,8 @@ applications:
nproc_per_node: 2 # Number of GPU processes per node
sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
engine_args: # vLLM engine-specific settings
- max_model_len: 4096 # Maximum sequence length the engine supports
- gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
+ max_model_len: 16000 # Maximum sequence length the engine supports
+ gpu_memory_utilization: 0.7 # Fraction of GPU memory to use (0.0-1.0)
enable_lora: true # Allow loading LoRA adapters during inference
logprobs_mode: processed_logprobs # Logprobs mode for sampling results
device_group: # Logical device group for the sampler
diff --git a/cookbook/client/tinker/self_host/dpo.py b/cookbook/client/tinker/self_host/dpo.py
new file mode 100644
index 00000000..d55e9ce3
--- /dev/null
+++ b/cookbook/client/tinker/self_host/dpo.py
@@ -0,0 +1,207 @@
+# Tinker-Compatible Client - DPO (Direct Preference Optimization) Training with LoRA
+#
+# This script demonstrates how to fine-tune a language model using DPO
+# through the Tinker-compatible client API.
+#
+# Training flow per step:
+# 1. forward_backward with 'cross_entropy' + disable_lora=True
+# → base-model forward pass; LoRA weights are NOT in the computation graph
+# so backward accumulates zero LoRA gradients (safe to discard).
+# 2. Attach returned per-token ref logps to each datum's loss_fn_inputs.
+# 3. forward_backward with 'importance_sampling'
+# → server detects ref_logps and switches to DPOLoss + DPOMetric.
+# 4. optim_step → update LoRA, DPO metrics returned automatically.
+#
+# The server must be running first (see server.py and server_config.yaml).
+
+import os
+import numpy as np
+import torch
+from tqdm import tqdm
+from typing import Any, Dict, List
+
+import swanlab
+
+from tinker import types
+from twinkle import init_tinker_client, get_logger
+from twinkle.dataset import Dataset, DatasetMeta, LazyDataset
+from twinkle.dataloader import DataLoader
+from twinkle.preprocessor import EmojiDPOProcessor
+from twinkle.server.common import input_feature_to_datum
+
+logger = get_logger()
+
+# Initialize the Tinker client before importing ServiceClient
+init_tinker_client()
+
+from tinker import ServiceClient # noqa: E402 (must follow init_tinker_client)
+
+# ---------------------------------------------------------------------------
+# Configuration
+# ---------------------------------------------------------------------------
+base_model = 'Qwen/Qwen3.5-4B'
+base_url = 'http://localhost:8000'
+api_key = 'EMPTY_API_KEY'
+dataset_id = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'
+
+batch_size = 4
+learning_rate = 1e-4
+dpo_beta = 0.1
+sft_weight = 1.0
+max_length = 2048
+lora_rank = 8
+system_prompt = 'You are a helpful assistant.'
+use_swanlab = True
+
+
+# ---------------------------------------------------------------------------
+# Dataset helpers (reused from twinkle/self_host/dpo.py)
+# ---------------------------------------------------------------------------
+
+def create_dpo_dataset():
+ """Create DPO dataset with positive/negative format."""
+ dataset = LazyDataset(DatasetMeta(dataset_id, data_slice=range(5000)))
+ dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=max_length)
+ dataset.map(
+ EmojiDPOProcessor,
+ init_args={'system': system_prompt},
+ )
+ # EmojiDPOProcessor returns {'positive': InputFeature, 'negative': InputFeature, ...}
+ # encode handles this format automatically
+ dataset.encode()
+ return dataset
+
+
+def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Reorganise batch into DP-safe interleaved format [pos_1, neg_1, pos_2, neg_2, ...].
+
+ Args:
+ batch: List of rows, each with 'positive' and 'negative' InputFeatures.
+
+ Returns:
+ Interleaved list so each DP worker slice contains complete pairs.
+ """
+ result = []
+ for row in batch:
+ base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}
+ pos_sample = {**base_fields, **row['positive']}
+ neg_sample = {**base_fields, **row['negative']}
+ result.append(pos_sample)
+ result.append(neg_sample)
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Training
+# ---------------------------------------------------------------------------
+
+def train():
+ # Step 0: Initialize SwanLab if enabled
+ if use_swanlab:
+ swanlab.login(api_key=os.environ['SWANLAB_API_KEY'])
+ swanlab.init(
+ project='twinkle-dpo',
+ experiment_name='dpo-lora-training',
+ config={
+ 'base_model': base_model,
+ 'batch_size': batch_size,
+ 'learning_rate': learning_rate,
+ 'dpo_beta': dpo_beta,
+ 'sft_weight': sft_weight,
+ 'max_length': max_length,
+ 'lora_rank': lora_rank,
+ },
+ )
+ logger.info('SwanLab initialized')
+
+ # Step 1: Prepare dataset & dataloader
+ logger.info('Loading DPO dataset...')
+ dataset = create_dpo_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
+ logger.info(f'Dataset ready: {len(dataloader)} steps per epoch')
+
+ # Step 2: Connect to server and create LoRA training client
+ service_client = ServiceClient(base_url=base_url, api_key=api_key)
+ training_client = service_client.create_lora_training_client(
+ base_model=base_model,
+ rank=lora_rank,
+ )
+ logger.info(f'LoRA training client created (rank={lora_rank})')
+ logger.info(f'Starting DPO training: beta={dpo_beta}, lr={learning_rate}')
+
+ # Step 3: Training loop
+ for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
+ # Normalise numpy / torch tensors to plain Python lists for serialisation
+ for row in batch:
+ for key in list(row.keys()):
+ if isinstance(row[key], np.ndarray):
+ row[key] = row[key].tolist()
+ elif isinstance(row[key], torch.Tensor):
+ row[key] = row[key].cpu().numpy().tolist()
+
+ # Build interleaved [pos, neg, pos, neg, ...] batch
+ dpo_batch = prepare_dpo_batch(batch)
+
+ # Convert each InputFeature dict to a Tinker Datum
+ input_datums = [input_feature_to_datum(row) for row in dpo_batch]
+
+ # -----------------------------------------------------------------
+ # A. Reference forward pass (base model, disable_lora=True)
+ # LoRA weights are outside the computation graph → backward
+ # produces zero LoRA gradients, so this call is safe.
+ # -----------------------------------------------------------------
+ ref_result = training_client.forward(
+ input_datums,
+ 'cross_entropy',
+ loss_fn_config={'disable_lora': True},
+ ).result()
+
+ # -----------------------------------------------------------------
+ # B. Attach per-token ref logps to each datum's loss_fn_inputs
+ # -----------------------------------------------------------------
+ for datum, ref_out in zip(input_datums, ref_result.loss_fn_outputs):
+ ref_logprobs_np = np.array(ref_out['logprobs'].tolist(), dtype=np.float32)
+ datum.loss_fn_inputs['ref_logps'] = types.TensorData.from_numpy(ref_logprobs_np)
+
+ # -----------------------------------------------------------------
+ # C. DPO forward_backward
+ # Server detects ref_logps → sets DPOLoss + DPOMetric automatically.
+ # Optional DPO hyper-params can be forwarded via loss_fn_config.
+ # (e.g. beta, sft_weight, not support dpo_loss_type for tinker)
+ # -----------------------------------------------------------------
+ fwdbwd_result = training_client.forward_backward(
+ input_datums,
+ 'importance_sampling',
+ loss_fn_config={
+ 'dpo_beta': dpo_beta,
+ 'dpo_sft_weight': sft_weight,
+ },
+ ).result()
+
+ # -----------------------------------------------------------------
+ # D. Optimizer step — DPOMetric is calculated automatically on the
+ # server and returned inside optim_result.metrics.
+ # -----------------------------------------------------------------
+ optim_result = training_client.optim_step(
+ types.AdamParams(learning_rate=learning_rate)
+ ).result()
+
+ logger.info(f'[Step {step}] metrics={optim_result.metrics}')
+
+ # Log metrics to SwanLab
+ if use_swanlab and optim_result.metrics:
+ swanlab.log(optim_result.metrics, step=step)
+
+ # Step 4: Save checkpoint
+ save_result = training_client.save_state('dpo-lora-final').result()
+ logger.info(f'Saved checkpoint: {save_result.path}')
+
+ # Step 5: (Optional) Upload to ModelScope Hub
+ # YOUR_USER_NAME = 'your_username'
+ # hub_model_id = f'{YOUR_USER_NAME}/twinkle-tinker-dpo-lora'
+ # training_client.publish_checkpoint_from_tinker_path(save_result.path).result()
+ # logger.info(f'Uploaded checkpoint to hub: {hub_model_id}')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/client/tinker/self_host/sample.py b/cookbook/client/tinker/self_host/sample.py
index a1f21f00..5a84c75b 100644
--- a/cookbook/client/tinker/self_host/sample.py
+++ b/cookbook/client/tinker/self_host/sample.py
@@ -8,7 +8,7 @@
from tinker import types
from twinkle.data_format import Message, Trajectory
-from twinkle.template import Template
+from twinkle.template import Template, Qwen3_5Template
from twinkle import init_tinker_client
# Step 1: Initialize Tinker client
@@ -20,21 +20,21 @@
base_model = 'Qwen/Qwen3.5-4B'
service_client = ServiceClient(
base_url='http://localhost:8000',
- api_key='EMPTY-TOKEN'
+ api_key='EMPTY_TOKEN'
)
# Step 3: Create a sampling client by loading weights from a saved checkpoint.
# The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint.
# The server will load the base model and apply the LoRA adapter weights.
sampling_client = service_client.create_sampling_client(
- model_path='twinkle://xxx-Qwen_Qwen3.5-4B-xxx/weights/twinkle-lora-1',
+ model_path='twinkle://20260409_202355-Qwen_Qwen3_5-4B-d0360fdb/sampler_weights/20260409_202639',
base_model=base_model
)
# Step 4: Load the tokenizer locally to encode the prompt and decode the results
print(f'Using model {base_model}')
-template = Template(model_id=f'ms://{base_model}')
+template = Qwen3_5Template(model_id=f'ms://{base_model}')
trajectory = Trajectory(
messages=[
@@ -43,7 +43,7 @@
]
)
-input_feature = template.encode(trajectory, add_generation_prompt=True)
+input_feature = template.batch_encode([trajectory], add_generation_prompt=True)[0]
input_ids = input_feature['input_ids'].tolist()
diff --git a/cookbook/client/tinker/self_host/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py
index f6fe8b45..f077c669 100644
--- a/cookbook/client/tinker/self_host/short_math_grpo.py
+++ b/cookbook/client/tinker/self_host/short_math_grpo.py
@@ -1,12 +1,12 @@
-# Tinker-Compatible Client - Math GRPO Training Example
+# Tinker-Compatible Client - GSM8K GRPO Training Example
#
-# This script demonstrates Math problem training using the
+# This script demonstrates GSM8K math problem training using the
# Tinker-compatible client API with save_weights_for_sampler for weight sync.
# Instead of calling sync_weights directly, it periodically saves weights and
# creates a sampling client for generation.
#
# Flow:
-# 1. Prepare Math dataset (client-side)
+# 1. Prepare GSM8K dataset (client-side)
# 2. Initialize Tinker-compatible training & sampling clients
# 3. Training loop:
# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client
@@ -22,15 +22,15 @@
import os
import re
from tinker import types
-from typing import List, Tuple
+from typing import List, Tuple, Dict, Any
from twinkle import init_tinker_client
from twinkle import get_logger
from twinkle.advantage import GRPOAdvantage
-from twinkle.data_format import Message, Trajectory
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.preprocessor import Preprocessor
+from twinkle.preprocessor.llm import GSM8KProcessor
+from twinkle.reward import GSM8KAccuracyReward
from twinkle.reward.base import Reward
from twinkle.metric import CompletionRewardMetric
from twinkle.template import Template
@@ -39,173 +39,84 @@
# ========== Configuration ==========
BASE_MODEL = 'Qwen/Qwen3.5-4B'
-NUM_GENERATIONS = 8
+NUM_GENERATIONS = 4
MAX_NEW_TOKENS = 4096
-LEARNING_RATE = 1e-4
+LEARNING_RATE = 1e-5
MAX_STEPS = 1000
-BATCH_SIZE = 2
+BATCH_SIZE = 4
TEMPERATURE = 1.0
SYNC_INTERVAL = 1 # Save weights for sampler every N steps
-LORA_RANK = 8
+LORA_RANK = 16
DATA_NUM = 2000 # Number of Math samples to use
-SYSTEM_PROMPT = ('You are a math assistant that values brevity. '
- 'Solve problems with minimal but correct reasoning.\n\n'
- 'Rules:\n'
- '1. Use tags for reasoning\n'
- '2. Final answer after ####\n\n'
- 'Example:\nKey step1 -> Ket step 2 -> conclusion\n#### 42')
+SYSTEM_PROMPT = ('You are a helpful math assistant. Solve the problem with minimal but correct reasoning '
+ 'and put your final answer within \\boxed{}.')
+# ========== Reward Functions ==========
+class GSM8KBrevityReward(Reward):
+ """Brevity reward: rewards shorter completions that contain a valid answer.
-class MathPreprocessor(Preprocessor):
-
- def __call__(self, rows):
- rows = self.map_col_to_row(rows)
- rows = [self.preprocess(row) for row in rows]
- rows = self.map_row_to_col(rows)
- return rows
-
- def preprocess(self, sample):
- if sample['level'] not in ('Level 4', 'Level 5'):
- return Trajectory(messages=[], user_data=[])
-
- def get_boxed_answer(text):
- match = re.search(r'\\boxed{([^}]*)}', text)
- return match.group(1) if match else None
-
- ground_truth = get_boxed_answer(sample['solution'])
- if ground_truth is None:
- return Trajectory(messages=[], user_data=[])
- problem = sample['problem']
- return Trajectory(
- messages=[
- Message(role='system', content=SYSTEM_PROMPT),
- Message(role='user', content=problem),
- ],
- user_data=[('ground_truth', ground_truth)],
- )
-
-
-# ========== Math Reward Functions ==========
-class MathAccuracyReward(Reward):
- """Accuracy reward for Math: checks if the model's answer matches ground truth.
-
- Extracts the last '#### ' from model output and compares with ground truth.
- Returns 1.0 for correct, 0.0 for incorrect.
- """
-
- @staticmethod
- def extract_answer(completion: str) -> str:
- """Extract the last #### answer from model completion."""
- # Only check last 500 chars for efficiency
- text = completion[-500:] if len(completion) > 500 else completion
- matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
- if matches:
- return matches[-1].replace(',', '').replace(' ', '').strip()
- return ''
-
- def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
- rewards = []
- for trajectory in trajectories:
- messages = trajectory.get('messages', [])
- # Get model completion (last assistant message)
- completion = ''
- for msg in reversed(messages):
- if msg.get('role') == 'assistant':
- completion = msg.get('content', '')
- break
-
- # Get ground truth from user_data
- gt = ''
- user_data = trajectory.get('user_data', [])
- if isinstance(user_data, list):
- for item in user_data:
- if isinstance(item, (list, tuple)) and len(item) == 2:
- if item[0] == 'ground_truth':
- gt = str(item[1])
- break
-
- predicted = self.extract_answer(completion)
-
- # Numeric comparison
- correct = False
- if predicted and gt:
- try:
- correct = abs(float(predicted) - float(gt)) < 1e-5
- except (ValueError, OverflowError):
- correct = predicted == gt
-
- rewards.append(1.0 if correct else 0.0)
- return rewards
-
-
-class MathFormatReward(Reward):
- """Format reward: checks format and rewards shorter completions.
-
- Returns higher score for shorter completions (1.0 at length 100 or less).
- Returns 0.0 if format is incorrect.
+ Returns 0.0 if no valid answer format (\\boxed{} or ####).
+ Otherwise returns higher score for shorter completions (1.0 at <=200 chars).
"""
- def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
+ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
rewards = []
- for trajectory in trajectories:
- messages = trajectory.get('messages', [])
+ for traj in trajectories:
+ messages = traj.get('messages', [])
completion = ''
for msg in reversed(messages):
if msg.get('role') == 'assistant':
completion = msg.get('content', '')
break
- has_think = bool(re.search(r'.*?', completion, re.DOTALL))
- has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
+ has_answer = bool(
+ re.search(r'\\boxed\{[^}]+\}', completion)
+ or re.search(r'####\s*[\-\d,\.]+', completion)
+ )
- if not (has_think and has_answer):
+ if not has_answer:
rewards.append(0.0)
else:
length = len(completion)
- if length <= 100:
+ if length <= 200:
rewards.append(1.0)
else:
- reward = max(0.0, 1.0 - (length - 100) / 2000)
- rewards.append(reward)
-
+ rewards.append(max(0.0, 1.0 - (length - 200) / 3000))
return rewards
-def create_math_dataset():
- """Create Math dataset."""
- meta = DatasetMeta(
- 'ms://modelscope/competition_math',
- subset_name='default',
- split='train',
- data_slice=range(DATA_NUM),
- )
- dataset = Dataset(meta)
- dataset.set_template('Qwen3_5Template', model_id=BASE_MODEL, max_length=4096, truncation_strategy='delete')
- dataset.map(MathPreprocessor())
- dataset.filter(lambda row: bool(row['messages']))
+# ========== Dataset ==========
+def create_gsm8k_dataset():
+ """Create GSM8K dataset."""
+ dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train', data_slice=range(DATA_NUM)))
+ dataset.set_template('Qwen3_5Template', model_id=f'ms://{BASE_MODEL}', max_length=4096,
+ truncation_strategy='delete', enable_thinking=False)
+ dataset.map(GSM8KProcessor(system=SYSTEM_PROMPT))
dataset.encode(add_generation_prompt=True)
return dataset
-def compute_rewards(trajectories: List[Trajectory], ) -> Tuple[List[float], List[float], List[float]]:
- """Compute accuracy and format rewards for Math."""
- accuracy_reward_fn = MathAccuracyReward()
- format_reward_fn = MathFormatReward()
+def compute_rewards(
+ trajectories: List[Dict[str, Any]],
+) -> Tuple[List[float], List[float], List[float]]:
+ """Compute accuracy and brevity rewards for GSM8K."""
+ accuracy_reward_fn = GSM8KAccuracyReward()
+ brevity_reward_fn = GSM8KBrevityReward()
- accuracy_rewards = accuracy_reward_fn(trajectories, [])
- format_rewards = format_reward_fn(trajectories, [])
- total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
- return total_rewards, format_rewards, accuracy_rewards
+ accuracy_rewards = accuracy_reward_fn(trajectories)
+ brevity_rewards = brevity_reward_fn(trajectories)
+ total_rewards = [a + b for a, b in zip(accuracy_rewards, brevity_rewards)]
+ return total_rewards, brevity_rewards, accuracy_rewards
def main():
- logger.info('Starting Math GRPO training...')
+ logger.info('Starting GSM8K GRPO training...')
# Step 1: Prepare dataset and dataloader (client-side)
- dataset = create_math_dataset()
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+ dataset = create_gsm8k_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0)
template = Template(model_id=f'ms://{BASE_MODEL}')
logger.info('Dataset and template initialized')
@@ -254,7 +165,7 @@ def main():
if step % SYNC_INTERVAL == 0:
logger.info(f'Step {step}: Saving weights for sampler...')
- sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'Math-step-{step}'))
+ sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'GSM8K-step-{step}'))
logger.info(f'Step {step}: Sampling client ready')
if sampling_client is None:
@@ -303,8 +214,8 @@ def main():
},
{
'role': 'user',
- 'content': 'Math problem'
- }, # Placeholder
+ 'content': 'Math problem' # Placeholder
+ },
{
'role': 'assistant',
'content': decoded_text
@@ -317,14 +228,12 @@ def main():
completion_lengths.append(len(seq.tokens))
# ========== 4. Compute rewards ==========
- total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories)
+ total_rewards, brevity_rewards, accuracy_rewards = compute_rewards(trajectories)
metrics.accumulate(
- None,
- None,
completion_lengths=completion_lengths,
rewards={
'total': total_rewards,
- 'format': format_rewards,
+ 'brevity': brevity_rewards,
'accuracy': accuracy_rewards,
})
@@ -407,7 +316,7 @@ def main():
step += 1
# Save final checkpoint
- save_future = training_client.save_state('Math-grpo-final')
+ save_future = training_client.save_state('gsm8k-grpo-final')
save_result = save_future.result()
logger.info(f'Saved final checkpoint to {save_result.path}')
diff --git a/cookbook/client/twinkle/self_host/dpo.py b/cookbook/client/twinkle/self_host/dpo.py
new file mode 100644
index 00000000..acec9a09
--- /dev/null
+++ b/cookbook/client/twinkle/self_host/dpo.py
@@ -0,0 +1,207 @@
+# Twinkle Client - DPO (Direct Preference Optimization) Training with LoRA
+#
+# This script demonstrates how to fine-tune a language model using DPO
+# through the Twinkle client-server architecture.
+# The server must be running first (see server.py and server_config.yaml).
+
+# Step 1: Load environment variables from a .env file (e.g., API tokens)
+import dotenv
+import os
+from typing import Any, Dict, List
+
+dotenv.load_dotenv('.env')
+import numpy as np
+import torch
+from peft import LoraConfig
+
+from twinkle import get_logger
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle_client import init_twinkle_client
+from twinkle.dataloader import DataLoader
+from twinkle_client.model import MultiLoraTransformersModel
+from twinkle.loss import DPOLoss
+from twinkle.metric import DPOMetric
+from twinkle.preprocessor import EmojiDPOProcessor
+from twinkle.processor import InputProcessor
+
+logger = get_logger()
+
+# Configuration (direct values, not from env)
+base_model = 'Qwen/Qwen3.5-4B'
+base_url = 'http://localhost:8000'
+dataset_id = 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji'
+
+batch_size = 4
+gradient_accumulation_steps = 2
+learning_rate = 1e-4
+dpo_beta = 0.1
+sft_weight = 1.0
+loss_type = 'sigmoid'
+max_length = 2048
+adapter_name = 'default'
+system_prompt = 'You are a helpful assistant.'
+
+# Step 2: Initialize the Twinkle client to communicate with the remote server.
+# - base_url: the address of the running Twinkle server
+# - api_key: authentication token (loaded from environment variable)
+client = init_twinkle_client(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN'))
+
+# Step 3: Query the server for existing training runs and their checkpoints.
+# This is useful for resuming a previous training session.
+runs = client.list_training_runs()
+
+resume_path = None
+for run in runs:
+ logger.info(run.model_dump_json(indent=2))
+ # List all saved checkpoints for this training run
+ checkpoints = client.list_checkpoints(run.training_run_id)
+
+ for checkpoint in checkpoints:
+ logger.info(checkpoint.model_dump_json(indent=2))
+ # Uncomment the line below to resume from a specific checkpoint:
+ # resume_path = checkpoint.twinkle_path
+
+
+def create_dpo_dataset():
+ """Create DPO dataset with positive/negative format."""
+ dataset = Dataset(DatasetMeta(dataset_id, data_slice=range(100)))
+ dataset.set_template('Qwen3_5Template', model_id=f'ms://{base_model}', max_length=max_length)
+ dataset.map(
+ EmojiDPOProcessor,
+ init_args={
+ 'system': system_prompt,
+ }
+ )
+ # DPO preprocessor returns {'positive': [...], 'negative': [...]}
+ # batch_encode handles this format automatically
+ dataset.encode()
+ return dataset
+
+
+def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Prepare DPO batch: reorganize batch for training with DP-safe interleaving.
+
+ Args:
+ batch: List of rows, each with 'positive' and 'negative' InputFeatures
+ and other fields (question, etc.)
+
+ Returns:
+ List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP
+ worker gets complete positive/negative pairs after slicing.
+ Each item contains all original fields plus the InputFeature fields.
+ """
+ result = []
+
+ for row in batch:
+ # Get base fields (excluding positive/negative)
+ base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}
+
+ # Positive sample: merge base fields with positive InputFeature
+ pos_sample = {**base_fields, **row['positive']}
+ # Negative sample: merge base fields with negative InputFeature
+ neg_sample = {**base_fields, **row['negative']}
+
+ # Interleave: [pos, neg] per pair for DP-safe slicing
+ result.append(pos_sample)
+ result.append(neg_sample)
+
+ return result
+
+
+def train():
+ # Step 4: Prepare the dataset
+
+ # Load the DPO dataset from ModelScope
+ dataset = create_dpo_dataset()
+
+ # Wrap the dataset into a DataLoader that yields batches
+ dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
+
+ # Step 5: Configure the model
+
+ # Create a multi-LoRA Transformers model pointing to the base model on ModelScope
+ model = MultiLoraTransformersModel(model_id=f'ms://{base_model}')
+
+ # Define LoRA configuration: apply low-rank adapters to all linear layers
+ lora_config = LoraConfig(
+ target_modules='all-linear',
+ r=8,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ )
+
+ # Attach the LoRA adapter named 'default' to the model.
+ # gradient_accumulation_steps means gradients are accumulated over micro-batches
+ # before an optimizer step, effectively increasing the batch size.
+ model.add_adapter_to_model(adapter_name, lora_config, gradient_accumulation_steps=gradient_accumulation_steps)
+
+ # Set the same chat template used during data preprocessing
+ model.set_template('Qwen3_5Template')
+
+ # Set the input processor (pads sequences on the right side)
+ model.set_processor('InputProcessor', padding_side='right')
+
+ # Use DPO loss for preference optimization
+ model.set_loss('DPOLoss', beta=dpo_beta, loss_type=loss_type, reference_free=False, sft_weight=sft_weight)
+
+ # Add DPO metric for logging
+ model.add_metric('DPOMetric', beta=dpo_beta)
+
+ # Use Adam optimizer with a learning rate of 1e-4
+ model.set_optimizer('Adam', lr=learning_rate)
+
+ # Step 6: Optionally resume from a previous checkpoint
+ if resume_path:
+ logger.info(f'Resuming training from {resume_path}')
+ model.load(resume_path, load_optimizer=True)
+
+ # Step 7: Run the training loop
+ logger.info(model.get_train_configs().model_dump())
+
+ optim_step = 0
+ max_steps = len(dataloader)
+ logger.info(f'Starting LoRA DPO training: loss_type={loss_type}, beta={dpo_beta}, lr={learning_rate}')
+ logger.info(f'Using base model (disable_lora=True) as reference model')
+
+ for batch in dataloader:
+ # batch is List[Dict] with 'positive' and 'negative' keys
+ # Convert numpy/torch tensors to lists for serialization
+ for row in batch:
+ for key in row:
+ if isinstance(row[key], np.ndarray):
+ row[key] = row[key].tolist()
+ elif isinstance(row[key], torch.Tensor):
+ row[key] = row[key].cpu().numpy().tolist()
+
+ dpo_batch = prepare_dpo_batch(batch)
+
+ # Get reference outputs using base model (without LoRA adapter)
+ # disable_lora=True tells the model to skip LoRA and use base weights
+ ref_outputs = model.forward_only(inputs=dpo_batch, disable_lora=True)
+ model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs.result)
+ model.clip_grad_and_step()
+
+ optim_step += 1
+
+ # Logging
+ if optim_step % gradient_accumulation_steps == 0:
+ metrics = model.calculate_metric(is_training=True)
+ logger.info(f'[Step {optim_step // gradient_accumulation_steps}/{max_steps}] {metrics}')
+
+ # Step 8: Save the trained checkpoint
+ twinkle_path = model.save(name='dpo-lora-final', save_optimizer=True)
+ logger.info(f'Saved checkpoint: {twinkle_path}')
+
+ # Step 9: Upload the checkpoint to ModelScope Hub
+ # YOUR_USER_NAME = "your_username"
+ # hub_model_id = f'{YOUR_USER_NAME}/twinkle-dpo-lora'
+ # model.upload_to_hub(
+ # checkpoint_dir=twinkle_path,
+ # hub_model_id=hub_model_id,
+ # async_upload=False
+ # )
+ # logger.info(f"Uploaded checkpoint to hub: {hub_model_id}")
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/rl/dpo_multi_lora.py b/cookbook/rl/dpo_multi_lora.py
new file mode 100644
index 00000000..7c09bf61
--- /dev/null
+++ b/cookbook/rl/dpo_multi_lora.py
@@ -0,0 +1,216 @@
+"""DPO (Direct Preference Optimization) Training with MultiLoRA (Megatron Backend).
+
+MultiLoRA-based DPO training: uses the base model (without LoRA adapter) as reference
+model by calling forward_only with disable_lora=True. This eliminates the need for
+a separate reference model GPU group.
+
+Uses Megatron backend with MultiLoRAMegatronModel for efficient multi-tenant LoRA training.
+
+Pipeline:
+ 1. Load preference dataset with chosen/rejected pairs.
+ 2. Encode positive and negative separately.
+ 3. Compute reference model log probabilities using base model (disable_lora=True).
+ 4. Train policy model (with LoRA adapter) using DPO loss.
+
+Architecture (Ray - Single Group):
+ ┌─────────────────────────────────────────────────────────────────┐
+ │ Driver (CPU) │
+ │ dataloader ──► batched preference pairs │
+ │ policy_model.forward_only(disable_lora=True) ──► ref logps │
+ │ policy_model.forward_backward() ──► DPO loss + gradient │
+ └─────────────────────────────────────────────────────────────────┘
+ │
+ PolicyModel (with LoRA adapter)
+ - forward_only(disable_lora=True) → base model inference (reference)
+ - forward_backward() → LoRA adapter training (policy)
+
+DPO data format (after preprocessing):
+ - positive: List[Trajectory] - chosen responses
+ - negative: List[Trajectory] - rejected responses
+
+Environment variables (all optional):
+ MODEL_ID – (default: ms://Qwen/Qwen3.5-4B)
+ DATASET_ID – (default: ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji)
+ MODEL_GPUS – GPUs for policy model (default: 2)
+ BATCH_SIZE – global batch size (preference pairs) (default: 8)
+ MAX_STEPS – total optimization steps (default: 1000)
+ LR – learning rate (default: 1e-4)
+ DPO_BETA – DPO temperature parameter (default: 0.1)
+ LOSS_TYPE – DPO variant (sigmoid/hinge/ipo) (default: sigmoid)
+ SAVE_STEPS – checkpoint save interval (default: 100)
+ MAX_LENGTH – max sequence length (default: 2048)
+"""
+
+import os
+from typing import Any, Dict, List, Optional
+
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+from twinkle.data_format import Trajectory
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.loss import DPOLoss
+from twinkle.metric import DPOMetric
+from twinkle.preprocessor import EmojiDPOProcessor
+from twinkle.processor import InputProcessor
+
+logger = get_logger()
+
+# ── Configuration ─────────────────────────────────────────────────────────────
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
+
+MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2))
+
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 8)) # Number of preference pairs
+GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 2))
+LEARNING_RATE = float(os.environ.get('LR', 1e-4)) # LoRA DPO requires higher LR (1e-4 to 3e-4)
+DPO_BETA = float(os.environ.get('DPO_BETA', 0.1))
+SFT_WEIGHT = float(os.environ.get('SFT_WEIGHT', 1.0)) # SFT loss weight for regularization
+LOSS_TYPE = os.environ.get('LOSS_TYPE', 'sigmoid') # sigmoid, hinge, ipo
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 100))
+MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 2048))
+ADAPTER_NAME = 'default_0'
+SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', 'You are a helpful assistant.')
+
+
+def create_dpo_dataset():
+ """Create DPO dataset with positive/negative format."""
+ dataset = Dataset(DatasetMeta(DATASET_ID, data_slice=range(50)))
+ dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
+ dataset.map(
+ EmojiDPOProcessor,
+ init_args={
+ 'system': SYSTEM_PROMPT,
+ }
+ )
+ # DPO preprocessor returns {'positive': [...], 'negative': [...]}
+ # batch_encode handles this format automatically
+ dataset.encode(load_from_cache_file=True)
+ return dataset
+
+
+def prepare_dpo_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Prepare DPO batch: reorganize batch for training with DP-safe interleaving.
+
+ Args:
+ batch: List of rows, each with 'positive' and 'negative' InputFeatures
+ and other fields (question, etc.)
+
+ Returns:
+ List interleaved as [pos_1, neg_1, pos_2, neg_2, ...] to ensure each DP
+ worker gets complete positive/negative pairs after slicing.
+ Each item contains all original fields plus the InputFeature fields.
+ """
+ result = []
+
+ for row in batch:
+ # Get base fields (excluding positive/negative)
+ base_fields = {k: v for k, v in row.items() if k not in ('positive', 'negative')}
+
+ # Positive sample: merge base fields with positive InputFeature
+ pos_sample = {**base_fields, **row['positive']}
+ # Negative sample: merge base fields with negative InputFeature
+ neg_sample = {**base_fields, **row['negative']}
+
+ # Interleave: [pos, neg] per pair for DP-safe slicing
+ result.append(pos_sample)
+ result.append(neg_sample)
+
+ return result
+
+
+# ── Main Training Loop ────────────────────────────────────────────────────────
+
+def main():
+ # Set up device groups - only one group for LoRA training
+ device_groups = [
+ DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
+ ]
+
+ # Configure device mesh for MultiLoRA Megatron: dp=2, pp=1
+ from twinkle.model import MultiLoraMegatronModel
+ policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=2, pp_size=1)
+ ModelClass = MultiLoraMegatronModel
+
+ twinkle.initialize(mode='ray', nproc_per_node=MODEL_GPUS, groups=device_groups)
+
+ # ── DataLoader Setup ──────────────────────────────────────────────────────
+ dataloader = DataLoader(
+ dataset=create_dpo_dataset,
+ batch_size=BATCH_SIZE,
+ min_batch_size=BATCH_SIZE,
+ device_mesh=policy_mesh,
+ )
+
+ # ── Policy Model Setup with LoRA ──────────────────────────────────────────
+ lora_config = LoraConfig(
+ target_modules='all-linear',
+ r=8,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ )
+
+ policy_model = ModelClass(
+ model_id=MODEL_ID,
+ device_mesh=policy_mesh,
+ remote_group='policy',
+ )
+ MAX_STEPS = len(dataloader)
+ policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
+
+ # Configure optimizer based on backend
+ policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
+ policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME)
+
+
+ # Set up loss function and metrics
+ loss_fn = DPOLoss(
+ beta=DPO_BETA,
+ loss_type=LOSS_TYPE,
+ reference_free=False, # We use base model as reference via disable_lora=True
+ sft_weight=SFT_WEIGHT,
+ )
+
+ policy_model.set_loss(loss_fn, adapter_name=ADAPTER_NAME)
+ policy_model.add_metric(DPOMetric, beta=DPO_BETA, adapter_name=ADAPTER_NAME)
+ policy_model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME)
+ policy_model.set_template('Qwen3_5Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
+
+ optim_step = 0
+ backend_name = 'MultiLoRA Megatron'
+ logger.info(get_device_placement())
+ logger.info(f'Starting MultiLoRA DPO training ({backend_name}): loss_type={LOSS_TYPE}, beta={DPO_BETA}, lr={LEARNING_RATE}')
+ logger.info(f'Using base model (disable_lora=True) as reference model')
+
+ # ── Training Loop ─────────────────────────────────────────────────────────
+ for batch in dataloader:
+ # batch is List[Dict] with 'positive' and 'negative' keys
+ dpo_batch = prepare_dpo_batch(batch)
+
+ # Get reference outputs using base model (without LoRA adapter)
+ # disable_lora=True tells the model to skip LoRA and use base weights
+ ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True, adapter_name=ADAPTER_NAME)
+ policy_model.forward_backward(inputs=dpo_batch, ref_outputs=ref_outputs, adapter_name=ADAPTER_NAME)
+ policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
+
+ optim_step += 1
+
+ # Logging
+ if optim_step % GRADIENT_ACCUMULATION_STEPS == 0:
+ metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
+ logger.info(f'[Step {optim_step // GRADIENT_ACCUMULATION_STEPS}/{MAX_STEPS}] {metrics}')
+
+ # Checkpointing
+ if optim_step % SAVE_STEPS == 0:
+ policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)
+
+ # ── Save Final Checkpoint ─────────────────────────────────────────────────
+ logger.info(f'Training completed. Total steps: {optim_step}')
+ policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/src/twinkle/loss/dpo.py b/src/twinkle/loss/dpo.py
index 44d81f82..d50837a0 100644
--- a/src/twinkle/loss/dpo.py
+++ b/src/twinkle/loss/dpo.py
@@ -310,6 +310,8 @@ def __call__(
reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype)
elif ref_logps is not None:
# Per-token reference log probs provided, need to align and sum
+ if not torch.is_tensor(ref_logps):
+ ref_logps = torch.as_tensor(ref_logps)
ref_logps_aligned = self._align_logps(ref_logps, labels.shape, device, dtype)
ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned)
reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels)
diff --git a/src/twinkle/metric/dpo.py b/src/twinkle/metric/dpo.py
index 5ce61410..8a1d4d6c 100644
--- a/src/twinkle/metric/dpo.py
+++ b/src/twinkle/metric/dpo.py
@@ -50,6 +50,9 @@ def _align_logps(self, logps, target_shape, device, dtype):
Aligned tensor with shape matching target_shape
"""
import torch
+
+ if not torch.is_tensor(logps):
+ logps = torch.as_tensor(logps)
logps = logps.to(device=device, dtype=dtype)
batch_size, src_len = logps.shape
_, target_len = target_shape
diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py
index c134e41c..4cdbba84 100644
--- a/src/twinkle/model/megatron/megatron.py
+++ b/src/twinkle/model/megatron/megatron.py
@@ -1169,11 +1169,14 @@ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=Non
# Save config on rank 0 only
if dp_rank == 0:
self.hf_config.save_pretrained(output_dir)
+ if isinstance(model[0], PeftModel):
+ model[0].peft_config[adapter_name].save_pretrained(output_dir)
def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None):
"""Save in Megatron checkpoint format."""
os.makedirs(output_dir, exist_ok=True)
-
+ from megatron.core import parallel_state as mpu
+ dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0
state_dict = self._get_trainable_parameters(adapter_name)
cpu_state_dict = {}
for k, v in state_dict.items():
@@ -1189,6 +1192,12 @@ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_convert
rank = dist.get_rank() if dist.is_initialized() else 0
checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt')
torch.save(cpu_state_dict, checkpoint_path)
+ # Save config on rank 0 only
+ model = self.strategy.unwrap_model(self.model)
+ if dp_rank == 0:
+ self.hf_config.save_pretrained(output_dir)
+ if isinstance(model[0], PeftModel):
+ model[0].peft_config[adapter_name].save_pretrained(output_dir)
def _save_tokenizer(self, output_dir: str, **kwargs):
from twinkle.utils import is_last_rank
diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py
index 77cb330e..346b9f86 100644
--- a/src/twinkle/model/megatron/multi_lora_megatron.py
+++ b/src/twinkle/model/megatron/multi_lora_megatron.py
@@ -2,6 +2,7 @@
import os
import torch.distributed as dist
import torch.nn as nn
+from functools import partial
from peft import LoraConfig
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@@ -184,12 +185,12 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs):
with self.multi_adapter.save_context(kwargs.get('adapter_name')) as real_adapter_name:
save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron'
+ # Use partial to bind adapter_name to save_lora_converter
+ lora_converter = partial(self.multi_adapter.save_lora_converter, adapter_name=real_adapter_name)
if save_format == 'hf':
- self._save_hf_format(
- checkpoint_dir, real_adapter_name, lora_converter=self.multi_adapter.save_lora_converter)
+ self._save_hf_format(checkpoint_dir, real_adapter_name, lora_converter=lora_converter)
else:
- self._save_megatron_format(
- checkpoint_dir, real_adapter_name, lora_converter=self.multi_adapter.save_lora_converter)
+ self._save_megatron_format(checkpoint_dir, real_adapter_name, lora_converter=lora_converter)
self._save_tokenizer(checkpoint_dir, adapter_name=kwargs.get('adapter_name'))
# Final synchronization to ensure all ranks complete save
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
index 91aee481..4c3bc6de 100644
--- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
+++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
@@ -235,14 +235,15 @@ async def _sample_single(
"""
multi_modal_data = self._extract_multi_modal_data(feat)
response = await self.engine.sample(
- prompt=feat['prompt'],
+ prompt=feat['prompt'] if 'prompt' in feat else feat['input_ids'],
sampling_params=sampling_params,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=feat.get('mm_processor_kwargs'),
)
- feat['input_ids'] = response.prompt_token_ids
- feat['labels'] = [-100] * len(response.prompt_token_ids)
+ if 'input_ids' not in feat:
+ feat['input_ids'] = response.prompt_token_ids
+ feat['labels'] = [-100] * len(response.prompt_token_ids)
if not logprobs_only:
# response.sequences contains num_samples sequences for this prompt
sequences = []
@@ -318,7 +319,7 @@ def sample(
inputs_list = self._normalize_inputs(inputs)
# Check if inputs are Trajectory (not encoded) - aligned with Model.forward logic
- is_trajectory = 'prompt' not in inputs_list[0] or 'input_ids' not in inputs_list[0]
+ is_trajectory = 'prompt' not in inputs_list[0] and 'input_ids' not in inputs_list[0]
logprobs_only = False
if sampling_params.max_tokens == 0:
sampling_params.max_tokens = 1
diff --git a/src/twinkle/server/common/datum.py b/src/twinkle/server/common/datum.py
index e78f20fb..1cb1510e 100644
--- a/src/twinkle/server/common/datum.py
+++ b/src/twinkle/server/common/datum.py
@@ -71,6 +71,11 @@ def extract_rl_feature(datum: types.Datum | list[types.Datum]) -> dict:
if 'advantages' in d.loss_fn_inputs:
advantages = d.loss_fn_inputs['advantages'].to_numpy().tolist()
result['advantages'].append(advantages)
+
+ # 'ref_logps' -> 'ref_logps' (for DPO loss)
+ if 'ref_logps' in d.loss_fn_inputs:
+ ref_logps = d.loss_fn_inputs['ref_logps'].to_numpy().tolist()
+ result['ref_logps'].append(ref_logps)
return result
diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py
index 79a90349..84910baf 100644
--- a/src/twinkle/server/gateway/server.py
+++ b/src/twinkle/server/gateway/server.py
@@ -13,6 +13,7 @@
from typing import Any
import twinkle_client.types as types
+from twinkle.server.utils.metrics import create_metrics_middleware
from twinkle.server.utils.state import get_server_state
from twinkle.server.utils.validation import verify_request_token
from twinkle.utils.logger import get_logger
@@ -93,6 +94,8 @@ def build_server_app(deploy_options: dict[str, Any],
async def verify_token(request: Request, call_next):
return await verify_request_token(request=request, call_next=call_next)
+ app.middleware('http')(create_metrics_middleware('Gateway'))
+
def get_self() -> GatewayServer:
return serve.get_replica_context().servable_object
diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py
index 9271b681..9a2c2f27 100644
--- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py
+++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py
@@ -125,3 +125,17 @@ async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str)
ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)
return types.CheckpointPathResponse(path=str(ckpt_dir), twinkle_path=checkpoint.twinkle_path)
+
+ @app.get('/twinkle/status')
+ async def status(
+ request: Request,
+ self: GatewayServer = Depends(self_fn),
+ ) -> dict:
+ cleanup_stats = await self.state.get_cleanup_stats()
+ return {
+ 'resources': cleanup_stats['resource_counts'],
+ 'cleanup': {
+ 'running': cleanup_stats['cleanup_running'],
+ 'expiration_timeout': cleanup_stats['expiration_timeout'],
+ },
+ }
diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py
index cf084606..9d28e52b 100644
--- a/src/twinkle/server/launcher.py
+++ b/src/twinkle/server/launcher.py
@@ -128,7 +128,12 @@ def _init_ray(self) -> None:
# Use runtime_env to apply patches in worker processes
# This is required because Ray Serve's ProxyActor runs in separate processes
runtime_env = get_runtime_env_for_patches()
- ray.init(namespace=namespace, runtime_env=runtime_env)
+ # Connect to existing cluster if available, otherwise start local instance
+ ray.init(
+ address='auto',
+ namespace=namespace,
+ runtime_env=runtime_env,
+ )
logger.info(f'Ray initialized with namespace={namespace}')
self._ray_initialized = True
diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py
index 41351811..037a3f31 100644
--- a/src/twinkle/server/model/app.py
+++ b/src/twinkle/server/model/app.py
@@ -15,6 +15,7 @@
import twinkle
from twinkle import DeviceGroup, DeviceMesh
from twinkle.server.utils.lifecycle import AdapterManagerMixin
+from twinkle.server.utils.metrics import create_metrics_middleware
from twinkle.server.utils.state import ServerStateProxy, get_server_state
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
@@ -81,7 +82,7 @@ def __init__(self,
self._replica_registered = False
# Initialize mixins
- self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
+ self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Model')
self._init_adapter_manager(**adapter_config)
# Note: countdown task is started lazily in _ensure_sticky()
@@ -164,6 +165,8 @@ def build_model_app(model_id: str,
async def verify_token(request: Request, call_next):
return await verify_request_token(request=request, call_next=call_next)
+ app.middleware('http')(create_metrics_middleware('Model'))
+
def get_self() -> ModelManagement:
return serve.get_replica_context().servable_object
diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py
index cd0ca21d..2cc1e091 100644
--- a/src/twinkle/server/model/backends/common.py
+++ b/src/twinkle/server/model/backends/common.py
@@ -128,6 +128,91 @@ class TwinkleCompatModelBase:
def get_template(self, adapter_name: str) -> Template:
return self.optimizer_group[adapter_name].template
+ def _tinker_setup_loss(self, loss_fn: str, inputs, adapter_name: str, kwargs: dict):
+ """Set up loss function based on loss_fn; pops DPO/GRPO-specific params from kwargs in-place."""
+ if loss_fn == 'cross_entropy':
+ self.set_loss('CrossEntropyLoss', adapter_name=adapter_name)
+ elif loss_fn == 'importance_sampling':
+ has_ref_logps = any('ref_logps' in d.loss_fn_inputs for d in inputs)
+ if has_ref_logps:
+ beta = kwargs.pop('dpo_beta', 0.1)
+ loss_type = kwargs.pop('dpo_loss_type', 'sigmoid')
+ sft_weight = kwargs.pop('dpo_sft_weight', 0.0)
+ self.set_loss(
+ 'DPOLoss', adapter_name=adapter_name, beta=beta, loss_type=loss_type, sft_weight=sft_weight)
+ # Only add DPOMetric if not already present for this adapter
+ self._ensure_dpo_metric(adapter_name, beta)
+ else:
+ epsilon = kwargs.pop('epsilon', 0.2)
+ grpo_beta = kwargs.pop('beta', 0.0)
+ self.set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=epsilon, beta=grpo_beta)
+ else:
+ self.set_loss('CrossEntropyLoss', adapter_name=adapter_name)
+
+ def _ensure_dpo_metric(self, adapter_name: str, beta: float):
+ """Add DPOMetric for the adapter if not already present.
+
+ This prevents duplicate metric accumulation across training steps.
+ """
+ from twinkle.metric.dpo import DPOMetric
+ optimizer_config = self.optimizer_group[adapter_name]
+ # Check if DPOMetric already exists in training metrics
+ for metric in optimizer_config.train_status.metrics:
+ if isinstance(metric, DPOMetric):
+ return
+ self.add_metric('DPOMetric', adapter_name=adapter_name, beta=beta)
+
+ def _tinker_build_output(self, inputs, outputs):
+ """Extract logits/logps from model outputs and build per-datum output list."""
+ logits = outputs.get('logits')
+ if logits is not None:
+ logits = self._normalize_tensor_output(logits)
+ logps = outputs.get('logps', None)
+ if logps is not None:
+ logps = self._normalize_tensor_output(logps)
+ return self._get_forward_output(inputs, logits, logps)
+
+ @staticmethod
+ def _normalize_tensor_output(value):
+ """Normalize various output formats (tensor, list of tensors, nested lists, floats) to a single tensor.
+
+ Handles:
+ - torch.Tensor: detach and move to cpu
+ - list of torch.Tensor: cat along dim=0
+ - nested lists: recursively flatten and cat
+ - list of floats/int: convert to tensor
+ """
+ if value is None:
+ return None
+
+ if isinstance(value, torch.Tensor):
+ return value.detach().cpu()
+
+ if isinstance(value, list):
+ return torch.as_tensor(value, dtype=torch.float32).detach().cpu()
+
+ if isinstance(value, (int, float)):
+ return torch.tensor([value], dtype=torch.float32)
+
+ raise ValueError(f'Unexpected type for tensor output: {type(value)}')
+
+ @staticmethod
+ def _tinker_prepare_ref_outputs(loss_values: dict, loss_kwargs: dict):
+ """Convert ref_logps list-of-lists into a padded tensor and inject into loss_kwargs.
+
+ Returns the ref_outputs dict (or None if ref_logps not present), so callers
+ can optionally propagate it to train_status.forward_kwargs.
+ """
+ if 'ref_logps' not in loss_values:
+ return None
+ import torch.nn.functional as F
+ ref_logps_lists = loss_values.pop('ref_logps')
+ max_len = max(len(r) for r in ref_logps_lists)
+ padded = [F.pad(torch.tensor(r, dtype=torch.float32), (0, max_len - len(r))) for r in ref_logps_lists]
+ ref_outputs_dict = {'logps': torch.stack(padded)}
+ loss_kwargs['ref_outputs'] = ref_outputs_dict
+ return ref_outputs_dict
+
@staticmethod
def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]:
"""Convert raw logits to the expected output format with logprobs and elementwise_loss."""
diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py
index 831b9468..55cc4e72 100644
--- a/src/twinkle/server/model/backends/megatron_model.py
+++ b/src/twinkle/server/model/backends/megatron_model.py
@@ -8,6 +8,7 @@
from twinkle import remote_class, remote_function
from twinkle.data_format import InputFeature, Trajectory
+from twinkle.infra import collect_tensor_dict
from twinkle.model.megatron import MultiLoraMegatronModel
from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature
from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics,
@@ -24,55 +25,33 @@ class TwinkleCompatMegatronModel(MultiLoraMegatronModel, TwinkleCompatModelBase)
@remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True)
def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs):
"""Combined forward and backward pass."""
- if loss_fn == 'importance_sampling':
- super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0)
+ self._tinker_setup_loss(loss_fn, inputs, adapter_name, kwargs)
template = self.get_template(adapter_name=adapter_name)
input_features = datum_to_input_feature(inputs, template)
loss_values = extract_rl_feature(inputs)
loss_kwargs = kwargs.copy()
+ # ref_logps → padded tensor; megatron forward_backward auto-stores loss_kwargs in
+ # train_status.forward_kwargs (megatron.py:465), so DPOMetric reads it next step.
+ self._tinker_prepare_ref_outputs(loss_values, loss_kwargs)
loss_kwargs.update(loss_values)
+
outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs)
loss = outputs.get('loss', None)
- logits_list = outputs.get('logits', [])
- logps = outputs.get('logps', [])
- if logits_list is None and logps is None:
- return [None, None]
-
- logits = None
- if logits_list is not None:
- if isinstance(logits_list, torch.Tensor):
- logits = logits_list.detach()
- else:
- logits = torch.cat([logit.detach() for logit in logits_list], dim=0)
- logps = logps.detach().cpu()
- results = self._get_forward_output(inputs, logits, logps)
-
if isinstance(loss, torch.Tensor):
loss = loss.item()
else:
- loss = float(loss)
-
+ loss = float(loss) if loss is not None else 0.0
+ results = self._tinker_build_output(inputs, outputs)
return [results, loss]
- @remote_function(dispatch='slice_dp', collect='flatten')
- def tinker_forward_only(self, *, inputs: List[types.Datum], **kwargs):
+ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results)
+ def tinker_forward_only(self, *, inputs: List[types.Datum], adapter_name: str = None, **kwargs):
"""Forward pass without gradient computation."""
- template = self.get_template(**kwargs)
+ template = self.get_template(adapter_name)
input_features = datum_to_input_feature(inputs, template)
- outputs = super().forward_only(inputs=input_features, **kwargs)
- logits = outputs.get('logits', None)
- logps = outputs.get('logps', None)
-
- if logits is not None:
- if isinstance(logits, torch.Tensor):
- logits = logits.detach().cpu()
- elif isinstance(logits, list) and len(logits) > 0:
- logits = torch.cat([logit.detach().cpu() for logit in logits], dim=0)
- results = self._get_forward_output(inputs, logits, logps)
- else:
- results = [{'logprobs': None, 'elementwise_loss': None} for _ in inputs]
-
- return results
+ outputs = super().forward_only(inputs=input_features, adapter_name=adapter_name, **kwargs)
+ results = self._tinker_build_output(inputs, outputs)
+ return [results, 0.0]
@remote_function(dispatch='all')
def tinker_step(self, *, adam_params: types.AdamParams, **kwargs):
@@ -119,7 +98,13 @@ def tinker_load(self, checkpoint_dir: str, **kwargs):
# Twinkle-native methods (InputFeature/Trajectory-based I/O)
# ------------------------------------------------------------------
- @remote_function(dispatch='slice_dp', collect='mean')
+ @remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
+ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
+ """Forward-only for twinkle-native clients (InputFeature/Trajectory I/O)."""
+ output = super().forward_only(inputs=inputs, **kwargs)
+ return to_cpu_safe_output(output)
+
+ @remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
**kwargs):
"""Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O)."""
diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py
index fe30f616..121b4df7 100644
--- a/src/twinkle/server/model/backends/transformers_model.py
+++ b/src/twinkle/server/model/backends/transformers_model.py
@@ -11,6 +11,7 @@
from twinkle import remote_class, remote_function
from twinkle.data_format import InputFeature, Trajectory
+from twinkle.infra import collect_tensor_dict
from twinkle.model import MultiLoraTransformersModel
from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature
from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics,
@@ -30,43 +31,33 @@ class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatMo
# Tinker-compat methods (Datum-based I/O)
# ------------------------------------------------------------------
- @remote_function(dispatch='slice_dp', collect='flatten')
- def tinker_forward_only(self, *, inputs: List[types.Datum], **kwargs):
- template = self.get_template(**kwargs)
+ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results)
+ def tinker_forward_only(self, *, inputs: List[types.Datum], adapter_name: str = None, **kwargs):
+ template = self.get_template(adapter_name)
input_features = datum_to_input_feature(inputs, template)
- outputs = super().forward_only(inputs=input_features, **kwargs)
- logits = outputs.get('logits')
- if logits is not None:
- logits = logits.detach().cpu()
- logps = outputs.get('logps', None)
- if logps is not None:
- logps = logps.detach().cpu()
- results = self._get_forward_output(inputs, logits, logps)
- return results
+ outputs = super().forward_only(inputs=input_features, adapter_name=adapter_name, **kwargs)
+ results = self._tinker_build_output(inputs, outputs)
+ return [results, 0.0]
@remote_function(dispatch='slice_dp', collect=collect_forward_backward_results)
def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs):
- if loss_fn == 'cross_entropy':
- super().set_loss('CrossEntropyLoss', adapter_name=adapter_name)
- elif loss_fn == 'importance_sampling':
- super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0)
- else:
- super().set_loss('CrossEntropyLoss', adapter_name=adapter_name)
+ self._tinker_setup_loss(loss_fn, inputs, adapter_name, kwargs)
template = self.get_template(adapter_name)
input_features = datum_to_input_feature(inputs, template)
outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs)
loss_values = extract_rl_feature(inputs)
loss_kwargs = kwargs.copy()
+ # Convert ref_logps list-of-lists into a padded tensor wrapped in ref_outputs
+ # so that DPOLoss and DPOMetric can consume it via ref_outputs.get('logps').
+ ref_outputs_dict = self._tinker_prepare_ref_outputs(loss_values, loss_kwargs)
+ if ref_outputs_dict is not None:
+ # Propagate to train_status.forward_kwargs so DPOMetric.accumulate
+ # gets ref_outputs on the next forward() call (where accumulate_metrics runs).
+ self.optimizer_group[adapter_name].train_status.forward_kwargs['ref_outputs'] = ref_outputs_dict
loss_kwargs.update(loss_values)
loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs)
super().backward(adapter_name=adapter_name, **kwargs)
- logits = outputs.get('logits')
- if logits is not None:
- logits = logits.detach()
- logps = outputs.get('logps', None)
- if logps is not None:
- logps = logps.detach().cpu()
- results = self._get_forward_output(inputs, logits, logps)
+ results = self._tinker_build_output(inputs, outputs)
return [results, loss]
@remote_function()
@@ -106,7 +97,13 @@ def tinker_load(self, checkpoint_dir: str, **kwargs):
# Twinkle-native methods (InputFeature/Trajectory-based I/O)
# ------------------------------------------------------------------
- @remote_function(dispatch='slice_dp', collect='mean')
+ @remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
+ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
+ """Forward-only for twinkle-native clients (InputFeature/Trajectory I/O)."""
+ output = super().forward_only(inputs=inputs, **kwargs)
+ return to_cpu_safe_output(output)
+
+ @remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
**kwargs):
"""Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O)."""
diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py
index 3f702621..88a59b2c 100644
--- a/src/twinkle/server/model/tinker_handlers.py
+++ b/src/twinkle/server/model/tinker_handlers.py
@@ -17,6 +17,7 @@
from .app import ModelManagement
from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager
+from twinkle.server.utils import get_template_for_model
from twinkle.utils.logger import get_logger
logger = get_logger()
@@ -46,7 +47,9 @@ async def _create_adapter():
adapter_name = self.get_adapter_name(adapter_name=_model_id)
self.register_resource(adapter_name, token, session_id=body.session_id)
self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg)
- self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model)
+ # Select template based on model type
+ template = get_template_for_model(self.base_model)
+ self.model.set_template(template, adapter_name=adapter_name, model_id=self.base_model)
self.model.set_processor('InputProcessor', adapter_name=adapter_name)
self.model.set_optimizer('Adam', adapter_name=adapter_name)
self.set_resource_state(adapter_name, 'grad_ready', False)
@@ -111,12 +114,12 @@ async def _do_forward():
self.assert_resource_exists(adapter_name)
datum_list = body.forward_input.data
loss_fn_config = body.forward_input.loss_fn_config or {}
- output = self.model.tinker_forward_only(inputs=datum_list, adapter_name=adapter_name)
- loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config)
+ output, loss = self.model.tinker_forward_only(
+ inputs=datum_list, adapter_name=adapter_name, **loss_fn_config)
return types.ForwardBackwardOutput(
loss_fn_output_type='CrossEntropyLossReturn',
loss_fn_outputs=output,
- metrics={'loss:sum': loss},
+ metrics={'loss:avg': loss},
)
except Exception:
logger.error(traceback.format_exc())
diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py
index 92cdd62d..40fdadbe 100644
--- a/src/twinkle/server/processor/app.py
+++ b/src/twinkle/server/processor/app.py
@@ -21,6 +21,7 @@
import twinkle
from twinkle import DeviceGroup, DeviceMesh, get_logger
from twinkle.server.utils.lifecycle import ProcessorManagerMixin
+from twinkle.server.utils.metrics import create_metrics_middleware
from twinkle.server.utils.state import ServerStateProxy, get_server_state
from twinkle.server.utils.validation import verify_request_token
from .twinkle_handlers import _register_processor_routes
@@ -124,6 +125,8 @@ def build_processor_app(ncpu_proc_per_node: int,
async def verify_token(request: Request, call_next):
return await verify_request_token(request=request, call_next=call_next)
+ app.middleware('http')(create_metrics_middleware('Processor'))
+
def get_self() -> ProcessorManagement:
return serve.get_replica_context().servable_object
diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py
index 0443df94..65cb2f5d 100644
--- a/src/twinkle/server/sampler/app.py
+++ b/src/twinkle/server/sampler/app.py
@@ -13,6 +13,7 @@
import twinkle
from twinkle import DeviceGroup, DeviceMesh
+from twinkle.server.utils.metrics import create_metrics_middleware
from twinkle.server.utils.state import ServerStateProxy, get_server_state
from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
from twinkle.server.utils.validation import get_token_from_request, verify_request_token
@@ -50,6 +51,7 @@ def __init__(self,
else:
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
self.sampler_type = sampler_type
+ self.model_id = model_id
replica_context = serve.get_replica_context()
replica_id = replica_context.replica_id.unique_id
@@ -76,11 +78,10 @@ def __init__(self,
remote_group=self.device_group.name,
**kwargs)
- self.sampler.set_template('Template', model_id=model_id)
self.state: ServerStateProxy = get_server_state()
# Initialize task queue mixin
- self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
+ self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Sampler')
@serve.multiplexed(max_num_models_per_replica=5)
async def _sticky_entry(self, sticky_key: str):
@@ -135,6 +136,8 @@ def build_sampler_app(model_id: str,
async def verify_token(request: Request, call_next):
return await verify_request_token(request=request, call_next=call_next)
+ app.middleware('http')(create_metrics_middleware('Sampler'))
+
def get_self() -> SamplerManagement:
return serve.get_replica_context().servable_object
diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py
index f0106024..ae7496c0 100644
--- a/src/twinkle/server/sampler/tinker_handlers.py
+++ b/src/twinkle/server/sampler/tinker_handlers.py
@@ -17,6 +17,7 @@
from twinkle.data_format import SamplingParams
from twinkle.server.common.checkpoint_factory import create_checkpoint_manager
+from twinkle.server.utils import get_template_for_model
from twinkle.utils.logger import get_logger
logger = get_logger()
@@ -48,6 +49,10 @@ async def _do_sample():
# Extract prompt token IDs from ModelInput
prompt_inputs = {'input_ids': body.prompt.to_ints()}
+ # Set template for sampler based on model type
+ template = get_template_for_model(self.model_id)
+ self.sampler.set_template(template, model_id=self.model_id)
+
# Get model_path from body or sampling session
model_path = body.model_path
if not model_path and body.sampling_session_id:
diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py
index fdab4278..9b4abe66 100644
--- a/src/twinkle/server/utils/__init__.py
+++ b/src/twinkle/server/utils/__init__.py
@@ -5,3 +5,4 @@
from .lifecycle import AdapterManagerMixin, ProcessorManagerMixin, SessionResourceMixin
from .rate_limiter import RateLimiter
from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus
+from .template_utils import get_template_for_model
diff --git a/src/twinkle/server/utils/metrics.py b/src/twinkle/server/utils/metrics.py
new file mode 100644
index 00000000..eee915d7
--- /dev/null
+++ b/src/twinkle/server/utils/metrics.py
@@ -0,0 +1,267 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Central metrics module for Twinkle server observability.
+
+Provides ray.util.metrics instruments that feed both the Ray Dashboard
+(port 8265) and Prometheus (via /api/prometheus).
+
+All metric names use the ``twinkle_`` prefix. Metric instances are
+cached per deployment to avoid duplicate registration.
+
+Public entry-points:
+
+* ``create_metrics_middleware(deployment)`` – FastAPI HTTP middleware
+* ``get_task_metrics(deployment)`` – task-queue / rate-limit gauges
+* ``get_resource_metrics()`` – ServerState resource gauges
+"""
+from __future__ import annotations
+
+import time
+from pydantic import BaseModel, ConfigDict
+from ray.util.metrics import Counter, Gauge, Histogram
+from typing import Any, Callable
+
+from twinkle.utils.logger import get_logger
+
+logger = get_logger()
+
+# ---------------------------------------------------------------------------
+# Histogram bucket boundaries (seconds) – shared by all histograms
+# ---------------------------------------------------------------------------
+_HISTOGRAM_BOUNDARIES = [
+ 0.01,
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ 30.0,
+ 60.0,
+ 120.0,
+ 300.0,
+]
+
+# ---------------------------------------------------------------------------
+# Lazy caches – populated on first call per deployment / globally
+# ---------------------------------------------------------------------------
+_task_metrics_cache: dict[str, TaskMetrics] = {}
+_resource_metrics_cache: ResourceMetrics | None = None
+_request_metrics_cache: dict[str, _RequestMetrics] = {}
+
+# ---------------------------------------------------------------------------
+# Pydantic models for structured metric access
+# ---------------------------------------------------------------------------
+
+
+class TaskMetrics(BaseModel):
+ """Task queue metrics container.
+
+ Attributes:
+ queue_depth: Current number of queued tasks.
+ tasks_total: Total task completions.
+ execution_seconds: Pure task execution time in seconds.
+ queue_wait_seconds: Time from enqueue to execution start.
+ rate_limit_rejections: Total rate-limit rejections.
+ rate_limiter_active_tokens: Tokens tracked by rate limiter.
+ """
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ queue_depth: Gauge
+ tasks_total: Counter
+ execution_seconds: Histogram
+ queue_wait_seconds: Histogram
+ rate_limit_rejections: Counter
+ rate_limiter_active_tokens: Gauge
+
+
+class ResourceMetrics(BaseModel):
+ """Resource gauge metrics container.
+
+ Attributes:
+ active_sessions: Current active session count.
+ active_models: Current registered model count.
+ active_sampling_sessions: Current sampling session count.
+ active_futures: Current future/request count.
+ """
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ active_sessions: Gauge
+ active_models: Gauge
+ active_sampling_sessions: Gauge
+ active_futures: Gauge
+
+
+class _RequestMetrics(BaseModel):
+ """HTTP request metrics container (internal)."""
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ requests_total: Counter
+ request_duration_seconds: Histogram
+
+
+# ---------------------------------------------------------------------------
+# A. Request-level metrics (FastAPI middleware)
+# ---------------------------------------------------------------------------
+
+
+def _get_request_metrics(deployment: str) -> _RequestMetrics:
+ """Return (or create) per-deployment HTTP request metrics."""
+ if deployment in _request_metrics_cache:
+ return _request_metrics_cache[deployment]
+
+ metrics = _RequestMetrics(
+ requests_total=Counter(
+ 'twinkle_requests_total',
+ description='Total HTTP requests.',
+ tag_keys=('deployment', 'method', 'status'),
+ ),
+ request_duration_seconds=Histogram(
+ 'twinkle_request_duration_seconds',
+ description='End-to-end HTTP request latency in seconds.',
+ boundaries=_HISTOGRAM_BOUNDARIES,
+ tag_keys=('deployment', 'method'),
+ ),
+ )
+ _request_metrics_cache[deployment] = metrics
+ return metrics
+
+
+def create_metrics_middleware(deployment: str) -> Callable:
+ """Return a FastAPI ``http`` middleware that records request metrics.
+
+ Usage inside a ``build_*_app()`` function::
+
+ from twinkle.server.utils.metrics import create_metrics_middleware
+ metrics_mw = create_metrics_middleware("Model")
+ app.middleware('http')(metrics_mw)
+
+ Because FastAPI executes middleware in LIFO order, registering this
+ **after** ``verify_token`` means it wraps the outermost layer and
+ captures full end-to-end latency including auth.
+ """
+
+ async def metrics_middleware(request: Any, call_next: Callable) -> Any:
+ start = time.monotonic()
+ response = await call_next(request)
+ elapsed = time.monotonic() - start
+ status = str(response.status_code)
+ method = request.scope['route'].path if 'route' in request.scope else request.url.path
+ m = _get_request_metrics(deployment)
+ m.requests_total.inc(tags={
+ 'deployment': deployment,
+ 'method': method,
+ 'status': status,
+ })
+ m.request_duration_seconds.observe(
+ elapsed, tags={
+ 'deployment': deployment,
+ 'method': method,
+ })
+ return response
+
+ return metrics_middleware
+
+
+# ---------------------------------------------------------------------------
+# B. Task-queue metrics
+# ---------------------------------------------------------------------------
+
+
+def get_task_metrics(deployment: str) -> TaskMetrics:
+ """Return (or create) per-deployment task-queue metrics.
+
+ Returns a :class:`TaskMetrics` Pydantic model with:
+
+ - ``queue_depth`` – Gauge
+ - ``tasks_total`` – Counter
+ - ``execution_seconds`` – Histogram
+ - ``queue_wait_seconds`` – Histogram
+ - ``rate_limit_rejections`` – Counter
+ - ``rate_limiter_active_tokens`` – Gauge
+ """
+ if deployment in _task_metrics_cache:
+ return _task_metrics_cache[deployment]
+
+ metrics = TaskMetrics(
+ queue_depth=Gauge(
+ 'twinkle_task_queue_depth',
+ description='Current number of queued tasks.',
+ tag_keys=('deployment', ),
+ ),
+ tasks_total=Counter(
+ 'twinkle_tasks_total',
+ description='Total task completions.',
+ tag_keys=('deployment', 'task_type', 'status'),
+ ),
+ execution_seconds=Histogram(
+ 'twinkle_task_execution_seconds',
+ description='Pure task execution time in seconds.',
+ boundaries=_HISTOGRAM_BOUNDARIES,
+ tag_keys=('deployment', 'task_type'),
+ ),
+ queue_wait_seconds=Histogram(
+ 'twinkle_task_queue_wait_seconds',
+ description='Time from enqueue to execution start in seconds.',
+ boundaries=_HISTOGRAM_BOUNDARIES,
+ tag_keys=('deployment', 'task_type'),
+ ),
+ rate_limit_rejections=Counter(
+ 'twinkle_rate_limit_rejections_total',
+ description='Total rate-limit rejections.',
+ tag_keys=('deployment', ),
+ ),
+ rate_limiter_active_tokens=Gauge(
+ 'twinkle_rate_limiter_active_tokens',
+ description='Number of tokens tracked by the rate limiter.',
+ tag_keys=('deployment', ),
+ ),
+ )
+ _task_metrics_cache[deployment] = metrics
+ return metrics
+
+
+# ---------------------------------------------------------------------------
+# D. Resource gauges (ServerState actor, updated every 15 s)
+# ---------------------------------------------------------------------------
+
+
+def get_resource_metrics() -> ResourceMetrics:
+ """Return (or create) global resource gauge metrics.
+
+ Returns a :class:`ResourceMetrics` Pydantic model with:
+
+ - ``active_sessions`` – Gauge
+ - ``active_models`` – Gauge
+ - ``active_sampling_sessions`` – Gauge
+ - ``active_futures`` – Gauge
+ """
+ global _resource_metrics_cache
+ if _resource_metrics_cache is not None:
+ return _resource_metrics_cache
+
+ metrics = ResourceMetrics(
+ active_sessions=Gauge(
+ 'twinkle_active_sessions',
+ description='Current active session count.',
+ ),
+ active_models=Gauge(
+ 'twinkle_active_models',
+ description='Current registered model count.',
+ ),
+ active_sampling_sessions=Gauge(
+ 'twinkle_active_sampling_sessions',
+ description='Current sampling session count.',
+ ),
+ active_futures=Gauge(
+ 'twinkle_active_futures',
+ description='Current future/request count.',
+ ),
+ )
+ _resource_metrics_cache = metrics
+ return metrics
diff --git a/src/twinkle/server/utils/rate_limiter.py b/src/twinkle/server/utils/rate_limiter.py
index beefaa83..845cf246 100644
--- a/src/twinkle/server/utils/rate_limiter.py
+++ b/src/twinkle/server/utils/rate_limiter.py
@@ -41,6 +41,8 @@ def __init__(
window_seconds: float = 1.0,
token_cleanup_multiplier: float = 10.0,
token_cleanup_interval: float = 60.0,
+ active_tokens_gauge=None,
+ deployment_name: str = '',
):
"""Initialize the rate limiter.
@@ -53,6 +55,8 @@ def __init__(
will be removed. Default is 10.0 (10x the window).
token_cleanup_interval: How often to run the cleanup task in seconds.
Default is 60.0 (every minute).
+ active_tokens_gauge: Optional ray.util.metrics Gauge for tracking active token count.
+ deployment_name: Deployment name for metrics labels.
"""
self.rps_limit = rps_limit
self.tps_limit = tps_limit
@@ -72,6 +76,10 @@ def __init__(
self._cleanup_task: asyncio.Task | None = None
self._cleanup_started = False
+ # Metrics gauge for active token count
+ self._active_tokens_gauge = active_tokens_gauge
+ self._deployment_name = deployment_name
+
def _cleanup_old_requests(self, token: str, current_time: float) -> None:
"""Remove requests outside the sliding window.
@@ -122,6 +130,10 @@ async def _cleanup_inactive_tokens(self) -> None:
logger.debug(f'[RateLimiter] Cleaned up {len(tokens_to_remove)} inactive tokens. '
f'Active tokens remaining: {len(self._token_requests)}')
+ if self._active_tokens_gauge is not None:
+ tags = {'deployment': self._deployment_name} if self._deployment_name else {}
+ self._active_tokens_gauge.set(len(self._token_requests), tags=tags)
+
except asyncio.CancelledError:
logger.debug('[RateLimiter] Cleanup task cancelled')
break
@@ -193,6 +205,9 @@ async def check_and_record(self, token: str, input_tokens: int) -> tuple[bool, s
# Record this request
self._token_requests[token].append((current_time, input_tokens))
+ if self._active_tokens_gauge is not None:
+ tags = {'deployment': self._deployment_name} if self._deployment_name else {}
+ self._active_tokens_gauge.set(len(self._token_requests), tags=tags)
return True, None
def get_stats(self, token: str) -> dict[str, Any]:
diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py
index a42aa9a4..8e7689b2 100644
--- a/src/twinkle/server/utils/state/server_state.py
+++ b/src/twinkle/server/utils/state/server_state.py
@@ -9,6 +9,7 @@
from datetime import datetime
from typing import Any
+from twinkle.server.utils.metrics import get_resource_metrics
from twinkle.utils.logger import get_logger
from .config_manager import ConfigManager
from .future_manager import FutureManager
@@ -51,6 +52,11 @@ def __init__(
self._cleanup_task: asyncio.Task | None = None
self._cleanup_running = False
+ # Metrics loop state
+ self._metrics_task: asyncio.Task | None = None
+ self._metrics_running = False
+ self._metrics_update_interval: float = float(kwargs.get('metrics_update_interval', 15.0))
+
# ----- Session Management -----
async def create_session(self, payload: dict[str, Any]) -> str:
@@ -284,6 +290,22 @@ async def _cleanup_loop(self) -> None:
logger.warning(f'[ServerState Cleanup] Error during cleanup: {e}')
continue
+ async def _metrics_loop(self) -> None:
+ """Background task that updates resource gauge metrics every N seconds."""
+ resource_metrics = get_resource_metrics()
+ while self._metrics_running:
+ try:
+ await asyncio.sleep(self._metrics_update_interval)
+ resource_metrics.active_sessions.set(self._session_mgr.count())
+ resource_metrics.active_models.set(self._model_mgr.count())
+ resource_metrics.active_sampling_sessions.set(self._sampling_mgr.count())
+ resource_metrics.active_futures.set(self._future_mgr.count())
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.debug(f'[ServerState] Error updating metrics: {e}')
+ continue
+
async def start_cleanup_task(self) -> bool:
"""Start the background cleanup task.
@@ -294,6 +316,9 @@ async def start_cleanup_task(self) -> bool:
return False
self._cleanup_running = True
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
+ if not self._metrics_running:
+ self._metrics_running = True
+ self._metrics_task = asyncio.create_task(self._metrics_loop())
return True
async def stop_cleanup_task(self) -> bool:
@@ -308,6 +333,10 @@ async def stop_cleanup_task(self) -> bool:
if self._cleanup_task:
self._cleanup_task.cancel()
self._cleanup_task = None
+ self._metrics_running = False
+ if self._metrics_task:
+ self._metrics_task.cancel()
+ self._metrics_task = None
return True
async def get_cleanup_stats(self) -> dict[str, Any]:
diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py
index d0985c15..9cd99253 100644
--- a/src/twinkle/server/utils/task_queue.py
+++ b/src/twinkle/server/utils/task_queue.py
@@ -18,6 +18,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Deque, Dict, Optional
+from twinkle.server.utils.metrics import get_task_metrics
from twinkle.utils.logger import get_logger
from .rate_limiter import RateLimiter
@@ -157,11 +158,12 @@ async def _do_work():
# Type hint for state attribute that inheriting classes must provide
state: ServerStateProxy
- def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None:
+ def _init_task_queue(self, config: TaskQueueConfig | None = None, deployment_name: str = '') -> None:
"""Initialize the task queue system.
Args:
config: Optional TaskQueueConfig. If None, uses default config.
+ deployment_name: Deployment name for metrics labels (e.g. 'Model', 'Sampler').
"""
self._task_queue_config = config or TaskQueueConfig()
# Per-key queues, but executed by a single global worker.
@@ -169,6 +171,10 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None:
self._queue_order: Deque[str] = deque()
self._new_task_event: asyncio.Event = asyncio.Event()
+ # Metrics initialization
+ self._deployment_name = deployment_name
+ self._task_metrics = get_task_metrics(deployment_name) if deployment_name else None
+
# Initialize rate limiter for RPS/TPS control
self._rate_limiter = RateLimiter(
rps_limit=self._task_queue_config.rps_limit,
@@ -176,6 +182,8 @@ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None:
window_seconds=self._task_queue_config.window_seconds,
token_cleanup_multiplier=self._task_queue_config.token_cleanup_multiplier,
token_cleanup_interval=self._task_queue_config.token_cleanup_interval,
+ active_tokens_gauge=self._task_metrics.rate_limiter_active_tokens if self._task_metrics else None,
+ deployment_name=deployment_name,
)
# Start the rate limiter cleanup task
self._rate_limiter.start_cleanup_task()
@@ -247,6 +255,18 @@ async def _queue_worker(self) -> None:
except asyncio.QueueEmpty:
continue
+ # Record queue wait time and update depth gauge
+ if self._task_metrics:
+ queue_wait = time.monotonic() - task.created_at
+ task_type_label = task.task_type or 'unknown'
+ self._task_metrics.queue_wait_seconds.observe(
+ queue_wait, tags={
+ 'deployment': self._deployment_name,
+ 'task_type': task_type_label
+ })
+ total_depth = sum(qq.qsize() for qq in self._task_queues.values())
+ self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name})
+
now = time.monotonic()
# Global queue timeout
@@ -263,6 +283,13 @@ async def _queue_worker(self) -> None:
queue_state=QueueState.PAUSED_CAPACITY.value,
queue_state_reason=error_payload['error'],
)
+ if self._task_metrics:
+ self._task_metrics.tasks_total.inc(
+ tags={
+ 'deployment': self._deployment_name,
+ 'task_type': task.task_type or 'unknown',
+ 'status': 'timeout'
+ })
q.task_done()
continue
@@ -273,6 +300,8 @@ async def _queue_worker(self) -> None:
await self.state.store_future_status(
task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value)
+ exec_start = time.monotonic()
+ task_status = 'completed'
try:
coro = task.coro_factory()
result = await coro
@@ -283,6 +312,7 @@ async def _queue_worker(self) -> None:
result=result,
queue_state=QueueState.ACTIVE.value)
except Exception:
+ task_status = 'failed'
error_payload = {'error': traceback.format_exc(), 'category': 'Server'}
await self.state.store_future_status(
task.request_id,
@@ -292,6 +322,20 @@ async def _queue_worker(self) -> None:
queue_state=QueueState.ACTIVE.value)
finally:
q.task_done()
+ if self._task_metrics:
+ exec_time = time.monotonic() - exec_start
+ self._task_metrics.execution_seconds.observe(
+ exec_time,
+ tags={
+ 'deployment': self._deployment_name,
+ 'task_type': task.task_type or 'unknown'
+ })
+ self._task_metrics.tasks_total.inc(
+ tags={
+ 'deployment': self._deployment_name,
+ 'task_type': task.task_type or 'unknown',
+ 'status': task_status
+ })
# Keep serial semantics: execute at most one runnable task per loop
break
@@ -409,6 +453,8 @@ async def _perform_preflight_checks(
# Check rate limits
allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens)
if not allowed:
+ if self._task_metrics:
+ self._task_metrics.rate_limit_rejections.inc(tags={'deployment': self._deployment_name})
error_msg = f'Rate limit exceeded: {reason}'
error_payload = {'error': error_msg, 'category': 'User'}
await self.state.store_future_status(
@@ -506,6 +552,10 @@ async def schedule_task(
self._new_task_event.set()
+ if self._task_metrics:
+ total_depth = sum(q.qsize() for q in self._task_queues.values())
+ self._task_metrics.queue_depth.set(total_depth, tags={'deployment': self._deployment_name})
+
return {'request_id': request_id, 'model_id': model_id}
def get_queue_stats(self) -> dict[str, Any]:
diff --git a/src/twinkle/server/utils/template_utils.py b/src/twinkle/server/utils/template_utils.py
new file mode 100644
index 00000000..ad015175
--- /dev/null
+++ b/src/twinkle/server/utils/template_utils.py
@@ -0,0 +1,46 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Template utility functions for model and sampler handlers.
+
+Provides centralized template selection logic for different model types,
+making it easy to maintain and extend template configurations.
+"""
+
+# Template mapping for different model families
+# Key: model name pattern to match, Value: template name
+MODEL_TEMPLATE_MAPPING = {
+ 'Qwen3.5': 'Qwen3_5Template',
+ # Add more model-template mappings here as needed
+ # 'ModelName': 'TemplateName',
+}
+
+# Default template for models not in the mapping
+DEFAULT_TEMPLATE = 'Template'
+
+
+def get_template_for_model(model_name: str) -> str:
+ """
+ Get the appropriate template name for a given model.
+
+ This function determines which template to use based on the model name.
+ It checks if the model name matches any known patterns and returns the
+ corresponding template, or falls back to the default template.
+
+ Args:
+ model_name: The name or identifier of the model (e.g., 'Qwen3.5-4B')
+
+ Returns:
+ The template name to use (e.g., 'Qwen3_5Template' or 'Template')
+
+ Examples:
+ >>> get_template_for_model('Qwen3.5-4B')
+ 'Qwen3_5Template'
+ >>> get_template_for_model('Qwen2-7B')
+ 'Template'
+ >>> get_template_for_model('llama-3-8b')
+ 'Template'
+ """
+ for pattern, template in MODEL_TEMPLATE_MAPPING.items():
+ if pattern in model_name:
+ return template
+ return DEFAULT_TEMPLATE
diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py
index 96a1f33a..5f5cdaf8 100644
--- a/src/twinkle/server/utils/validation.py
+++ b/src/twinkle/server/utils/validation.py
@@ -11,7 +11,7 @@ async def verify_request_token(request: Request, call_next):
This middleware:
1. Extracts the Bearer token from Authorization header
2. Validates the token
- 3. Extracts X-Ray-Serve-Request-Id for sticky sessions
+ 3. Extracts X-Ray-Serve-Request-Id for sticky sessions (skipped for healthz endpoint)
4. Stores token and request_id in request.state for later use
Args:
@@ -26,11 +26,16 @@ async def verify_request_token(request: Request, call_next):
if not is_token_valid(token):
return JSONResponse(status_code=403, content={'detail': 'Invalid token'})
- request_id = request.headers.get('X-Ray-Serve-Request-Id')
- if not request_id:
- return JSONResponse(
- status_code=400, content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'})
- request.state.request_id = request_id
+ # Skip X-Ray-Serve-Request-Id check for healthz endpoint (path ends with /healthz)
+ if not request.url.path.endswith('/healthz'):
+ request_id = request.headers.get('X-Ray-Serve-Request-Id')
+ if not request_id:
+ return JSONResponse(
+ status_code=400,
+ content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'})
+ request.state.request_id = request_id
+ else:
+ request.state.request_id = ''
request.state.token = token
request.state.session_id = request.headers.get('X-Twinkle-Session-Id') or ''
response = await call_next(request)
diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py
index 203cd6d4..d3a74c1d 100644
--- a/src/twinkle/template/base.py
+++ b/src/twinkle/template/base.py
@@ -444,18 +444,56 @@ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bo
for k, v in b.items() if v is not None
} for b in msg['content'] if isinstance(b, dict)]
tools = [dict(tool) for tool in trajectory.get('tools', [])]
- if 'tokenize' not in kwargs:
- kwargs['tokenize'] = True
- if 'enable_thinking' not in kwargs:
- kwargs['enable_thinking'] = self.enable_thinking
- inputs = self.processor.apply_chat_template(
- messages,
- tools=tools,
- padding=False,
- return_dict=True,
- add_generation_prompt=add_generation_prompt,
- return_tensors='pt',
- **kwargs)
+
+ # Use inspect to get apply_chat_template signature params
+ sig = inspect.signature(self.processor.apply_chat_template)
+ supported_params = set(sig.parameters.keys())
+
+ # Check if processor_kwargs is supported
+ if 'processor_kwargs' in supported_params:
+ # Separate supported params from processor_kwargs
+ apply_chat_template_kwargs = {}
+ processor_kwargs = {}
+
+ for key in list(kwargs.keys()):
+ if key in supported_params:
+ apply_chat_template_kwargs[key] = kwargs.pop(key)
+
+ # tokenize is in apply_chat_template_kwargs, set default value
+ if 'tokenize' not in apply_chat_template_kwargs:
+ apply_chat_template_kwargs['tokenize'] = True
+
+ # Set default values for processor_kwargs
+ if 'enable_thinking' not in kwargs:
+ processor_kwargs['enable_thinking'] = self.enable_thinking
+
+ # Add remaining kwargs to processor_kwargs
+ processor_kwargs.update(kwargs)
+
+ inputs = self.processor.apply_chat_template(
+ messages,
+ tools=tools,
+ padding=False,
+ return_dict=True,
+ add_generation_prompt=add_generation_prompt,
+ return_tensors='pt',
+ processor_kwargs=processor_kwargs,
+ **apply_chat_template_kwargs)
+ else:
+ # No processor_kwargs support, pass all kwargs directly
+ if 'tokenize' not in kwargs:
+ kwargs['tokenize'] = True
+ if 'enable_thinking' not in kwargs:
+ kwargs['enable_thinking'] = self.enable_thinking
+
+ inputs = self.processor.apply_chat_template(
+ messages,
+ tools=tools,
+ padding=False,
+ return_dict=True,
+ add_generation_prompt=add_generation_prompt,
+ return_tensors='pt',
+ **kwargs)
return inputs
def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs) -> InputFeature:
diff --git a/src/twinkle_client/common/serialize.py b/src/twinkle_client/common/serialize.py
index de3ca4bb..b2d1720c 100644
--- a/src/twinkle_client/common/serialize.py
+++ b/src/twinkle_client/common/serialize.py
@@ -2,6 +2,7 @@
import json
from numbers import Number
from peft import LoraConfig
+from pydantic import BaseModel
from typing import Any, Mapping
from twinkle.dataset import DatasetMeta
@@ -56,6 +57,9 @@ def serialize_object(obj) -> str:
}
filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig'
return json.dumps(filtered_dict, ensure_ascii=False)
+ elif isinstance(obj, BaseModel):
+ # Pydantic models: convert to dict for JSON serialization by requests
+ return obj.model_dump(mode='json')
elif isinstance(obj, Mapping):
return json.dumps(obj, ensure_ascii=False)
elif isinstance(obj, basic_types):
diff --git a/src/twinkle_client/utils/patch_tinker.py b/src/twinkle_client/utils/patch_tinker.py
index 5f6d955e..e26140bd 100644
--- a/src/twinkle_client/utils/patch_tinker.py
+++ b/src/twinkle_client/utils/patch_tinker.py
@@ -10,11 +10,12 @@
from __future__ import annotations
import os
-from typing import TYPE_CHECKING, Any, Mapping, Union
+from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
from twinkle_client.http.utils import get_api_key, get_request_id
_patched = False
+_loss_fn_config_patched = False
async def _create_sampling_session(self, model_path: str | None = None, base_model: str | None = None) -> str: