Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Messages,
MessageType,
ModelResponse,
RolloutCallback,
RolloutInput,
RolloutTiming,
SamplingArgs,
Expand Down Expand Up @@ -614,6 +615,7 @@ async def generate(
save_results: bool = False,
save_every: int = -1,
use_tqdm: bool = True,
on_rollout_complete: RolloutCallback | None = None,
) -> GenerateOutputs:
"""
Generate rollouts for a set of inputs by group.
Expand Down Expand Up @@ -668,23 +670,43 @@ async def generate(
# process groups as they complete
pbar = None
if use_tqdm:
from tqdm import tqdm
from tqdm.asyncio import tqdm

pbar = tqdm(
total=len(group_list),
desc=f"Processing {len(group_list)} groups ({len(inputs_list)} total rollouts)",
leave=True,
position=0,
dynamic_ncols=True,
)

groups_completed = 0
all_states: list[State] = []
rollout_count = 0
running_reward_sum = 0.0

try:
for coro in asyncio.as_completed(group_tasks.keys()):
group_states = await coro
all_states.extend(group_states)
groups_completed += 1

if on_rollout_complete is not None:
for state in group_states:
rollout_count += 1
running_reward_sum += state.get("reward", 0.0)
await on_rollout_complete(state)

if pbar is not None:
pbar.update(1)
if on_rollout_complete is not None and rollout_count > 0:
avg_reward = running_reward_sum / rollout_count
pbar.set_postfix(
{
"avg_reward": f"{avg_reward:.3f}",
"saved_rollouts": rollout_count,
}
)

# save intermediate results
if (
Expand Down Expand Up @@ -790,6 +812,7 @@ async def evaluate(
state_columns: list[str] | None = None,
save_results: bool = False,
save_every: int = -1,
on_rollout_complete: RolloutCallback | None = None,
**kwargs,
) -> GenerateOutputs:
"""
Expand All @@ -808,6 +831,7 @@ async def evaluate(
state_columns=state_columns,
save_results=save_results,
save_every=save_every,
on_rollout_complete=on_rollout_complete,
**kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def get(self, key: str, default: Any = None) -> Any:
return default


RolloutCallback = Callable[[State], Awaitable[None]]

# oai tools
JsonPrimitive = Literal["string", "number", "integer", "boolean", "array", "object"]

Expand Down
141 changes: 111 additions & 30 deletions verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import importlib.util
import json
import logging
Expand All @@ -11,15 +12,76 @@
from datasets.utils import logging as ds_logging

import verifiers as vf
from verifiers.types import Endpoints, EvalConfig, GenerateMetadata, GenerateOutputs
from verifiers.types import (
Endpoints,
EvalConfig,
GenerateMetadata,
GenerateOutputs,
State,
)
from verifiers.utils.client_utils import setup_client
from verifiers.utils.logging_utils import print_prompt_completions_sample
from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls
from verifiers.utils.message_utils import messages_to_printable
from verifiers.utils.path_utils import get_eval_results_path
from verifiers.utils.rollout_utils import serialize_rollout

logger = logging.getLogger(__name__)


class _StreamWriter:
def __init__(
self,
results_path: Path,
total_rollouts: int,
state_columns: list[str] | None = None,
):
self.results_path = results_path
self.total_rollouts = total_rollouts
self.state_columns = state_columns or []
self.completed_count = 0
self.running_reward_sum = 0.0
self.running_metrics_sum: dict[str, float] = {}

self.results_path.parent.mkdir(parents=True, exist_ok=True)

def log_rollout(self, state: State) -> None:
self.completed_count += 1
reward = state.get("reward", 0.0)
self.running_reward_sum += reward

metrics = state.get("metrics", {})
for k, v in metrics.items():
self.running_metrics_sum[k] = self.running_metrics_sum.get(k, 0.0) + v

def _write_sync(self, json_line: str) -> None:
with open(self.results_path, "a") as f:
f.write(json_line)

async def write_rollout_jsonl(self, state: State) -> None:
rollout_data = serialize_rollout(
state,
state_columns=self.state_columns,
include_timestamp=True,
)
json_line = json.dumps(rollout_data) + "\n"
await asyncio.to_thread(self._write_sync, json_line)

def get_running_stats(self) -> dict[str, float]:
if self.completed_count == 0:
return {}

stats = {
"avg_reward": self.running_reward_sum / self.completed_count,
"completed": self.completed_count,
"total": self.total_rollouts,
}

for k, v in self.running_metrics_sum.items():
stats[f"avg_{k}"] = v / self.completed_count

return stats


def load_endpoints(endpoints_path: str):
try:
endpoints_path_obj = Path(endpoints_path)
Expand Down Expand Up @@ -115,6 +177,24 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs:
logger.info(
f"Configuration: num_examples={config.num_examples}, rollouts_per_example={config.rollouts_per_example}, max_concurrent={config.max_concurrent}"
)

# Setup streaming handler for incremental saving
streaming_handler = None
rollout_callback = None

if config.save_results:
total_rollouts = config.num_examples * config.rollouts_per_example
results_jsonl_path = results_path / "results.jsonl"
streaming_handler = _StreamWriter(
results_path=results_jsonl_path,
total_rollouts=total_rollouts,
state_columns=config.state_columns,
)

async def rollout_callback(state: State) -> None:
streaming_handler.log_rollout(state)
await streaming_handler.write_rollout_jsonl(state)

start_time = time.time()
results = await vf_env.evaluate(
client=client,
Expand All @@ -129,13 +209,28 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs:
state_columns=config.state_columns,
save_results=config.save_results,
save_every=config.save_every,
on_rollout_complete=rollout_callback,
)
end_time = time.time()
logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds")

if config.print_results:
print_results(results)

if config.save_results:
save_rollout_results(results, config.save_to_hf_hub, config.hf_hub_dataset_name)
metadata_dict = sanitize_metadata(results["metadata"])
with open(results_path / "metadata.json", "w") as f:
json.dump(metadata_dict, f)
logger.info(f"Metadata saved to {results_path / 'metadata.json'}")

if config.save_to_hf_hub:
dataset = Dataset.from_json(str(results_path / "results.jsonl"))
dataset_name = config.hf_hub_dataset_name or get_hf_hub_dataset_name(
results
)
dataset.push_to_hub(dataset_name)
logger.info(f"Dataset pushed to Hugging Face Hub: {dataset_name}")

return results


Expand All @@ -162,35 +257,21 @@ def get_hf_hub_dataset_name(results: GenerateOutputs) -> str:


def make_dataset(results: GenerateOutputs, **kwargs) -> Dataset:
clean_prompts = [messages_to_printable(p) for p in results["prompt"]]
clean_prompts = [sanitize_tool_calls(p) for p in clean_prompts]
clean_completions = [messages_to_printable(c) for c in results["completion"]]
clean_completions = [sanitize_tool_calls(c) for c in clean_completions]
save_info = any(info != {} for info in results["info"])
save_answer = any(answer != "" for answer in results["answer"])
state_columns = results["metadata"]["state_columns"]

serialized_rollouts = [
serialize_rollout(state, state_columns=state_columns, include_timestamp=False)
for state in results["state"]
]

if not serialized_rollouts:
return Dataset.from_dict({})

all_keys = {key for rollout in serialized_rollouts for key in rollout.keys()}

results_dict = {
"example_id": results["example_id"],
"prompt": clean_prompts,
"completion": clean_completions,
"task": results["task"],
"reward": results["reward"],
"generation_ms": [s["timing"]["generation_ms"] for s in results["state"]],
"scoring_ms": [s["timing"]["scoring_ms"] for s in results["state"]],
"total_ms": [s["timing"]["total_ms"] for s in results["state"]],
key: [rollout.get(key) for rollout in serialized_rollouts] for key in all_keys
}
if save_info:
results_dict["info"] = results["info"]
if save_answer:
results_dict["answer"] = results["answer"]
for k in results["metrics"]:
v = results["metrics"][k]
results_dict[k] = v

# Add selected state columns if specified
state_columns = results["metadata"]["state_columns"]
if state_columns:
for col in state_columns:
results_dict[col] = [s.get(col) for s in results["state"]]

return Dataset.from_dict(results_dict)

Expand Down
55 changes: 55 additions & 0 deletions verifiers/utils/rollout_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from datetime import datetime
from typing import Any

from verifiers.types import State
from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls


def serialize_rollout(
state: State,
state_columns: list[str] | None = None,
include_timestamp: bool = False,
) -> dict[str, Any]:
prompt = state.get("prompt", "")
completion = state.get("completion", "")
clean_prompt = sanitize_tool_calls(messages_to_printable(prompt))
clean_completion = sanitize_tool_calls(messages_to_printable(completion))

rollout_data: dict[str, Any] = {
"example_id": state.get("example_id", 0),
"prompt": clean_prompt,
"completion": clean_completion,
"task": state.get("task", ""),
"reward": state.get("reward", 0.0),
}

if include_timestamp:
rollout_data["timestamp"] = datetime.now().isoformat()

timing = state.get("timing", {})
if timing:
rollout_data["generation_ms"] = timing.get("generation_ms", 0.0)
rollout_data["scoring_ms"] = timing.get("scoring_ms", 0.0)
rollout_data["total_ms"] = timing.get("total_ms", 0.0)

metrics = state.get("metrics", {})
for k, v in metrics.items():
rollout_data[k] = v

answer = state.get("answer", "")
if answer:
rollout_data["answer"] = answer

info = state.get("info", {})
if info:
rollout_data["info"] = info

if state_columns:
for col in state_columns:
if col in state:
if col == "responses":
rollout_data[col] = [r.model_dump() for r in state[col]]
else:
rollout_data[col] = state[col]

return rollout_data
Loading