11import json
22import logging
33import random
4+ from collections .abc import Iterable
45from typing import Any , Callable , Protocol , TypedDict
56
67from gepa import EvaluationBatch , GEPAAdapter
1314from dspy .adapters .types .tool import Tool
1415from dspy .evaluate import Evaluate
1516from dspy .primitives import Example , Prediction
16- from dspy .teleprompt .bootstrap_trace import TraceData
17+ from dspy .teleprompt .bootstrap_trace import FailedPrediction , TraceData
1718
1819logger = logging .getLogger (__name__ )
1920
@@ -255,6 +256,7 @@ def traverse(o):
255256
256257 for tool_name , tool_config in improved_tools .items ():
257258 if tool_name not in all_tools :
259+ logger .warning (f"Skipping updates for tool:'{ tool_name } ' because it cannot be detected on the student program." )
258260 continue
259261
260262 tool = all_tools [tool_name ]
@@ -301,6 +303,10 @@ def evaluate(self, batch, candidate, capture_traces=False):
301303 if hasattr (score , "score" ):
302304 score = score ["score" ]
303305 scores .append (score )
306+
307+ if self .enable_tool_optimization :
308+ self ._update_candidate_tools (candidate , program , trajs )
309+
304310 return EvaluationBatch (outputs = outputs , scores = scores , trajectories = trajs )
305311 else :
306312 evaluator = Evaluate (
@@ -322,34 +328,17 @@ def evaluate(self, batch, candidate, capture_traces=False):
322328 def make_reflective_dataset (
323329 self , candidate , eval_batch , components_to_update
324330 ) -> dict [str , list [ReflectiveExample ]]:
325- from dspy .teleprompt .bootstrap_trace import FailedPrediction
326331 program = self .build_program (candidate )
327332
328333 ret_d : dict [str , list [ReflectiveExample ]] = {}
329334
330- # collect unique tools from traces for each tool-using predictor, serialize to candidate at end
331- tools_by_predictor : dict [str , dict [str , Tool ]] = {}
332-
333335 for pred_name in components_to_update :
334336 # Extract predictor name from component key
335337 if pred_name .startswith (REACT_MODULE_PREFIX ):
336338 target_name = pred_name .removeprefix (f"{ REACT_MODULE_PREFIX } :" )
337339
338340 elif pred_name .startswith (TOOL_MODULE_PREFIX ):
339341 target_name = pred_name .removeprefix (f"{ TOOL_MODULE_PREFIX } :" )
340- tools_by_predictor [pred_name ] = {}
341-
342- # Helper function for extracting tools (only needed for tool modules)
343- def extract_tools_from_value (value , tools_dict ):
344- """Extract Tool objects from value (handles single, list, dict)."""
345- if isinstance (value , Tool ):
346- tools_dict [value .name ] = value
347- elif isinstance (value , (list , tuple , set )):
348- for item in value :
349- extract_tools_from_value (item , tools_dict )
350- elif isinstance (value , dict ):
351- for item in value .values ():
352- extract_tools_from_value (item , tools_dict )
353342
354343 else :
355344 target_name = pred_name
@@ -378,13 +367,6 @@ def extract_tools_from_value(value, tools_dict):
378367 if len (trace_instances ) == 0 :
379368 continue
380369
381- # Extract tools that are used in the trace instances
382- if pred_name .startswith (TOOL_MODULE_PREFIX ):
383- for t in trace_instances :
384- trace_inputs = t [1 ]
385- for input_value in trace_inputs .values ():
386- extract_tools_from_value (input_value , tools_by_predictor [pred_name ])
387-
388370 # TODO: Workaround for ReAct's multiple predictor calls with partial trajectories.
389371 # Using last trace ensures full aggregated trajectory (same as extract predictor).
390372 # After PR #8999 merges (https://github.com/stanfordnlp/dspy/pull/8999), test if we can
@@ -478,25 +460,91 @@ def extract_tools_from_value(value, tools_dict):
478460
479461 ret_d [pred_name ] = items
480462
481- # Update candidate configs with extracted tools (after all traces processed)
482- for pred_name , tools_dict in tools_by_predictor .items ():
463+ if len (ret_d ) == 0 :
464+ raise Exception ("No valid predictions found for any module." )
465+
466+ return ret_d
467+
468+ def _update_candidate_tools (self , candidate , program , trajectories ) -> None :
469+ """Extract dspy.Tool objects from traces for tool modules and update candidate["tools"]."""
470+
471+ tools_by_predictor : dict [str , dict [str , Tool ]] = {}
472+
473+ def extract_tools_from_value (value : Any , tools_dict : dict [str , Tool ]) -> None :
474+ """Recursively collect dspy.Tool instances from arbitrary input structures.
475+ Traverses nested containers (lists, dicts, etc.) to find all dspy.Tool objects passed as input arguments, populating the provided tools_dict.
476+ """
477+
478+ if isinstance (value , Tool ):
479+ tools_dict [value .name ] = value
480+ return
481+
482+ # For mappings, recurse over values only.
483+ if isinstance (value , dict ):
484+ for v in value .values ():
485+ extract_tools_from_value (v , tools_dict )
486+ return
487+
488+ # For other iterables (including list, tuple, set, dict_values, etc.), recurse over elements.
489+ # Skip strings/bytes to avoid treating them as iterables of characters.
490+ if isinstance (value , Iterable ) and not isinstance (value , (str , bytes )):
491+ for item in value :
492+ extract_tools_from_value (item , tools_dict )
493+
494+ # We iterate over all candidate keys to find tool modules
495+ for component_key in candidate .keys ():
496+ if not component_key .startswith (TOOL_MODULE_PREFIX ):
497+ continue
498+
499+ target_name = component_key .removeprefix (f"{ TOOL_MODULE_PREFIX } :" )
500+ tools_by_predictor [component_key ] = {}
501+
502+ # Find the predictor object
503+ module = None
504+ for name , m in program .named_predictors ():
505+ if name == target_name :
506+ module = m
507+ break
508+ if module is None :
509+ logger .warning (f"Predictor not found for tool module { target_name } " )
510+ continue
511+
512+ for data in trajectories or []:
513+ trace = data ["trace" ]
514+
515+ trace_instances = [t for t in trace if t [0 ].signature .equals (module .signature )]
516+ if not self .add_format_failure_as_feedback :
517+ trace_instances = [t for t in trace_instances if not isinstance (t [2 ], FailedPrediction )]
518+
519+ if len (trace_instances ) == 0 :
520+ continue
521+
522+ for t in trace_instances :
523+ trace_inputs = t [1 ]
524+
525+ for input_value in trace_inputs .values ():
526+ # Recursively collect dspy.Tool objects from input values
527+ extract_tools_from_value (input_value , tools_by_predictor [component_key ])
528+
529+ # Update candidate["tools"] with tools found in traces
530+ for component_key , tools_dict in tools_by_predictor .items ():
483531 if not tools_dict :
532+ logger .debug (f"No tools extracted from traces for { component_key } (eval_batch.trajectories may be missing tool calls)" )
484533 continue
485534
486- config = json .loads (candidate [pred_name ])
487- config ["tools" ] = {
488- tool_name : {
535+ config = json .loads (candidate [component_key ])
536+
537+ # Initialize tools dict from existing config if present, otherwise empty
538+ tools_config = config .get ("tools" , {})
539+
540+ # Update with tools found in traces (this updates existing entries or adds new ones)
541+ for tool_name , tool in tools_dict .items ():
542+ tools_config [tool_name ] = {
489543 "desc" : tool .desc ,
490544 "args" : tool .args ,
491545 }
492- for tool_name , tool in tools_dict .items ()
493- }
494- candidate [pred_name ] = json .dumps (config , indent = 2 )
495-
496- if len (ret_d ) == 0 :
497- raise Exception ("No valid predictions found for any module." )
498-
499- return ret_d
546+ config ["tools" ] = tools_config
547+ candidate [component_key ] = json .dumps (config , indent = 2 )
500548
501549 # TODO: The current DSPyAdapter implementation uses the GEPA default propose_new_texts.
502550 # We can potentially override this, to use the instruction proposal similar to MIPROv2.
0 commit comments