Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions openevolve/process_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ def _worker_init(config_dict: dict, evaluation_file: str, parent_env: dict = Non
_worker_evaluator = None
_worker_llm_ensemble = None
_worker_prompt_sampler = None
_worker_embedding_client = None


def _lazy_init_worker_components():
"""Lazily initialize expensive components on first use"""
global _worker_evaluator
global _worker_llm_ensemble
global _worker_prompt_sampler
global _worker_embedding_client

if _worker_llm_ensemble is None:
from openevolve.llm.ensemble import LLMEnsemble
Expand Down Expand Up @@ -130,6 +132,72 @@ def _lazy_init_worker_components():
suffix=getattr(_worker_config, "file_suffix", ".py"),
)

if _worker_embedding_client is None:
embedding_model = getattr(_worker_config.database, "embedding_model", None)
if embedding_model:
from openevolve.embedding import EmbeddingClient

_worker_embedding_client = EmbeddingClient(embedding_model)


def _pre_eval_novelty_check(
child_code: str, db_snapshot: Dict[str, Any], island_idx: int
) -> Optional[str]:
"""
Check if the generated code is novel before evaluation.

This runs the embedding-based similarity check in the worker process,
before the expensive program evaluation step. Programs that are too
similar to existing island programs are rejected early, saving compute.

Returns None if the program is novel (or novelty checking is disabled),
or an error string if it should be rejected.
"""
if _worker_embedding_client is None:
return None

similarity_threshold = getattr(_worker_config.database, "similarity_threshold", 0.99)
if similarity_threshold <= 0.0:
return None

try:
import numpy as np

child_embd = _worker_embedding_client.get_embedding(child_code)
if not child_embd:
return None

child_arr = np.array(child_embd, dtype=np.float32)
child_norm = np.linalg.norm(child_arr)
if child_norm == 0:
return None

island_program_ids = db_snapshot["islands"][island_idx]
programs = db_snapshot["programs"]

for pid in island_program_ids:
prog_dict = programs.get(pid)
if prog_dict is None:
continue
other_embd = prog_dict.get("embedding")
if not other_embd:
continue
other_arr = np.array(other_embd, dtype=np.float32)
other_norm = np.linalg.norm(other_arr)
if other_norm == 0:
continue
similarity = float(np.dot(child_arr, other_arr) / (child_norm * other_norm))
if similarity >= similarity_threshold:
return (
f"Pre-evaluation novelty check failed: generated code is too similar "
f"to existing program {pid} (similarity={similarity:.4f} >= "
f"threshold={similarity_threshold}), skipping evaluation"
)
except Exception as e:
logger.warning(f"Pre-evaluation novelty check error (skipping check): {e}")

return None


