diff --git a/pyproject.toml b/pyproject.toml index 75a7d6b5..78cf1fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ target-version = "py311" [tool.ruff.lint] # https://docs.astral.sh/ruff/rules/ # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules -select = ["A", "B0", "C4", "D2", "D4", "E", "F", "I"] +select = ["A", "B0", "C4", "D2", "D4", "E", "F", "FURB", "I", "PYI", "UP"] ignore = ["D203", "D206", "D213", "D400", "D401", "D413", "D415", "E1", "E501"] [tool.eva] diff --git a/scripts/run_text_only.py b/scripts/run_text_only.py index 7d17163b..1362e37b 100644 --- a/scripts/run_text_only.py +++ b/scripts/run_text_only.py @@ -25,7 +25,7 @@ import os import shutil import sys -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -514,7 +514,7 @@ async def run_record( user_message = record.user_goal["starting_utterance"] end_reason = "max_turns" turn_count = 0 - started_at = datetime.now(timezone.utc) + started_at = datetime.now(UTC) for turn in range(max_turns): turn_count = turn + 1 @@ -546,7 +546,7 @@ async def run_record( else: end_reason = "max_turns" - ended_at = datetime.now(timezone.utc) + ended_at = datetime.now(UTC) duration = (ended_at - started_at).total_seconds() logger.info(f"Conversation ended: {end_reason} ({turn_count} turns, {duration:.1f}s)") diff --git a/src/eva/assistant/agentic/audio_llm_system.py b/src/eva/assistant/agentic/audio_llm_system.py index 6a5cdd77..369e0510 100644 --- a/src/eva/assistant/agentic/audio_llm_system.py +++ b/src/eva/assistant/agentic/audio_llm_system.py @@ -5,8 +5,9 @@ audio so the model has full conversational context across turns. """ +from collections.abc import AsyncGenerator from pathlib import Path -from typing import Any, AsyncGenerator, Optional +from typing import Any from eva.assistant.agentic.audit_log import AuditLog from eva.assistant.agentic.system import AgenticSystem @@ -37,7 +38,7 @@ def __init__( tool_handler: ToolExecutor, audit_log: AuditLog, alm_client: ALMvLLMClient, - output_dir: Optional[Path] = None, + output_dir: Path | None = None, ): super().__init__( current_date_time=current_date_time, diff --git a/src/eva/assistant/agentic/audit_log.py b/src/eva/assistant/agentic/audit_log.py index 7fc28438..cc238c0a 100644 --- a/src/eva/assistant/agentic/audit_log.py +++ b/src/eva/assistant/agentic/audit_log.py @@ -4,7 +4,7 @@ import time from enum import StrEnum from pathlib import Path -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -32,11 +32,11 @@ class ConversationMessage(BaseModel): role: MessageRole content: str - tool_calls: Optional[list[dict[str, Any]]] = None - tool_call_id: Optional[str] = None - name: Optional[str] = None # For tool messages - turn_id: Optional[int] = None # For associating transcription updates - reasoning: Optional[str] = None # For model reasoning (e.g., from OpenAI o1) + tool_calls: list[dict[str, Any]] | None = None + tool_call_id: str | None = None + name: str | None = None # For tool messages + turn_id: int | None = None # For associating transcription updates + reasoning: str | None = None # For model reasoning (e.g., from OpenAI o1) def to_dict(self) -> dict[str, Any]: """Convert to a plain dict, excluding None fields and internal tracking fields.""" @@ -47,17 +47,17 @@ class LLMCall(BaseModel): """Record of an LLM call.""" messages: list[dict] - tools: Optional[list[dict]] = None - response: Optional[ConversationMessage] = None + tools: list[dict] | None = None + response: ConversationMessage | None = None duration_seconds: float = 0.0 start_time: str = "" end_time: str = "" status: str = "success" - model: Optional[str] = None + model: str | None = None # New fields for enhanced tracking (optional for backward compatibility) - latency_ms: Optional[float] = None - error_type: Optional[str] = None - error_source: Optional[str] = None + latency_ms: float | None = None + error_type: str | None = None + error_source: str | None = None retry_attempt: int = 0 @@ -75,13 +75,13 @@ def __init__(self): self.conversation_messages: list[ConversationMessage] = [] # Full message sequence for LLM context self._tool_calls_count = 0 self._tools_called: list[str] = [] - self._last_tool_call: Optional[str] = None # Track last tool called for matching responses + self._last_tool_call: str | None = None # Track last tool called for matching responses def append_user_input( self, content: str, - timestamp_ms: Optional[str] = None, - turn_id: Optional[int] = None, + timestamp_ms: str | None = None, + turn_id: int | None = None, ) -> None: """Record user input. @@ -189,7 +189,7 @@ def append_assistant_output( self, content: str, tool_calls: list[dict[str, Any]] | None = None, - timestamp_ms: Optional[str] = None, + timestamp_ms: str | None = None, ) -> None: """Record assistant output. @@ -236,7 +236,7 @@ def append_tool_message(self, tool_call_id: str, content: str) -> None: ) logger.debug(f"Audit: tool message for call_id {tool_call_id}") - def append_llm_call(self, llm_call: LLMCall, agent_name: Optional[str] = None) -> None: + def append_llm_call(self, llm_call: LLMCall, agent_name: str | None = None) -> None: """Record an LLM call.""" response_content = llm_call.response.content if llm_call.response else "" response_dict = llm_call.response.to_dict() if llm_call.response else None @@ -276,7 +276,7 @@ def append_tool_call( self, tool_name: str, parameters: dict[str, Any], - response: Optional[dict[str, Any]] = None, + response: dict[str, Any] | None = None, ) -> None: """Record a tool call and its response.""" # Record tool call in transcript @@ -342,7 +342,7 @@ def append_realtime_tool_call( def get_conversation_messages( self, - max_messages: Optional[int] = None, + max_messages: int | None = None, ) -> list[ConversationMessage]: """Get conversation messages for LLM context. diff --git a/src/eva/assistant/agentic/system.py b/src/eva/assistant/agentic/system.py index f9df217c..a4125725 100644 --- a/src/eva/assistant/agentic/system.py +++ b/src/eva/assistant/agentic/system.py @@ -5,8 +5,9 @@ import json import time import warnings +from collections.abc import AsyncGenerator from pathlib import Path -from typing import Any, AsyncGenerator +from typing import Any from eva.assistant.agentic.audit_log import ( AuditLog, @@ -220,7 +221,7 @@ async def _run_tool_loop( llm_call_response = ConversationMessage( role=MessageRole.ASSISTANT, content=response_content, - tool_calls=tool_calls_dicts if tool_calls_dicts else None, + tool_calls=tool_calls_dicts or None, reasoning=llm_stats.get("reasoning"), ) diff --git a/src/eva/assistant/pipeline/alm_vllm.py b/src/eva/assistant/pipeline/alm_vllm.py index 657ca4c3..ae41f289 100644 --- a/src/eva/assistant/pipeline/alm_vllm.py +++ b/src/eva/assistant/pipeline/alm_vllm.py @@ -10,7 +10,7 @@ import struct import time import wave -from typing import Any, Optional +from typing import Any from openai import AsyncOpenAI @@ -163,7 +163,7 @@ def build_audio_user_message( async def complete( self, messages: list[dict[str, Any]], - tools: Optional[list[dict]] = None, + tools: list[dict] | None = None, ) -> tuple[Any, dict[str, Any]]: """Chat completion with audio and tool support. @@ -188,7 +188,7 @@ async def complete( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" - last_exception: Optional[Exception] = None + last_exception: Exception | None = None for attempt in range(self.max_retries + 1): try: start_time = time.time() diff --git a/src/eva/assistant/pipeline/audio_llm_processor.py b/src/eva/assistant/pipeline/audio_llm_processor.py index bb5b24b3..25d16a7a 100644 --- a/src/eva/assistant/pipeline/audio_llm_processor.py +++ b/src/eva/assistant/pipeline/audio_llm_processor.py @@ -23,7 +23,7 @@ import wave from collections.abc import Awaitable from pathlib import Path -from typing import Any, Optional +from typing import Any from openai import AsyncOpenAI from pipecat.frames.frames import ( @@ -81,7 +81,7 @@ def __init__( self, context, user_context_aggregator, - pre_speech_secs: Optional[float] = None, + pre_speech_secs: float | None = None, **kwargs, ): super().__init__(**kwargs) @@ -176,7 +176,7 @@ def __init__( audit_log: AuditLog, alm_client: ALMvLLMClient, audio_collector: AudioLLMUserAudioCollector, - output_dir: Optional[Path] = None, + output_dir: Path | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -195,11 +195,11 @@ def __init__( ) # State tracking (mirrors BenchmarkAgentProcessor) - self._current_query_task: Optional[asyncio.Task] = None + self._current_query_task: asyncio.Task | None = None self._interrupted = asyncio.Event() # Optional callback for transcript saving (set by server.py) - self.on_assistant_response: Optional[Awaitable] = None + self.on_assistant_response: Awaitable | None = None async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: if isinstance(frame, (EndFrame, CancelFrame)): @@ -410,7 +410,7 @@ def __init__( audio_collector: AudioLLMUserAudioCollector, model: str = "", params: dict[str, Any] = None, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, sample_rate: int = PIPELINE_SAMPLE_RATE, **kwargs, ): @@ -426,7 +426,7 @@ def __init__( self._client: AsyncOpenAI = AsyncOpenAI(api_key=self._api_key, base_url=base_url) # Callback for when transcription is ready (set by server.py) - self.on_transcription: Optional[Any] = None + self.on_transcription: Any | None = None # Track background transcription tasks so they can complete even during interruptions self._transcription_tasks: list[asyncio.Task] = [] @@ -463,7 +463,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: # Clean up completed tasks self._transcription_tasks = [t for t in self._transcription_tasks if not t.done()] - async def transcribe(self, timestamp: str, turn_id: Optional[int] = None) -> Optional[str]: + async def transcribe(self, timestamp: str, turn_id: int | None = None) -> str | None: """Transcribe audio from the collector using chat completions. This method can be called directly from event handlers or via frame processing. @@ -479,9 +479,7 @@ async def transcribe(self, timestamp: str, turn_id: Optional[int] = None) -> Opt audio_data = self._audio_collector.peek_buffered_audio() return await self._transcribe_audio(audio_data, timestamp, turn_id) - async def _transcribe_audio( - self, audio_data: bytes, timestamp: str, turn_id: Optional[int] = None - ) -> Optional[str]: + async def _transcribe_audio(self, audio_data: bytes, timestamp: str, turn_id: int | None = None) -> str | None: """Transcribe pre-captured audio data using chat completions. This method takes audio data directly instead of reading from the collector, diff --git a/src/eva/assistant/pipeline/nvidia_stt.py b/src/eva/assistant/pipeline/nvidia_stt.py index 620004a1..1e6b0974 100644 --- a/src/eva/assistant/pipeline/nvidia_stt.py +++ b/src/eva/assistant/pipeline/nvidia_stt.py @@ -10,7 +10,7 @@ import json import ssl import time -from typing import AsyncGenerator, Optional +from collections.abc import AsyncGenerator import websockets from loguru import logger @@ -48,7 +48,7 @@ def __init__( self, *, url: str = "ws://localhost:8080", - api_key: Optional[str] = None, + api_key: str | None = None, sample_rate: int = 16000, verify: bool = True, **kwargs, @@ -58,7 +58,7 @@ def __init__( self._api_key = api_key self._verify = verify self._websocket = None - self._receive_task: Optional[asyncio.Task] = None + self._receive_task: asyncio.Task | None = None self._ready = False def can_generate_metrics(self) -> bool: @@ -135,7 +135,7 @@ async def _connect_websocket(self): self._websocket = await websockets.connect( self._url, ssl=ssl_context, - additional_headers=extra_headers if extra_headers else None, + additional_headers=extra_headers or None, ) self._ready = False @@ -149,7 +149,7 @@ async def _connect_websocket(self): else: logger.warning(f"{self} unexpected initial message: {data}") self._ready = True - except asyncio.TimeoutError: + except TimeoutError: logger.warning(f"{self} timeout waiting for ready, proceeding") self._ready = True diff --git a/src/eva/assistant/pipeline/realtime_llm.py b/src/eva/assistant/pipeline/realtime_llm.py index b502b4df..9d93c035 100644 --- a/src/eva/assistant/pipeline/realtime_llm.py +++ b/src/eva/assistant/pipeline/realtime_llm.py @@ -14,7 +14,7 @@ import struct import time from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from pipecat.frames.frames import Frame, InputAudioRawFrame, VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame from pipecat.processors.frame_processor import FrameDirection @@ -75,14 +75,14 @@ def __init__(self, *, audit_log: AuditLog, **kwargs: Any) -> None: # Assistant response accumulation (across audio_transcript_delta events) self._current_assistant_transcript_parts: list[str] = [] - self._assistant_response_start_wall_ms: Optional[str] = None + self._assistant_response_start_wall_ms: str | None = None # Track whether we're mid-assistant-response (for interruption flushing) self._assistant_responding: bool = False # Track audio frame timing for VAD delay calculation - self._last_audio_frame_time: Optional[float] = None - self._vad_delay_ms: Optional[int] = None + self._last_audio_frame_time: float | None = None + self._vad_delay_ms: int | None = None async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: """Track audio frame timing before passing to parent. @@ -295,7 +295,7 @@ def _reset_assistant_state(self) -> None: self._assistant_responding = False @property - def last_vad_delay_ms(self) -> Optional[int]: + def last_vad_delay_ms(self) -> int | None: """Return the most recent VAD delay in milliseconds. This is the time between when audio frames stopped arriving and when diff --git a/src/eva/assistant/pipeline/services.py b/src/eva/assistant/pipeline/services.py index c8ee3eff..30ddc667 100644 --- a/src/eva/assistant/pipeline/services.py +++ b/src/eva/assistant/pipeline/services.py @@ -4,7 +4,8 @@ """ import datetime -from typing import Any, AsyncGenerator, Optional +from collections.abc import AsyncGenerator +from typing import Any from deepgram import LiveOptions from openai import AsyncAzureOpenAI, BadRequestError @@ -82,8 +83,8 @@ def create_stt_service( - model: Optional[str], - params: Optional[dict[str, Any]] = None, + model: str | None, + params: dict[str, Any] | None = None, language_code: str = "en", ) -> STTService | None: """Create speech-to-text service. @@ -212,8 +213,8 @@ def create_stt_service( def create_tts_service( - model: Optional[str], - params: Optional[dict[str, Any]] = None, + model: str | None, + params: dict[str, Any] | None = None, language_code: str = "en", ) -> TTSService | None: """Create text-to-speech service. @@ -360,11 +361,11 @@ def create_tts_service( def create_realtime_llm_service( - model: Optional[str], - params: Optional[dict[str, Any]] = None, - agent: Optional[AgentConfig] = None, - audit_log: Optional[AuditLog] = None, - current_date_time: Optional[str] = None, + model: str | None, + params: dict[str, Any] | None = None, + agent: AgentConfig | None = None, + audit_log: AuditLog | None = None, + current_date_time: str | None = None, ) -> LLMService: """Create realtime LLM service. diff --git a/src/eva/assistant/server.py b/src/eva/assistant/server.py index 4282e894..5e920a15 100644 --- a/src/eva/assistant/server.py +++ b/src/eva/assistant/server.py @@ -8,7 +8,7 @@ import json import wave from pathlib import Path -from typing import Any, Optional +from typing import Any import uvicorn from fastapi import FastAPI, WebSocket @@ -135,7 +135,7 @@ def __init__( ) # Wall-clock captured at on_user_turn_started for non-instrumented S2S models - self._user_turn_started_wall_ms: Optional[str] = None + self._user_turn_started_wall_ms: str | None = None # Audio buffer for accumulating audio data self._audio_buffer = bytearray() @@ -147,12 +147,12 @@ def __init__( self._app = None self._server = None self._server_task = None - self._runner: Optional[PipelineRunner] = None - self._task: Optional[PipelineTask] = None + self._runner: PipelineRunner | None = None + self._task: PipelineTask | None = None self._running = False self.num_seconds = 0 self._latency_measurements: list[float] = [] - self._metrics_observer: Optional[MetricsFileObserver] = None + self._metrics_observer: MetricsFileObserver | None = None self.non_instrumented_realtime_llm = False async def start(self) -> None: @@ -217,7 +217,7 @@ async def stop(self) -> None: if self._server_task: try: await asyncio.wait_for(self._server_task, timeout=5.0) - except asyncio.TimeoutError: + except TimeoutError: # Force cancellation if graceful shutdown times out self._server_task.cancel() try: diff --git a/src/eva/assistant/services/llm.py b/src/eva/assistant/services/llm.py index 4d941304..9f329c5b 100644 --- a/src/eva/assistant/services/llm.py +++ b/src/eva/assistant/services/llm.py @@ -2,7 +2,7 @@ import asyncio import time -from typing import Any, Optional +from typing import Any import litellm from dotenv import load_dotenv @@ -37,7 +37,7 @@ def __init__(self, model: str): async def complete( self, messages: list[dict[str, Any]], - tools: Optional[list[dict]] = None, + tools: list[dict] | None = None, max_retries: int = 5, initial_delay: float = 1.0, ) -> tuple[Any, dict[str, Any]]: diff --git a/src/eva/assistant/tools/airline_params.py b/src/eva/assistant/tools/airline_params.py index 7173a74a..93238a7f 100644 --- a/src/eva/assistant/tools/airline_params.py +++ b/src/eva/assistant/tools/airline_params.py @@ -10,7 +10,7 @@ """ from enum import StrEnum -from typing import Annotated, Optional +from typing import Annotated from pydantic import BaseModel, Field, ValidationError @@ -133,8 +133,8 @@ class RebookFlightParams(BaseModel): new_journey_id: JourneyIdStr rebooking_type: RebookingType waive_change_fee: bool - new_fare_class: Optional[FareClass] = None - flight_number: Optional[FlightNumberStr] = None + new_fare_class: FareClass | None = None + flight_number: FlightNumberStr | None = None class AddToStandbyParams(BaseModel): @@ -148,14 +148,14 @@ class AssignSeatParams(BaseModel): passenger_id: PassengerIdStr journey_id: JourneyIdStr seat_preference: SeatPreference - flight_number: Optional[FlightNumberStr] = None + flight_number: FlightNumberStr | None = None class AddBaggageAllowanceParams(BaseModel): confirmation_number: ConfirmationNumber journey_id: JourneyIdStr num_bags: int = Field(ge=0, le=5) - flight_number: Optional[FlightNumberStr] = None + flight_number: FlightNumberStr | None = None class AddMealRequestParams(BaseModel): @@ -163,7 +163,7 @@ class AddMealRequestParams(BaseModel): passenger_id: PassengerIdStr journey_id: JourneyIdStr meal_type: MealType - flight_number: Optional[FlightNumberStr] = None + flight_number: FlightNumberStr | None = None class IssueTravelCreditParams(BaseModel): diff --git a/src/eva/assistant/tools/airline_tools.py b/src/eva/assistant/tools/airline_tools.py index a07fbd95..2aee65a4 100644 --- a/src/eva/assistant/tools/airline_tools.py +++ b/src/eva/assistant/tools/airline_tools.py @@ -434,7 +434,7 @@ def rebook_flight(params: dict, db: dict, call_index: int) -> dict: # Determine fare classes original_fare_class = booking.get("fare_class", "main_cabin") - target_fare_class = new_fare_class if new_fare_class else original_fare_class + target_fare_class = new_fare_class or original_fare_class # Computations is_irrops = "irrops" in rebooking_type diff --git a/src/eva/assistant/tools/tool_executor.py b/src/eva/assistant/tools/tool_executor.py index 9722e514..32bd6c57 100644 --- a/src/eva/assistant/tools/tool_executor.py +++ b/src/eva/assistant/tools/tool_executor.py @@ -3,8 +3,8 @@ import copy import importlib import json +from collections.abc import Callable from pathlib import Path -from typing import Callable import yaml from pipecat.services.llm_service import FunctionCallParams diff --git a/src/eva/metrics/base.py b/src/eva/metrics/base.py index 174273ff..ade4722b 100644 --- a/src/eva/metrics/base.py +++ b/src/eva/metrics/base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from enum import StrEnum from pathlib import Path -from typing import Any, Optional +from typing import Any from pydub import AudioSegment @@ -61,13 +61,13 @@ def __init__( num_turns: int = 0, num_tool_calls: int = 0, tools_called: list[str] = None, - conversation_ended_reason: Optional[str] = None, + conversation_ended_reason: str | None = None, duration_seconds: float = 0.0, # Paths to files output_dir: str = "", - audio_assistant_path: Optional[str] = None, - audio_user_path: Optional[str] = None, - audio_mixed_path: Optional[str] = None, + audio_assistant_path: str | None = None, + audio_user_path: str | None = None, + audio_mixed_path: str | None = None, # Processed log data from postprocessor transcribed_assistant_turns: dict[int, str] | None = None, transcribed_user_turns: dict[int, str] | None = None, diff --git a/src/eva/metrics/processor.py b/src/eva/metrics/processor.py index ad4d5d6c..38fd6803 100644 --- a/src/eva/metrics/processor.py +++ b/src/eva/metrics/processor.py @@ -4,7 +4,6 @@ from collections import Counter from dataclasses import dataclass, field from pathlib import Path -from typing import Optional from eva.assistant.agentic.system import GENERIC_ERROR from eva.models.results import ConversationResult @@ -647,7 +646,7 @@ class _ProcessorContext: """Processed log data for metric computation.""" def __init__(self): - self.record_id: Optional[str] = None + self.record_id: str | None = None # Per-role turn data (indexed by turn_id, 0-indexed) self.transcribed_assistant_turns: dict[int, str] = {} @@ -670,9 +669,9 @@ def __init__(self): self.conversation_trace: list[dict] = [] - self.audio_assistant_path: Optional[str] = None - self.audio_user_path: Optional[str] = None - self.audio_mixed_path: Optional[str] = None + self.audio_assistant_path: str | None = None + self.audio_user_path: str | None = None + self.audio_mixed_path: str | None = None # Interruption data self.assistant_interrupted_turns: set[int] = set() @@ -680,7 +679,7 @@ def __init__(self): # Conversation metadata self.conversation_finished: bool = False - self.conversation_ended_reason: Optional[str] = None + self.conversation_ended_reason: str | None = None self.is_audio_native: bool = False # Response latencies from Pipecat's UserBotLatencyObserver @@ -698,7 +697,7 @@ def process_record( result: ConversationResult, output_dir: Path, is_audio_native: bool = False, - ) -> Optional[_ProcessorContext]: + ) -> _ProcessorContext | None: """Process a single conversation record to create metric context. Args: diff --git a/src/eva/metrics/registry.py b/src/eva/metrics/registry.py index 1eb0c2cd..5c6c87b6 100644 --- a/src/eva/metrics/registry.py +++ b/src/eva/metrics/registry.py @@ -1,6 +1,6 @@ """Registry for managing available metrics.""" -from typing import Any, Type +from typing import Any from eva.metrics.base import BaseMetric from eva.utils.logging import get_logger @@ -15,9 +15,9 @@ class MetricRegistry: """ def __init__(self): - self._metrics: dict[str, Type[BaseMetric]] = {} + self._metrics: dict[str, type[BaseMetric]] = {} - def register(self, metric_class: Type[BaseMetric]) -> Type[BaseMetric]: + def register(self, metric_class: type[BaseMetric]) -> type[BaseMetric]: """Register a metric class. Can be used as a decorator: @@ -39,7 +39,7 @@ class MyMetric(BaseMetric): logger.debug(f"Registered metric: {name}") return metric_class - def get(self, name: str) -> Type[BaseMetric] | None: + def get(self, name: str) -> type[BaseMetric] | None: """Get a metric class by name. Args: @@ -74,7 +74,7 @@ def list_metrics(self) -> list[str]: """Get list of all registered metric names.""" return list(self._metrics.keys()) - def get_all(self) -> dict[str, Type[BaseMetric]]: + def get_all(self) -> dict[str, type[BaseMetric]]: """Get all registered metrics.""" return self._metrics.copy() @@ -88,7 +88,7 @@ def get_global_registry() -> MetricRegistry: return _global_registry -def register_metric(metric_class: Type[BaseMetric]) -> Type[BaseMetric]: +def register_metric(metric_class: type[BaseMetric]) -> type[BaseMetric]: """Register a metric in the global registry. Decorator for metric classes. diff --git a/src/eva/metrics/validation/conversation_finished.py b/src/eva/metrics/validation/conversation_finished.py index 03e8a030..acd89f05 100644 --- a/src/eva/metrics/validation/conversation_finished.py +++ b/src/eva/metrics/validation/conversation_finished.py @@ -38,7 +38,7 @@ async def compute(self, context: MetricContext) -> MetricScore: details={"file_path": str(elevenlabs_events_path)}, ) - with open(elevenlabs_events_path, "r") as f: + with open(elevenlabs_events_path) as f: lines = f.readlines() if not lines: diff --git a/src/eva/models/agents.py b/src/eva/models/agents.py index 09f3140d..e08fd030 100644 --- a/src/eva/models/agents.py +++ b/src/eva/models/agents.py @@ -1,7 +1,7 @@ """Agent configuration models.""" from pathlib import Path -from typing import Any, Optional +from typing import Any import yaml from pydantic import BaseModel, ConfigDict, Field @@ -12,11 +12,11 @@ class AgentToolParameter(BaseModel): name: str = Field(..., description="Parameter name") type: str = Field("string", description="Parameter type") - enum: Optional[list[str]] = Field(None, description="Allowed values for enum types") + enum: list[str] | None = Field(None, description="Allowed values for enum types") description: str = Field("", description="Parameter description") - items: Optional[dict[str, Any]] = Field(None, description="Items schema for array types") - properties: Optional[dict[str, Any]] = Field(None, description="Properties schema for object types") - additionalProperties: Optional[bool | dict[str, Any]] = Field( + items: dict[str, Any] | None = Field(None, description="Items schema for array types") + properties: dict[str, Any] | None = Field(None, description="Properties schema for object types") + additionalProperties: bool | dict[str, Any] | None = Field( None, description="Additional properties for object types" ) @@ -30,7 +30,7 @@ class AgentTool(BaseModel): required_parameters: list[str | AgentToolParameter] = Field(default_factory=list, description="Required parameters") optional_parameters: list[str | AgentToolParameter] = Field(default_factory=list, description="Optional parameters") invoke_cache_flush: bool = Field(False, description="Whether to flush cache on invocation") - tool_type: Optional[str] = Field(None, description="Type of tool: 'read' or 'write'") + tool_type: str | None = Field(None, description="Type of tool: 'read' or 'write'") metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") model_config = ConfigDict(extra="allow") @@ -104,7 +104,7 @@ class AgentConfig(BaseModel): role: str = Field(..., description="Agent role description") instructions: str = Field(..., description="Agent instructions/prompt") tools: list[AgentTool] = Field(default_factory=list, description="Tools available to this agent") - personality: Optional[str] = Field(None, description="Agent personality description") + personality: str | None = Field(None, description="Agent personality description") tool_module_path: str = Field( description="Python module path for tool implementations (e.g., 'eva.assistant.tools.airline_tools')", ) @@ -215,14 +215,14 @@ class AgentsConfig(BaseModel): agents: list[AgentConfig] = Field(default_factory=list, description="List of available agents") - def get_agent_by_id(self, agent_id: str) -> Optional[AgentConfig]: + def get_agent_by_id(self, agent_id: str) -> AgentConfig | None: """Get an agent by its ID.""" for agent in self.agents: if agent.id == agent_id: return agent return None - def get_agent_by_name(self, name: str) -> Optional[AgentConfig]: + def get_agent_by_name(self, name: str) -> AgentConfig | None: """Get an agent by its name.""" for agent in self.agents: if agent.name == name: diff --git a/src/eva/models/provenance.py b/src/eva/models/provenance.py index 7828f164..d5228859 100644 --- a/src/eva/models/provenance.py +++ b/src/eva/models/provenance.py @@ -1,7 +1,6 @@ """Provenance model for tracking run artifacts and environment.""" from datetime import datetime -from typing import Optional from pydantic import BaseModel, Field @@ -19,14 +18,14 @@ class BaseProvenance(BaseModel): eva_version: str simulation_version: str = "" metrics_version: str = "" - git_commit_sha: Optional[str] = None - git_branch: Optional[str] = None - git_dirty: Optional[bool] = None - git_diff_hash: Optional[str] = None - dataset: Optional[ArtifactInfo] = None - agent_config: Optional[ArtifactInfo] = None - tool_module: Optional[ArtifactInfo] = None - scenario_db: Optional[ArtifactInfo] = None + git_commit_sha: str | None = None + git_branch: str | None = None + git_dirty: bool | None = None + git_diff_hash: str | None = None + dataset: ArtifactInfo | None = None + agent_config: ArtifactInfo | None = None + tool_module: ArtifactInfo | None = None + scenario_db: ArtifactInfo | None = None python_version: str = "" platform: str = "" captured_at: datetime = Field(default_factory=datetime.now) diff --git a/src/eva/models/record.py b/src/eva/models/record.py index 57563820..a0ca7a09 100644 --- a/src/eva/models/record.py +++ b/src/eva/models/record.py @@ -2,7 +2,7 @@ import json from pathlib import Path -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -99,9 +99,9 @@ class GroundTruth(BaseModel): class AgentOverride(BaseModel): """Override agent configuration for a specific record.""" - instructions: Optional[str] = Field(None, description="Override agent instructions") - tools_enabled: Optional[list[str]] = Field(None, description="Subset of tools to enable") - personality: Optional[str] = Field(None, description="Override agent personality") + instructions: str | None = Field(None, description="Override agent instructions") + tools_enabled: list[str] | None = Field(None, description="Subset of tools to enable") + personality: str | None = Field(None, description="Override agent personality") class EvaluationRecord(BaseModel): @@ -123,10 +123,10 @@ class EvaluationRecord(BaseModel): ground_truth: GroundTruth = Field(default_factory=GroundTruth, description="Expected outcomes for evaluation") # Optional overrides - agent_override: Optional[AgentOverride] = Field(None, description="Override agent configuration for this record") + agent_override: AgentOverride | None = Field(None, description="Override agent configuration for this record") # Metadata - category: Optional[str] = Field(None, description="Category for grouping, e.g., 'hr_pto', 'it_support'") + category: str | None = Field(None, description="Category for grouping, e.g., 'hr_pto', 'it_support'") @classmethod def load_dataset(cls, path: Path | str) -> list["EvaluationRecord"]: @@ -145,5 +145,4 @@ def save_dataset(cls, records: list["EvaluationRecord"], path: Path | str) -> No """Save records to JSONL file.""" path = Path(path) with open(path, "w", encoding="utf-8") as f: - for record in records: - f.write(record.model_dump_json() + "\n") + f.writelines(record.model_dump_json() + "\n" for record in records) diff --git a/src/eva/models/results.py b/src/eva/models/results.py index a10985f8..ef088813 100644 --- a/src/eva/models/results.py +++ b/src/eva/models/results.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -21,7 +21,7 @@ class ErrorDetails(BaseModel): retry_count: int = Field(0, description="Number of retry attempts") retry_succeeded: bool = Field(False, description="Whether retry succeeded") timestamps: list[str] = Field(default_factory=list, description="Timestamp of each attempt") - stack_trace: Optional[str] = Field(None, description="Stack trace if available") + stack_trace: str | None = Field(None, description="Stack trace if available") original_error: str = Field(..., description="Original error message") @@ -40,15 +40,15 @@ class ConversationResult(BaseModel): record_id: str = Field(..., description="ID of the evaluation record") completed: bool = Field(..., description="Whether the conversation completed successfully") - error: Optional[str] = Field(None, description="Error message if failed") - error_details: Optional[ErrorDetails] = Field( + error: str | None = Field(None, description="Error message if failed") + error_details: ErrorDetails | None = Field( None, description="Detailed error information (new field, optional for backward compatibility)" ) # Latency statistics - llm_latency: Optional[LatencyStats] = Field(None, description="LLM latency statistics") - stt_latency: Optional[LatencyStats] = Field(None, description="STT latency statistics") - tts_latency: Optional[LatencyStats] = Field(None, description="TTS latency statistics") + llm_latency: LatencyStats | None = Field(None, description="LLM latency statistics") + stt_latency: LatencyStats | None = Field(None, description="STT latency statistics") + tts_latency: LatencyStats | None = Field(None, description="TTS latency statistics") # Timing started_at: datetime = Field(..., description="When the conversation started") @@ -57,25 +57,25 @@ class ConversationResult(BaseModel): # Paths to outputs output_dir: str = Field(..., description="Path to output directory for this record") - audio_assistant_path: Optional[str] = Field(None, description="Path to assistant audio file") - audio_user_path: Optional[str] = Field(None, description="Path to user audio file") - audio_mixed_path: Optional[str] = Field(None, description="Path to mixed audio file") - transcript_path: Optional[str] = Field(None, description="Path to transcript JSONL file") - audit_log_path: Optional[str] = Field(None, description="Path to audit log JSON file") - conversation_log_path: Optional[str] = Field(None, description="Path to conversation log file") - pipecat_logs_path: Optional[str] = Field(None, description="Path to pipecat logs JSONL file") - elevenlabs_logs_path: Optional[str] = Field(None, description="Path to elevenlabs logs JSONL file") + audio_assistant_path: str | None = Field(None, description="Path to assistant audio file") + audio_user_path: str | None = Field(None, description="Path to user audio file") + audio_mixed_path: str | None = Field(None, description="Path to mixed audio file") + transcript_path: str | None = Field(None, description="Path to transcript JSONL file") + audit_log_path: str | None = Field(None, description="Path to audit log JSON file") + conversation_log_path: str | None = Field(None, description="Path to conversation log file") + pipecat_logs_path: str | None = Field(None, description="Path to pipecat logs JSONL file") + elevenlabs_logs_path: str | None = Field(None, description="Path to elevenlabs logs JSONL file") # Summary stats (pre-metrics) num_turns: int = Field(0, description="Number of conversation turns") num_tool_calls: int = Field(0, description="Number of tool calls made") tools_called: list[str] = Field(default_factory=list, description="List of tools that were called") - conversation_ended_reason: Optional[str] = Field( + conversation_ended_reason: str | None = Field( None, description="Reason conversation ended: 'goodbye', 'timeout', 'transfer', 'error'", ) - initial_scenario_db_hash: Optional[str] = Field(None, description="SHA-256 hash of initial scenario database") - final_scenario_db_hash: Optional[str] = Field(None, description="SHA-256 hash of final scenario database") + initial_scenario_db_hash: str | None = Field(None, description="SHA-256 hash of initial scenario database") + final_scenario_db_hash: str | None = Field(None, description="SHA-256 hash of final scenario database") class MetricScore(BaseModel): @@ -108,16 +108,14 @@ class RecordMetrics(BaseModel): model_config = {"extra": "allow"} # Allow extra fields for backwards compatibility record_id: str = Field(..., description="ID of the evaluation record") - context: Optional[dict[str, Any]] = Field( - default=None, description="MetricContext fields used for computing metrics" - ) + context: dict[str, Any] | None = Field(default=None, description="MetricContext fields used for computing metrics") metrics: dict[str, MetricScore] = Field(default_factory=dict, description="Metrics keyed by metric name") aggregate_metrics: dict[str, float | None] = Field( default_factory=dict, description="EVA composite aggregate scores (EVA-A, EVA-X, EVA-overall)", ) - def get_score(self, metric_name: str) -> Optional[float]: + def get_score(self, metric_name: str) -> float | None: """Get the normalized score for a metric, falling back to raw score.""" if metric_name not in self.metrics: return None @@ -126,7 +124,7 @@ def get_score(self, metric_name: str) -> Optional[float]: return None return metric.normalized_score if metric.normalized_score is not None else metric.score - def get_context_field(self, field_name: str) -> Optional[Any]: + def get_context_field(self, field_name: str) -> Any | None: """Safely get a field from context.""" if self.context and isinstance(self.context, dict): return self.context.get(field_name) diff --git a/src/eva/orchestrator/port_pool.py b/src/eva/orchestrator/port_pool.py index 0b118af8..ec86cfb7 100644 --- a/src/eva/orchestrator/port_pool.py +++ b/src/eva/orchestrator/port_pool.py @@ -1,7 +1,6 @@ """Port pool for managing WebSocket server ports.""" import asyncio -from typing import Optional from eva.utils.logging import get_logger @@ -46,7 +45,7 @@ async def initialize(self) -> None: f"(range: {self.base_port}-{self.base_port + self.pool_size - 1})" ) - async def acquire(self, timeout: Optional[float] = None) -> int: + async def acquire(self, timeout: float | None = None) -> int: """Acquire an available port from the pool. Args: @@ -71,7 +70,7 @@ async def acquire(self, timeout: Optional[float] = None) -> int: logger.debug(f"Acquired port {port} ({len(self._in_use)} in use)") return port - except asyncio.TimeoutError: + except TimeoutError: logger.warning(f"Timeout waiting for available port (timeout={timeout}s)") raise @@ -107,10 +106,10 @@ def is_port_in_use(self, port: int) -> bool: class PortPoolContextManager: """Context manager for acquiring and releasing ports.""" - def __init__(self, pool: PortPool, timeout: Optional[float] = None): + def __init__(self, pool: PortPool, timeout: float | None = None): self.pool = pool self.timeout = timeout - self.port: Optional[int] = None + self.port: int | None = None async def __aenter__(self) -> int: self.port = await self.pool.acquire(timeout=self.timeout) diff --git a/src/eva/orchestrator/runner.py b/src/eva/orchestrator/runner.py index ac5a45f3..590d0828 100644 --- a/src/eva/orchestrator/runner.py +++ b/src/eva/orchestrator/runner.py @@ -954,16 +954,15 @@ def _save_results_csv( f.write("record_id,completed,duration_seconds,num_turns,num_tool_calls,ended_reason,error\n") # Successful records - for output_id, result in successful: - f.write( - f"{output_id},true,{result.duration_seconds:.2f}," - f"{result.num_turns},{result.num_tool_calls}," - f"{result.conversation_ended_reason or ''},\n" - ) + f.writelines( + f"{output_id},true,{result.duration_seconds:.2f}," + f"{result.num_turns},{result.num_tool_calls}," + f"{result.conversation_ended_reason or ''},\n" + for output_id, result in successful + ) # Failed records - for record_id in failed_ids: - f.write(f"{record_id},false,0,0,0,error,failed\n") + f.writelines(f"{record_id},false,0,0,0,error,failed\n" for record_id in failed_ids) @classmethod def from_config_file(cls, config_path: Path | str) -> "BenchmarkRunner": diff --git a/src/eva/orchestrator/worker.py b/src/eva/orchestrator/worker.py index bf549d6b..9f96ba2d 100644 --- a/src/eva/orchestrator/worker.py +++ b/src/eva/orchestrator/worker.py @@ -5,7 +5,7 @@ import math from datetime import datetime from pathlib import Path -from typing import Any, Optional +from typing import Any from eva.assistant.server import AssistantServer from eva.models.agents import AgentConfig @@ -106,9 +106,9 @@ async def run(self) -> ConversationResult: logger.info(f"Starting conversation for record {self.record.id} on port {self.port}") - conversation_ended_reason: Optional[str] = None - error: Optional[str] = None - error_details: Optional[ErrorDetails] = None + conversation_ended_reason: str | None = None + error: str | None = None + error_details: ErrorDetails | None = None try: # 1. Start assistant server @@ -126,7 +126,7 @@ async def run(self) -> ConversationResult: timeout=self.config.conversation_timeout_seconds, ) logger.info(f"Conversation {self.record.id} ended: {conversation_ended_reason}") - except asyncio.TimeoutError: + except TimeoutError: conversation_ended_reason = "timeout" logger.warning(f"Conversation {self.record.id} timed out") except asyncio.CancelledError: @@ -288,7 +288,7 @@ async def _cleanup(self) -> None: if self._user_simulator: self._user_simulator = None - def _calculate_stt_latency(self) -> Optional[LatencyStats]: + def _calculate_stt_latency(self) -> LatencyStats | None: """Calculate STT latency statistics from Pipecat metrics. Uses ProcessingMetricsData from pipecat_metrics.jsonl, which measures @@ -330,7 +330,7 @@ def _calculate_stt_latency(self) -> Optional[LatencyStats]: logger.warning(f"Failed to calculate STT latency: {e}") return None - def _calculate_tts_latency(self) -> Optional[LatencyStats]: + def _calculate_tts_latency(self) -> LatencyStats | None: """Calculate TTS latency statistics from Pipecat metrics. Uses TTFBMetricsData (Time To First Byte) from pipecat_metrics.jsonl, @@ -375,7 +375,7 @@ def _calculate_tts_latency(self) -> Optional[LatencyStats]: logger.warning(f"Failed to calculate TTS latency: {e}") return None - def _calculate_llm_latency(self) -> Optional[LatencyStats]: + def _calculate_llm_latency(self) -> LatencyStats | None: """Calculate LLM latency statistics from audit log. LLM latency = time from LLM call start to response completion diff --git a/src/eva/user_simulator/audio_interface.py b/src/eva/user_simulator/audio_interface.py index 2dcee112..b8e0ec0b 100644 --- a/src/eva/user_simulator/audio_interface.py +++ b/src/eva/user_simulator/audio_interface.py @@ -10,7 +10,7 @@ import asyncio import base64 import json -from typing import Callable, Optional +from collections.abc import Callable import websockets from websockets.protocol import State as WebSocketState @@ -71,9 +71,9 @@ def __init__( self, websocket_uri: str, conversation_id: str, - record_callback: Optional[Callable[[str, bytes], None]] = None, + record_callback: Callable[[str, bytes], None] | None = None, event_logger=None, - conversation_done_callback: Optional[Callable[[str], None]] = None, + conversation_done_callback: Callable[[str], None] | None = None, ): """Initialize the audio interface. @@ -447,7 +447,7 @@ async def _continuous_input_stream(self) -> None: chunk = await asyncio.wait_for(self.audio_buffer.get(), timeout=timeout) audio_chunk += chunk consecutive_empty_chunks = 0 # Reset on successful audio - except asyncio.TimeoutError: + except TimeoutError: # In silence mode with short timeout, keep trying until chunk duration elapsed if consecutive_empty_chunks > 0 and remaining_time > NORMAL_POLL_TIMEOUT_S: continue @@ -597,7 +597,7 @@ async def _send_to_assistant(self) -> None: # Reset silence timing when transitioning to audio silence_start_time = None silence_chunks_sent = 0 - except asyncio.TimeoutError: + except TimeoutError: pass # Refresh current_time after queue wait for accurate timing diff --git a/src/eva/user_simulator/client.py b/src/eva/user_simulator/client.py index 51946a48..ff8aa29f 100644 --- a/src/eva/user_simulator/client.py +++ b/src/eva/user_simulator/client.py @@ -8,7 +8,6 @@ import json import os from pathlib import Path -from typing import Optional import httpx from elevenlabs.client import ElevenLabs @@ -65,7 +64,7 @@ def __init__( # State self._conversation = None - self._audio_interface: Optional[BotToBotAudioInterface] = None + self._audio_interface: BotToBotAudioInterface | None = None self._end_reason: str = "unknown" self._conversation_done = asyncio.Event() @@ -203,7 +202,7 @@ async def _run_elevenlabs_conversation(self, api_key: str) -> str: try: await asyncio.wait_for(self._conversation_done.wait(), timeout=self.timeout) logger.info(f"Conversation ended: {self._end_reason}") - except asyncio.TimeoutError: + except TimeoutError: logger.info(f"Conversation timed out after {self.timeout}s") self._end_reason = "timeout" self.event_logger.log_event("timeout", {"duration": self.timeout}) diff --git a/src/eva/user_simulator/event_logger.py b/src/eva/user_simulator/event_logger.py index 2b22d101..c1183e2c 100644 --- a/src/eva/user_simulator/event_logger.py +++ b/src/eva/user_simulator/event_logger.py @@ -147,8 +147,7 @@ def save(self) -> None: self.output_path.parent.mkdir(parents=True, exist_ok=True) with open(self.output_path, "w") as f: - for event in self._events: - f.write(json.dumps(event) + "\n") + f.writelines(json.dumps(event) + "\n" for event in self._events) logger.info(f"Saved {len(self._events)} ElevenLabs events to {self.output_path}") diff --git a/src/eva/utils/__init__.py b/src/eva/utils/__init__.py index 179a1af5..273f5d6d 100644 --- a/src/eva/utils/__init__.py +++ b/src/eva/utils/__init__.py @@ -2,7 +2,6 @@ import logging import sys -from typing import Optional def get_logger(name: str) -> logging.Logger: @@ -19,8 +18,8 @@ def get_logger(name: str) -> logging.Logger: def setup_logging( level: str = "INFO", - log_file: Optional[str] = None, - format_string: Optional[str] = None, + log_file: str | None = None, + format_string: str | None = None, ) -> None: """Set up logging configuration for the application. diff --git a/src/eva/utils/error_handler.py b/src/eva/utils/error_handler.py index 7cf38746..4c7e8b1b 100644 --- a/src/eva/utils/error_handler.py +++ b/src/eva/utils/error_handler.py @@ -7,8 +7,7 @@ import asyncio import traceback from dataclasses import dataclass -from datetime import datetime, timezone -from typing import Optional +from datetime import UTC, datetime from litellm.exceptions import ( # Network/Connection Errors @@ -46,7 +45,7 @@ class ErrorInfo: error_type: str # Maps to ErrorDetails.error_type error_source: str # Provider or component name is_retryable: bool - status_code: Optional[int] + status_code: int | None original_exception: Exception @@ -353,7 +352,7 @@ def create_error_details( is_retryable=error_info.is_retryable, retry_count=retry_count, retry_succeeded=retry_succeeded, - timestamps=[datetime.now(timezone.utc).isoformat()], + timestamps=[datetime.now(UTC).isoformat()], stack_trace=stack_trace, original_error=str(error), ) diff --git a/src/eva/utils/llm_client.py b/src/eva/utils/llm_client.py index dd949982..8d9fb1ad 100644 --- a/src/eva/utils/llm_client.py +++ b/src/eva/utils/llm_client.py @@ -3,7 +3,7 @@ import asyncio import itertools import random -from typing import ClassVar, Optional +from typing import ClassVar from dotenv import load_dotenv @@ -88,7 +88,7 @@ def _calculate_backoff_delay(self, attempt: int) -> float: # Ensure delay is positive return max(0, delay) - async def generate_text(self, messages: list[dict], response_format: Optional[dict] = None) -> str: + async def generate_text(self, messages: list[dict], response_format: dict | None = None) -> str: """Generate text completion with automatic retries. Args: diff --git a/src/eva/utils/logging.py b/src/eva/utils/logging.py index dfc8b8cb..71c78fd0 100644 --- a/src/eva/utils/logging.py +++ b/src/eva/utils/logging.py @@ -3,7 +3,6 @@ import contextvars import logging import sys -from typing import Optional # ContextVar that tracks which record_id the current asyncio task is processing. # Each ConversationWorker sets this at the start of its run() method. @@ -47,8 +46,8 @@ def get_logger(name: str) -> logging.Logger: def setup_logging( level: str = "INFO", - log_file: Optional[str] = None, - format_string: Optional[str] = None, + log_file: str | None = None, + format_string: str | None = None, ) -> None: """Set up logging configuration for the application. diff --git a/src/eva/utils/prompt_manager.py b/src/eva/utils/prompt_manager.py index 56971149..c67db941 100644 --- a/src/eva/utils/prompt_manager.py +++ b/src/eva/utils/prompt_manager.py @@ -10,7 +10,7 @@ """ from pathlib import Path -from typing import Any, Optional +from typing import Any import yaml @@ -36,7 +36,7 @@ class PromptManager: ... ) """ - def __init__(self, prompts_path: Optional[Path | str] = None): + def __init__(self, prompts_path: Path | str | None = None): """Initialize the prompt manager. Args: @@ -63,7 +63,7 @@ def _load_prompts(self) -> None: def _load_single_file(self, file_path: Path) -> None: """Load prompts from a single YAML file.""" try: - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: data = yaml.safe_load(f) or {} self.prompts.update(data) self.loaded_files.append(file_path) @@ -126,7 +126,7 @@ def get_prompt(self, path: str, **variables) -> str: # Global singleton instance -_prompt_manager: Optional[PromptManager] = None +_prompt_manager: PromptManager | None = None def get_prompt_manager() -> PromptManager: diff --git a/src/eva/utils/provenance.py b/src/eva/utils/provenance.py index c66188ca..259eee71 100644 --- a/src/eva/utils/provenance.py +++ b/src/eva/utils/provenance.py @@ -6,7 +6,6 @@ import subprocess import sys from pathlib import Path -from typing import Optional import eva from eva.models.config import RunConfig @@ -17,7 +16,7 @@ logger = get_logger(__name__) -def _run_git_command(args: list[str]) -> Optional[str]: +def _run_git_command(args: list[str]) -> str | None: """Run a git command and return stripped stdout, or None on failure.""" try: result = subprocess.run( @@ -35,7 +34,7 @@ def _run_git_command(args: list[str]) -> Optional[str]: def _get_git_info() -> dict: """Collect git state information.""" - info: dict[str, Optional[str | bool]] = { + info: dict[str, str | bool | None] = { "git_commit_sha": None, "git_branch": None, "git_dirty": None, @@ -59,7 +58,7 @@ def _get_git_info() -> dict: return info -def _find_project_root() -> Optional[Path]: +def _find_project_root() -> Path | None: """Find project root by searching up from this file for pyproject.toml.""" current = Path(__file__).resolve().parent for parent in [current, *current.parents]: @@ -68,7 +67,7 @@ def _find_project_root() -> Optional[Path]: return None -def resolve_tool_module_file(tool_module_path: Optional[str]) -> Optional[Path]: +def resolve_tool_module_file(tool_module_path: str | None) -> Path | None: """Resolve a Python module path to its filesystem path.""" if not tool_module_path: return None @@ -83,7 +82,7 @@ def resolve_tool_module_file(tool_module_path: Optional[str]) -> Optional[Path]: def capture_provenance( config: RunConfig, - tool_module_file: Optional[Path] = None, + tool_module_file: Path | None = None, ) -> RunProvenance: """Capture full provenance for a benchmark run. @@ -130,7 +129,7 @@ def _relative_path(path: Path) -> str: else: logger.warning("Could not find configs/prompts/ directory for provenance") - tool_module_info: Optional[ArtifactInfo] = None + tool_module_info: ArtifactInfo | None = None if tool_module_file is None: tool_module_file = resolve_tool_module_file(config.tool_module_path) if tool_module_file and tool_module_file.exists(): @@ -173,7 +172,7 @@ def _relative_path(path: Path) -> str: def capture_metrics_provenance( metric_names: list[str], - run_config: Optional[dict] = None, + run_config: dict | None = None, ) -> MetricsProvenance: """Capture provenance for a metrics computation run. @@ -188,7 +187,7 @@ def capture_metrics_provenance( git_info = _get_git_info() project_root = _find_project_root() - def _make_artifact(path_str: Optional[str], is_dir: bool = False) -> Optional[ArtifactInfo]: + def _make_artifact(path_str: str | None, is_dir: bool = False) -> ArtifactInfo | None: if not path_str: return None path = Path(path_str) @@ -200,10 +199,10 @@ def _make_artifact(path_str: Optional[str], is_dir: bool = False) -> Optional[Ar sha = hash_directory(path) if is_dir else hash_file(path) return ArtifactInfo(path=path_str, sha256=sha) - dataset_info: Optional[ArtifactInfo] = None - agent_config_info: Optional[ArtifactInfo] = None - tool_module_info: Optional[ArtifactInfo] = None - scenario_db_info: Optional[ArtifactInfo] = None + dataset_info: ArtifactInfo | None = None + agent_config_info: ArtifactInfo | None = None + tool_module_info: ArtifactInfo | None = None + scenario_db_info: ArtifactInfo | None = None if run_config: dataset_info = _make_artifact(run_config.get("dataset_path")) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 72e0f49b..2ed400e0 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -106,8 +106,7 @@ def dataset_file(temp_dir, sample_record_data): """Create a temporary dataset JSONL file.""" path = temp_dir / "dataset.jsonl" with open(path, "w") as f: - for record in sample_record_data: - f.write(json.dumps(record) + "\n") + f.writelines(json.dumps(record) + "\n" for record in sample_record_data) return path diff --git a/tests/unit/metrics/test_conversation_finished.py b/tests/unit/metrics/test_conversation_finished.py index e370a4be..0e0c247f 100644 --- a/tests/unit/metrics/test_conversation_finished.py +++ b/tests/unit/metrics/test_conversation_finished.py @@ -16,8 +16,7 @@ def _write_events(self, tmp_path, events: list[dict]) -> str: """Write events to elevenlabs_events.jsonl and return output_dir.""" events_file = tmp_path / "elevenlabs_events.jsonl" with open(events_file, "w") as f: - for event in events: - f.write(json.dumps(event) + "\n") + f.writelines(json.dumps(event) + "\n" for event in events) return str(tmp_path) @pytest.mark.asyncio diff --git a/tests/unit/utils/test_error_handler.py b/tests/unit/utils/test_error_handler.py index fc2aef53..498cc59a 100644 --- a/tests/unit/utils/test_error_handler.py +++ b/tests/unit/utils/test_error_handler.py @@ -1,7 +1,5 @@ """Tests for eva.utils.error_handler module.""" -import asyncio - import httpx import pytest from litellm.exceptions import ( @@ -137,7 +135,7 @@ def test_generic_api_error(self): assert info.is_retryable is True def test_asyncio_timeout(self): - err = asyncio.TimeoutError() + err = TimeoutError() info = categorize_error(err) assert info.error_type == "timeout_error" assert info.error_source == "system" @@ -201,7 +199,7 @@ def test_llm_provider_attribute(self): assert get_error_source(err) == "openai" def test_asyncio_timeout(self): - assert get_error_source(asyncio.TimeoutError()) == "system" + assert get_error_source(TimeoutError()) == "system" def test_tts_providers(self): assert get_error_source(Exception("cartesia error")) == "cartesia" @@ -229,7 +227,7 @@ def test_retryable_litellm_errors(self, cls): assert is_retryable_error(err) is True def test_asyncio_timeout_retryable(self): - assert is_retryable_error(asyncio.TimeoutError()) is True + assert is_retryable_error(TimeoutError()) is True def test_tts_stt_retryable(self): assert is_retryable_error(Exception("cartesia error")) is True