Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/ragas/metrics/collections/_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ class AnswerCorrectness(BaseMetric):

# Type hints for linter (attributes are set in __init__)
llm: "InstructorBaseRagasLLM"
embeddings: "BaseRagasEmbedding"
embeddings: t.Optional["BaseRagasEmbedding"]

def __init__(
self,
llm: "InstructorBaseRagasLLM",
embeddings: "BaseRagasEmbedding",
embeddings: t.Optional["BaseRagasEmbedding"] = None,
name: str = "answer_correctness",
weights: List[float] = [0.75, 0.25],
beta: float = 1.0,
Expand All @@ -104,9 +104,21 @@ def __init__(

Args:
llm: Modern instructor-based LLM for statement generation and classification
embeddings: Modern embeddings model for similarity calculation
embeddings: Modern embeddings model for similarity calculation. Optional if similarity
weight is 0 (pure factuality evaluation). Required if similarity weight > 0.
name: The metric name
weights: [factuality_weight, similarity_weight]. Must sum to > 0.
beta: F-beta score parameter. β>1 favors recall, β<1 favors precision.

Raises:
ValueError: If weights are invalid or embeddings are missing when needed for similarity scoring.

Examples:
Pure factuality (no embeddings needed):
>>> metric = AnswerCorrectness(llm=llm, weights=[1.0, 0.0])

Factuality + Similarity (embeddings required):
>>> metric = AnswerCorrectness(llm=llm, embeddings=embeddings, weights=[0.75, 0.25])
"""
# Set attributes explicitly before calling super()
self.llm = llm
Expand All @@ -124,6 +136,14 @@ def __init__(
if not all([w >= 0 for w in weights]):
raise ValueError("Weights must be non-negative")

# Validate embeddings availability when similarity weight > 0
if weights[1] > 0 and embeddings is None:
raise ValueError(
"Embeddings are required for semantic similarity scoring. "
"Either provide embeddings or set similarity weight to 0 (weights=[1.0, 0.0]) "
"for pure factuality-only evaluation."
)

# Validate beta
if not isinstance(beta, float):
raise ValueError(
Expand All @@ -133,6 +153,17 @@ def __init__(
# Call super() for validation (without passing llm/embeddings in kwargs)
super().__init__(name=name, **kwargs)

def _validate_embeddings(self) -> None:
"""Override base validation to allow optional embeddings.

AnswerCorrectness metric allows embeddings to be None when using
pure factuality evaluation (weights=[1.0, 0.0]). The main validation
of embeddings availability happens in __init__ based on weights.
"""
# Only validate embeddings if similarity weight > 0
# (validation logic already in __init__)
pass

async def ascore(
self, user_input: str, response: str, reference: str
) -> MetricResult:
Expand Down
14 changes: 13 additions & 1 deletion tests/e2e/metrics_migration/test_answer_correctness_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,10 @@ def test_answer_correctness_parameter_validation(self):
"""Test that v2 implementation properly validates parameters."""
from unittest.mock import Mock

mock_llm = Mock()
from ragas.llms.base import InstructorBaseRagasLLM

# Create proper mocks that inherit from the required base class
mock_llm = Mock(spec=InstructorBaseRagasLLM)
mock_embeddings = Mock()

# Test invalid weights
Expand All @@ -360,6 +363,15 @@ def test_answer_correctness_parameter_validation(self):
with pytest.raises(ValueError, match="Beta must be a float"):
AnswerCorrectness(llm=mock_llm, embeddings=mock_embeddings, beta="invalid") # type: ignore

# Test optional embeddings - should work with pure factuality (weight=0)
metric = AnswerCorrectness(llm=mock_llm, weights=[1.0, 0.0])
assert metric.embeddings is None
print("✅ Optional embeddings working for pure factuality!")

# Test embeddings required when similarity weight > 0
with pytest.raises(ValueError, match="Embeddings are required"):
AnswerCorrectness(llm=mock_llm, embeddings=None, weights=[0.75, 0.25])

print("✅ Parameter validation working correctly!")

def test_answer_correctness_migration_requirements_documented(self):
Expand Down
Loading