diff --git a/tinker_cookbook/recipes/verifiers_rl/evaluate.py b/tinker_cookbook/recipes/verifiers_rl/evaluate.py index 2949702..4364d33 100644 --- a/tinker_cookbook/recipes/verifiers_rl/evaluate.py +++ b/tinker_cookbook/recipes/verifiers_rl/evaluate.py @@ -57,22 +57,45 @@ def log_results( print(out) -def evaluate( +async def evaluate( vf_env_id: str, vf_env_args: dict, - model_name: str, + model_name: str | None, num_examples: int, rollouts_per_example: int, max_concurrent: int, max_tokens: int, temperature: float, + model_path: str | None = None, ): + service = tinker.ServiceClient() + + # If model_path is provided, get the base model from the training run + if model_path is not None: + rest_client = service.create_rest_client() + training_run = await rest_client.get_training_run_by_tinker_path_async(model_path) + if model_name: + if model_name != training_run.base_model: + raise ValueError( + f"Model name {model_name} does not match training run base model {training_run.base_model}" + ) + else: + model_name = training_run.base_model + + if model_name is None: + raise ValueError("model_name or model_path must be provided") + env = vf.load_environment(vf_env_id, **vf_env_args) tokenizer = get_tokenizer(model_name) renderer_name = model_info.get_recommended_renderer_name(model_name) renderer = renderers.get_renderer(renderer_name, tokenizer) - service = tinker.ServiceClient() - sampling = service.create_sampling_client(base_model=model_name) + + # Create sampling client from checkpoint path or base model + if model_path: + sampling = service.create_sampling_client(model_path=model_path, base_model=model_name) + else: + sampling = service.create_sampling_client(base_model=model_name) + client = TinkerAsyncOpenAIClient(sampling, renderer, tokenizer) start_time = time.time() results = env.evaluate_sync( @@ -95,11 +118,13 @@ def evaluate( rollouts_per_example, end_time - start_time, ) + return results @chz.chz class CLIConfig: - model_name: str = "Qwen/Qwen3-4B-Instruct-2507" + model_name: str | None = None # Base model name (auto-detected from checkpoint if not provided) + model_path: str | None = None # Path to checkpoint (e.g., from checkpoints.jsonl sampler_path) vf_env_id: str = "reverse-text" vf_env_args: str | None = None # JSON string num_examples: int = 5 @@ -111,7 +136,7 @@ class CLIConfig: async def cli_main(cfg: CLIConfig): env_args = json.loads(cfg.vf_env_args) if cfg.vf_env_args else {} - return evaluate( + return await evaluate( vf_env_id=cfg.vf_env_id, vf_env_args=env_args, model_name=cfg.model_name, @@ -120,6 +145,7 @@ async def cli_main(cfg: CLIConfig): max_concurrent=cfg.max_concurrent, max_tokens=cfg.max_tokens, temperature=cfg.temperature, + model_path=cfg.model_path, )