Skip to content

Commit ff442dd

Browse files
xingyaowwenystopenhands-agent
authored
Implement streaming for Chat Completions (#1270)
Co-authored-by: Engel Nyst <engel.nyst@gmail.com> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: openhands <openhands@all-hands.dev>
1 parent bbf8cff commit ff442dd

23 files changed

+548
-39
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import os
2+
import sys
3+
from typing import Literal
4+
5+
from pydantic import SecretStr
6+
7+
from openhands.sdk import (
8+
Conversation,
9+
get_logger,
10+
)
11+
from openhands.sdk.llm import LLM
12+
from openhands.sdk.llm.streaming import ModelResponseStream
13+
from openhands.tools.preset.default import get_default_agent
14+
15+
16+
logger = get_logger(__name__)
17+
18+
19+
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY")
20+
if not api_key:
21+
raise RuntimeError("Set LLM_API_KEY or OPENAI_API_KEY in your environment.")
22+
23+
model = os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929")
24+
base_url = os.getenv("LLM_BASE_URL")
25+
llm = LLM(
26+
model=model,
27+
api_key=SecretStr(api_key),
28+
base_url=base_url,
29+
usage_id="stream-demo",
30+
stream=True,
31+
)
32+
33+
agent = get_default_agent(llm=llm, cli_mode=True)
34+
35+
36+
# Define streaming states
37+
StreamingState = Literal["thinking", "content", "tool_name", "tool_args"]
38+
# Track state across on_token calls for boundary detection
39+
_current_state: StreamingState | None = None
40+
41+
42+
def on_token(chunk: ModelResponseStream) -> None:
43+
"""
44+
Handle all types of streaming tokens including content,
45+
tool calls, and thinking blocks with dynamic boundary detection.
46+
"""
47+
global _current_state
48+
49+
choices = chunk.choices
50+
for choice in choices:
51+
delta = choice.delta
52+
if delta is not None:
53+
# Handle thinking blocks (reasoning content)
54+
reasoning_content = getattr(delta, "reasoning_content", None)
55+
if isinstance(reasoning_content, str) and reasoning_content:
56+
if _current_state != "thinking":
57+
if _current_state is not None:
58+
sys.stdout.write("\n")
59+
sys.stdout.write("THINKING: ")
60+
_current_state = "thinking"
61+
sys.stdout.write(reasoning_content)
62+
sys.stdout.flush()
63+
64+
# Handle regular content
65+
content = getattr(delta, "content", None)
66+
if isinstance(content, str) and content:
67+
if _current_state != "content":
68+
if _current_state is not None:
69+
sys.stdout.write("\n")
70+
sys.stdout.write("CONTENT: ")
71+
_current_state = "content"
72+
sys.stdout.write(content)
73+
sys.stdout.flush()
74+
75+
# Handle tool calls
76+
tool_calls = getattr(delta, "tool_calls", None)
77+
if tool_calls:
78+
for tool_call in tool_calls:
79+
tool_name = (
80+
tool_call.function.name if tool_call.function.name else ""
81+
)
82+
tool_args = (
83+
tool_call.function.arguments
84+
if tool_call.function.arguments
85+
else ""
86+
)
87+
if tool_name:
88+
if _current_state != "tool_name":
89+
if _current_state is not None:
90+
sys.stdout.write("\n")
91+
sys.stdout.write("TOOL NAME: ")
92+
_current_state = "tool_name"
93+
sys.stdout.write(tool_name)
94+
sys.stdout.flush()
95+
if tool_args:
96+
if _current_state != "tool_args":
97+
if _current_state is not None:
98+
sys.stdout.write("\n")
99+
sys.stdout.write("TOOL ARGS: ")
100+
_current_state = "tool_args"
101+
sys.stdout.write(tool_args)
102+
sys.stdout.flush()
103+
104+
105+
conversation = Conversation(
106+
agent=agent,
107+
workspace=os.getcwd(),
108+
token_callbacks=[on_token],
109+
)
110+
111+
story_prompt = (
112+
"Tell me a long story about LLM streaming, write it a file, "
113+
"make sure it has multiple paragraphs. "
114+
)
115+
conversation.send_message(story_prompt)
116+
print("Token Streaming:")
117+
print("-" * 100 + "\n")
118+
conversation.run()
119+
120+
cleanup_prompt = (
121+
"Thank you. Please delete the streaming story file now that I've read it, "
122+
"then confirm the deletion."
123+
)
124+
conversation.send_message(cleanup_prompt)
125+
print("Token Streaming:")
126+
print("-" * 100 + "\n")
127+
conversation.run()
128+
129+
# Report cost
130+
cost = llm.metrics.accumulated_cost
131+
print(f"EXAMPLE_COST: {cost}")

openhands-sdk/openhands/sdk/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
LLM,
2222
ImageContent,
2323
LLMRegistry,
24+
LLMStreamChunk,
2425
Message,
2526
RedactedThinkingBlock,
2627
RegistryEvent,
2728
TextContent,
2829
ThinkingBlock,
30+
TokenCallbackType,
2931
)
3032
from openhands.sdk.logger import get_logger
3133
from openhands.sdk.mcp import (
@@ -58,6 +60,8 @@
5860
__all__ = [
5961
"LLM",
6062
"LLMRegistry",
63+
"LLMStreamChunk",
64+
"TokenCallbackType",
6165
"ConversationStats",
6266
"RegistryEvent",
6367
"Message",

openhands-sdk/openhands/sdk/agent/agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from openhands.sdk.conversation import (
1414
ConversationCallbackType,
1515
ConversationState,
16+
ConversationTokenCallbackType,
1617
LocalConversation,
1718
)
1819
from openhands.sdk.conversation.state import ConversationExecutionStatus
@@ -135,6 +136,7 @@ def step(
135136
self,
136137
conversation: LocalConversation,
137138
on_event: ConversationCallbackType,
139+
on_token: ConversationTokenCallbackType | None = None,
138140
) -> None:
139141
state = conversation.state
140142
# Check for pending actions (implicit confirmation)
@@ -167,7 +169,10 @@ def step(
167169

168170
try:
169171
llm_response = make_llm_completion(
170-
self.llm, _messages, tools=list(self.tools_map.values())
172+
self.llm,
173+
_messages,
174+
tools=list(self.tools_map.values()),
175+
on_token=on_token,
171176
)
172177
except FunctionCallValidationError as e:
173178
logger.warning(f"LLM generated malformed function call: {e}")

openhands-sdk/openhands/sdk/agent/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
if TYPE_CHECKING:
2222
from openhands.sdk.conversation import ConversationState, LocalConversation
23-
from openhands.sdk.conversation.types import ConversationCallbackType
23+
from openhands.sdk.conversation.types import (
24+
ConversationCallbackType,
25+
ConversationTokenCallbackType,
26+
)
2427

2528

2629
logger = get_logger(__name__)
@@ -239,6 +242,7 @@ def step(
239242
self,
240243
conversation: "LocalConversation",
241244
on_event: "ConversationCallbackType",
245+
on_token: "ConversationTokenCallbackType | None" = None,
242246
) -> None:
243247
"""Taking a step in the conversation.
244248
@@ -250,6 +254,9 @@ def step(
250254
4.1 If conversation is finished, set state.execution_status to FINISHED
251255
4.2 Otherwise, just return, Conversation will kick off the next step
252256
257+
If the underlying LLM supports streaming, partial deltas are forwarded to
258+
``on_token`` before the full response is returned.
259+
253260
NOTE: state will be mutated in-place.
254261
"""
255262

openhands-sdk/openhands/sdk/agent/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from openhands.sdk.context.condenser.base import CondenserBase
1414
from openhands.sdk.context.view import View
15+
from openhands.sdk.conversation.types import ConversationTokenCallbackType
1516
from openhands.sdk.event.base import Event, LLMConvertibleEvent
1617
from openhands.sdk.event.condenser import Condensation
1718
from openhands.sdk.llm import LLM, LLMResponse, Message
@@ -182,13 +183,15 @@ def make_llm_completion(
182183
llm: LLM,
183184
messages: list[Message],
184185
tools: list[ToolDefinition] | None = None,
186+
on_token: ConversationTokenCallbackType | None = None,
185187
) -> LLMResponse:
186188
"""Make an LLM completion call with the provided messages and tools.
187189
188190
Args:
189191
llm: The LLM instance to use for completion
190192
messages: The messages to send to the LLM
191193
tools: Optional list of tools to provide to the LLM
194+
on_token: Optional callback for streaming token updates
192195
193196
Returns:
194197
LLMResponse from the LLM completion call
@@ -200,10 +203,12 @@ def make_llm_completion(
200203
include=None,
201204
store=False,
202205
add_security_risk_prediction=True,
206+
on_token=on_token,
203207
)
204208
else:
205209
return llm.completion(
206210
messages=messages,
207211
tools=tools or [],
208212
add_security_risk_prediction=True,
213+
on_token=on_token,
209214
)

openhands-sdk/openhands/sdk/conversation/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
ConversationState,
1212
)
1313
from openhands.sdk.conversation.stuck_detector import StuckDetector
14-
from openhands.sdk.conversation.types import ConversationCallbackType
14+
from openhands.sdk.conversation.types import (
15+
ConversationCallbackType,
16+
ConversationTokenCallbackType,
17+
)
1518
from openhands.sdk.conversation.visualizer import (
1619
ConversationVisualizerBase,
1720
DefaultConversationVisualizer,
@@ -24,6 +27,7 @@
2427
"ConversationState",
2528
"ConversationExecutionStatus",
2629
"ConversationCallbackType",
30+
"ConversationTokenCallbackType",
2731
"DefaultConversationVisualizer",
2832
"ConversationVisualizerBase",
2933
"SecretRegistry",

openhands-sdk/openhands/sdk/conversation/base.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Iterable, Mapping
33
from pathlib import Path
4-
from typing import TYPE_CHECKING, Protocol
4+
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
55

66
from openhands.sdk.conversation.conversation_stats import ConversationStats
77
from openhands.sdk.conversation.events_list_base import EventsListBase
88
from openhands.sdk.conversation.secret_registry import SecretValue
9-
from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID
9+
from openhands.sdk.conversation.types import (
10+
ConversationCallbackType,
11+
ConversationID,
12+
ConversationTokenCallbackType,
13+
)
1014
from openhands.sdk.llm.llm import LLM
1115
from openhands.sdk.llm.message import Message
1216
from openhands.sdk.observability.laminar import (
@@ -27,6 +31,13 @@
2731
from openhands.sdk.conversation.state import ConversationExecutionStatus
2832

2933

34+
CallbackType = TypeVar(
35+
"CallbackType",
36+
ConversationCallbackType,
37+
ConversationTokenCallbackType,
38+
)
39+
40+
3041
class ConversationStateProtocol(Protocol):
3142
"""Protocol defining the interface for conversation state objects."""
3243

@@ -235,9 +246,7 @@ def ask_agent(self, question: str) -> str:
235246
...
236247

237248
@staticmethod
238-
def compose_callbacks(
239-
callbacks: Iterable[ConversationCallbackType],
240-
) -> ConversationCallbackType:
249+
def compose_callbacks(callbacks: Iterable[CallbackType]) -> CallbackType:
241250
"""Compose multiple callbacks into a single callback function.
242251
243252
Args:
@@ -252,4 +261,4 @@ def composed(event) -> None:
252261
if cb:
253262
cb(event)
254263

255-
return composed
264+
return cast(CallbackType, composed)

openhands-sdk/openhands/sdk/conversation/conversation.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from openhands.sdk.agent.base import AgentBase
55
from openhands.sdk.conversation.base import BaseConversation
66
from openhands.sdk.conversation.secret_registry import SecretValue
7-
from openhands.sdk.conversation.types import ConversationCallbackType, ConversationID
7+
from openhands.sdk.conversation.types import (
8+
ConversationCallbackType,
9+
ConversationID,
10+
ConversationTokenCallbackType,
11+
)
812
from openhands.sdk.conversation.visualizer import (
913
ConversationVisualizerBase,
1014
DefaultConversationVisualizer,
@@ -49,6 +53,7 @@ def __new__(
4953
persistence_dir: str | Path | None = None,
5054
conversation_id: ConversationID | None = None,
5155
callbacks: list[ConversationCallbackType] | None = None,
56+
token_callbacks: list[ConversationTokenCallbackType] | None = None,
5257
max_iteration_per_run: int = 500,
5358
stuck_detection: bool = True,
5459
visualizer: (
@@ -65,6 +70,7 @@ def __new__(
6570
workspace: RemoteWorkspace,
6671
conversation_id: ConversationID | None = None,
6772
callbacks: list[ConversationCallbackType] | None = None,
73+
token_callbacks: list[ConversationTokenCallbackType] | None = None,
6874
max_iteration_per_run: int = 500,
6975
stuck_detection: bool = True,
7076
visualizer: (
@@ -81,6 +87,7 @@ def __new__(
8187
persistence_dir: str | Path | None = None,
8288
conversation_id: ConversationID | None = None,
8389
callbacks: list[ConversationCallbackType] | None = None,
90+
token_callbacks: list[ConversationTokenCallbackType] | None = None,
8491
max_iteration_per_run: int = 500,
8592
stuck_detection: bool = True,
8693
visualizer: (
@@ -104,6 +111,7 @@ def __new__(
104111
agent=agent,
105112
conversation_id=conversation_id,
106113
callbacks=callbacks,
114+
token_callbacks=token_callbacks,
107115
max_iteration_per_run=max_iteration_per_run,
108116
stuck_detection=stuck_detection,
109117
visualizer=visualizer,
@@ -115,6 +123,7 @@ def __new__(
115123
agent=agent,
116124
conversation_id=conversation_id,
117125
callbacks=callbacks,
126+
token_callbacks=token_callbacks,
118127
max_iteration_per_run=max_iteration_per_run,
119128
stuck_detection=stuck_detection,
120129
visualizer=visualizer,

0 commit comments

Comments
 (0)