Skip to content

Commit f267ccc

Browse files
committed
refactor(gepa): simplify initialization logic
Move helper function outside loop and simplify predictor deduplication check by validating keys before parsing JSON.
1 parent 7f81e88 commit f267ccc

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

dspy/teleprompt/gepa/gepa.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -619,36 +619,34 @@ def feedback_fn(
619619
"extract instructions, tool descriptions, and tool argument descriptions."
620620
)
621621

622+
# Detect tool-using predictors via type checking
623+
def is_tool_field(annotation) -> bool:
624+
"""Check if a field annotation is Tool or contains Tool."""
625+
if annotation is Tool:
626+
return True
627+
origin = get_origin(annotation)
628+
if origin is not None:
629+
args = get_args(annotation)
630+
for arg in args:
631+
if is_tool_field(arg): # Recursive for nested types
632+
return True
633+
return False
634+
622635
# Then, process individual predictors (skip if already part of a module config)
623636
for name, pred in student.named_predictors():
624637
if self.enable_tool_optimization:
625638
# Skip if predictor is part of a module config (e.g., ReAct)
626639
found = False
627-
for val in base_program.values():
628-
try:
640+
for key, val in base_program.items():
641+
if key.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX)):
629642
config = json.loads(val)
630643
if name in config:
631644
found = True
632645
break
633-
except (json.JSONDecodeError, TypeError, ValueError):
634-
pass
635646

636647
if found:
637648
continue
638649

639-
# Detect tool-using predictors via type checking
640-
def is_tool_field(annotation) -> bool:
641-
"""Check if a field annotation is Tool or contains Tool."""
642-
if annotation is Tool:
643-
return True
644-
origin = get_origin(annotation)
645-
if origin is not None:
646-
args = get_args(annotation)
647-
for arg in args:
648-
if is_tool_field(arg): # Recursive for nested types
649-
return True
650-
return False
651-
652650
# Add tool module if predictor uses tools
653651
if any(is_tool_field(field.annotation) for field in pred.signature.input_fields.values()):
654652
module_key = f"{TOOL_MODULE_PREFIX}:{name}"

0 commit comments

Comments
 (0)