Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions tinker_cookbook/recipes/verifiers_rl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,45 @@ def log_results(
print(out)


def evaluate(
async def evaluate(
vf_env_id: str,
vf_env_args: dict,
model_name: str,
model_name: str | None,
num_examples: int,
rollouts_per_example: int,
max_concurrent: int,
max_tokens: int,
temperature: float,
model_path: str | None = None,
):
service = tinker.ServiceClient()

# If model_path is provided, get the base model from the training run
if model_path is not None:
rest_client = service.create_rest_client()
training_run = await rest_client.get_training_run_by_tinker_path_async(model_path)
if model_name:
if model_name != training_run.base_model:
raise ValueError(
f"Model name {model_name} does not match training run base model {training_run.base_model}"
)
else:
model_name = training_run.base_model

if model_name is None:
raise ValueError("model_name or model_path must be provided")

env = vf.load_environment(vf_env_id, **vf_env_args)
tokenizer = get_tokenizer(model_name)
renderer_name = model_info.get_recommended_renderer_name(model_name)
renderer = renderers.get_renderer(renderer_name, tokenizer)
service = tinker.ServiceClient()
sampling = service.create_sampling_client(base_model=model_name)

# Create sampling client from checkpoint path or base model
if model_path:
sampling = service.create_sampling_client(model_path=model_path, base_model=model_name)
else:
sampling = service.create_sampling_client(base_model=model_name)

client = TinkerAsyncOpenAIClient(sampling, renderer, tokenizer)
start_time = time.time()
results = env.evaluate_sync(
Expand All @@ -95,11 +118,13 @@ def evaluate(
rollouts_per_example,
end_time - start_time,
)
return results


@chz.chz
class CLIConfig:
model_name: str = "Qwen/Qwen3-4B-Instruct-2507"
model_name: str | None = None # Base model name (auto-detected from checkpoint if not provided)
model_path: str | None = None # Path to checkpoint (e.g., from checkpoints.jsonl sampler_path)
vf_env_id: str = "reverse-text"
vf_env_args: str | None = None # JSON string
num_examples: int = 5
Expand All @@ -111,7 +136,7 @@ class CLIConfig:

async def cli_main(cfg: CLIConfig):
env_args = json.loads(cfg.vf_env_args) if cfg.vf_env_args else {}
return evaluate(
return await evaluate(
vf_env_id=cfg.vf_env_id,
vf_env_args=env_args,
model_name=cfg.model_name,
Expand All @@ -120,6 +145,7 @@ async def cli_main(cfg: CLIConfig):
max_concurrent=cfg.max_concurrent,
max_tokens=cfg.max_tokens,
temperature=cfg.temperature,
model_path=cfg.model_path,
)


Expand Down