Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/05_bring_your_own_model_provider.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,46 @@ class MyImageQAProvider(ImageQAProvider):
```


### Execution Cost Tracking

The built-in VLM providers include default pricing for supported models. You can override the pricing on any provider by passing `input_cost_per_million_tokens` and `output_cost_per_million_tokens`:

```python
from askui import AgentSettings, ComputerAgent
from askui.model_providers import AnthropicVlmProvider
from askui.reporting import SimpleHtmlReporter

with ComputerAgent(
reporters=[SimpleHtmlReporter()],
settings=AgentSettings(
vlm_provider=AnthropicVlmProvider(
model_id="claude-sonnet-4-6",
input_cost_per_million_tokens=3.0,
output_cost_per_million_tokens=15.0,
),
),
) as agent:
agent.act("Open settings")
```

If you implement a fully custom `VlmProvider`, override the `pricing` property to enable cost tracking:

```python
from askui.model_providers import VlmProvider, ModelPricing

class MyVlmProvider(VlmProvider):
@property
def pricing(self) -> ModelPricing | None:
return ModelPricing(
input_cost_per_million_tokens=1.0,
output_cost_per_million_tokens=5.0,
)

# ... rest of implementation
```

---

## Advanced: Injecting a Custom Client

For full control over HTTP settings (timeouts, proxies, retries), you can inject a pre-configured client:
Expand Down
9 changes: 9 additions & 0 deletions docs/08_reporting.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ This generates an HTML file (typically in the current directory) showing:
SimpleHtmlReporter(output_dir="./execution_reports", filename="agent_run.html")
```

### Execution Cost Tracking

The HTML report automatically shows the estimated API cost when using a `VlmProvider` with pricing information. The built-in Anthropic and AskUI providers include default pricing for supported Claude models.

The report will display:
- Total estimated cost
- Per-token rates used for the calculation
- Input and output token breakdowns (as before)

### Custom Reporters

