diff --git a/docs/11_callbacks.md b/docs/11_callbacks.md index d9b354fa..ce5c8a6f 100644 --- a/docs/11_callbacks.md +++ b/docs/11_callbacks.md @@ -2,6 +2,8 @@ Callbacks provide hooks into the agent's conversation lifecycle, similar to PyTorch Lightning's callback system. Use them for logging, monitoring, custom metrics, or extending agent behavior. +All callbacks live in `askui.callbacks` and can be imported from there. + ## Usage Subclass `ConversationCallback` and override the hooks you need: @@ -71,6 +73,10 @@ with ComputerAgent(callbacks=[TimingCallback()]) as agent: agent.act("Search for documents") ``` +## Built-in Callbacks + +(we will add built-in callbacks at a later stage) + ## Multiple Callbacks Pass multiple callbacks to combine behaviors: diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 56d9ffaa..aa79894f 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -9,6 +9,7 @@ from .agent_base import Agent from .agent_settings import AgentSettings +from .callbacks import ConversationCallback from .computer_agent import ComputerAgent, VisionAgent from .locators import Locator from .models import ( @@ -30,7 +31,6 @@ ToolUseBlockParam, UrlImageSourceParam, ) -from .models.shared.conversation_callback import ConversationCallback from .models.shared.settings import ( DEFAULT_GET_RESOLUTION, DEFAULT_LOCATE_RESOLUTION, diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 5df3e1a2..5d1d4ee4 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -10,11 +10,11 @@ from typing_extensions import Self from askui.agent_settings import AgentSettings +from askui.callbacks import ConversationCallback, UsageTrackingCallback from askui.container import telemetry from askui.locators.locators import Locator from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.conversation import Conversation, Speakers -from askui.models.shared.conversation_callback import ConversationCallback from askui.models.shared.settings import ( ActSettings, CacheWritingSettings, @@ -23,7 +23,6 @@ LocateSettings, ) from askui.models.shared.tools import Tool, ToolCollection -from askui.models.shared.usage_tracking_callback import UsageTrackingCallback from askui.prompts.act_prompts import CACHE_USE_PROMPT, create_default_prompt from askui.tools.agent_os import AgentOs from askui.tools.android.agent_os import AndroidAgentOs diff --git a/src/askui/android_agent.py b/src/askui/android_agent.py index f9779821..b4fb0182 100644 --- a/src/askui/android_agent.py +++ b/src/askui/android_agent.py @@ -6,10 +6,10 @@ from askui.agent_base import Agent from askui.agent_settings import AgentSettings +from askui.callbacks import ConversationCallback from askui.container import telemetry from askui.locators.locators import Locator from askui.models.models import Point -from askui.models.shared.conversation_callback import ConversationCallback from askui.models.shared.settings import ActSettings, MessageSettings from askui.models.shared.tools import Tool from askui.prompts.act_prompts import create_android_agent_prompt diff --git a/src/askui/callbacks/__init__.py b/src/askui/callbacks/__init__.py new file mode 100644 index 00000000..29eb7029 --- /dev/null +++ b/src/askui/callbacks/__init__.py @@ -0,0 +1,7 @@ +from .conversation_callback import ConversationCallback +from .usage_tracking_callback import UsageTrackingCallback + +__all__ = [ + "ConversationCallback", + "UsageTrackingCallback", +] diff --git a/src/askui/models/shared/conversation_callback.py b/src/askui/callbacks/conversation_callback.py similarity index 100% rename from src/askui/models/shared/conversation_callback.py rename to src/askui/callbacks/conversation_callback.py diff --git a/src/askui/models/shared/usage_tracking_callback.py b/src/askui/callbacks/usage_tracking_callback.py similarity index 98% rename from src/askui/models/shared/usage_tracking_callback.py rename to src/askui/callbacks/usage_tracking_callback.py index 5e245434..a04f02d4 100644 --- a/src/askui/models/shared/usage_tracking_callback.py +++ b/src/askui/callbacks/usage_tracking_callback.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from typing_extensions import override -from askui.models.shared.conversation_callback import ConversationCallback +from askui.callbacks.conversation_callback import ConversationCallback from askui.reporting import NULL_REPORTER if TYPE_CHECKING: diff --git a/src/askui/computer_agent.py b/src/askui/computer_agent.py index a71ef63c..d35a97a7 100644 --- a/src/askui/computer_agent.py +++ b/src/askui/computer_agent.py @@ -6,10 +6,10 @@ from askui.agent_base import Agent from askui.agent_settings import AgentSettings +from askui.callbacks import ConversationCallback from askui.container import telemetry from askui.locators.locators import Locator from askui.models.models import Point -from askui.models.shared.conversation_callback import ConversationCallback from askui.models.shared.settings import ActSettings, LocateSettings, MessageSettings from askui.models.shared.tools import Tool from askui.prompts.act_prompts import ( diff --git a/src/askui/models/shared/conversation.py b/src/askui/models/shared/conversation.py index 167df491..9205d259 100644 --- a/src/askui/models/shared/conversation.py +++ b/src/askui/models/shared/conversation.py @@ -22,7 +22,7 @@ from askui.tools.switch_speaker_tool import SwitchSpeakerTool if TYPE_CHECKING: - from askui.models.shared.conversation_callback import ConversationCallback + from askui.callbacks import ConversationCallback from askui.utils.caching.cache_manager import CacheManager logger = logging.getLogger(__name__) @@ -207,9 +207,21 @@ def _execute_control_loop(self) -> None: self._step_index = 0 continue_execution = True while continue_execution: + if self._is_max_steps_reached(): + break continue_execution = self._execute_step() self._on_control_loop_end() + def _is_max_steps_reached(self) -> bool: + if self.settings.max_steps is None: + return False + if self._step_index >= self.settings.max_steps: + msg = ( + f"Reached max_steps limit {self.settings.max_steps}, stopping execution" + ) + raise ConversationException(msg) + return False + @tracer.start_as_current_span("_teardown_control_loop") def _teardown_control_loop(self) -> None: # Finish recording if cache_manager is active and not executing from cache diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 3fc7624a..293d7fb2 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -89,6 +89,7 @@ class ActSettings(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) messages: MessageSettings = Field(default_factory=MessageSettings) + max_steps: int | None = None class GetSettings(BaseModel): diff --git a/src/askui/reporting.py b/src/askui/reporting.py index ec601286..2feea108 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from PIL import Image - from askui.models.shared.usage_tracking_callback import UsageSummary + from askui.callbacks.usage_tracking_callback import UsageSummary def normalize_to_pil_images( diff --git a/tests/unit/model_providers/test_model_pricing.py b/tests/unit/model_providers/test_model_pricing.py index 7fa00c6d..e1b2103b 100644 --- a/tests/unit/model_providers/test_model_pricing.py +++ b/tests/unit/model_providers/test_model_pricing.py @@ -4,11 +4,11 @@ import pytest -from askui.models.shared.agent_message_param import UsageParam -from askui.models.shared.usage_tracking_callback import ( +from askui.callbacks.usage_tracking_callback import ( UsageSummary, UsageTrackingCallback, ) +from askui.models.shared.agent_message_param import UsageParam from askui.utils.model_pricing import ModelPricing