Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ verifiers = [
"verifiers",
"openai",
]
gem = [
"gem-llm",
]

[build-system]
requires = ["hatchling"]
Expand Down
68 changes: 68 additions & 0 deletions tinker_cookbook/recipes/gem_rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
## RL Training with Tinker / Tinker Cookbook + GEM

[GEM](https://github.com/axon-rl/gem) is a library providing Gym-like environments for LLM RL. This document demonstrates how one can leverage Tinker (or Tinker Cookbook) to train RL agents on GEM environments.

In this guide, we present two implementations leveraging Tinker:
1. Standard Training with Cookbook Utilities:
We introduce an adapter layer that enables GEM to work seamlessly with the Tinker Cookbook’s training loop.
2. Custom Training via Low-level API:
For advanced use cases, we provide a basic RL implementation that interacts directly with Tinker’s low-level API, enabling maximum customization.


### Getting Started

Assuming you have already installed `tinker` and `tinker-cookbook`, you can install `gem` from PyPI and prepare your tinker API key:
```bash
# install gem
pip install -U gem-llm

# export tinker api key
export TINKER_API_KEY=<your-tinker-api-key>
```

### Using Tinker Cookbook
> Find the training curves in [[📈 WandB Logs](https://wandb.ai/cameron_chen/gem-tinker-cookbook)].

To train agents with the Tinker Cookbook, we implemented an adaption layer (`tinker_cookbook_adapter.py`) that exposes a GEM-compatible interface. The entry point is `tinker_cookbook_train.py`.

**Example 1: Training on Math Environments**

```bash
python tinker_cookbook_train.py env_id=math:Math12K groups_per_batch=64 group_size=16 learning_rate=2e-5 max_tokens=2048 model_name=Qwen/Qwen3-8B-Base env_kwargs_json='{"use_mp": false}'
```

Note:
- You may train on different math environments by simply changing the `env_id` argument.
- `env_kwargs_json='{"use_mp": false}'` is only required for math environments.

**Example 2: Training on Reasoning Gym**

```bash
python tinker_cookbook_train.py env_id=rg:simple_equations groups_per_batch=64 group_size=8 learning_rate=2e-5 max_tokens=2048 model_name=Qwen/Qwen3-8B-Base
```

### Using Tinker
> Find the training curves in [[📈 WandB Logs](https://wandb.ai/lkevinzc/gem-tinker_train)]

We can also directly integrate GEM with Tinker, bypassing the abstraction layer defined by tinker-cookbook. The entry point is `tinker_train.py`, which implements REINFORCE with Return Batch Normalization (introduced in our [paper](https://arxiv.org/pdf/2510.01051)).

**Example 1: Training on Math Environments**

```bash
python tinker_train.py model_name=Qwen/Qwen3-8B-Base env_id=math:DeepScaleR40K num_env=128 max_tokens=8192
```
* In this example we observe the increasing response length phenomenon (as in DeepSeek-R1-Zero) with LoRA training with rank 32!

**Example 2: Training on Math with Python Tool**

```bash
python tinker_train.py env_id=math:Math12K num_env=128 max_tokens=2048 template=no model_name=meta-llama/Llama-3.1-8B-Instruct env_wrappers=python_tool_no_int_reward_last_line_error,concat_chat gamma=1
```

**Example 3: Training on Multi-turn Language Games**

```bash
python tinker_train.py model_name=Qwen/Qwen3-8B-Base num_env=64 env_id=game:Sudoku-v0-easy max_tokens=1024 template=qwen3_game model_name=Qwen/Qwen3-8B-Base env_wrappers=concat
```

* Training is currently slow because in each turn the agent needs to think then act, and this process repeats sequentially. We can optimize it via prefix sharing, async sample/learn, etc.
180 changes: 180 additions & 0 deletions tinker_cookbook/recipes/gem_rl/tinker_cookbook_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

import json
import time
from dataclasses import dataclass
from typing import Any, Sequence

import chz
import tinker
from tinker_cookbook import renderers
from tinker_cookbook.completers import StopCondition
from tinker_cookbook.rl.types import (
Action,
Env,
EnvGroupBuilder,
Metrics,
Observation,
RLDataset,
RLDatasetBuilder,
StepResult,
)
from tinker_cookbook.tokenizer_utils import get_tokenizer

import gem


def apply_general_prompt(init_obs: str) -> str:
return (
f"Question: {init_obs}"
"\nPlease reason step by step, and put your final answer within \\boxed{}."
)


def apply_no_template(init_obs: str) -> str:
return init_obs


PROMPT_FACTORY = {"no": apply_no_template, "general": apply_general_prompt}


class GemTinkerEnv(Env):
def __init__(
self,
env_gem: gem.Env,
init_obs: Observation,
renderer: renderers.Renderer,
prompt_type: str = "general",
convo_prefix: list[renderers.Message] | None = None,
):
self.env_gem = env_gem
self.init_obs = init_obs
self.renderer = renderer
self.prompt_type = prompt_type
self.convo: list[renderers.Message] = list(convo_prefix or [])

@property
def stop_condition(self):
return self.renderer.get_stop_sequences()

async def initial_observation(self) -> tuple[Observation, StopCondition]:
self.convo.append(
{"role": "user", "content": PROMPT_FACTORY[self.prompt_type](str(self.init_obs))}
)
return self.renderer.build_generation_prompt(self.convo), self.stop_condition

async def step(self, action: Action) -> StepResult:
message, parse_success = self.renderer.parse_response(action)
text = message.get("content", "") if parse_success else ""
next_obs, reward, terminated, truncated, info = self.env_gem.step(text)
reward = float(reward)

metrics: Metrics = {}
for k, v in (info or {}).items():
if isinstance(v, (int, float)):
metrics[k] = v

done = terminated or truncated
if done:
next_ob = tinker.ModelInput.empty()
next_stop = self.stop_condition
else:
self.convo.append({"role": "assistant", "content": text})
self.convo.append({"role": "user", "content": next_obs})
next_ob = self.renderer.build_generation_prompt(self.convo)
next_stop = self.stop_condition

return StepResult(
reward=reward,
episode_done=done,
next_observation=next_ob,
next_stop_condition=next_stop,
metrics=metrics,
)


@dataclass(frozen=True)
class GemEnvGroupBuilder(EnvGroupBuilder):
pool: list[gem.Env]
renderer: renderers.Renderer
prompt_type: str
group_size: int
env_id: str
convo_prefix: list[renderers.Message] | None = None
group_index: int = -1 # which env in the pool to use for this

async def make_envs(self) -> Sequence[Env]:
assert 0 <= self.group_index < len(self.pool), (
"group_index should be within the range of the pool size"
)
assert hasattr(self.pool[0], "get_state"), (
"env must support get_state() to run in GemEnvGroupBuilder"
)

# duplicate the env for the group size
env_parent = self.pool[self.group_index]
init_obs, _ = env_parent.reset()
return [
GemTinkerEnv(
env_parent.spawn(same_state=True),
init_obs,
self.renderer,
self.prompt_type,
self.convo_prefix,
)
for _ in range(self.group_size)
]

def logging_tags(self) -> list[str]:
return self.env_id.split(":")


class GemDataset(RLDataset):
def __init__(self, builder_config: dict[str, Any], groups_per_batch: int, n_batches: int):
pool = builder_config["pool"]
assert len(set(env.seed for env in pool)) == len(pool), (
"All envs in the pool must have different seeds."
)

self.builder_config = builder_config
self.groups_per_batch = groups_per_batch
self.n_batches = n_batches

def get_batch(self, index: int) -> list[EnvGroupBuilder]:
return [
GemEnvGroupBuilder(group_index=i, **self.builder_config)
for i in range(self.groups_per_batch)
]

def __len__(self) -> int:
return self.n_batches


@chz.chz
class GemDatasetBuilder(RLDatasetBuilder):
env_id: str
model_name_for_tokenizer: str
renderer_name: str
prompt_type: str = "general"
group_size: int
groups_per_batch: int
n_batches: int = 100
env_kwargs_json: str | None = None
convo_prefix: list[renderers.Message] | None = None

async def __call__(self) -> tuple[RLDataset, RLDataset | None]:
env_kwargs = json.loads(self.env_kwargs_json) if self.env_kwargs_json else {}
env_parent = gem.make(self.env_id, seed=int(time.time_ns()), **env_kwargs)
seed_parent = env_parent.seed # type: ignore
pool = [env_parent.spawn(seed=i + seed_parent + 1) for i in range(self.groups_per_batch)]
tokenizer = get_tokenizer(self.model_name_for_tokenizer)
renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)
builder_config = {
"pool": pool,
"renderer": renderer,
"prompt_type": self.prompt_type,
"group_size": self.group_size,
"env_id": self.env_id,
"convo_prefix": self.convo_prefix,
}
return GemDataset(builder_config, self.groups_per_batch, self.n_batches), None
109 changes: 109 additions & 0 deletions tinker_cookbook/recipes/gem_rl/tinker_cookbook_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""GEM ❤️ Tinker.

A script for training agents on GEM environments with Tinker-Cookbook.
"""

import asyncio
from datetime import datetime

import chz
from tinker_cookbook import cli_utils, model_info
from tinker_cookbook.rl.train import AsyncConfig, Config, main

from .tinker_cookbook_adapter import GemDatasetBuilder


@chz.chz
class CLIConfig:
# Model
model_name: str = "meta-llama/Llama-3.2-1B"
lora_rank: int = 32
renderer_name: str | None = None
load_checkpoint_path: str | None = None

# GEM env
env_id: str = "game:GuessTheNumber-v0"
env_kwargs_json: str | None = None # e.g., '{"max_turns": 4}'

# Training
prompt_type: str = "general"
group_size: int = 4
groups_per_batch: int = 64
n_batches: int = 200
learning_rate: float = 1e-5
max_tokens: int = 256
kl_penalty_coef: float = 0.0
num_substeps: int = 1

# Logging
log_path: str | None = None
wandb_project: str | None = None
wandb_name: str | None = None
compute_post_kl: bool = False
eval_every: int = 0
save_every: int = 25

# Service
base_url: str | None = None

behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask"
max_steps_off_policy: int | None = None


async def cli_main(cli_config: CLIConfig):
renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(
cli_config.model_name
)
model_name_sanitized = cli_config.model_name.replace("/", "-")
run_name = (
f"gem-{cli_config.env_id.replace(':', '_')}-{model_name_sanitized}-r{cli_config.lora_rank}-"
f"lr{cli_config.learning_rate}-g{cli_config.group_size}-b{cli_config.groups_per_batch}-"
f"{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
)
log_path = cli_config.log_path or f"./outputs-tinker/gem/{run_name}"
wandb_name = cli_config.wandb_name or run_name

dataset_builder = GemDatasetBuilder(
env_id=cli_config.env_id,
model_name_for_tokenizer=cli_config.model_name,
renderer_name=renderer_name,
prompt_type=cli_config.prompt_type,
group_size=cli_config.group_size,
groups_per_batch=cli_config.groups_per_batch,
n_batches=cli_config.n_batches,
env_kwargs_json=cli_config.env_kwargs_json,
)

cfg = Config(
learning_rate=cli_config.learning_rate,
dataset_builder=dataset_builder,
model_name=cli_config.model_name,
lora_rank=cli_config.lora_rank,
max_tokens=cli_config.max_tokens,
wandb_project=cli_config.wandb_project,
wandb_name=wandb_name,
log_path=log_path,
base_url=cli_config.base_url,
load_checkpoint_path=cli_config.load_checkpoint_path,
compute_post_kl=cli_config.compute_post_kl,
kl_penalty_coef=cli_config.kl_penalty_coef,
num_substeps=cli_config.num_substeps,
eval_every=cli_config.eval_every,
save_every=cli_config.save_every,
async_config=(
AsyncConfig(
max_steps_off_policy=cli_config.max_steps_off_policy,
groups_per_batch=cli_config.groups_per_batch,
)
if cli_config.max_steps_off_policy is not None
else None
),
)

cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)
await main(cfg)


if __name__ == "__main__":
cfg = chz.entrypoint(CLIConfig)
asyncio.run(cli_main(cfg))
Loading
Loading