11from __future__ import annotations
22
3- from collections .abc import AsyncIterator , Callable , Iterator , Sequence
3+ from collections .abc import AsyncIterable , AsyncIterator , Callable , Iterator , Sequence
44from contextlib import AbstractAsyncContextManager , asynccontextmanager , contextmanager
55from contextvars import ContextVar
6+ from dataclasses import dataclass
67from datetime import timedelta
78from typing import Any , Literal , overload
89
10+ from pydantic import ConfigDict , with_config
911from pydantic .errors import PydanticUserError
1012from pydantic_core import PydanticSerializationError
11- from temporalio import workflow
13+ from temporalio import activity , workflow
1214from temporalio .common import RetryPolicy
1315from temporalio .workflow import ActivityConfig
1416from typing_extensions import Never
2123)
2224from pydantic_ai ._run_context import AgentDepsT
2325from pydantic_ai .agent import AbstractAgent , AgentRun , AgentRunResult , EventStreamHandler , RunOutputDataT , WrapperAgent
24- from pydantic_ai .durable_exec .temporal ._run_context import TemporalRunContext
2526from pydantic_ai .exceptions import UserError
2627from pydantic_ai .models import Model
2728from pydantic_ai .output import OutputDataT , OutputSpec
2829from pydantic_ai .result import StreamedRunResult
2930from pydantic_ai .settings import ModelSettings
3031from pydantic_ai .tools import (
3132 DeferredToolResults ,
33+ RunContext ,
3234 Tool ,
3335 ToolFuncEither ,
3436)
3537from pydantic_ai .toolsets import AbstractToolset
3638
3739from ._model import TemporalModel
40+ from ._run_context import TemporalRunContext
3841from ._toolset import TemporalWrapperToolset , temporalize_toolset
3942
4043
44+ @dataclass
45+ @with_config (ConfigDict (arbitrary_types_allowed = True ))
46+ class _EventStreamHandlerParams :
47+ event : _messages .AgentStreamEvent
48+ serialized_run_context : Any
49+
50+
4151class TemporalAgent (WrapperAgent [AgentDepsT , OutputDataT ]):
4252 def __init__ (
4353 self ,
@@ -86,6 +96,10 @@ def __init__(
8696 """
8797 super ().__init__ (wrapped )
8898
99+ self ._name = name
100+ self ._event_stream_handler = event_stream_handler
101+ self .run_context_type = run_context_type
102+
89103 # start_to_close_timeout is required
90104 activity_config = activity_config or ActivityConfig (start_to_close_timeout = timedelta (seconds = 60 ))
91105
@@ -97,13 +111,13 @@ def __init__(
97111 PydanticUserError .__name__ ,
98112 ]
99113 activity_config ['retry_policy' ] = retry_policy
114+ self .activity_config = activity_config
100115
101116 model_activity_config = model_activity_config or {}
102117 toolset_activity_config = toolset_activity_config or {}
103118 tool_activity_config = tool_activity_config or {}
104119
105- self ._name = name or wrapped .name
106- if self ._name is None :
120+ if self .name is None :
107121 raise UserError (
108122 "An agent needs to have a unique `name` in order to be used with Temporal. The name will be used to identify the agent's activities within the workflow."
109123 )
@@ -116,13 +130,33 @@ def __init__(
116130 'An agent needs to have a `model` in order to be used with Temporal, it cannot be set at agent run time.'
117131 )
118132
133+ async def event_stream_handler_activity (params : _EventStreamHandlerParams , deps : AgentDepsT ) -> None :
134+ # We can never get here without an `event_stream_handler`, as `TemporalAgent.run_stream` and `TemporalAgent.iter` raise an error saying to use `TemporalAgent.run` instead,
135+ # and that only ends up calling `event_stream_handler` if it is set.
136+ assert self .event_stream_handler is not None
137+
138+ run_context = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
139+
140+ async def streamed_response ():
141+ yield params .event
142+
143+ await self .event_stream_handler (run_context , streamed_response ())
144+
145+ # Set type hint explicitly so that Temporal can take care of serialization and deserialization
146+ event_stream_handler_activity .__annotations__ ['deps' ] = self .deps_type
147+
148+ self .event_stream_handler_activity = activity .defn (name = f'{ activity_name_prefix } __event_stream_handler' )(
149+ event_stream_handler_activity
150+ )
151+ activities .append (self .event_stream_handler_activity )
152+
119153 temporal_model = TemporalModel (
120154 wrapped .model ,
121155 activity_name_prefix = activity_name_prefix ,
122156 activity_config = activity_config | model_activity_config ,
123157 deps_type = self .deps_type ,
124- run_context_type = run_context_type ,
125- event_stream_handler = event_stream_handler or wrapped .event_stream_handler ,
158+ run_context_type = self . run_context_type ,
159+ event_stream_handler = self .event_stream_handler ,
126160 )
127161 activities .extend (temporal_model .temporal_activities )
128162
@@ -139,7 +173,7 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
139173 activity_config | toolset_activity_config .get (id , {}),
140174 tool_activity_config .get (id , {}),
141175 self .deps_type ,
142- run_context_type ,
176+ self . run_context_type ,
143177 )
144178 if isinstance (toolset , TemporalWrapperToolset ):
145179 activities .extend (toolset .temporal_activities )
@@ -155,7 +189,7 @@ def temporalize_toolset(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset
155189
156190 @property
157191 def name (self ) -> str | None :
158- return self ._name
192+ return self ._name or super (). name
159193
160194 @name .setter
161195 def name (self , value : str | None ) -> None : # pragma: no cover
@@ -167,6 +201,33 @@ def name(self, value: str | None) -> None: # pragma: no cover
167201 def model (self ) -> Model :
168202 return self ._model
169203
204+ @property
205+ def event_stream_handler (self ) -> EventStreamHandler [AgentDepsT ] | None :
206+ handler = self ._event_stream_handler or super ().event_stream_handler
207+ if handler is None :
208+ return None
209+ elif workflow .in_workflow ():
210+ return self ._call_event_stream_handler_activity
211+ else :
212+ return handler
213+
214+ async def _call_event_stream_handler_activity (
215+ self , ctx : RunContext [AgentDepsT ], stream : AsyncIterable [_messages .AgentStreamEvent ]
216+ ) -> None :
217+ serialized_run_context = self .run_context_type .serialize_run_context (ctx )
218+ async for event in stream :
219+ await workflow .execute_activity ( # pyright: ignore[reportUnknownMemberType]
220+ activity = self .event_stream_handler_activity ,
221+ args = [
222+ _EventStreamHandlerParams (
223+ event = event ,
224+ serialized_run_context = serialized_run_context ,
225+ ),
226+ ctx .deps ,
227+ ],
228+ ** self .activity_config ,
229+ )
230+
170231 @property
171232 def toolsets (self ) -> Sequence [AbstractToolset [AgentDepsT ]]:
172233 with self ._temporal_overrides ():
@@ -296,7 +357,7 @@ async def main():
296357 usage = usage ,
297358 infer_name = infer_name ,
298359 toolsets = toolsets ,
299- event_stream_handler = event_stream_handler ,
360+ event_stream_handler = event_stream_handler or self . event_stream_handler ,
300361 ** _deprecated_kwargs ,
301362 )
302363
0 commit comments