def _run_iteration_worker(
iteration: int, db_snapshot: Dict[str, Any], parent_id: str, inspiration_ids: List[str]
Expand Down Expand Up @@ -285,6 +353,12 @@ def _run_iteration_worker(
iteration=iteration,
)

# Pre-evaluation novelty check using embedding similarity.
# This avoids wasting compute evaluating programs that are too similar to existing ones.
novelty_check_result = _pre_eval_novelty_check(child_code, db_snapshot, parent_island)
if novelty_check_result is not None:
return SerializableResult(error=novelty_check_result, iteration=iteration)

# Evaluate the child program
import uuid

Expand Down
137 changes: 137 additions & 0 deletions tests/test_pre_eval_novelty_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Tests for the pre-evaluation novelty check in the worker process.

Issue #439: Novelty check runs after evaluation, wasting compute on rejected programs.
This test verifies that the embedding-based similarity check is performed before
the expensive evaluation step.
"""

import os
import unittest
from unittest.mock import MagicMock, patch

# Set dummy API key for testing
os.environ.setdefault("OPENAI_API_KEY", "test")


class TestPreEvalNoveltyCheck(unittest.TestCase):
"""Tests for _pre_eval_novelty_check function in process_parallel.py"""

def _make_snapshot(self, island_programs, all_programs=None):
"""Helper: build a minimal db_snapshot dict."""
if all_programs is None:
all_programs = island_programs

programs_dict = {}
for p in all_programs:
programs_dict[p["id"]] = p

return {
"programs": programs_dict,
"islands": [list(p["id"] for p in island_programs)],
"current_island": 0,
"feature_dimensions": [],
"artifacts": {},
}

def _get_check_fn(self):
"""Import the function under test with mocked worker globals."""
import openevolve.process_parallel as pp

# Reset globals to a clean state for each test
pp._worker_config = MagicMock()
pp._worker_config.database.similarity_threshold = 0.99
return pp._pre_eval_novelty_check

def test_returns_none_when_no_embedding_client(self):
"""When _worker_embedding_client is None (novelty disabled), skip check."""
import openevolve.process_parallel as pp

fn = self._get_check_fn()
pp._worker_embedding_client = None
snapshot = self._make_snapshot([])
result = fn("def foo(): pass", snapshot, 0)
self.assertIsNone(result)

def test_returns_none_when_threshold_zero(self):
"""When threshold <= 0, novelty checking is effectively disabled."""
import openevolve.process_parallel as pp

fn = self._get_check_fn()
mock_client = MagicMock()
pp._worker_embedding_client = mock_client
pp._worker_config.database.similarity_threshold = 0.0

snapshot = self._make_snapshot([])
result = fn("def foo(): pass", snapshot, 0)
self.assertIsNone(result)
mock_client.get_embedding.assert_not_called()

def test_returns_none_when_island_has_no_embeddings(self):
"""When existing programs have no embeddings, novel by default."""
import openevolve.process_parallel as pp

fn = self._get_check_fn()
mock_client = MagicMock()
mock_client.get_embedding.return_value = [1.0] + [0.0] * 9
pp._worker_embedding_client = mock_client
pp._worker_config.database.similarity_threshold = 0.99

existing = {"id": "p1", "code": "def bar(): pass", "embedding": None}
snapshot = self._make_snapshot([existing])
result = fn("def foo(): pass", snapshot, 0)
self.assertIsNone(result)

def test_returns_error_when_code_too_similar(self):
"""When cosine similarity exceeds threshold, return an error string."""
import openevolve.process_parallel as pp

fn = self._get_check_fn()
# Both use the same embedding vector → similarity = 1.0
shared_embd = [1.0] + [0.0] * 9
mock_client = MagicMock()
mock_client.get_embedding.return_value = shared_embd
pp._worker_embedding_client = mock_client
pp._worker_config.database.similarity_threshold = 0.99

existing = {"id": "p1", "code": "def bar(): pass", "embedding": shared_embd}
snapshot = self._make_snapshot([existing])
result = fn("def foo(): pass", snapshot, 0)
self.assertIsNotNone(result)
self.assertIn("novelty check failed", result)
self.assertIn("p1", result)

def test_returns_none_when_code_sufficiently_different(self):
"""When similarity is below threshold, program is novel."""
import openevolve.process_parallel as pp

fn = self._get_check_fn()
child_embd = [1.0] + [0.0] * 9
existing_embd = [0.0, 1.0] + [0.0] * 8 # orthogonal → similarity = 0
mock_client = MagicMock()
mock_client.get_embedding.return_value = child_embd
pp._worker_embedding_client = mock_client
pp._worker_config.database.similarity_threshold = 0.99

existing = {"id": "p1", "code": "def bar(): pass", "embedding": existing_embd}
snapshot = self._make_snapshot([existing])
result = fn("def foo(): pass", snapshot, 0)
self.assertIsNone(result)

def test_returns_none_on_embedding_error(self):
"""When embedding client raises an exception, skip the check gracefully."""
import openevolve.process_parallel as pp

fn = self._get_check_fn()
mock_client = MagicMock()
mock_client.get_embedding.side_effect = RuntimeError("API error")
pp._worker_embedding_client = mock_client
pp._worker_config.database.similarity_threshold = 0.99

snapshot = self._make_snapshot([])
result = fn("def foo(): pass", snapshot, 0)
self.assertIsNone(result)


if __name__ == "__main__":
unittest.main()
Loading