Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# MTC GRPO Training with RayCluster

This directory contains configurations for running GRPO training with VERL with [HyperPod Managed Tiered Checkpointing](https://docs.aws.amazon.com/sagemaker/latest/dg/managed-tier-checkpointing.html).

## Files

- `mtc-grpo-cluster.yaml` - RayCluster configuration
- `submit-mtc-grpo.sh` - Script to submit the GRPO training job to the Ray cluster

## Setup

1. Source environment variables:
```bash
# 1. Load environment variables
source setup/env_vars
```

2. Create Service Account for your pods to have S3 access. To do this, please read the [IRSA-README.md](../setup/IRSA-README.md).

## Deploy the RayCluster
```
envsubst < managed-tiered-checkpointing/mtc-grpo-cluster.yaml | kubectl apply -f -
```

## Clone MTC-enabled VERL Code
Delete existing verl repo if you already cloned:
```
rm -rf verl
```

Clone MTC-enabled VERL code. This is a fork from the main VERL repo that has modified checkpointing code to enabled managed tiered checkpointing:
```
git clone https://github.com/aruncs2005/verl.git
```

## Submit the training job
```
./managed-tiered-checkpointing/submit-mtc-grpo.sh
```

## Monitoring

- **Ray Dashboard**: http://localhost:8265 (after port forwarding)
- **View logs**: `kubectl logs -f <head-pod-name>`
- **Check job status**: `ray job status <job-id>`
- **Follow job logs**: `ray job logs <job-id> --follow`

## Configuration

Edit `submit-mtc-grpo.sh` to modify training parameters:

- `train_prompt_bsz` - Training batch size
- `train_prompt_mini_bsz` - Mini batch size for PPO
- `train_prompt_micro_bsz_per_gpu` - Micro batch size per GPU
- `n_resp_per_prompt` - Number of responses per prompt
- `gen_tp` - Tensor parallelism for generation
- Model path, data paths, S3 checkpoint location, etc.

## Cleanup

```bash
# Delete the RayCluster
kubectl delete raycluster mtc-grpo-cluster
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
apiVersion: ray.io/v1alpha1
kind: RayCluster
metadata:
name: mtc-grpo-cluster
labels:
controller-tools.k8s.io: "1.0"
annotations:
karpenter.sh/do-not-disrupt: "true"
spec:
# Ray head pod template
headGroupSpec:
rayStartParams:
dashboard-host: '0.0.0.0'
metrics-export-port: '8080'
template:
spec:
serviceAccountName: ray-s3-sa
nodeSelector:
node.kubernetes.io/instance-type: $INSTANCE_TYPE
sagemaker.amazonaws.com/node-health-status: Schedulable
securityContext:
runAsUser: 0
runAsGroup: 0
fsGroup: 0
containers:
- name: ray-head
image: ${REGISTRY}${IMAGE}:${TAG}
env:
## PROMETHEUS AND GRAFANA
- name: RAY_GRAFANA_IFRAME_HOST
value: http://localhost:3000
- name: RAY_GRAFANA_HOST
value: http://prometheus-grafana.prometheus-system.svc:80
- name: RAY_PROMETHEUS_HOST
value: http://prometheus-kube-prometheus-prometheus.prometheus-system.svc:9090
## EFA AND NCCL CONFIGURATION
- name: FI_PROVIDER
value: "efa"
- name: FI_EFA_USE_DEVICE_RDMA
value: "1"
- name: FI_EFA_FORK_SAFE
value: "1"
- name: NCCL_PROTO
value: "simple"
- name: NCCL_SOCKET_IFNAME
value: "^docker,lo,veth"
- name: NCCL_DEBUG
value: "INFO"
- name: TORCH_NCCL_DUMP_ON_TIMEOUT
value: "1"
- name: TORCH_NCCL_ASYNC_ERROR_HANDLING
value: "1"
- name: HF_TOKEN
value: ${HF_TOKEN}
lifecycle:
preStop:
exec:
command: ["/bin/sh","-c","ray stop"]
resources:
limits:
cpu: 8
memory: 32Gi
requests:
cpu: 8
memory: 32Gi
ports:
- containerPort: 6379
name: gcs-server
- containerPort: 8265
name: dashboard
- containerPort: 10001
name: client
- containerPort: 8000
name: serve
- containerPort: 8080
name: metrics
volumeMounts:
- name: fsx-storage
mountPath: /fsx
- name: ray-logs
mountPath: /tmp/ray
- name: checkpoint-logs
mountPath: /var/log/sagemaker_checkpointing
volumes:
- name: ray-logs
emptyDir: {}
- name: fsx-storage
persistentVolumeClaim:
claimName: fsx-claim
- name: checkpoint-logs
hostPath:
path: /var/logs/sagemaker_checkpointing
type: DirectoryOrCreate
workerGroupSpecs:
- replicas: $NUM_NODES
minReplicas: 1
maxReplicas: 10
groupName: gpu-group
rayStartParams:
num-gpus: "$NUM_GPU_PER_NODE"
metrics-export-port: '8080'
template:
spec:
serviceAccountName: ray-s3-sa
nodeSelector:
node.kubernetes.io/instance-type: $INSTANCE_TYPE
sagemaker.amazonaws.com/node-health-status: Schedulable
securityContext:
runAsUser: 0
runAsGroup: 0
fsGroup: 0
containers:
- name: ray-worker
image: ${REGISTRY}${IMAGE}:${TAG}
env:
- name: FI_PROVIDER
value: "efa"
- name: FI_EFA_USE_DEVICE_RDMA
value: "1"
- name: FI_EFA_FORK_SAFE
value: "1"
- name: NCCL_PROTO
value: "simple"
- name: NCCL_SOCKET_IFNAME
value: "^docker,lo,veth"
- name: NCCL_DEBUG
value: "INFO"
- name: TORCH_NCCL_DUMP_ON_TIMEOUT
value: "1"
- name: TORCH_NCCL_ASYNC_ERROR_HANDLING
value: "1"
- name: HF_TOKEN
value: ${HF_TOKEN}
lifecycle:
preStop:
exec:
command: ["/bin/sh","-c","ray stop"]
resources:
limits:
nvidia.com/gpu: $NUM_GPU_PER_NODE
requests:
nvidia.com/gpu: $NUM_GPU_PER_NODE
ports:
- containerPort: 8080
name: metrics
volumeMounts:
- name: ray-logs
mountPath: /tmp/ray
- name: fsx-storage
mountPath: /fsx
- name: checkpoint-logs
mountPath: /var/log/sagemaker_checkpointing
volumes:
- name: fsx-storage
persistentVolumeClaim:
claimName: fsx-claim
- name: ray-logs
emptyDir: {}
- name: checkpoint-logs
hostPath:
path: /var/logs/sagemaker_checkpointing
type: DirectoryOrCreate
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env bash
set -xeuo pipefail

# Load environment variables
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/../setup/env_vars"

# Project configuration
project_name='verl_grpo_example_gsm8k'
exp_name='qwen3_0.6b_function_rm'

# GRPO Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
use_kl_loss=True
kl_loss_coef=0.001
kl_loss_type=low_var_kl
entropy_coeff=0

# Token length configuration
max_prompt_length=512
max_response_length=1024
filter_overlong_prompts=True
truncation='error'

# Training configuration
train_prompt_bsz=${TRAIN_BATCH_SIZE:-32} # Total batch size
gen_prompt_bsz=${GEN_BATCH_SIZE:-$train_prompt_bsz}
n_resp_per_prompt=${N_RESP_PER_PROMPT:-5}
train_prompt_mini_bsz=32 # Must be <= train_batch_size
train_prompt_micro_bsz_per_gpu=1

# Ray configuration
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}

# Cluster configuration
NNODES=${NUM_NODES:-4}
GPUS_PER_NODE=${NUM_GPU_PER_NODE:-4}

# Model and data paths
MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen3-0.6B"}
RAY_DATA_HOME=${RAY_DATA_HOME:-"/fsx/verl"}

# Data files - using GSM8K dataset
TRAIN_FILE="${RAY_DATA_HOME}/data/gsm8k/train.parquet"
TEST_FILE="${RAY_DATA_HOME}/data/gsm8k/test.parquet"

# S3 checkpoint configuration
S3_CHECKPOINT_BASE=${S3_CHECKPOINT_BASE:-"s3://s3-bucket-example"}
# Performance parameters
gen_tp=2
log_prob_micro_bsz_per_gpu=32
gpu_memory_utilization=0.6

# Memory optimization
param_offload=False
optimizer_offload=False
ref_param_offload=True

# Print configuration for verification
echo "=== MTC GRPO Training Configuration ==="
echo "Project: ${project_name}"
echo "Experiment: ${exp_name}"
echo "Model: ${MODEL_PATH}"
echo "Nodes: ${NNODES}"
echo "GPUs per node: ${GPUS_PER_NODE}"
echo "Total GPUs: $((NNODES * GPUS_PER_NODE))"
echo "Data home: ${RAY_DATA_HOME}"
echo "S3 Checkpoints: ${S3_CHECKPOINT_BASE}"
echo "Ray address: ${RAY_ADDRESS}"
echo "=================================="

# Submit Ray job
ray job submit --no-wait \
--address "${RAY_ADDRESS}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=${adv_estimator} \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=question \
data.train_batch_size=${train_prompt_bsz} \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.filter_overlong_prompts=${filter_overlong_prompts} \
data.truncation=${truncation} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_prompt_micro_bsz_per_gpu} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.kl_loss_type=${kl_loss_type} \
actor_rollout_ref.actor.entropy_coeff=${entropy_coeff} \
actor_rollout_ref.actor.fsdp_config.param_offload=${param_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${optimizer_offload} \
actor_rollout_ref.actor.checkpoint.s3_base_path=${S3_CHECKPOINT_BASE} \
actor_rollout_ref.actor.checkpoint.ckpt_namespace=mtc-grpo-$(date +%s) \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${log_prob_micro_bsz_per_gpu} \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${log_prob_micro_bsz_per_gpu} \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_param_offload} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
trainer.critic_warmup=0 \
trainer.logger='["console"]' \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=${GPUS_PER_NODE} \
trainer.nnodes=${NNODES} \
trainer.save_freq=1 \
trainer.test_freq=2 \
trainer.total_epochs=5 \
trainer.s3_base_path=${S3_CHECKPOINT_BASE}

echo ""
echo "Job submitted! Check status with: ray job status <job-id>"
echo "Or view logs with: ray job logs <job-id> --follow"
Loading