diff --git a/examples/01_standalone_sdk/29_llm_streaming.py b/examples/01_standalone_sdk/29_llm_streaming.py new file mode 100644 index 0000000000..a896956b5d --- /dev/null +++ b/examples/01_standalone_sdk/29_llm_streaming.py @@ -0,0 +1,131 @@ +import os +import sys +from typing import Literal + +from pydantic import SecretStr + +from openhands.sdk import ( + Conversation, + get_logger, +) +from openhands.sdk.llm import LLM +from openhands.sdk.llm.streaming import ModelResponseStream +from openhands.tools.preset.default import get_default_agent + + +logger = get_logger(__name__) + + +api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") +if not api_key: + raise RuntimeError("Set LLM_API_KEY or OPENAI_API_KEY in your environment.") + +model = os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929") +base_url = os.getenv("LLM_BASE_URL") +llm = LLM( + model=model, + api_key=SecretStr(api_key), + base_url=base_url, + usage_id="stream-demo", + stream=True, +) + +agent = get_default_agent(llm=llm, cli_mode=True) + + +# Define streaming states +StreamingState = Literal["thinking", "content", "tool_name", "tool_args"] +# Track state across on_token calls for boundary detection +_current_state: StreamingState | None = None + + +def on_token(chunk: ModelResponseStream) -> None: + """ + Handle all types of streaming tokens including content, + tool calls, and thinking blocks with dynamic boundary detection. + """ + global _current_state + + choices = chunk.choices + for choice in choices: + delta = choice.delta + if delta is not None: + # Handle thinking blocks (reasoning content) + reasoning_content = getattr(delta, "reasoning_content", None) + if isinstance(reasoning_content, str) and reasoning_content: + if _current_state != "thinking": + if _current_state is not None: + sys.stdout.write("\n") + sys.stdout.write("THINKING: ") + _current_state = "thinking" + sys.stdout.write(reasoning_content) + sys.stdout.flush() + + # Handle regular content + content = getattr(delta, "content", None) + if isinstance(content, str) and content: + if _current_state != "content": + if _current_state is not None: + sys.stdout.write("\n") + sys.stdout.write("CONTENT: ") + _current_state = "content" + sys.stdout.write(content) + sys.stdout.flush() + + # Handle tool calls + tool_calls = getattr(delta, "tool_calls", None) + if tool_calls: + for tool_call in tool_calls: + tool_name = ( + tool_call.function.name if tool_call.function.name else "" + ) + tool_args = ( + tool_call.function.arguments + if tool_call.function.arguments + else "" + ) + if tool_name: + if _current_state != "tool_name": + if _current_state is not None: + sys.stdout.write("\n") + sys.stdout.write("TOOL NAME: ") + _current_state = "tool_name" + sys.stdout.write(tool_name) + sys.stdout.flush() + if tool_args: + if _current_state != "tool_args": + if _current_state is not None: + sys.stdout.write("\n") + sys.stdout.write("TOOL ARGS: ") + _current_state = "tool_args" + sys.stdout.write(tool_args) + sys.stdout.flush() + + +conversation = Conversation( + agent=agent, + workspace=os.getcwd(), + token_callbacks=[on_token], +) + +story_prompt = ( + "Tell me a long story about LLM streaming, write it a file, " + "make sure it has multiple paragraphs. " +) +conversation.send_message(story_prompt) +print("Token Streaming:") +print("-" * 100 + "\n") +conversation.run() + +cleanup_prompt = ( + "Thank you. Please delete the streaming story file now that I've read it, " + "then confirm the deletion." +) +conversation.send_message(cleanup_prompt) +print("Token Streaming:") +print("-" * 100 + "\n") +conversation.run() + +# Report cost +cost = llm.metrics.accumulated_cost +print(f"EXAMPLE_COST: {cost}") diff --git a/openhands-sdk/openhands/sdk/__init__.py b/openhands-sdk/openhands/sdk/__init__.py index 9f58c0a017..4c6c48af2d 100644 --- a/openhands-sdk/openhands/sdk/__init__.py +++ b/openhands-sdk/openhands/sdk/__init__.py @@ -21,11 +21,13 @@ LLM, ImageContent, LLMRegistry, + LLMStreamChunk, Message, RedactedThinkingBlock, RegistryEvent, TextContent, ThinkingBlock, + TokenCallbackType, ) from openhands.sdk.logger import get_logger from openhands.sdk.mcp import ( @@ -58,6 +60,8 @@ __all__ = [ "LLM", "LLMRegistry", + "LLMStreamChunk", + "TokenCallbackType", "ConversationStats", "RegistryEvent", "Message", diff --git a/openhands-sdk/openhands/sdk/agent/agent.py b/openhands-sdk/openhands/sdk/agent/agent.py index c05143a7a7..036fa6bd4f 100644 --- a/openhands-sdk/openhands/sdk/agent/agent.py +++ b/openhands-sdk/openhands/sdk/agent/agent.py @@ -13,6 +13,7 @@ from openhands.sdk.conversation import ( ConversationCallbackType, ConversationState, + ConversationTokenCallbackType, LocalConversation, ) from openhands.sdk.conversation.state import ConversationExecutionStatus @@ -135,6 +136,7 @@ def step( self, conversation: LocalConversation, on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, ) -> None: state = conversation.state # Check for pending actions (implicit confirmation) @@ -167,7 +169,10 @@ def step( try: llm_response = make_llm_completion( - self.llm, _messages, tools=list(self.tools_map.values()) + self.llm, + _messages, + tools=list(self.tools_map.values()), + on_token=on_token, ) except FunctionCallValidationError as e: logger.warning(f"LLM generated malformed function call: {e}") diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index b5619b912d..fe03280e2e 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -20,7 +20,10 @@ if TYPE_CHECKING: from openhands.sdk.conversation import ConversationState, LocalConversation - from openhands.sdk.conversation.types import ConversationCallbackType + from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, + ) logger = get_logger(__name__) @@ -239,6 +242,7 @@ def step( self, conversation: "LocalConversation", on_event: "ConversationCallbackType", + on_token: "ConversationTokenCallbackType | None" = None, ) -> None: """Taking a step in the conversation. @@ -250,6 +254,9 @@ def step( 4.1 If conversation is finished, set state.execution_status to FINISHED 4.2 Otherwise, just return, Conversation will kick off the next step + If the underlying LLM supports streaming, partial deltas are forwarded to + ``on_token`` before the full response is returned. + NOTE: state will be mutated in-place. """ diff --git a/openhands-sdk/openhands/sdk/agent/utils.py b/openhands-sdk/openhands/sdk/agent/utils.py index 80970bca0d..68f5f27cd5 100644 --- a/openhands-sdk/openhands/sdk/agent/utils.py +++ b/openhands-sdk/openhands/sdk/agent/utils.py @@ -12,6 +12,7 @@ from openhands.sdk.context.condenser.base import CondenserBase from openhands.sdk.context.view import View +from openhands.sdk.conversation.types import ConversationTokenCallbackType from openhands.sdk.event.base import Event, LLMConvertibleEvent from openhands.sdk.event.condenser import Condensation from openhands.sdk.llm import LLM, LLMResponse, Message @@ -182,6 +183,7 @@ def make_llm_completion( llm: LLM, messages: list[Message], tools: list[ToolDefinition] | None = None, + on_token: ConversationTokenCallbackType | None = None, ) -> LLMResponse: """Make an LLM completion call with the provided messages and tools. @@ -189,6 +191,7 @@ def make_llm_completion( llm: The LLM instance to use for completion messages: The messages to send to the LLM tools: Optional list of tools to provide to the LLM + on_token: Optional callback for streaming token updates Returns: LLMResponse from the LLM completion call @@ -200,10 +203,12 @@ def make_llm_completion( include=None, store=False, add_security_risk_prediction=True, + on_token=on_token, ) else: return llm.completion( messages=messages, tools=tools or [], add_security_risk_prediction=True, + on_token=on_token, ) diff --git a/openhands-sdk/openhands/sdk/conversation/__init__.py b/openhands-sdk/openhands/sdk/conversation/__init__.py index e411b4b73c..426d57329f 100644 --- a/openhands-sdk/openhands/sdk/conversation/__init__.py +++ b/openhands-sdk/openhands/sdk/conversation/__init__.py @@ -11,7 +11,10 @@ ConversationState, ) from openhands.sdk.conversation.stuck_detector import StuckDetector -from openhands.sdk.conversation.types import ConversationCallbackType +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, +) from openhands.sdk.conversation.visualizer import ( ConversationVisualizerBase, DefaultConversationVisualizer, @@ -24,6 +27,7 @@ "ConversationState", "ConversationExecutionStatus", "ConversationCallbackType", + "ConversationTokenCallbackType", "DefaultConversationVisualizer", "ConversationVisualizerBase", "SecretRegistry", diff --git a/openhands-sdk/openhands/sdk/conversation/base.py b/openhands-sdk/openhands/sdk/conversation/base.py index 79079a7025..57f2f1280e 100644 --- a/openhands-sdk/openhands/sdk/conversation/base.py +++ b/openhands-sdk/openhands/sdk/conversation/base.py @@ -1,12 +1,16 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping from pathlib import Path -from typing import TYPE_CHECKING, Protocol +from typing import TYPE_CHECKING, Protocol, TypeVar, cast from openhands.sdk.conversation.conversation_stats import ConversationStats from openhands.sdk.conversation.events_list_base import EventsListBase from openhands.sdk.conversation.secret_registry import SecretValue -from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationID, + ConversationTokenCallbackType, +) from openhands.sdk.llm.llm import LLM from openhands.sdk.llm.message import Message from openhands.sdk.observability.laminar import ( @@ -27,6 +31,13 @@ from openhands.sdk.conversation.state import ConversationExecutionStatus +CallbackType = TypeVar( + "CallbackType", + ConversationCallbackType, + ConversationTokenCallbackType, +) + + class ConversationStateProtocol(Protocol): """Protocol defining the interface for conversation state objects.""" @@ -235,9 +246,7 @@ def ask_agent(self, question: str) -> str: ... @staticmethod - def compose_callbacks( - callbacks: Iterable[ConversationCallbackType], - ) -> ConversationCallbackType: + def compose_callbacks(callbacks: Iterable[CallbackType]) -> CallbackType: """Compose multiple callbacks into a single callback function. Args: @@ -252,4 +261,4 @@ def composed(event) -> None: if cb: cb(event) - return composed + return cast(CallbackType, composed) diff --git a/openhands-sdk/openhands/sdk/conversation/conversation.py b/openhands-sdk/openhands/sdk/conversation/conversation.py index 09edcf33b1..4dc70cbf1e 100644 --- a/openhands-sdk/openhands/sdk/conversation/conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/conversation.py @@ -4,7 +4,11 @@ from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation.base import BaseConversation from openhands.sdk.conversation.secret_registry import SecretValue -from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationID, + ConversationTokenCallbackType, +) from openhands.sdk.conversation.visualizer import ( ConversationVisualizerBase, DefaultConversationVisualizer, @@ -49,6 +53,7 @@ def __new__( persistence_dir: str | Path | None = None, conversation_id: ConversationID | None = None, callbacks: list[ConversationCallbackType] | None = None, + token_callbacks: list[ConversationTokenCallbackType] | None = None, max_iteration_per_run: int = 500, stuck_detection: bool = True, visualizer: ( @@ -65,6 +70,7 @@ def __new__( workspace: RemoteWorkspace, conversation_id: ConversationID | None = None, callbacks: list[ConversationCallbackType] | None = None, + token_callbacks: list[ConversationTokenCallbackType] | None = None, max_iteration_per_run: int = 500, stuck_detection: bool = True, visualizer: ( @@ -81,6 +87,7 @@ def __new__( persistence_dir: str | Path | None = None, conversation_id: ConversationID | None = None, callbacks: list[ConversationCallbackType] | None = None, + token_callbacks: list[ConversationTokenCallbackType] | None = None, max_iteration_per_run: int = 500, stuck_detection: bool = True, visualizer: ( @@ -104,6 +111,7 @@ def __new__( agent=agent, conversation_id=conversation_id, callbacks=callbacks, + token_callbacks=token_callbacks, max_iteration_per_run=max_iteration_per_run, stuck_detection=stuck_detection, visualizer=visualizer, @@ -115,6 +123,7 @@ def __new__( agent=agent, conversation_id=conversation_id, callbacks=callbacks, + token_callbacks=token_callbacks, max_iteration_per_run=max_iteration_per_run, stuck_detection=stuck_detection, visualizer=visualizer, diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 801b4d25ff..6d3055c61a 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -4,7 +4,6 @@ from pathlib import Path from openhands.sdk.agent.base import AgentBase -from openhands.sdk.agent.utils import make_llm_completion, prepare_llm_messages from openhands.sdk.context.prompts.prompt import render_template from openhands.sdk.conversation.base import BaseConversation from openhands.sdk.conversation.exceptions import ConversationRunError @@ -15,7 +14,11 @@ ) from openhands.sdk.conversation.stuck_detector import StuckDetector from openhands.sdk.conversation.title_utils import generate_conversation_title -from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationID, + ConversationTokenCallbackType, +) from openhands.sdk.conversation.visualizer import ( ConversationVisualizerBase, DefaultConversationVisualizer, @@ -46,6 +49,7 @@ class LocalConversation(BaseConversation): _state: ConversationState _visualizer: ConversationVisualizerBase | None _on_event: ConversationCallbackType + _on_token: ConversationTokenCallbackType | None max_iteration_per_run: int _stuck_detector: StuckDetector | None llm_registry: LLMRegistry @@ -58,6 +62,7 @@ def __init__( persistence_dir: str | Path | None = None, conversation_id: ConversationID | None = None, callbacks: list[ConversationCallbackType] | None = None, + token_callbacks: list[ConversationTokenCallbackType] | None = None, max_iteration_per_run: int = 500, stuck_detection: bool = True, visualizer: ( @@ -78,6 +83,7 @@ def __init__( be used to identify the conversation. The user might want to suffix their persistent filestore with this ID. callbacks: Optional list of callback functions to handle events + token_callbacks: Optional list of callbacks invoked for streaming deltas max_iteration_per_run: Maximum number of iterations per run visualizer: Visualization configuration. Can be: - ConversationVisualizerBase subclass: Class to instantiate @@ -143,6 +149,12 @@ def _default_callback(e): self._visualizer = None self._on_event = BaseConversation.compose_callbacks(composed_list) + self._on_token = ( + BaseConversation.compose_callbacks(token_callbacks) + if token_callbacks + else None + ) + self.max_iteration_per_run = max_iteration_per_run # Initialize stuck detector @@ -305,8 +317,9 @@ def run(self) -> None: ConversationExecutionStatus.RUNNING ) - # step must mutate the SAME state object - self.agent.step(self, on_event=self._on_event) + self.agent.step( + self, on_event=self._on_event, on_token=self._on_token + ) iteration += 1 # Check for non-finished terminal conditions @@ -436,7 +449,7 @@ def close(self) -> None: executable_tool = tool.as_executable() executable_tool.executor.close() except NotImplementedError: - # Tool has no executor, skip it + # Tool has no executor, skip it without erroring continue except Exception as e: logger.warning(f"Error closing executor for tool '{tool.name}': {e}") @@ -456,6 +469,9 @@ def ask_agent(self, question: str) -> str: Returns: A string response from the agent """ + # Import here to avoid circular imports + from openhands.sdk.agent.utils import make_llm_completion, prepare_llm_messages + template_dir = ( Path(__file__).parent.parent.parent / "context" / "prompts" / "templates" ) diff --git a/openhands-sdk/openhands/sdk/conversation/types.py b/openhands-sdk/openhands/sdk/conversation/types.py index d10b085666..f84c4080c3 100644 --- a/openhands-sdk/openhands/sdk/conversation/types.py +++ b/openhands-sdk/openhands/sdk/conversation/types.py @@ -2,9 +2,14 @@ from collections.abc import Callable from openhands.sdk.event.base import Event +from openhands.sdk.llm.streaming import TokenCallbackType ConversationCallbackType = Callable[[Event], None] +"""Type alias for event callback functions.""" + +ConversationTokenCallbackType = TokenCallbackType +"""Callback type invoked for streaming LLM deltas.""" ConversationID = uuid.UUID """Type alias for conversation IDs.""" diff --git a/openhands-sdk/openhands/sdk/llm/__init__.py b/openhands-sdk/openhands/sdk/llm/__init__.py index fabed357d1..63d8d437e6 100644 --- a/openhands-sdk/openhands/sdk/llm/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/__init__.py @@ -12,6 +12,7 @@ content_to_str, ) from openhands.sdk.llm.router import RouterLLM +from openhands.sdk.llm.streaming import LLMStreamChunk, TokenCallbackType from openhands.sdk.llm.utils.metrics import Metrics, MetricsSnapshot from openhands.sdk.llm.utils.unverified_models import ( UNVERIFIED_MODELS_EXCLUDING_BEDROCK, @@ -34,6 +35,8 @@ "RedactedThinkingBlock", "ReasoningItemModel", "content_to_str", + "LLMStreamChunk", + "TokenCallbackType", "Metrics", "MetricsSnapshot", "VERIFIED_MODELS", diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 4fefd69f33..95a00ea076 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -40,6 +40,7 @@ from litellm import ( ChatCompletionToolParam, + CustomStreamWrapper, ResponseInputParam, completion as litellm_completion, ) @@ -72,6 +73,9 @@ from openhands.sdk.llm.mixins.non_native_fc import NonNativeToolCallingMixin from openhands.sdk.llm.options.chat_options import select_chat_options from openhands.sdk.llm.options.responses_options import select_responses_options +from openhands.sdk.llm.streaming import ( + TokenCallbackType, +) from openhands.sdk.llm.utils.metrics import Metrics, MetricsSnapshot from openhands.sdk.llm.utils.model_features import get_default_temperature, get_features from openhands.sdk.llm.utils.retry_mixin import RetryMixin @@ -184,6 +188,14 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): ) ollama_base_url: str | None = Field(default=None) + stream: bool = Field( + default=False, + description=( + "Enable streaming responses from the LLM. " + "When enabled, the provided `on_token` callback in .completions " + "and .responses will be invoked for each chunk of tokens." + ), + ) drop_params: bool = Field(default=True) modify_params: bool = Field( default=True, @@ -447,6 +459,7 @@ def completion( tools: Sequence[ToolDefinition] | None = None, _return_metrics: bool = False, add_security_risk_prediction: bool = False, + on_token: TokenCallbackType | None = None, **kwargs, ) -> LLMResponse: """Generate a completion from the language model. @@ -466,9 +479,11 @@ def completion( >>> response = llm.completion(messages) >>> print(response.content) """ - # Check if streaming is requested - if kwargs.get("stream", False): - raise ValueError("Streaming is not supported") + enable_streaming = bool(kwargs.get("stream", False)) or self.stream + if enable_streaming: + if on_token is None: + raise ValueError("Streaming requires an on_token callback") + kwargs["stream"] = True # 1) serialize messages formatted_messages = self.format_messages_for_llm(messages) @@ -531,7 +546,12 @@ def _one_attempt(**retry_kwargs) -> ModelResponse: self._telemetry.on_request(log_ctx=log_ctx) # Merge retry-modified kwargs (like temperature) with call_kwargs final_kwargs = {**call_kwargs, **retry_kwargs} - resp = self._transport_call(messages=formatted_messages, **final_kwargs) + resp = self._transport_call( + messages=formatted_messages, + **final_kwargs, + enable_streaming=enable_streaming, + on_token=on_token, + ) raw_resp: ModelResponse | None = None if use_mock_tools: raw_resp = copy.deepcopy(resp) @@ -588,15 +608,15 @@ def responses( store: bool | None = None, _return_metrics: bool = False, add_security_risk_prediction: bool = False, + on_token: TokenCallbackType | None = None, **kwargs, ) -> LLMResponse: """Alternative invocation path using OpenAI Responses API via LiteLLM. Maps Message[] -> (instructions, input[]) and returns LLMResponse. - Non-stream only for v1. """ # Streaming not yet supported - if kwargs.get("stream", False): + if kwargs.get("stream", False) or self.stream or on_token is not None: raise ValueError("Streaming is not supported for Responses API yet") # Build instructions + input list using dedicated Responses formatter @@ -707,7 +727,12 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: # Transport + helpers # ========================================================================= def _transport_call( - self, *, messages: list[dict[str, Any]], **kwargs + self, + *, + messages: list[dict[str, Any]], + enable_streaming: bool = False, + on_token: TokenCallbackType | None = None, + **kwargs, ) -> ModelResponse: # litellm.modify_params is GLOBAL; guard it for thread-safety with self._litellm_modify_params_ctx(self.modify_params): @@ -729,6 +754,11 @@ def _transport_call( "ignore", category=UserWarning, ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="Accessing the 'model_fields' attribute.*", + ) # Extract api_key value with type assertion for type checker api_key_value: str | None = None if self.api_key: @@ -747,6 +777,14 @@ def _transport_call( messages=messages, **kwargs, ) + if enable_streaming and on_token is not None: + assert isinstance(ret, CustomStreamWrapper) + chunks = [] + for chunk in ret: + on_token(chunk) + chunks.append(chunk) + ret = litellm.stream_chunk_builder(chunks, messages=messages) + assert isinstance(ret, ModelResponse), ( f"Expected ModelResponse, got {type(ret)}" ) diff --git a/openhands-sdk/openhands/sdk/llm/router/base.py b/openhands-sdk/openhands/sdk/llm/router/base.py index cd908255e6..20a680c259 100644 --- a/openhands-sdk/openhands/sdk/llm/router/base.py +++ b/openhands-sdk/openhands/sdk/llm/router/base.py @@ -10,6 +10,7 @@ from openhands.sdk.llm.llm import LLM from openhands.sdk.llm.llm_response import LLMResponse from openhands.sdk.llm.message import Message +from openhands.sdk.llm.streaming import TokenCallbackType from openhands.sdk.logger import get_logger from openhands.sdk.tool.tool import ToolDefinition @@ -52,6 +53,7 @@ def completion( tools: Sequence[ToolDefinition] | None = None, return_metrics: bool = False, add_security_risk_prediction: bool = False, + on_token: TokenCallbackType | None = None, **kwargs, ) -> LLMResponse: """ @@ -70,6 +72,7 @@ def completion( tools=tools, _return_metrics=return_metrics, add_security_risk_prediction=add_security_risk_prediction, + on_token=on_token, **kwargs, ) diff --git a/openhands-sdk/openhands/sdk/llm/streaming.py b/openhands-sdk/openhands/sdk/llm/streaming.py new file mode 100644 index 0000000000..d160c03037 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/streaming.py @@ -0,0 +1,9 @@ +from collections.abc import Callable + +from litellm.types.utils import ModelResponseStream + + +# Type alias for stream chunks +LLMStreamChunk = ModelResponseStream + +TokenCallbackType = Callable[[LLMStreamChunk], None] diff --git a/tests/cross/test_registry_directories.py b/tests/cross/test_registry_directories.py index 505c250b3e..d4549b872d 100644 --- a/tests/cross/test_registry_directories.py +++ b/tests/cross/test_registry_directories.py @@ -10,7 +10,10 @@ from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation import Conversation, LocalConversation from openhands.sdk.conversation.state import ConversationState -from openhands.sdk.conversation.types import ConversationCallbackType +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, +) from openhands.sdk.event.llm_convertible import SystemPromptEvent from openhands.sdk.llm import LLM, TextContent from openhands.sdk.tool.registry import resolve_tool @@ -38,7 +41,10 @@ def init_state( on_event(event) def step( - self, conversation: LocalConversation, on_event: ConversationCallbackType + self, + conversation: LocalConversation, + on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, ) -> None: pass diff --git a/tests/sdk/agent/test_agent_utils.py b/tests/sdk/agent/test_agent_utils.py index 238d76118b..4e94344776 100644 --- a/tests/sdk/agent/test_agent_utils.py +++ b/tests/sdk/agent/test_agent_utils.py @@ -278,6 +278,7 @@ def test_make_llm_completion_with_completion_api(mock_llm, sample_messages): messages=sample_messages, tools=[], add_security_risk_prediction=True, + on_token=None, ) mock_llm.responses.assert_not_called() @@ -301,6 +302,7 @@ def test_make_llm_completion_with_responses_api(mock_llm, sample_messages): include=None, store=False, add_security_risk_prediction=True, + on_token=None, ) mock_llm.completion.assert_not_called() @@ -324,6 +326,7 @@ def test_make_llm_completion_with_tools_completion_api( messages=sample_messages, tools=sample_tools, add_security_risk_prediction=True, + on_token=None, ) @@ -348,6 +351,7 @@ def test_make_llm_completion_with_tools_responses_api( include=None, store=False, add_security_risk_prediction=True, + on_token=None, ) @@ -367,6 +371,7 @@ def test_make_llm_completion_with_none_tools(mock_llm, sample_messages): messages=sample_messages, tools=[], add_security_risk_prediction=True, + on_token=None, ) @@ -386,6 +391,7 @@ def test_make_llm_completion_with_empty_tools_list(mock_llm, sample_messages): messages=sample_messages, tools=[], add_security_risk_prediction=True, + on_token=None, ) @@ -405,6 +411,7 @@ def test_make_llm_completion_empty_messages(mock_llm): messages=[], tools=[], add_security_risk_prediction=True, + on_token=None, ) @@ -441,6 +448,7 @@ def test_prepare_llm_messages_and_make_llm_completion_integration( messages=sample_messages, tools=[], add_security_risk_prediction=True, + on_token=None, ) @@ -467,6 +475,7 @@ def test_make_llm_completion_api_selection(): messages=messages, tools=[], add_security_risk_prediction=True, + on_token=None, ) mock_llm.responses.assert_not_called() @@ -485,5 +494,6 @@ def test_make_llm_completion_api_selection(): include=None, store=False, add_security_risk_prediction=True, + on_token=None, ) mock_llm.completion.assert_not_called() diff --git a/tests/sdk/conversation/local/test_conversation_default_callback.py b/tests/sdk/conversation/local/test_conversation_default_callback.py index edaf7b0b57..c56b6b9610 100644 --- a/tests/sdk/conversation/local/test_conversation_default_callback.py +++ b/tests/sdk/conversation/local/test_conversation_default_callback.py @@ -3,7 +3,10 @@ from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation import Conversation, LocalConversation from openhands.sdk.conversation.state import ConversationState -from openhands.sdk.conversation.types import ConversationCallbackType +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, +) from openhands.sdk.event.llm_convertible import MessageEvent, SystemPromptEvent from openhands.sdk.llm import LLM, Message, TextContent @@ -24,7 +27,10 @@ def init_state( on_event(event) def step( - self, conversation: LocalConversation, on_event: ConversationCallbackType + self, + conversation: LocalConversation, + on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, ) -> None: on_event( MessageEvent( diff --git a/tests/sdk/conversation/local/test_conversation_id.py b/tests/sdk/conversation/local/test_conversation_id.py index bd9f9285ce..721100b048 100644 --- a/tests/sdk/conversation/local/test_conversation_id.py +++ b/tests/sdk/conversation/local/test_conversation_id.py @@ -5,7 +5,10 @@ from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation import Conversation, LocalConversation from openhands.sdk.conversation.state import ConversationState -from openhands.sdk.conversation.types import ConversationCallbackType +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, +) from openhands.sdk.event.llm_convertible import SystemPromptEvent from openhands.sdk.llm import LLM, TextContent from openhands.sdk.security.confirmation_policy import AlwaysConfirm, NeverConfirm @@ -27,7 +30,10 @@ def init_state( on_event(event) def step( - self, conversation: LocalConversation, on_event: ConversationCallbackType + self, + conversation: LocalConversation, + on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, ) -> None: pass diff --git a/tests/sdk/conversation/local/test_conversation_send_message.py b/tests/sdk/conversation/local/test_conversation_send_message.py index 74409dd10d..e19f87c334 100644 --- a/tests/sdk/conversation/local/test_conversation_send_message.py +++ b/tests/sdk/conversation/local/test_conversation_send_message.py @@ -3,7 +3,10 @@ from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation import Conversation, LocalConversation from openhands.sdk.conversation.state import ConversationState -from openhands.sdk.conversation.types import ConversationCallbackType +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, +) from openhands.sdk.event.llm_convertible import MessageEvent, SystemPromptEvent from openhands.sdk.llm import LLM, Message, TextContent @@ -24,7 +27,10 @@ def init_state( on_event(event) def step( - self, conversation: LocalConversation, on_event: ConversationCallbackType + self, + conversation: LocalConversation, + on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, ) -> None: on_event( MessageEvent( diff --git a/tests/sdk/conversation/local/test_run_exception_includes_conversation_id.py b/tests/sdk/conversation/local/test_run_exception_includes_conversation_id.py index 1c56bdcf21..3d01218f8d 100644 --- a/tests/sdk/conversation/local/test_run_exception_includes_conversation_id.py +++ b/tests/sdk/conversation/local/test_run_exception_includes_conversation_id.py @@ -5,11 +5,20 @@ from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation import Conversation from openhands.sdk.conversation.exceptions import ConversationRunError +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, +) from openhands.sdk.llm import LLM class FailingAgent(AgentBase): - def step(self, conversation, on_event): # noqa: D401, ARG002 + def step( + self, + conversation, + on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, + ): # noqa: D401, ARG002 """Intentionally fail to simulate an unexpected runtime error.""" raise ValueError("boom") diff --git a/tests/sdk/conversation/local/test_state_serialization.py b/tests/sdk/conversation/local/test_state_serialization.py index 0c068a391a..1b289356bd 100644 --- a/tests/sdk/conversation/local/test_state_serialization.py +++ b/tests/sdk/conversation/local/test_state_serialization.py @@ -15,6 +15,10 @@ ConversationExecutionStatus, ConversationState, ) +from openhands.sdk.conversation.types import ( + ConversationCallbackType, + ConversationTokenCallbackType, +) from openhands.sdk.event.llm_convertible import MessageEvent, SystemPromptEvent from openhands.sdk.llm import LLM, Message, TextContent from openhands.sdk.llm.llm_registry import RegistryEvent @@ -438,7 +442,12 @@ def __init__(self): def init_state(self, state, on_event): pass - def step(self, conversation, on_event): + def step( + self, + conversation, + on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, + ): pass llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test-key"), usage_id="test-llm") diff --git a/tests/sdk/llm/test_llm_completion.py b/tests/sdk/llm/test_llm_completion.py index 2a90b2ac37..de0f482816 100644 --- a/tests/sdk/llm/test_llm_completion.py +++ b/tests/sdk/llm/test_llm_completion.py @@ -2,15 +2,17 @@ from collections.abc import Sequence from typing import ClassVar -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest -from litellm import ChatCompletionMessageToolCall +from litellm import ChatCompletionMessageToolCall, CustomStreamWrapper from litellm.types.utils import ( Choices, + Delta, Function, Message as LiteLLMMessage, ModelResponse, + StreamingChoices, Usage, ) from pydantic import SecretStr @@ -107,16 +109,220 @@ def test_llm_completion_basic(mock_completion): def test_llm_streaming_not_supported(default_config): - """Test that streaming is not supported in the basic LLM class.""" + """Test that streaming requires an on_token callback.""" llm = default_config messages = [Message(role="user", content=[TextContent(text="Hello")])] - # Streaming should raise an error - with pytest.raises(ValueError, match="Streaming is not supported"): + # Streaming without callback should raise an error + with pytest.raises(ValueError, match="Streaming requires an on_token callback"): llm.completion(messages=messages, stream=True) +@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm.stream_chunk_builder") +def test_llm_completion_streaming_with_callback(mock_stream_builder, mock_completion): + """Test that streaming with on_token callback works correctly.""" + + # Create stream chunks + chunk1 = ModelResponse( + id="chatcmpl-test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content="Hello", role="assistant"), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion.chunk", + ) + + chunk2 = ModelResponse( + id="chatcmpl-test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=" world!", role=None), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion.chunk", + ) + + chunk3 = ModelResponse( + id="chatcmpl-test", + choices=[ + StreamingChoices( + finish_reason="stop", + index=0, + delta=Delta(content=None, role=None), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion.chunk", + ) + + # Create a mock stream wrapper + mock_stream = MagicMock(spec=CustomStreamWrapper) + mock_stream.__iter__.return_value = iter([chunk1, chunk2, chunk3]) + mock_completion.return_value = mock_stream + + # Mock the stream builder to return a complete response + final_response = create_mock_response("Hello world!") + mock_stream_builder.return_value = final_response + + # Create LLM + llm = LLM( + usage_id="test-llm", + model="gpt-4o", + api_key=SecretStr("test_key"), + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + # Track chunks received by callback + received_chunks = [] + + def on_token(chunk): + received_chunks.append(chunk) + + messages = [Message(role="user", content=[TextContent(text="Hello")])] + response = llm.completion(messages=messages, stream=True, on_token=on_token) + + # Verify callback was invoked for each chunk + assert len(received_chunks) == 3 + assert received_chunks[0] == chunk1 + assert received_chunks[1] == chunk2 + assert received_chunks[2] == chunk3 + + # Verify stream builder was called to assemble final response + mock_stream_builder.assert_called_once() + + # Verify final response + assert response.message.role == "assistant" + assert isinstance(response.message.content[0], TextContent) + assert response.message.content[0].text == "Hello world!" + + +@patch("openhands.sdk.llm.llm.litellm_completion") +@patch("openhands.sdk.llm.llm.litellm.stream_chunk_builder") +def test_llm_completion_streaming_with_tools(mock_stream_builder, mock_completion): + """Test streaming completion with tool calls.""" + + # Create stream chunks with tool call + chunk1 = ModelResponse( + id="chatcmpl-test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + role="assistant", + content=None, + tool_calls=[ + { + "index": 0, + "id": "call_123", + "type": "function", + "function": {"name": "test_tool", "arguments": ""}, + } + ], + ), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion.chunk", + ) + + chunk2 = ModelResponse( + id="chatcmpl-test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=None, + tool_calls=[ + { + "index": 0, + "function": {"arguments": '{"param": "value"}'}, + } + ], + ), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion.chunk", + ) + + chunk3 = ModelResponse( + id="chatcmpl-test", + choices=[ + StreamingChoices( + finish_reason="tool_calls", + index=0, + delta=Delta(content=None), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion.chunk", + ) + + # Create mock stream + mock_stream = MagicMock(spec=CustomStreamWrapper) + mock_stream.__iter__.return_value = iter([chunk1, chunk2, chunk3]) + mock_completion.return_value = mock_stream + + # Mock final response with tool call + final_response = create_mock_response("I'll use the tool") + final_response.choices[0].message.tool_calls = [ # type: ignore + ChatCompletionMessageToolCall( + id="call_123", + type="function", + function=Function( + name="test_tool", + arguments='{"param": "value"}', + ), + ) + ] + mock_stream_builder.return_value = final_response + + llm = LLM( + usage_id="test-llm", + model="gpt-4o", + api_key=SecretStr("test_key"), + ) + + received_chunks = [] + + def on_token(chunk): + received_chunks.append(chunk) + + messages = [Message(role="user", content=[TextContent(text="Use test_tool")])] + tools = list(_MockTool.create()) + + response = llm.completion( + messages=messages, tools=tools, stream=True, on_token=on_token + ) + + # Verify chunks were received + assert len(received_chunks) == 3 + + # Verify final response has tool call + assert response.message.tool_calls is not None + assert len(response.message.tool_calls) == 1 + assert response.message.tool_calls[0].name == "test_tool" + + @patch("openhands.sdk.llm.llm.litellm_completion") def test_llm_completion_with_tools(mock_completion): """Test LLM completion with tools.""" diff --git a/tests/sdk/llm/test_responses_parsing_and_kwargs.py b/tests/sdk/llm/test_responses_parsing_and_kwargs.py index 1b2cb8493a..b47e436fc6 100644 --- a/tests/sdk/llm/test_responses_parsing_and_kwargs.py +++ b/tests/sdk/llm/test_responses_parsing_and_kwargs.py @@ -1,6 +1,9 @@ from unittest.mock import patch -from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse +from litellm.types.llms.openai import ( + ResponseAPIUsage, + ResponsesAPIResponse, +) from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_text import ResponseOutputText @@ -9,7 +12,7 @@ Summary, ) -from openhands.sdk.llm.llm import LLM +from openhands.sdk.llm import LLM from openhands.sdk.llm.message import Message, ReasoningItemModel, TextContent from openhands.sdk.llm.options.responses_options import select_responses_options