Skip to content

Commit 7f81e88

Browse files
committed
refactor(gepa): improve tool extraction robustness and observability
Move tool extraction logic to evaluate() loop for immediate capture. Fix overwrite risk by merging discovered tools with existing config. Improve logging and docstrings for better maintainability.
1 parent b1e4f3d commit 7f81e88

File tree

1 file changed

+86
-38
lines changed

1 file changed

+86
-38
lines changed

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 86 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
import random
4+
from collections.abc import Iterable
45
from typing import Any, Callable, Protocol, TypedDict
56

67
from gepa import EvaluationBatch, GEPAAdapter
@@ -13,7 +14,7 @@
1314
from dspy.adapters.types.tool import Tool
1415
from dspy.evaluate import Evaluate
1516
from dspy.primitives import Example, Prediction
16-
from dspy.teleprompt.bootstrap_trace import TraceData
17+
from dspy.teleprompt.bootstrap_trace import FailedPrediction, TraceData
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -255,6 +256,7 @@ def traverse(o):
255256

256257
for tool_name, tool_config in improved_tools.items():
257258
if tool_name not in all_tools:
259+
logger.warning(f"Skipping updates for tool:'{tool_name}' because it cannot be detected on the student program.")
258260
continue
259261

260262
tool = all_tools[tool_name]
@@ -301,6 +303,10 @@ def evaluate(self, batch, candidate, capture_traces=False):
301303
if hasattr(score, "score"):
302304
score = score["score"]
303305
scores.append(score)
306+
307+
if self.enable_tool_optimization:
308+
self._update_candidate_tools(candidate, program, trajs)
309+
304310
return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajs)
305311
else:
306312
evaluator = Evaluate(
@@ -322,34 +328,17 @@ def evaluate(self, batch, candidate, capture_traces=False):
322328
def make_reflective_dataset(
323329
self, candidate, eval_batch, components_to_update
324330
) -> dict[str, list[ReflectiveExample]]:
325-
from dspy.teleprompt.bootstrap_trace import FailedPrediction
326331
program = self.build_program(candidate)
327332

328333
ret_d: dict[str, list[ReflectiveExample]] = {}
329334

330-
# collect unique tools from traces for each tool-using predictor, serialize to candidate at end
331-
tools_by_predictor: dict[str, dict[str, Tool]] = {}
332-
333335
for pred_name in components_to_update:
334336
# Extract predictor name from component key
335337
if pred_name.startswith(REACT_MODULE_PREFIX):
336338
target_name = pred_name.removeprefix(f"{REACT_MODULE_PREFIX}:")
337339

338340
elif pred_name.startswith(TOOL_MODULE_PREFIX):
339341
target_name = pred_name.removeprefix(f"{TOOL_MODULE_PREFIX}:")
340-
tools_by_predictor[pred_name] = {}
341-
342-
# Helper function for extracting tools (only needed for tool modules)
343-
def extract_tools_from_value(value, tools_dict):
344-
"""Extract Tool objects from value (handles single, list, dict)."""
345-
if isinstance(value, Tool):
346-
tools_dict[value.name] = value
347-
elif isinstance(value, (list, tuple, set)):
348-
for item in value:
349-
extract_tools_from_value(item, tools_dict)
350-
elif isinstance(value, dict):
351-
for item in value.values():
352-
extract_tools_from_value(item, tools_dict)
353342

354343
else:
355344
target_name = pred_name
@@ -378,13 +367,6 @@ def extract_tools_from_value(value, tools_dict):
378367
if len(trace_instances) == 0:
379368
continue
380369

381-
# Extract tools that are used in the trace instances
382-
if pred_name.startswith(TOOL_MODULE_PREFIX):
383-
for t in trace_instances:
384-
trace_inputs = t[1]
385-
for input_value in trace_inputs.values():
386-
extract_tools_from_value(input_value, tools_by_predictor[pred_name])
387-
388370
# TODO: Workaround for ReAct's multiple predictor calls with partial trajectories.
389371
# Using last trace ensures full aggregated trajectory (same as extract predictor).
390372
# After PR #8999 merges (https://github.com/stanfordnlp/dspy/pull/8999), test if we can
@@ -478,25 +460,91 @@ def extract_tools_from_value(value, tools_dict):
478460

479461
ret_d[pred_name] = items
480462

481-
# Update candidate configs with extracted tools (after all traces processed)
482-
for pred_name, tools_dict in tools_by_predictor.items():
463+
if len(ret_d) == 0:
464+
raise Exception("No valid predictions found for any module.")
465+
466+
return ret_d
467+
468+
def _update_candidate_tools(self, candidate, program, trajectories) -> None:
469+
"""Extract dspy.Tool objects from traces for tool modules and update candidate["tools"]."""
470+
471+
tools_by_predictor: dict[str, dict[str, Tool]] = {}
472+
473+
def extract_tools_from_value(value: Any, tools_dict: dict[str, Tool]) -> None:
474+
"""Recursively collect dspy.Tool instances from arbitrary input structures.
475+
Traverses nested containers (lists, dicts, etc.) to find all dspy.Tool objects passed as input arguments, populating the provided tools_dict.
476+
"""
477+
478+
if isinstance(value, Tool):
479+
tools_dict[value.name] = value
480+
return
481+
482+
# For mappings, recurse over values only.
483+
if isinstance(value, dict):
484+
for v in value.values():
485+
extract_tools_from_value(v, tools_dict)
486+
return
487+
488+
# For other iterables (including list, tuple, set, dict_values, etc.), recurse over elements.
489+
# Skip strings/bytes to avoid treating them as iterables of characters.
490+
if isinstance(value, Iterable) and not isinstance(value, (str, bytes)):
491+
for item in value:
492+
extract_tools_from_value(item, tools_dict)
493+
494+
# We iterate over all candidate keys to find tool modules
495+
for component_key in candidate.keys():
496+
if not component_key.startswith(TOOL_MODULE_PREFIX):
497+
continue
498+
499+
target_name = component_key.removeprefix(f"{TOOL_MODULE_PREFIX}:")
500+
tools_by_predictor[component_key] = {}
501+
502+
# Find the predictor object
503+
module = None
504+
for name, m in program.named_predictors():
505+
if name == target_name:
506+
module = m
507+
break
508+
if module is None:
509+
logger.warning(f"Predictor not found for tool module {target_name}")
510+
continue
511+
512+
for data in trajectories or []:
513+
trace = data["trace"]
514+
515+
trace_instances = [t for t in trace if t[0].signature.equals(module.signature)]
516+
if not self.add_format_failure_as_feedback:
517+
trace_instances = [t for t in trace_instances if not isinstance(t[2], FailedPrediction)]
518+
519+
if len(trace_instances) == 0:
520+
continue
521+
522+
for t in trace_instances:
523+
trace_inputs = t[1]
524+
525+
for input_value in trace_inputs.values():
526+
# Recursively collect dspy.Tool objects from input values
527+
extract_tools_from_value(input_value, tools_by_predictor[component_key])
528+
529+
# Update candidate["tools"] with tools found in traces
530+
for component_key, tools_dict in tools_by_predictor.items():
483531
if not tools_dict:
532+
logger.debug(f"No tools extracted from traces for {component_key} (eval_batch.trajectories may be missing tool calls)")
484533
continue
485534

486-
config = json.loads(candidate[pred_name])
487-
config["tools"] = {
488-
tool_name: {
535+
config = json.loads(candidate[component_key])
536+
537+
# Initialize tools dict from existing config if present, otherwise empty
538+
tools_config = config.get("tools", {})
539+
540+
# Update with tools found in traces (this updates existing entries or adds new ones)
541+
for tool_name, tool in tools_dict.items():
542+
tools_config[tool_name] = {
489543
"desc": tool.desc,
490544
"args": tool.args,
491545
}
492-
for tool_name, tool in tools_dict.items()
493-
}
494-
candidate[pred_name] = json.dumps(config, indent=2)
495-
496-
if len(ret_d) == 0:
497-
raise Exception("No valid predictions found for any module.")
498-
499-
return ret_d
546+
config["tools"] = tools_config
547+
candidate[component_key] = json.dumps(config, indent=2)
500548

501549
# TODO: The current DSPyAdapter implementation uses the GEPA default propose_new_texts.
502550
# We can potentially override this, to use the instruction proposal similar to MIPROv2.

0 commit comments

Comments
 (0)