Create custom reporters by implementing the `Reporter` interface:
Expand Down
7 changes: 6 additions & 1 deletion src/askui/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def __init__(
# Create conversation with speakers and model providers
speakers = Speakers()
_callbacks = list(callbacks or [])
_callbacks.append(UsageTrackingCallback(reporter=self._reporter))
_callbacks.append(
UsageTrackingCallback(
reporter=self._reporter,
pricing=self._vlm_provider.pricing,
)
)
self._conversation = Conversation(
speakers=speakers,
vlm_provider=self._vlm_provider,
Expand Down
2 changes: 2 additions & 0 deletions src/askui/model_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from askui.model_providers.google_image_qa_provider import GoogleImageQAProvider
from askui.model_providers.image_qa_provider import ImageQAProvider
from askui.model_providers.vlm_provider import VlmProvider
from askui.utils.model_pricing import ModelPricing

__all__ = [
"AnthropicImageQAProvider",
Expand All @@ -33,5 +34,6 @@
"DetectionProvider",
"GoogleImageQAProvider",
"ImageQAProvider",
"ModelPricing",
"VlmProvider",
]
18 changes: 18 additions & 0 deletions src/askui/model_providers/anthropic_vlm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from askui.models.shared.prompts import SystemPrompt
from askui.models.shared.tools import ToolCollection
from askui.utils.model_pricing import ModelPricing

_DEFAULT_MODEL_ID = "claude-sonnet-4-6"

Expand All @@ -38,6 +39,11 @@ class AnthropicVlmProvider(VlmProvider):
`\"claude-sonnet-4-6\"`.
client (Anthropic | None, optional): Pre-configured Anthropic client.
If provided, other connection parameters are ignored.
input_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M input tokens. Both cost params must be set
to override the built-in defaults.
output_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M output tokens.

Example:
```python
Expand All @@ -60,6 +66,8 @@ def __init__(
auth_token: str | None = None,
model_id: str | None = None,
client: Anthropic | None = None,
input_cost_per_million_tokens: float | None = None,
output_cost_per_million_tokens: float | None = None,
) -> None:
self._model_id_value = (
model_id or os.environ.get("VLM_PROVIDER_MODEL_ID") or _DEFAULT_MODEL_ID
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mlikasam-askui and @philipph-askui ,

we need to discuaa what is the proper way of Reading and Validating Env Variables.

Expand All @@ -72,12 +80,22 @@ def __init__(
base_url=base_url,
auth_token=auth_token,
)
self._pricing = ModelPricing.for_model(
self._model_id_value,
input_cost_per_million_tokens=input_cost_per_million_tokens,
output_cost_per_million_tokens=output_cost_per_million_tokens,
)

@property
@override
def model_id(self) -> str:
return self._model_id_value

@property
@override
def pricing(self) -> ModelPricing | None:
return self._pricing

@cached_property
def _messages_api(self) -> AnthropicMessagesApi:
"""Lazily initialise the AnthropicMessagesApi on first use."""
Expand Down
1 change: 0 additions & 1 deletion src/askui/model_providers/askui_vlm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class AskUIVlmProvider(VlmProvider):
`"claude-sonnet-4-6"`.
client (Anthropic | None, optional): Pre-configured Anthropic client.
If provided, `workspace_id` and `token` are ignored.

Example:
```python
from askui import AgentSettings, ComputerAgent
Expand Down
10 changes: 10 additions & 0 deletions src/askui/model_providers/vlm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from askui.models.shared.prompts import SystemPrompt
from askui.models.shared.tools import ToolCollection
from askui.utils.model_pricing import ModelPricing


class VlmProvider(ABC):
Expand Down Expand Up @@ -43,6 +44,15 @@ class VlmProvider(ABC):
def model_id(self) -> str:
"""The model identifier used by this provider."""

@property
def pricing(self) -> ModelPricing | None:
"""Pricing information for this provider's model.

Returns ``None`` if no pricing information is available.
Override in subclasses to provide model-specific pricing.
"""
return None

@abstractmethod
def create_message(
self,
Expand Down
105 changes: 87 additions & 18 deletions src/askui/models/shared/usage_tracking_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,69 @@
from typing import TYPE_CHECKING

from opentelemetry import trace
from pydantic import BaseModel
from typing_extensions import override

from askui.models.shared.agent_message_param import UsageParam
from askui.models.shared.conversation_callback import ConversationCallback
from askui.reporting import NULL_REPORTER, Reporter
from askui.reporting import NULL_REPORTER

if TYPE_CHECKING:
from askui.models.shared.agent_message_param import UsageParam
from askui.models.shared.conversation import Conversation
from askui.reporting import Reporter
from askui.speaker.speaker import SpeakerResult
from askui.utils.model_pricing import ModelPricing


class UsageSummary(BaseModel):
"""Accumulated token usage and optional cost breakdown for a conversation.

Args:
input_tokens (int | None): Total input tokens sent to the API.
output_tokens (int | None): Total output tokens generated.
cache_creation_input_tokens (int | None): Tokens used for cache creation.
cache_read_input_tokens (int | None): Tokens read from cache.
input_cost (float | None): Computed input cost in `currency`.
output_cost (float | None): Computed output cost in `currency`.
total_cost (float | None): Sum of `input_cost` and `output_cost`.
currency (str | None): ISO 4217 currency code (e.g. ``"USD"``).
input_cost_per_million_tokens (float | None): Rate used to compute `input_cost`.
output_cost_per_million_tokens (float|None): Rate used to compute `output_cost`.
"""

input_tokens: int | None = None
output_tokens: int | None = None
cache_creation_input_tokens: int | None = None
cache_read_input_tokens: int | None = None
input_cost: float | None = None
output_cost: float | None = None
total_cost: float | None = None
currency: str | None = None
input_cost_per_million_tokens: float | None = None
output_cost_per_million_tokens: float | None = None


class UsageTrackingCallback(ConversationCallback):
"""Tracks token usage per step and reports a summary at conversation end.

Args:
reporter: Reporter to write the final usage summary to.
pricing: Pricing information for cost calculation. If ``None``,
no cost data is included in the usage summary.
"""

def __init__(self, reporter: Reporter = NULL_REPORTER) -> None:
def __init__(
self,
reporter: Reporter = NULL_REPORTER,
pricing: ModelPricing | None = None,
) -> None:
self._reporter = reporter
self._accumulated_usage = UsageParam()
self._pricing = pricing
self._summary = UsageSummary()

@override
def on_conversation_start(self, conversation: Conversation) -> None:
self._accumulated_usage = UsageParam()
self._summary = UsageSummary()

@override
def on_step_end(
Expand All @@ -43,27 +81,29 @@ def on_step_end(

@override
def on_conversation_end(self, conversation: Conversation) -> None:
self._reporter.add_usage_summary(self._accumulated_usage.model_dump())
self._reporter.add_usage_summary(self._summary)

@property
def accumulated_usage(self) -> UsageParam:
def accumulated_usage(self) -> UsageSummary:
"""Current accumulated usage statistics."""
return self._accumulated_usage
return self._summary

def _accumulate(self, step_usage: UsageParam) -> None:
self._accumulated_usage.input_tokens = (
self._accumulated_usage.input_tokens or 0
) + (step_usage.input_tokens or 0)
self._accumulated_usage.output_tokens = (
self._accumulated_usage.output_tokens or 0
) + (step_usage.output_tokens or 0)
self._accumulated_usage.cache_creation_input_tokens = (
self._accumulated_usage.cache_creation_input_tokens or 0
# Add step tokens to running totals (None counts as 0)
self._summary.input_tokens = (self._summary.input_tokens or 0) + (
step_usage.input_tokens or 0
)
self._summary.output_tokens = (self._summary.output_tokens or 0) + (
step_usage.output_tokens or 0
)
self._summary.cache_creation_input_tokens = (
self._summary.cache_creation_input_tokens or 0
) + (step_usage.cache_creation_input_tokens or 0)
self._accumulated_usage.cache_read_input_tokens = (
self._accumulated_usage.cache_read_input_tokens or 0
self._summary.cache_read_input_tokens = (
self._summary.cache_read_input_tokens or 0
) + (step_usage.cache_read_input_tokens or 0)

# Record per-step token counts on the current OTel span
current_span = trace.get_current_span()
current_span.set_attributes(
{
Expand All @@ -75,3 +115,32 @@ def _accumulate(self, step_usage: UsageParam) -> None:
"cache_read_input_tokens": (step_usage.cache_read_input_tokens or 0),
}
)

# Update costs from updated totals if pricing values are set
if not (
self._pricing
and self._pricing.input_cost_per_million_tokens
and self._pricing.output_cost_per_million_tokens
):
return

input_cost = (
self._summary.input_tokens
* self._pricing.input_cost_per_million_tokens
/ 1e6
)
output_cost = (
self._summary.output_tokens
* self._pricing.output_cost_per_million_tokens
/ 1e6
)
self._summary.input_cost = input_cost
self._summary.output_cost = output_cost
self._summary.total_cost = input_cost + output_cost
self._summary.currency = self._pricing.currency
self._summary.input_cost_per_million_tokens = (
self._pricing.input_cost_per_million_tokens
)
self._summary.output_cost_per_million_tokens = (
self._pricing.output_cost_per_million_tokens
)
Loading
Loading