From f661a80311f48d1c824f63ec79f36a8fc8756679 Mon Sep 17 00:00:00 2001 From: Kumar Anirudha Date: Mon, 10 Nov 2025 13:08:26 +0530 Subject: [PATCH] refactor: make embeddings optional in AnswerCorrectness when using pure factuality mode --- .../collections/_answer_correctness.py | 37 +++++++++++++++++-- .../test_answer_correctness_migration.py | 14 ++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/ragas/metrics/collections/_answer_correctness.py b/src/ragas/metrics/collections/_answer_correctness.py index 9712c6970..dbd8ee08a 100644 --- a/src/ragas/metrics/collections/_answer_correctness.py +++ b/src/ragas/metrics/collections/_answer_correctness.py @@ -88,12 +88,12 @@ class AnswerCorrectness(BaseMetric): # Type hints for linter (attributes are set in __init__) llm: "InstructorBaseRagasLLM" - embeddings: "BaseRagasEmbedding" + embeddings: t.Optional["BaseRagasEmbedding"] def __init__( self, llm: "InstructorBaseRagasLLM", - embeddings: "BaseRagasEmbedding", + embeddings: t.Optional["BaseRagasEmbedding"] = None, name: str = "answer_correctness", weights: List[float] = [0.75, 0.25], beta: float = 1.0, @@ -104,9 +104,21 @@ def __init__( Args: llm: Modern instructor-based LLM for statement generation and classification - embeddings: Modern embeddings model for similarity calculation + embeddings: Modern embeddings model for similarity calculation. Optional if similarity + weight is 0 (pure factuality evaluation). Required if similarity weight > 0. + name: The metric name weights: [factuality_weight, similarity_weight]. Must sum to > 0. beta: F-beta score parameter. β>1 favors recall, β<1 favors precision. + + Raises: + ValueError: If weights are invalid or embeddings are missing when needed for similarity scoring. + + Examples: + Pure factuality (no embeddings needed): + >>> metric = AnswerCorrectness(llm=llm, weights=[1.0, 0.0]) + + Factuality + Similarity (embeddings required): + >>> metric = AnswerCorrectness(llm=llm, embeddings=embeddings, weights=[0.75, 0.25]) """ # Set attributes explicitly before calling super() self.llm = llm @@ -124,6 +136,14 @@ def __init__( if not all([w >= 0 for w in weights]): raise ValueError("Weights must be non-negative") + # Validate embeddings availability when similarity weight > 0 + if weights[1] > 0 and embeddings is None: + raise ValueError( + "Embeddings are required for semantic similarity scoring. " + "Either provide embeddings or set similarity weight to 0 (weights=[1.0, 0.0]) " + "for pure factuality-only evaluation." + ) + # Validate beta if not isinstance(beta, float): raise ValueError( @@ -133,6 +153,17 @@ def __init__( # Call super() for validation (without passing llm/embeddings in kwargs) super().__init__(name=name, **kwargs) + def _validate_embeddings(self) -> None: + """Override base validation to allow optional embeddings. + + AnswerCorrectness metric allows embeddings to be None when using + pure factuality evaluation (weights=[1.0, 0.0]). The main validation + of embeddings availability happens in __init__ based on weights. + """ + # Only validate embeddings if similarity weight > 0 + # (validation logic already in __init__) + pass + async def ascore( self, user_input: str, response: str, reference: str ) -> MetricResult: diff --git a/tests/e2e/metrics_migration/test_answer_correctness_migration.py b/tests/e2e/metrics_migration/test_answer_correctness_migration.py index 6e8f38b31..003789817 100644 --- a/tests/e2e/metrics_migration/test_answer_correctness_migration.py +++ b/tests/e2e/metrics_migration/test_answer_correctness_migration.py @@ -339,7 +339,10 @@ def test_answer_correctness_parameter_validation(self): """Test that v2 implementation properly validates parameters.""" from unittest.mock import Mock - mock_llm = Mock() + from ragas.llms.base import InstructorBaseRagasLLM + + # Create proper mocks that inherit from the required base class + mock_llm = Mock(spec=InstructorBaseRagasLLM) mock_embeddings = Mock() # Test invalid weights @@ -360,6 +363,15 @@ def test_answer_correctness_parameter_validation(self): with pytest.raises(ValueError, match="Beta must be a float"): AnswerCorrectness(llm=mock_llm, embeddings=mock_embeddings, beta="invalid") # type: ignore + # Test optional embeddings - should work with pure factuality (weight=0) + metric = AnswerCorrectness(llm=mock_llm, weights=[1.0, 0.0]) + assert metric.embeddings is None + print("✅ Optional embeddings working for pure factuality!") + + # Test embeddings required when similarity weight > 0 + with pytest.raises(ValueError, match="Embeddings are required"): + AnswerCorrectness(llm=mock_llm, embeddings=None, weights=[0.75, 0.25]) + print("✅ Parameter validation working correctly!") def test_answer_correctness_migration_requirements_documented(self):