Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Core**

- Instructor library (`instructor>=1.14.0`) as core dependency for structured LLM output handling with automatic validation and retries. (PR: #49)
- `response_model` parameter on `ModelAdapter.chat()` — pass a Pydantic `BaseModel` class to get validated structured outputs via `ChatResponse.structured_response`. Supported on OpenAI, Anthropic, Google GenAI, and LiteLLM adapters. (PR: #49)
- `maseval.core.instructor` module with `create_instructor_client()` and `flatten_model_schema()` helpers for creating instructor-patched clients and generating provider-compatible JSON schemas. (PR: #49)

### Changed

**Core**

- Simulators (`ToolLLMSimulator`, `UserLLMSimulator`, `AgenticUserLLMSimulator`) now use instructor for structured output parsing with automatic validation and retries, replacing manual JSON extraction and retry loops. (PR: #49)

**Benchmarks**

- Tau2 benchmark uses `flatten_model_schema()` from `maseval.core.instructor` for tool parameter schema generation, replacing the manual `_flatten_schema()` function. (PR: #49)

- Usage and cost tracking via `Usage` and `TokenUsage` data classes. `ModelAdapter` tracks token usage automatically after each `chat()` call. Components that implement `UsageTrackableMixin` are collected via `gather_usage()`. Live totals available during benchmark runs via `benchmark.usage` (grand total) and `benchmark.usage_by_component` (per-component breakdowns). Post-hoc analysis via `UsageReporter.from_reports(benchmark.reports)` with breakdowns by task, component, or model. (PR: #45)
- Pluggable cost calculation via `CostCalculator` protocol. `StaticPricingCalculator` computes cost from user-supplied per-token rates. `LiteLLMCostCalculator` in `maseval.interface.usage` for automatic pricing via LiteLLM's model database (supports `custom_pricing` overrides and `model_id_map`; requires `litellm`). Pass a `cost_calculator` to `ModelAdapter` or `AgentAdapter` to compute `Usage.cost`. Provider-reported cost always takes precedence. (PR: #45)
- `AgentAdapter` now accepts `cost_calculator` and `model_id` parameters. For smolagents, CAMEL, and LlamaIndex, both are auto-detected from the framework's agent object (`LiteLLMCostCalculator` if litellm is installed). LangGraph requires explicit `model_id` since graphs can contain multiple models. Explicit parameters always override auto-detection. (PR: #45)
Expand Down
11 changes: 11 additions & 0 deletions maseval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
SimulatorError,
ToolSimulatorError,
UserSimulatorError,
ToolSimulatorResponse,
UserSimulatorResponse,
AgenticUserSimulatorResponse,
)
from .core.instructor import create_instructor_client, flatten_model_schema
from .core.model import ModelAdapter, ChatResponse
from .core.user import User, LLMUser, AgenticLLMUser, TerminationReason
from .core.evaluator import Evaluator
Expand Down Expand Up @@ -108,6 +112,13 @@
# Model adapters
"ModelAdapter",
"ChatResponse",
# Instructor integration
"create_instructor_client",
"flatten_model_schema",
# Simulator response models
"ToolSimulatorResponse",
"UserSimulatorResponse",
"AgenticUserSimulatorResponse",
# Exceptions and validation
"MASEvalError",
"AgentError",
Expand Down
58 changes: 3 additions & 55 deletions maseval/benchmark/tau2/domains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,59 +132,6 @@ def decorator(func: Callable) -> Callable:
return decorator


# =============================================================================
# JSON Schema Helpers (for get_tool_metadata)
# =============================================================================


def _resolve_node(node: Any, defs: Dict[str, Any]) -> Any:
"""Resolve a JSON schema node, inlining ``$ref`` references and simplifying ``anyOf``."""
if not isinstance(node, dict):
return node

if "$ref" in node:
ref_name = node["$ref"].rsplit("/", 1)[-1]
if ref_name in defs:
return _resolve_node(dict(defs[ref_name]), defs)
return node

# Simplify anyOf (typically Optional[X] -> X with nullable flag)
if "anyOf" in node:
variants = node["anyOf"]
non_null = [v for v in variants if not (isinstance(v, dict) and v.get("type") == "null")]
if len(non_null) == 1:
resolved = _resolve_node(non_null[0], defs)
resolved["nullable"] = True
if "description" in node and "description" not in resolved:
resolved["description"] = node["description"]
return resolved
if non_null:
return _resolve_node(non_null[0], defs)

out: Dict[str, Any] = {}
for key, value in node.items():
if key in ("$defs", "title", "default"):
continue
if isinstance(value, dict):
out[key] = _resolve_node(value, defs)
elif isinstance(value, list):
out[key] = [_resolve_node(v, defs) if isinstance(v, dict) else v for v in value]
else:
out[key] = value
return out


def _resolve_schema_properties(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Extract resolved per-parameter schemas from a JSON Schema ``properties`` block."""
defs = schema.get("$defs", {})
properties = schema.get("properties", {})

resolved: Dict[str, Any] = {}
for name, prop in properties.items():
resolved[name] = _resolve_node(prop, defs)
return resolved


# =============================================================================
# ToolKit Base Class
# =============================================================================
Expand Down Expand Up @@ -413,10 +360,11 @@ def get_tool_metadata(self, tool_name: str) -> Dict[str, Any]:
model_fields[param_name] = (anno, default)

params_model = create_model("parameters", **model_fields)
schema = params_model.model_json_schema()

# Resolve $ref/$defs and extract per-parameter schemas
inputs = _resolve_schema_properties(schema)
from maseval.core.instructor import flatten_model_schema

inputs = flatten_model_schema(params_model).get("properties", {})

return {
"description": description,
Expand Down
61 changes: 3 additions & 58 deletions maseval/benchmark/tau2/tau2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_model_adapter(self, model_id, **kwargs):
from maseval.core.exceptions import UserExhaustedError
from maseval.core.seeding import DefaultSeedGenerator, SeedGenerator

from maseval.core.instructor import flatten_model_schema
from maseval.benchmark.tau2.environment import Tau2Environment
from maseval.benchmark.tau2.evaluator import Tau2Evaluator

Expand Down Expand Up @@ -778,62 +779,6 @@ def execution_loop( # type: ignore[override]
""".strip()


def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""Flatten a JSON schema by inlining ``$ref`` references and removing unsupported keys.

Pydantic v2's ``model_json_schema()`` emits ``$ref`` / ``$defs`` for nested
models and ``anyOf`` for ``Optional`` fields. Google GenAI rejects all of
these. This helper recursively resolves them so the resulting schema is
self-contained and compatible with every provider.

Args:
schema: A JSON schema dict (typically from ``Model.model_json_schema()``).

Returns:
A flattened copy with ``$ref``, ``$defs``, ``anyOf``, ``title``,
``default``, and ``additionalProperties`` removed / inlined.
"""
_STRIP_KEYS = {"$defs", "additionalProperties", "title", "default"}

def _resolve(node: Any, defs: Dict[str, Any]) -> Any:
if not isinstance(node, dict):
return node

# Inline $ref
if "$ref" in node:
ref_name = node["$ref"].rsplit("/", 1)[-1]
if ref_name in defs:
return _resolve(dict(defs[ref_name]), defs)
return node

# Simplify anyOf (Optional[X] → X with nullable)
if "anyOf" in node:
variants = node["anyOf"]
non_null = [v for v in variants if not (isinstance(v, dict) and v.get("type") == "null")]
if len(non_null) == 1:
resolved = _resolve(non_null[0], defs)
resolved["nullable"] = True
if "description" in node and "description" not in resolved:
resolved["description"] = node["description"]
return resolved
if non_null:
return _resolve(non_null[0], defs)

out: Dict[str, Any] = {}
for key, value in node.items():
if key in _STRIP_KEYS or key == "anyOf":
continue
if isinstance(value, dict):
out[key] = _resolve(value, defs)
elif isinstance(value, list):
out[key] = [_resolve(v, defs) if isinstance(v, dict) else v for v in value]
else:
out[key] = value
return out

return _resolve(schema, schema.get("$defs", {}))


def _build_tool_definitions(tools: Dict[str, Callable]) -> List[Dict[str, Any]]:
"""Build OpenAI-format tool definitions from a dict of callables.

Expand Down Expand Up @@ -894,7 +839,7 @@ def _build_tool_definitions(tools: Dict[str, Callable]) -> List[Dict[str, Any]]:
"function": {
"name": name,
"description": description,
"parameters": _flatten_schema(params_model.model_json_schema()),
"parameters": flatten_model_schema(params_model),
},
}
)
Expand Down Expand Up @@ -1228,7 +1173,7 @@ def _get_tool_definitions(self) -> List[Dict[str, Any]]:
"function": {
"name": name,
"description": description,
"parameters": _flatten_schema(params_model.model_json_schema()),
"parameters": flatten_model_schema(params_model),
},
}
)
Expand Down
117 changes: 117 additions & 0 deletions maseval/core/instructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Instructor library integration for structured LLM outputs.

Provides helpers to create instructor-patched clients from provider SDK clients
and to generate flattened JSON schemas from Pydantic models.

Instructor adds ``response_model`` support with automatic validation and retries
to any supported LLM provider.

Example:
```python
from maseval.core.instructor import create_instructor_client

# Wrap an OpenAI client
import openai
client = openai.OpenAI()
instructor_client = create_instructor_client(client, provider="openai")
```
"""

from __future__ import annotations

from typing import Any, Optional, Dict


def create_instructor_client(
client: Any,
provider: str,
mode: Optional[str] = None,
) -> Any:
"""Create an instructor-patched client from a provider SDK client.

All patched clients expose a unified API:
``client.chat.completions.create(response_model=..., messages=...)``.

Args:
client: The provider SDK client instance (e.g., ``openai.OpenAI()``,
``anthropic.Anthropic()``). For LiteLLM, pass ``litellm.completion``.
provider: Provider name. One of: ``"openai"``, ``"litellm"``.
For other providers, use ``instructor.from_provider()`` directly.
mode: Optional instructor mode override. If None, uses the default
for the provider.

Returns:
An instructor-patched client supporting ``response_model``.

Raises:
ValueError: If provider is not recognized.
"""
import instructor

kwargs: Dict[str, Any] = {}
if mode is not None:
kwargs["mode"] = getattr(instructor.Mode, mode.upper(), mode)

if provider == "openai":
return instructor.from_openai(client, **kwargs)
elif provider == "litellm":
return instructor.from_litellm(client, **kwargs)
else:
raise ValueError(f"Unsupported provider: {provider!r}. Use instructor.from_provider() directly for other providers.")


def flatten_model_schema(model: type) -> Dict[str, Any]:
"""Generate a flattened JSON schema from a Pydantic model.

Uses instructor's ``openai_schema`` to produce a clean schema, then
applies additional flattening to remove ``anyOf`` (for ``Optional``
fields) and other constructs that some providers reject.

Args:
model: A Pydantic BaseModel subclass.

Returns:
A flat JSON schema dict suitable for LLM tool parameters.
"""
import instructor

schema_obj = instructor.openai_schema(model) # ty: ignore[invalid-argument-type]
schema = schema_obj.openai_schema["parameters"]

# instructor's openai_schema still produces anyOf for Optional fields.
# Flatten those for provider compatibility (especially Google GenAI).
return _resolve_schema(schema)


def _resolve_schema(node: Any) -> Any:
"""Recursively resolve anyOf and strip unsupported keys from a schema."""
if not isinstance(node, dict):
return node

_STRIP_KEYS = {"$defs", "additionalProperties", "title", "default"}

# Simplify anyOf (Optional[X] -> X with nullable)
if "anyOf" in node:
variants = node["anyOf"]
non_null = [v for v in variants if not (isinstance(v, dict) and v.get("type") == "null")]
if len(non_null) == 1:
resolved = _resolve_schema(non_null[0])
if isinstance(resolved, dict):
resolved["nullable"] = True
if "description" in node and "description" not in resolved:
resolved["description"] = node["description"]
return resolved
if non_null:
return _resolve_schema(non_null[0])

out: Dict[str, Any] = {}
for key, value in node.items():
if key in _STRIP_KEYS or key == "anyOf":
continue
if isinstance(value, dict):
out[key] = _resolve_schema(value)
elif isinstance(value, list):
out[key] = [_resolve_schema(v) if isinstance(v, dict) else v for v in value]
else:
out[key] = value
return out
Loading