diff --git a/pyproject.toml b/pyproject.toml index b7601b7..97e281b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,9 @@ verifiers = [ "verifiers", "openai", ] +gem = [ + "gem-llm", +] [build-system] requires = ["hatchling"] diff --git a/tinker_cookbook/recipes/gem_rl/README.md b/tinker_cookbook/recipes/gem_rl/README.md new file mode 100644 index 0000000..a3faac7 --- /dev/null +++ b/tinker_cookbook/recipes/gem_rl/README.md @@ -0,0 +1,68 @@ +## RL Training with Tinker / Tinker Cookbook + GEM + +[GEM](https://github.com/axon-rl/gem) is a library providing Gym-like environments for LLM RL. This document demonstrates how one can leverage Tinker (or Tinker Cookbook) to train RL agents on GEM environments. + +In this guide, we present two implementations leveraging Tinker: +1. Standard Training with Cookbook Utilities: + We introduce an adapter layer that enables GEM to work seamlessly with the Tinker Cookbook’s training loop. +2. Custom Training via Low-level API: + For advanced use cases, we provide a basic RL implementation that interacts directly with Tinker’s low-level API, enabling maximum customization. + + +### Getting Started + +Assuming you have already installed `tinker` and `tinker-cookbook`, you can install `gem` from PyPI and prepare your tinker API key: +```bash +# install gem +pip install -U gem-llm + +# export tinker api key +export TINKER_API_KEY= +``` + +### Using Tinker Cookbook +> Find the training curves in [[📈 WandB Logs](https://wandb.ai/cameron_chen/gem-tinker-cookbook)]. + +To train agents with the Tinker Cookbook, we implemented an adaption layer (`tinker_cookbook_adapter.py`) that exposes a GEM-compatible interface. The entry point is `tinker_cookbook_train.py`. + +**Example 1: Training on Math Environments** + +```bash +python tinker_cookbook_train.py env_id=math:Math12K groups_per_batch=64 group_size=16 learning_rate=2e-5 max_tokens=2048 model_name=Qwen/Qwen3-8B-Base env_kwargs_json='{"use_mp": false}' +``` + +Note: +- You may train on different math environments by simply changing the `env_id` argument. +- `env_kwargs_json='{"use_mp": false}'` is only required for math environments. + +**Example 2: Training on Reasoning Gym** + +```bash +python tinker_cookbook_train.py env_id=rg:simple_equations groups_per_batch=64 group_size=8 learning_rate=2e-5 max_tokens=2048 model_name=Qwen/Qwen3-8B-Base +``` + +### Using Tinker +> Find the training curves in [[📈 WandB Logs](https://wandb.ai/lkevinzc/gem-tinker_train)] + +We can also directly integrate GEM with Tinker, bypassing the abstraction layer defined by tinker-cookbook. The entry point is `tinker_train.py`, which implements REINFORCE with Return Batch Normalization (introduced in our [paper](https://arxiv.org/pdf/2510.01051)). + +**Example 1: Training on Math Environments** + +```bash +python tinker_train.py model_name=Qwen/Qwen3-8B-Base env_id=math:DeepScaleR40K num_env=128 max_tokens=8192 +``` +* In this example we observe the increasing response length phenomenon (as in DeepSeek-R1-Zero) with LoRA training with rank 32! + +**Example 2: Training on Math with Python Tool** + +```bash +python tinker_train.py env_id=math:Math12K num_env=128 max_tokens=2048 template=no model_name=meta-llama/Llama-3.1-8B-Instruct env_wrappers=python_tool_no_int_reward_last_line_error,concat_chat gamma=1 +``` + +**Example 3: Training on Multi-turn Language Games** + +```bash +python tinker_train.py model_name=Qwen/Qwen3-8B-Base num_env=64 env_id=game:Sudoku-v0-easy max_tokens=1024 template=qwen3_game model_name=Qwen/Qwen3-8B-Base env_wrappers=concat +``` + +* Training is currently slow because in each turn the agent needs to think then act, and this process repeats sequentially. We can optimize it via prefix sharing, async sample/learn, etc. diff --git a/tinker_cookbook/recipes/gem_rl/tinker_cookbook_adapter.py b/tinker_cookbook/recipes/gem_rl/tinker_cookbook_adapter.py new file mode 100644 index 0000000..5a50082 --- /dev/null +++ b/tinker_cookbook/recipes/gem_rl/tinker_cookbook_adapter.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import json +import time +from dataclasses import dataclass +from typing import Any, Sequence + +import chz +import tinker +from tinker_cookbook import renderers +from tinker_cookbook.completers import StopCondition +from tinker_cookbook.rl.types import ( + Action, + Env, + EnvGroupBuilder, + Metrics, + Observation, + RLDataset, + RLDatasetBuilder, + StepResult, +) +from tinker_cookbook.tokenizer_utils import get_tokenizer + +import gem + + +def apply_general_prompt(init_obs: str) -> str: + return ( + f"Question: {init_obs}" + "\nPlease reason step by step, and put your final answer within \\boxed{}." + ) + + +def apply_no_template(init_obs: str) -> str: + return init_obs + + +PROMPT_FACTORY = {"no": apply_no_template, "general": apply_general_prompt} + + +class GemTinkerEnv(Env): + def __init__( + self, + env_gem: gem.Env, + init_obs: Observation, + renderer: renderers.Renderer, + prompt_type: str = "general", + convo_prefix: list[renderers.Message] | None = None, + ): + self.env_gem = env_gem + self.init_obs = init_obs + self.renderer = renderer + self.prompt_type = prompt_type + self.convo: list[renderers.Message] = list(convo_prefix or []) + + @property + def stop_condition(self): + return self.renderer.get_stop_sequences() + + async def initial_observation(self) -> tuple[Observation, StopCondition]: + self.convo.append( + {"role": "user", "content": PROMPT_FACTORY[self.prompt_type](str(self.init_obs))} + ) + return self.renderer.build_generation_prompt(self.convo), self.stop_condition + + async def step(self, action: Action) -> StepResult: + message, parse_success = self.renderer.parse_response(action) + text = message.get("content", "") if parse_success else "" + next_obs, reward, terminated, truncated, info = self.env_gem.step(text) + reward = float(reward) + + metrics: Metrics = {} + for k, v in (info or {}).items(): + if isinstance(v, (int, float)): + metrics[k] = v + + done = terminated or truncated + if done: + next_ob = tinker.ModelInput.empty() + next_stop = self.stop_condition + else: + self.convo.append({"role": "assistant", "content": text}) + self.convo.append({"role": "user", "content": next_obs}) + next_ob = self.renderer.build_generation_prompt(self.convo) + next_stop = self.stop_condition + + return StepResult( + reward=reward, + episode_done=done, + next_observation=next_ob, + next_stop_condition=next_stop, + metrics=metrics, + ) + + +@dataclass(frozen=True) +class GemEnvGroupBuilder(EnvGroupBuilder): + pool: list[gem.Env] + renderer: renderers.Renderer + prompt_type: str + group_size: int + env_id: str + convo_prefix: list[renderers.Message] | None = None + group_index: int = -1 # which env in the pool to use for this + + async def make_envs(self) -> Sequence[Env]: + assert 0 <= self.group_index < len(self.pool), ( + "group_index should be within the range of the pool size" + ) + assert hasattr(self.pool[0], "get_state"), ( + "env must support get_state() to run in GemEnvGroupBuilder" + ) + + # duplicate the env for the group size + env_parent = self.pool[self.group_index] + init_obs, _ = env_parent.reset() + return [ + GemTinkerEnv( + env_parent.spawn(same_state=True), + init_obs, + self.renderer, + self.prompt_type, + self.convo_prefix, + ) + for _ in range(self.group_size) + ] + + def logging_tags(self) -> list[str]: + return self.env_id.split(":") + + +class GemDataset(RLDataset): + def __init__(self, builder_config: dict[str, Any], groups_per_batch: int, n_batches: int): + pool = builder_config["pool"] + assert len(set(env.seed for env in pool)) == len(pool), ( + "All envs in the pool must have different seeds." + ) + + self.builder_config = builder_config + self.groups_per_batch = groups_per_batch + self.n_batches = n_batches + + def get_batch(self, index: int) -> list[EnvGroupBuilder]: + return [ + GemEnvGroupBuilder(group_index=i, **self.builder_config) + for i in range(self.groups_per_batch) + ] + + def __len__(self) -> int: + return self.n_batches + + +@chz.chz +class GemDatasetBuilder(RLDatasetBuilder): + env_id: str + model_name_for_tokenizer: str + renderer_name: str + prompt_type: str = "general" + group_size: int + groups_per_batch: int + n_batches: int = 100 + env_kwargs_json: str | None = None + convo_prefix: list[renderers.Message] | None = None + + async def __call__(self) -> tuple[RLDataset, RLDataset | None]: + env_kwargs = json.loads(self.env_kwargs_json) if self.env_kwargs_json else {} + env_parent = gem.make(self.env_id, seed=int(time.time_ns()), **env_kwargs) + seed_parent = env_parent.seed # type: ignore + pool = [env_parent.spawn(seed=i + seed_parent + 1) for i in range(self.groups_per_batch)] + tokenizer = get_tokenizer(self.model_name_for_tokenizer) + renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer) + builder_config = { + "pool": pool, + "renderer": renderer, + "prompt_type": self.prompt_type, + "group_size": self.group_size, + "env_id": self.env_id, + "convo_prefix": self.convo_prefix, + } + return GemDataset(builder_config, self.groups_per_batch, self.n_batches), None diff --git a/tinker_cookbook/recipes/gem_rl/tinker_cookbook_train.py b/tinker_cookbook/recipes/gem_rl/tinker_cookbook_train.py new file mode 100644 index 0000000..1b9b630 --- /dev/null +++ b/tinker_cookbook/recipes/gem_rl/tinker_cookbook_train.py @@ -0,0 +1,109 @@ +"""GEM ❤️ Tinker. + +A script for training agents on GEM environments with Tinker-Cookbook. +""" + +import asyncio +from datetime import datetime + +import chz +from tinker_cookbook import cli_utils, model_info +from tinker_cookbook.rl.train import AsyncConfig, Config, main + +from .tinker_cookbook_adapter import GemDatasetBuilder + + +@chz.chz +class CLIConfig: + # Model + model_name: str = "meta-llama/Llama-3.2-1B" + lora_rank: int = 32 + renderer_name: str | None = None + load_checkpoint_path: str | None = None + + # GEM env + env_id: str = "game:GuessTheNumber-v0" + env_kwargs_json: str | None = None # e.g., '{"max_turns": 4}' + + # Training + prompt_type: str = "general" + group_size: int = 4 + groups_per_batch: int = 64 + n_batches: int = 200 + learning_rate: float = 1e-5 + max_tokens: int = 256 + kl_penalty_coef: float = 0.0 + num_substeps: int = 1 + + # Logging + log_path: str | None = None + wandb_project: str | None = None + wandb_name: str | None = None + compute_post_kl: bool = False + eval_every: int = 0 + save_every: int = 25 + + # Service + base_url: str | None = None + + behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask" + max_steps_off_policy: int | None = None + + +async def cli_main(cli_config: CLIConfig): + renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name( + cli_config.model_name + ) + model_name_sanitized = cli_config.model_name.replace("/", "-") + run_name = ( + f"gem-{cli_config.env_id.replace(':', '_')}-{model_name_sanitized}-r{cli_config.lora_rank}-" + f"lr{cli_config.learning_rate}-g{cli_config.group_size}-b{cli_config.groups_per_batch}-" + f"{datetime.now().strftime('%Y-%m-%d-%H-%M')}" + ) + log_path = cli_config.log_path or f"./outputs-tinker/gem/{run_name}" + wandb_name = cli_config.wandb_name or run_name + + dataset_builder = GemDatasetBuilder( + env_id=cli_config.env_id, + model_name_for_tokenizer=cli_config.model_name, + renderer_name=renderer_name, + prompt_type=cli_config.prompt_type, + group_size=cli_config.group_size, + groups_per_batch=cli_config.groups_per_batch, + n_batches=cli_config.n_batches, + env_kwargs_json=cli_config.env_kwargs_json, + ) + + cfg = Config( + learning_rate=cli_config.learning_rate, + dataset_builder=dataset_builder, + model_name=cli_config.model_name, + lora_rank=cli_config.lora_rank, + max_tokens=cli_config.max_tokens, + wandb_project=cli_config.wandb_project, + wandb_name=wandb_name, + log_path=log_path, + base_url=cli_config.base_url, + load_checkpoint_path=cli_config.load_checkpoint_path, + compute_post_kl=cli_config.compute_post_kl, + kl_penalty_coef=cli_config.kl_penalty_coef, + num_substeps=cli_config.num_substeps, + eval_every=cli_config.eval_every, + save_every=cli_config.save_every, + async_config=( + AsyncConfig( + max_steps_off_policy=cli_config.max_steps_off_policy, + groups_per_batch=cli_config.groups_per_batch, + ) + if cli_config.max_steps_off_policy is not None + else None + ), + ) + + cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) + await main(cfg) + + +if __name__ == "__main__": + cfg = chz.entrypoint(CLIConfig) + asyncio.run(cli_main(cfg)) diff --git a/tinker_cookbook/recipes/gem_rl/tinker_train.py b/tinker_cookbook/recipes/gem_rl/tinker_train.py new file mode 100644 index 0000000..cd33aff --- /dev/null +++ b/tinker_cookbook/recipes/gem_rl/tinker_train.py @@ -0,0 +1,485 @@ +"""GEM ❤️ Tinker. + +A basic RL implementation to train agents on GEM environments using Tinker backends. +""" + +import asyncio +import json +import logging +import os +import pprint +import time +from datetime import datetime +from typing import Any, List, Literal, Dict + +import chz +import numpy as np +import tinker +import torch +import wandb +from termcolor import colored +from tinker import types +from tinker.types.tensor_data import TensorData +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.tokenization_utils import PreTrainedTokenizer + +import gem +from gem.wrappers.wrapper_factory import get_wrapper_fns + +logger = logging.getLogger(__name__) +logging.getLogger("httpx").setLevel(logging.WARN) + + +@chz.chz +class Config: + model_name: str = "Qwen/Qwen3-8B-Base" + batch_size: int = 128 + learning_rate: float = 4e-5 + lora_rank: int = 32 + max_tokens: int = 2048 + seed: int = 0 + max_steps: int = 200 + save_every: int = -1 + + env_id: str = "rg:simple_equations" + num_env: int = 4 # number of parallel environments + env_wrappers: str = "concat" # wrappers are typically used to concat chat history, etc. + template: Literal["qwen3_general", "qwen3_game", "no"] = "qwen3_general" + + gamma: float = 0.9 + use_rebn: bool = True + loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling" + + eval_env_id: str = "eval:AIME24" + eval_max_tokens: int = 8192 + eval_n: int = 32 + eval_temperature: float = 0.6 + eval_top_p: float = 0.95 + eval_every: int = -1 + + wandb_entity: str | None = None + wandb_project: str | None = None + wandb_name: str | None = None + log_dir: str | None = None + + +# Define a lightweight renderer following tinker's renderer logics +def apply_qwen3_game_template(observation: str) -> str: + return ( + f"<|im_start|>user\nYou are playing language games. Make valid actions to win.\nObservation: {observation}" + "\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +def apply_qwen3_game_no_think_template(observation: str) -> str: + return ( + f"<|im_start|>user\nYou are playing language games. Make valid actions to win.\nObservation: {observation}" + "\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +def apply_qwen3_general_template(question: str) -> str: + return ( + f"<|im_start|>user\nQuestion: {question}" + "\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +def apply_no_template(observation: str) -> str: + return observation + + +TEMPLATE_FACTORY = { + "qwen3_game": apply_qwen3_game_template, + "qwen3_general": apply_qwen3_general_template, + "no": apply_no_template, +} + + +def get_tokenizer(model_name: str) -> PreTrainedTokenizer: + # Avoid gating of Llama 3 models: + if model_name.startswith("meta-llama/Llama-3"): + model_name = "baseten/Meta-Llama-3-tokenizer" + return AutoTokenizer.from_pretrained(model_name, use_fast=True) + + +async def save_checkpoint_async( + training_client: tinker.TrainingClient, + name: str, + log_path: str, + loop_state: dict[str, Any], + kind: Literal["state", "sampler", "both"] = "state", +) -> dict[str, str]: + """Save model checkpoint. + Args: + training_client: Training client to save from + name: Name for the checkpoint + log_path: Path to the log directory, where we can find checkpoints.jsonl file + Returns: + Path to the saved checkpoint + """ + futures = {} + if kind in ["state", "both"]: + futures["state"] = await training_client.save_state_async(name) + if kind in ["sampler", "both"]: + futures["sampler"] = await training_client.save_weights_for_sampler_async(name) + + results = {k: await v.result_async() for k, v in futures.items()} + paths = {k + "_path": v.path for k, v in results.items()} + logger.info(f"Saved checkpoints: {paths}") + full_dict = {"name": name, **loop_state, **paths} + with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f: + f.write(json.dumps(full_dict) + "\n") + + return paths + + +def prepare_training_datums( + transitions: List[dict], advantage_scaling: float = 1.0 +) -> List[types.Datum]: + training_datums = [] + for transition in transitions: + ob_len_m1 = len(transition["obs_tokens"]) - 1 # -1 due to shifting + tokens = transition["obs_tokens"] + transition["act_tokens"] + + input_tokens = tokens[:-1] + target_tokens = tokens[1:] + all_logprobs = [0.0] * ob_len_m1 + transition["act_logprobs"] + all_advantages = [0.0] * ob_len_m1 + [transition["return"]] * ( + len(input_tokens) - ob_len_m1 + ) + assert ( + len(input_tokens) == len(target_tokens) == len(all_logprobs) == len(all_advantages) + ), ( + f"len(input_tokens): {len(input_tokens)}, len(target_tokens): {len(target_tokens)}, len(all_logprobs): {len(all_logprobs)}, len(all_advantages): {len(all_advantages)}" + ) + + datum = types.Datum( + model_input=types.ModelInput.from_ints(tokens=input_tokens), + loss_fn_inputs={ + "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)), + "logprobs": TensorData.from_torch(torch.tensor(all_logprobs)), + "advantages": TensorData.from_torch( + torch.tensor(all_advantages) * advantage_scaling + ), + }, + ) + training_datums.append(datum) + return training_datums + + +async def collect_episode( + sampling_client, sampling_params, env: gem.Env, config, tokenizer, reset_idx=None +): + transitions = [] + kwargs = {"idx": reset_idx} if reset_idx else {} + obs, _ = env.reset(**kwargs) + while True: + # 1) prepare observation + obs = TEMPLATE_FACTORY[config.template](obs) # templated string + obs_tokens = tokenizer.encode(obs, add_special_tokens=False) + + # 2) sample an action from the policy + try: + sample_result = await sampling_client.sample_async( + prompt=types.ModelInput.from_ints(tokens=obs_tokens), + num_samples=1, + sampling_params=sampling_params, + ) + except Exception: + transitions = [] + break + sampled_tokens = sample_result.sequences[0].tokens + sampled_logprobs = sample_result.sequences[0].logprobs + action = tokenizer.decode(sampled_tokens) + unfinished = sample_result.sequences[0].stop_reason == "length" + + # 3) step the environment + next_obs, reward, terminated, truncated, info = env.step(action) + done = terminated | truncated + obs = next_obs + + # 4) save into buffer + transitions.append( + { + "obs_tokens": obs_tokens, + "act_tokens": sampled_tokens, + "act_logprobs": sampled_logprobs, + "obs_text": tokenizer.decode(obs_tokens), + "act_text": tokenizer.decode(sampled_tokens), + "reward": reward, + "done": done, + "unfinished": unfinished, + "info": info, + } + ) + + if done: + break + return transitions + + +async def main(config: Config): + # Setup logging + wandb_name = config.wandb_name or config.model_name.split("/")[-1] + f"_{config.env_id}" + wandb_name += "_" + datetime.now().strftime("%m%dT%H:%M:%S") + save_path = os.path.join("./tinker_output", wandb_name) + os.makedirs(save_path, exist_ok=True) + + wandb.init( + entity=config.wandb_entity, + project=config.wandb_project, + config=chz.asdict(config), + dir=str(config.log_dir) if config.log_dir else None, + name=wandb_name, + ) + + # Get tokenizer + tokenizer = get_tokenizer(config.model_name) + + # Setup environment for training + wrappers = get_wrapper_fns(config.env_wrappers, tokenizer=tokenizer) + # init one env first, check if it has dataset; if so we avoid load from HF multiple times + # by directly providing dataset when creating the env. (we can also use the gem.Env.spawn api). + envs = [gem.make(config.env_id, seed=int(time.time_ns()), use_mp=False)] + for i in range(config.num_env - 1): + dataset = envs[0].dataset if hasattr(envs[0], "dataset") else None # type: ignore + envs.append( + gem.make( + config.env_id, + seed=int(time.time_ns()) * i, + dataset=dataset, + use_mp=False, + ) + ) + for i in range(len(envs)): + for wrapper in wrappers: + envs[i] = wrapper(envs[i]) + + # Setup environment for in-distribution eval + eval_envs = [gem.make(config.eval_env_id, seed=int(time.time_ns()), use_mp=False, eval=True)] + skip_eval = not hasattr(envs[0], "dataset") + eval_data_size = 0 + if not skip_eval: + eval_data_size = len(eval_envs[0].dataset) # type: ignore + for i in range((config.eval_n * eval_data_size) - 1): + eval_envs.append( + gem.make( + config.eval_env_id, + seed=int(time.time_ns()) * i, + dataset=eval_envs[0].dataset, # type: ignore + use_mp=False, + eval=True, + ) + ) + for i in range(len(eval_envs)): + for wrapper in wrappers: + eval_envs[i] = wrapper(eval_envs[i]) + + # Setup agent (tinker training client) + service_client = tinker.ServiceClient() + training_client = await service_client.create_lora_training_client_async( + base_model=config.model_name, rank=config.lora_rank + ) + sampling_params = tinker.types.SamplingParams( + max_tokens=config.max_tokens, + ) + eval_sampling_params = tinker.types.SamplingParams( + max_tokens=config.eval_max_tokens, + temperature=config.eval_temperature, + top_p=config.eval_top_p, + ) + adam_params = types.AdamParams( + learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8 + ) + + # Start agent-environment loop (Algo: https://arxiv.org/pdf/2510.01051#page=15.10): + for policy_iteration_step in range(config.max_steps): + print("=" * 10 + f" Step {policy_iteration_step} " + "=" * 10) + metrics: Dict[str, Any] = {"step": policy_iteration_step} + + # create sampler + sampling_path = ( + training_client.save_weights_for_sampler(name=f"{policy_iteration_step:06d}") + .result() + .path + ) + sampling_client = service_client.create_sampling_client(model_path=sampling_path) + + # save model + if ( + config.save_every > 0 + and policy_iteration_step > 0 + and policy_iteration_step % config.save_every == 0 + ): + await save_checkpoint_async( + training_client, + f"{policy_iteration_step:06d}", + log_path=save_path, + kind="state", + loop_state={"policy_iteration_step": policy_iteration_step}, + ) + + # eval model + if config.eval_every > 0 and policy_iteration_step % config.eval_every == 0: + if skip_eval: + print("⚠️ Evaluation environment doesn't have .dataset attribute, skipping eval.") + else: + print(f"🔎 Start evaluation at step {policy_iteration_step}") + st = time.time() + eval_episodes = await asyncio.gather( + *[ + collect_episode( + sampling_client, + eval_sampling_params, + env, + config, + tokenizer, + i, + ) + for env, i in zip(eval_envs, list(range(eval_data_size)) * config.eval_n) + ] + ) + eval_episodes = [x for x in eval_episodes if x != []] + for drop_key in ["obs_tokens", "act_tokens", "act_logprobs"]: + _ = [t.pop(drop_key) for ep in eval_episodes for t in ep] + json.dump( + eval_episodes, + open( + os.path.join(save_path, f"eval-{policy_iteration_step:06d}.json"), + "w", + ), + indent=4, + ) + metrics["time/eval"] = time.time() - st + metrics["eval/episode_return"] = np.mean( + [ + sum(transition["reward"] for transition in episode) + for episode in eval_episodes + ] + ) + metrics["eval/support"] = len(eval_episodes) + + # collect episodes with parallel environments + print(f"🎲 Start collecting episodes at step {policy_iteration_step}") + st = time.time() + episodes_buffer = [] + while True: + batch_episodes = await asyncio.gather( + *[ + collect_episode(sampling_client, sampling_params, env, config, tokenizer) + for env in envs + ] + ) + batch_episodes = [x for x in batch_episodes if x != []] + episodes_buffer.extend(batch_episodes) + if sum([len(ep) for ep in episodes_buffer]) >= config.batch_size: + break + metrics["time/sample"] = time.time() - st + metrics["sampler/unfinished_rollout"] = np.mean( + [ + np.mean([transition["unfinished"] for transition in episode]) + for episode in episodes_buffer + ] + ) + metrics["sampler/episode_return"] = np.mean( + [sum(transition["reward"] for transition in episode) for episode in episodes_buffer] + ) + metrics["sampler/num_turns_per_episode"] = np.mean( + [len(episode) for episode in episodes_buffer] + ) + gen_tokens_lens = [ + sum(len(transition["act_tokens"]) for transition in episode) + for episode in episodes_buffer + ] + metrics["sampler/action_num_tokens"] = np.mean(gen_tokens_lens) + metrics["sampler/num_episodes"] = len(episodes_buffer) + + # print at most two episodes for debugging purposes + for n, episode in enumerate(episodes_buffer): + print(f"----- episode {n} -----") + for t, transition in enumerate(episode): + obs = tokenizer.decode(transition["obs_tokens"]) + act = tokenizer.decode(transition["act_tokens"]) + obs = obs[:196] + "\n...\n" + obs[-200:] if len(obs) > 396 else obs + act = act[:196] + "\n...\n" + act[-200:] if len(act) > 396 else act + print(f"turn={t + 1}") + print(colored(obs, "blue")) + print(colored(act, "light_red", attrs=["bold"])) + print( + colored( + "reward=" + str(transition["reward"]), + "light_magenta", + attrs=["bold"], + ) + ) + if n > 0: + break + + # prepare transitions + transitions = [] + for episode in episodes_buffer: + # One transition typically consists of (s, a, r). + # Here we augment it with a Monte Carlo return to + # serve as the advantage estimation. + rewards = [transition["reward"] for transition in episode] + # Compute returns + cur = 0.0 + for i in reversed(range(len(rewards))): + cur = rewards[i] + config.gamma * cur + episode[i]["return"] = cur + transitions.extend(episode) + + # return batch normalization (https://arxiv.org/pdf/2510.01051#page=5.73 shows it's effective) + if config.use_rebn: + returns = torch.tensor([transition["return"] for transition in transitions]).float() + returns = (returns - returns.mean()) / (returns.std() + 1e-9) + for i, transition in enumerate(transitions): + transition["return"] = returns[i].item() + + # prepare training datums compatible with Tinker API + training_datums = prepare_training_datums(transitions, 1 / len(transitions)) + + # training step + print(f"🎈 Start training at step {policy_iteration_step}") + st = time.time() + fwd_bwd_future = await training_client.forward_backward_async( + training_datums, + loss_fn=config.loss_fn, + ) + optim_step_future = await training_client.optim_step_async(adam_params) + fwd_bwd_result = await fwd_bwd_future.result_async() + _ = await optim_step_future.result_async() + metrics["time/train"] = time.time() - st + metrics["train/n_samples"] = len(training_datums) + + # compute policy entropy and sampler-learner difference + act_token_logprobs = [] + act_token_diffs = [] + for i in range(config.batch_size): + transition = transitions[i] + train_output = fwd_bwd_result.loss_fn_outputs[i] + act_token_logprobs.extend(transition["act_logprobs"]) + act_token_diffs.append( + torch.tensor(transition["act_logprobs"]) + - train_output["logprobs"].to_torch()[-len(transition["act_logprobs"]) :] + ) + act_token_diffs = torch.cat(act_token_diffs) + kl_sample_train_v1 = act_token_diffs.mean().item() + kl_sample_train_v2 = 0.5 * (act_token_diffs**2).mean().item() + metrics["sampler/token_entropy"] = -torch.tensor(act_token_logprobs).mean().item() + metrics["train/kl_sample_train_v1"] = kl_sample_train_v1 + metrics["train/kl_sample_train_v2"] = kl_sample_train_v2 + metrics.update(**{f"train/{k}": v for k, v in fwd_bwd_result.metrics.items()}) + + pprint.pprint(metrics) + wandb.log(metrics) + + wandb.finish() + + +if __name__ == "__main__": + asyncio.run(main(chz.entrypoint(Config)))