diff --git a/docs/05_bring_your_own_model_provider.md b/docs/05_bring_your_own_model_provider.md index 66cb5013..04f17d48 100644 --- a/docs/05_bring_your_own_model_provider.md +++ b/docs/05_bring_your_own_model_provider.md @@ -135,6 +135,46 @@ class MyImageQAProvider(ImageQAProvider): ``` +### Execution Cost Tracking + +The built-in VLM providers include default pricing for supported models. You can override the pricing on any provider by passing `input_cost_per_million_tokens` and `output_cost_per_million_tokens`: + +```python +from askui import AgentSettings, ComputerAgent +from askui.model_providers import AnthropicVlmProvider +from askui.reporting import SimpleHtmlReporter + +with ComputerAgent( + reporters=[SimpleHtmlReporter()], + settings=AgentSettings( + vlm_provider=AnthropicVlmProvider( + model_id="claude-sonnet-4-6", + input_cost_per_million_tokens=3.0, + output_cost_per_million_tokens=15.0, + ), + ), +) as agent: + agent.act("Open settings") +``` + +If you implement a fully custom `VlmProvider`, override the `pricing` property to enable cost tracking: + +```python +from askui.model_providers import VlmProvider, ModelPricing + +class MyVlmProvider(VlmProvider): + @property + def pricing(self) -> ModelPricing | None: + return ModelPricing( + input_cost_per_million_tokens=1.0, + output_cost_per_million_tokens=5.0, + ) + + # ... rest of implementation +``` + +--- + ## Advanced: Injecting a Custom Client For full control over HTTP settings (timeouts, proxies, retries), you can inject a pre-configured client: diff --git a/docs/08_reporting.md b/docs/08_reporting.md index 10a43f32..6041fafd 100644 --- a/docs/08_reporting.md +++ b/docs/08_reporting.md @@ -32,6 +32,15 @@ This generates an HTML file (typically in the current directory) showing: SimpleHtmlReporter(output_dir="./execution_reports", filename="agent_run.html") ``` +### Execution Cost Tracking + +The HTML report automatically shows the estimated API cost when using a `VlmProvider` with pricing information. The built-in Anthropic and AskUI providers include default pricing for supported Claude models. + +The report will display: +- Total estimated cost +- Per-token rates used for the calculation +- Input and output token breakdowns (as before) + ### Custom Reporters Create custom reporters by implementing the `Reporter` interface: diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 3f72d84c..5df3e1a2 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -75,7 +75,12 @@ def __init__( # Create conversation with speakers and model providers speakers = Speakers() _callbacks = list(callbacks or []) - _callbacks.append(UsageTrackingCallback(reporter=self._reporter)) + _callbacks.append( + UsageTrackingCallback( + reporter=self._reporter, + pricing=self._vlm_provider.pricing, + ) + ) self._conversation = Conversation( speakers=speakers, vlm_provider=self._vlm_provider, diff --git a/src/askui/model_providers/__init__.py b/src/askui/model_providers/__init__.py index a4f08f0d..add59506 100644 --- a/src/askui/model_providers/__init__.py +++ b/src/askui/model_providers/__init__.py @@ -23,6 +23,7 @@ from askui.model_providers.google_image_qa_provider import GoogleImageQAProvider from askui.model_providers.image_qa_provider import ImageQAProvider from askui.model_providers.vlm_provider import VlmProvider +from askui.utils.model_pricing import ModelPricing __all__ = [ "AnthropicImageQAProvider", @@ -33,5 +34,6 @@ "DetectionProvider", "GoogleImageQAProvider", "ImageQAProvider", + "ModelPricing", "VlmProvider", ] diff --git a/src/askui/model_providers/anthropic_vlm_provider.py b/src/askui/model_providers/anthropic_vlm_provider.py index b0e5d186..cc7525ec 100644 --- a/src/askui/model_providers/anthropic_vlm_provider.py +++ b/src/askui/model_providers/anthropic_vlm_provider.py @@ -16,6 +16,7 @@ ) from askui.models.shared.prompts import SystemPrompt from askui.models.shared.tools import ToolCollection +from askui.utils.model_pricing import ModelPricing _DEFAULT_MODEL_ID = "claude-sonnet-4-6" @@ -38,6 +39,11 @@ class AnthropicVlmProvider(VlmProvider): `\"claude-sonnet-4-6\"`. client (Anthropic | None, optional): Pre-configured Anthropic client. If provided, other connection parameters are ignored. + input_cost_per_million_tokens (float | None, optional): Override + cost in USD per 1M input tokens. Both cost params must be set + to override the built-in defaults. + output_cost_per_million_tokens (float | None, optional): Override + cost in USD per 1M output tokens. Example: ```python @@ -60,6 +66,8 @@ def __init__( auth_token: str | None = None, model_id: str | None = None, client: Anthropic | None = None, + input_cost_per_million_tokens: float | None = None, + output_cost_per_million_tokens: float | None = None, ) -> None: self._model_id_value = ( model_id or os.environ.get("VLM_PROVIDER_MODEL_ID") or _DEFAULT_MODEL_ID @@ -72,12 +80,22 @@ def __init__( base_url=base_url, auth_token=auth_token, ) + self._pricing = ModelPricing.for_model( + self._model_id_value, + input_cost_per_million_tokens=input_cost_per_million_tokens, + output_cost_per_million_tokens=output_cost_per_million_tokens, + ) @property @override def model_id(self) -> str: return self._model_id_value + @property + @override + def pricing(self) -> ModelPricing | None: + return self._pricing + @cached_property def _messages_api(self) -> AnthropicMessagesApi: """Lazily initialise the AnthropicMessagesApi on first use.""" diff --git a/src/askui/model_providers/askui_vlm_provider.py b/src/askui/model_providers/askui_vlm_provider.py index 5dfc9d29..d149deff 100644 --- a/src/askui/model_providers/askui_vlm_provider.py +++ b/src/askui/model_providers/askui_vlm_provider.py @@ -37,7 +37,6 @@ class AskUIVlmProvider(VlmProvider): `"claude-sonnet-4-6"`. client (Anthropic | None, optional): Pre-configured Anthropic client. If provided, `workspace_id` and `token` are ignored. - Example: ```python from askui import AgentSettings, ComputerAgent diff --git a/src/askui/model_providers/vlm_provider.py b/src/askui/model_providers/vlm_provider.py index fc1f046c..1e98b972 100644 --- a/src/askui/model_providers/vlm_provider.py +++ b/src/askui/model_providers/vlm_provider.py @@ -10,6 +10,7 @@ ) from askui.models.shared.prompts import SystemPrompt from askui.models.shared.tools import ToolCollection +from askui.utils.model_pricing import ModelPricing class VlmProvider(ABC): @@ -43,6 +44,15 @@ class VlmProvider(ABC): def model_id(self) -> str: """The model identifier used by this provider.""" + @property + def pricing(self) -> ModelPricing | None: + """Pricing information for this provider's model. + + Returns ``None`` if no pricing information is available. + Override in subclasses to provide model-specific pricing. + """ + return None + @abstractmethod def create_message( self, diff --git a/src/askui/models/shared/usage_tracking_callback.py b/src/askui/models/shared/usage_tracking_callback.py index 3bc58206..5e245434 100644 --- a/src/askui/models/shared/usage_tracking_callback.py +++ b/src/askui/models/shared/usage_tracking_callback.py @@ -5,15 +5,46 @@ from typing import TYPE_CHECKING from opentelemetry import trace +from pydantic import BaseModel from typing_extensions import override -from askui.models.shared.agent_message_param import UsageParam from askui.models.shared.conversation_callback import ConversationCallback -from askui.reporting import NULL_REPORTER, Reporter +from askui.reporting import NULL_REPORTER if TYPE_CHECKING: + from askui.models.shared.agent_message_param import UsageParam from askui.models.shared.conversation import Conversation + from askui.reporting import Reporter from askui.speaker.speaker import SpeakerResult + from askui.utils.model_pricing import ModelPricing + + +class UsageSummary(BaseModel): + """Accumulated token usage and optional cost breakdown for a conversation. + + Args: + input_tokens (int | None): Total input tokens sent to the API. + output_tokens (int | None): Total output tokens generated. + cache_creation_input_tokens (int | None): Tokens used for cache creation. + cache_read_input_tokens (int | None): Tokens read from cache. + input_cost (float | None): Computed input cost in `currency`. + output_cost (float | None): Computed output cost in `currency`. + total_cost (float | None): Sum of `input_cost` and `output_cost`. + currency (str | None): ISO 4217 currency code (e.g. ``"USD"``). + input_cost_per_million_tokens (float | None): Rate used to compute `input_cost`. + output_cost_per_million_tokens (float|None): Rate used to compute `output_cost`. + """ + + input_tokens: int | None = None + output_tokens: int | None = None + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + input_cost: float | None = None + output_cost: float | None = None + total_cost: float | None = None + currency: str | None = None + input_cost_per_million_tokens: float | None = None + output_cost_per_million_tokens: float | None = None class UsageTrackingCallback(ConversationCallback): @@ -21,15 +52,22 @@ class UsageTrackingCallback(ConversationCallback): Args: reporter: Reporter to write the final usage summary to. + pricing: Pricing information for cost calculation. If ``None``, + no cost data is included in the usage summary. """ - def __init__(self, reporter: Reporter = NULL_REPORTER) -> None: + def __init__( + self, + reporter: Reporter = NULL_REPORTER, + pricing: ModelPricing | None = None, + ) -> None: self._reporter = reporter - self._accumulated_usage = UsageParam() + self._pricing = pricing + self._summary = UsageSummary() @override def on_conversation_start(self, conversation: Conversation) -> None: - self._accumulated_usage = UsageParam() + self._summary = UsageSummary() @override def on_step_end( @@ -43,27 +81,29 @@ def on_step_end( @override def on_conversation_end(self, conversation: Conversation) -> None: - self._reporter.add_usage_summary(self._accumulated_usage.model_dump()) + self._reporter.add_usage_summary(self._summary) @property - def accumulated_usage(self) -> UsageParam: + def accumulated_usage(self) -> UsageSummary: """Current accumulated usage statistics.""" - return self._accumulated_usage + return self._summary def _accumulate(self, step_usage: UsageParam) -> None: - self._accumulated_usage.input_tokens = ( - self._accumulated_usage.input_tokens or 0 - ) + (step_usage.input_tokens or 0) - self._accumulated_usage.output_tokens = ( - self._accumulated_usage.output_tokens or 0 - ) + (step_usage.output_tokens or 0) - self._accumulated_usage.cache_creation_input_tokens = ( - self._accumulated_usage.cache_creation_input_tokens or 0 + # Add step tokens to running totals (None counts as 0) + self._summary.input_tokens = (self._summary.input_tokens or 0) + ( + step_usage.input_tokens or 0 + ) + self._summary.output_tokens = (self._summary.output_tokens or 0) + ( + step_usage.output_tokens or 0 + ) + self._summary.cache_creation_input_tokens = ( + self._summary.cache_creation_input_tokens or 0 ) + (step_usage.cache_creation_input_tokens or 0) - self._accumulated_usage.cache_read_input_tokens = ( - self._accumulated_usage.cache_read_input_tokens or 0 + self._summary.cache_read_input_tokens = ( + self._summary.cache_read_input_tokens or 0 ) + (step_usage.cache_read_input_tokens or 0) + # Record per-step token counts on the current OTel span current_span = trace.get_current_span() current_span.set_attributes( { @@ -75,3 +115,32 @@ def _accumulate(self, step_usage: UsageParam) -> None: "cache_read_input_tokens": (step_usage.cache_read_input_tokens or 0), } ) + + # Update costs from updated totals if pricing values are set + if not ( + self._pricing + and self._pricing.input_cost_per_million_tokens + and self._pricing.output_cost_per_million_tokens + ): + return + + input_cost = ( + self._summary.input_tokens + * self._pricing.input_cost_per_million_tokens + / 1e6 + ) + output_cost = ( + self._summary.output_tokens + * self._pricing.output_cost_per_million_tokens + / 1e6 + ) + self._summary.input_cost = input_cost + self._summary.output_cost = output_cost + self._summary.total_cost = input_cost + output_cost + self._summary.currency = self._pricing.currency + self._summary.input_cost_per_million_tokens = ( + self._pricing.input_cost_per_million_tokens + ) + self._summary.output_cost_per_million_tokens = ( + self._pricing.output_cost_per_million_tokens + ) diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 4388a2cf..ec601286 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import io import json @@ -9,14 +11,18 @@ from importlib.metadata import distributions from io import BytesIO from pathlib import Path -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from jinja2 import Template -from PIL import Image from typing_extensions import TypedDict, override from askui.utils.annotated_image import AnnotatedImage +if TYPE_CHECKING: + from PIL import Image + + from askui.models.shared.usage_tracking_callback import UsageSummary + def normalize_to_pil_images( image: Image.Image | list[Image.Image] | AnnotatedImage | None, @@ -80,15 +86,14 @@ def add_message( raise NotImplementedError @abstractmethod - def add_usage_summary(self, usage: dict[str, int | None]) -> None: + def add_usage_summary(self, usage: UsageSummary) -> None: """Add usage statistics summary to the report. - Called at the end of an act() execution with accumulated token usage. + Called at the end of an ``act()`` execution with accumulated token + usage and optional cost breakdown. Args: - usage (dict[str, int | None]): Accumulated usage statistics containing: - - input_tokens: Total input tokens sent to API - - output_tokens: Total output tokens generated + usage (UsageSummary): Accumulated usage statistics. """ raise NotImplementedError @@ -134,7 +139,7 @@ def add_message( pass @override - def add_usage_summary(self, usage: dict[str, int | None]) -> None: + def add_usage_summary(self, usage: UsageSummary) -> None: pass @override @@ -177,7 +182,7 @@ def add_message( reporter.add_message(role, content, image) @override - def add_usage_summary(self, usage: dict[str, int | None]) -> None: + def add_usage_summary(self, usage: UsageSummary) -> None: """Add usage summary to all reporters.""" for reporter in self._reporters: reporter.add_usage_summary(usage) @@ -215,7 +220,7 @@ def __init__(self, report_dir: str = "reports") -> None: self.report_dir = Path(report_dir) self.messages: list[dict[str, Any]] = [] self.system_info = self._collect_system_info() - self.usage_summary: dict[str, int | None] | None = None + self.usage_summary: UsageSummary | None = None self.cache_original_usage: dict[str, int | None] | None = None self._start_time: datetime | None = None @@ -264,7 +269,7 @@ def add_message( self.messages.append(message) @override - def add_usage_summary(self, usage: dict[str, int | None]) -> None: + def add_usage_summary(self, usage: UsageSummary) -> None: """Store usage summary for inclusion in the report.""" self.usage_summary = usage @@ -790,14 +795,14 @@ def generate(self) -> None: {% endif %} {% if usage_summary is not none %} - {% if usage_summary.get('input_tokens') is not none %} + {% if usage_summary.input_tokens is not none %}