From 7f260fe4260dd99cf17f94c41c26cd9791ed1271 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 14:08:02 -0700 Subject: [PATCH 1/8] Add SAA Temporal Nexus operation handling Expose customizable Temporal Nexus operation handlers and client support. Add activity link conversion, operation token handling, and related tests. --- temporalio/client/_activity.py | 3 + temporalio/client/_client.py | 9 + temporalio/client/_impl.py | 22 +- temporalio/client/_interceptor.py | 4 + temporalio/nexus/__init__.py | 2 + temporalio/nexus/_link_conversion.py | 64 +++- temporalio/nexus/_operation_context.py | 72 +++- temporalio/nexus/_operation_handlers.py | 47 ++- temporalio/nexus/_temporal_client.py | 273 ++++++++++++++ temporalio/nexus/_token.py | 26 +- tests/nexus/test_link_conversion.py | 56 +++ tests/nexus/test_nexus_type_errors.py | 120 ++++++- tests/nexus/test_operation_token.py | 99 +++++- tests/nexus/test_temporal_operation.py | 454 +++++++++++++++++++++++- 14 files changed, 1219 insertions(+), 32 deletions(-) diff --git a/temporalio/client/_activity.py b/temporalio/client/_activity.py index 99c9ede31..15001de62 100644 --- a/temporalio/client/_activity.py +++ b/temporalio/client/_activity.py @@ -691,6 +691,8 @@ def __init__( *, run_id: str | None = None, result_type: type | None = None, + start_activity_response: None + | temporalio.api.workflowservice.v1.StartActivityExecutionResponse = None, ) -> None: """Create activity handle.""" self._client = client @@ -700,6 +702,7 @@ def __init__( self._known_outcome: ( temporalio.api.activity.v1.ActivityExecutionOutcome | None ) = None + self._start_activity_response = start_activity_response @functools.cached_property def _data_converter(self) -> temporalio.converter.DataConverter: diff --git a/temporalio/client/_client.py b/temporalio/client/_client.py index 1d8b8e4f2..b11b3c1bd 100644 --- a/temporalio/client/_client.py +++ b/temporalio/client/_client.py @@ -1484,6 +1484,12 @@ async def start_activity( start_delay: timedelta | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, + # The following options should not be considered part of the public API. They + # are deliberately not exposed in overloads, and are not subject to any + # backwards compatibility guarantees. + callbacks: Sequence[Callback] = [], + links: Sequence[temporalio.api.common.v1.Link] = [], + request_id: str | None = None, ) -> ActivityHandle[ReturnType]: """Start an activity and return its handle. @@ -1542,6 +1548,9 @@ async def start_activity( rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout, priority=priority, + callbacks=callbacks, + links=links, + request_id=request_id, ) ) diff --git a/temporalio/client/_impl.py b/temporalio/client/_impl.py index af221865a..e8b9aa3c3 100644 --- a/temporalio/client/_impl.py +++ b/temporalio/client/_impl.py @@ -237,7 +237,7 @@ async def _build_start_workflow_execution_request( # Links are duplicated on request for compatibility with older server versions. req.links.extend(links) - if temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context(): + if temporalio.nexus._operation_context._in_nexus_backing_start_context(): req.on_conflict_options.attach_request_id = True req.on_conflict_options.attach_completion_callbacks = True req.on_conflict_options.attach_links = True @@ -566,6 +566,7 @@ async def start_activity(self, input: StartActivityInput) -> ActivityHandle[Any] input.id, run_id=resp.run_id, result_type=input.result_type, + start_activity_response=resp, ) async def _build_start_activity_execution_request( @@ -609,6 +610,8 @@ async def _build_start_activity_execution_request( ), ) + if input.request_id: + req.request_id = input.request_id if input.schedule_to_close_timeout is not None: req.schedule_to_close_timeout.FromTimedelta(input.schedule_to_close_timeout) if input.start_to_close_timeout is not None: @@ -644,6 +647,23 @@ async def _build_start_activity_execution_request( # Set priority req.priority.CopyFrom(input.priority._to_proto()) + req.completion_callbacks.extend( + temporalio.api.common.v1.Callback( + nexus=temporalio.api.common.v1.Callback.Nexus( + url=callback.url, + header=callback.headers, + ), + links=input.links, + ) + for callback in input.callbacks + ) + req.links.extend(input.links) + + if temporalio.nexus._operation_context._in_nexus_backing_start_context(): + req.on_conflict_options.attach_request_id = True + req.on_conflict_options.attach_completion_callbacks = True + req.on_conflict_options.attach_links = True + return req async def cancel_activity(self, input: CancelActivityInput) -> None: diff --git a/temporalio/client/_interceptor.py b/temporalio/client/_interceptor.py index 587b802d0..56f131445 100644 --- a/temporalio/client/_interceptor.py +++ b/temporalio/client/_interceptor.py @@ -230,6 +230,10 @@ class StartActivityInput: headers: Mapping[str, temporalio.api.common.v1.Payload] rpc_metadata: Mapping[str, str | bytes] rpc_timeout: timedelta | None + # The following options are experimental and unstable. + callbacks: Sequence[Callback] + links: Sequence[temporalio.api.common.v1.Link] + request_id: str | None @dataclass diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index f1a10767d..e117014c8 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -25,6 +25,7 @@ wait_for_worker_shutdown_sync, ) from ._operation_handlers import ( + CancelActivityOptions, CancelWorkflowRunOptions, TemporalNexusOperationHandler, ) @@ -33,6 +34,7 @@ __all__ = ( "workflow_run_operation", + "CancelActivityOptions", "CancelWorkflowRunOptions", "Info", "LoggerAdapter", diff --git a/temporalio/nexus/_link_conversion.py b/temporalio/nexus/_link_conversion.py index d02b543d9..0e4e6fb30 100644 --- a/temporalio/nexus/_link_conversion.py +++ b/temporalio/nexus/_link_conversion.py @@ -23,6 +23,10 @@ r"^/namespaces/(?P[^/]+)/nexus-operations/(?P[^/]+)$" ) +_ACTIVITY_LINK_URL_PATH_REGEX = re.compile( + r"^/namespaces/(?P[^/]+)/activities/(?P[^/]+)$" +) + _WORFKLOW_LINK_URL_PATH_REGEX = re.compile( r"^/namespaces/(?P[^/]+)/workflows/(?P[^/]+)/(?P[^/]+)/history$" ) @@ -31,6 +35,7 @@ class _LinkType(str, Enum): WORKFLOW = temporalio.api.common.v1.Link.WorkflowEvent.DESCRIPTOR.full_name NEXUS_OPERATION = temporalio.api.common.v1.Link.NexusOperation.DESCRIPTOR.full_name + ACTIVITY = temporalio.api.common.v1.Link.Activity.DESCRIPTOR.full_name LINK_EVENT_ID_PARAM_NAME = "eventID" @@ -84,6 +89,9 @@ def nexus_link_to_temporal_link( case _LinkType.NEXUS_OPERATION: return nexus_link_to_nexus_operation_link(nexus_link) + case _LinkType.ACTIVITY: + return nexus_link_to_activity_link(nexus_link) + def temporal_link_to_nexus_link( temporal_link: temporalio.api.common.v1.Link, @@ -99,7 +107,10 @@ def temporal_link_to_nexus_link( case "nexus_operation": return nexus_operation_to_nexus_link(temporal_link.nexus_operation) - case "activity" | "batch_job": + case "activity": + return activity_link_to_nexus_link(temporal_link.activity) + + case "batch_job": raise NotImplementedError("only workflow links are supported") case None: @@ -165,6 +176,25 @@ def nexus_operation_to_nexus_link( return nexusrpc.Link(url=url, type=_LinkType.NEXUS_OPERATION.value) +def activity_link_to_nexus_link( + activity: temporalio.api.common.v1.Link.Activity, +) -> nexusrpc.Link: + """Convert an Activity link into a nexusrpc link.""" + scheme = "temporal" + namespace = urllib.parse.quote(activity.namespace, safe="") + activity_id = urllib.parse.quote(activity.activity_id, safe="") + path = f"/namespaces/{namespace}/activities/{activity_id}" + + if activity.run_id: + query_params = urllib.parse.urlencode({LINK_RUN_ID_PARAM_NAME: activity.run_id}) + else: + query_params = "" + + url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', query_params, ''))}" + + return nexusrpc.Link(url=url, type=_LinkType.ACTIVITY.value) + + def nexus_link_to_workflow_event_link( link: nexusrpc.Link, ) -> temporalio.api.common.v1.Link | None: @@ -250,6 +280,38 @@ def nexus_link_to_nexus_operation_link( return temporalio.api.common.v1.Link(nexus_operation=nexus_op_link) +def nexus_link_to_activity_link( + nexus_link: nexusrpc.Link, +) -> temporalio.api.common.v1.Link | None: + """Convert a nexus link into a Temporal Activity link.""" + url = urllib.parse.urlparse(nexus_link.url) + match = _ACTIVITY_LINK_URL_PATH_REGEX.match(url.path) + if not match: + logger.warning( + f"Invalid Nexus link: {nexus_link}. Expected path to match {_ACTIVITY_LINK_URL_PATH_REGEX.pattern}" + ) + return None + + query_params = urllib.parse.parse_qs(url.query, keep_blank_values=True) + + match query_params.get(LINK_RUN_ID_PARAM_NAME): + case [run_id_param]: + run_id = run_id_param + case _: + logger.warning( + f"Invalid Nexus link: {nexus_link}. Expected {LINK_RUN_ID_PARAM_NAME} to have exactly 1 value" + ) + return None + + groups = match.groupdict() + activity_link = temporalio.api.common.v1.Link.Activity( + namespace=urllib.parse.unquote(groups["namespace"]), + activity_id=urllib.parse.unquote(groups["activity_id"]), + run_id=run_id, + ) + return temporalio.api.common.v1.Link(activity=activity_link) + + def _event_reference_to_query_params( event_ref: temporalio.api.common.v1.Link.WorkflowEvent.EventReference, ) -> str: diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index e8ead61fe..542df7b5d 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -43,6 +43,7 @@ ) from ._link_conversion import ( + activity_link_to_nexus_link, nexus_link_to_temporal_link, workflow_event_to_nexus_link, workflow_execution_started_event_link_from_workflow_handle, @@ -64,12 +65,12 @@ ContextVar("temporal-cancel-operation-context") ) -# A Nexus start handler might start zero or more workflows as usual using a Temporal client. In -# addition, it may start one "nexus-backing" workflow, using -# WorkflowRunOperationContext.start_workflow. This context is active while the latter is being done. +# A Nexus start handler might start zero or async Temporal actions as usual using a Temporal client. In +# addition, it may start one "nexus-backing" async Temporal action, using +# WorkflowRunOperationContext.start_workflow or methods from TemporalNexusClient. This context is active while the latter is being done. # It is thus a narrower context than _temporal_start_operation_context. -_temporal_nexus_backing_workflow_start_context: ContextVar[bool] = ContextVar( - "temporal-nexus-backing-workflow-start-context" +_temporal_nexus_backing_start_context: ContextVar[bool] = ContextVar( + "temporal-nexus-backing-start-context" ) @@ -168,16 +169,16 @@ def _try_temporal_context() -> ( @contextmanager -def _nexus_backing_workflow_start_context() -> Generator[None]: - token = _temporal_nexus_backing_workflow_start_context.set(True) +def _nexus_backing_start_context() -> Generator[None]: + token = _temporal_nexus_backing_start_context.set(True) try: yield finally: - _temporal_nexus_backing_workflow_start_context.reset(token) + _temporal_nexus_backing_start_context.reset(token) -def _in_nexus_backing_workflow_start_context() -> bool: # type:ignore[reportUnusedClass] - return _temporal_nexus_backing_workflow_start_context.get(False) +def _in_nexus_backing_start_context() -> bool: # type:ignore[reportUnusedClass] + return _temporal_nexus_backing_start_context.get(False) _OperationCtxT = TypeVar("_OperationCtxT", bound=OperationContext) @@ -243,13 +244,13 @@ def _get_callbacks( def _get_links( self, ) -> list[temporalio.api.common.v1.Link]: - event_links: list[temporalio.api.common.v1.Link] = [] + links: list[temporalio.api.common.v1.Link] = [] for inbound_link in self.nexus_context.inbound_links: if link := nexus_link_to_temporal_link(inbound_link): - event_links.append(link) - return event_links + links.append(link) + return links - def _add_outbound_links( + def _add_outbound_workflow_links( self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] ): # If links were not sent in StartWorkflowExecutionResponse then construct them. @@ -279,6 +280,45 @@ def _add_outbound_links( ) return workflow_handle + def _add_outbound_activity_links( + self, activity_handle: temporalio.client.ActivityHandle[Any] + ): + activity_links: list[temporalio.api.common.v1.Link.Activity] = [] + try: + if isinstance( + activity_handle._start_activity_response, + temporalio.api.workflowservice.v1.StartActivityExecutionResponse, + ): + if activity_handle._start_activity_response.HasField("link"): + if activity_handle._start_activity_response.link.HasField( + "activity" + ): + activity_links.append( + activity_handle._start_activity_response.link.activity + ) + if not activity_links: + activity_run_id = activity_handle.run_id + if activity_run_id is None: + raise ValueError( + f"Activity handle {activity_handle} has no run ID. " + "Cannot create Activity link." + ) + activity_links.append( + temporalio.api.common.v1.Link.Activity( + namespace=activity_handle._client.namespace, + activity_id=activity_handle.id, + run_id=activity_run_id, + ) + ) + self.nexus_context.outbound_links.extend( + activity_link_to_nexus_link(link) for link in activity_links + ) + except Exception as e: + logger.warning( + f"Failed to create Activity link for activity {activity_handle}: {e}" + ) + return activity_handle + class WorkflowRunOperationContext(StartOperationContext): """Context received by a workflow run operation.""" @@ -642,7 +682,7 @@ async def _start_nexus_backing_workflow( # namespace to deliver the result to the caller namespace when the workflow reaches a # terminal state) and inbound links to the caller workflow (attached to history events of # the workflow started in the handler namespace, and displayed in the UI). - with _nexus_backing_workflow_start_context(): + with _nexus_backing_start_context(): wf_handle = await temporal_context.client.start_workflow( # type: ignore workflow=workflow, arg=arg, @@ -674,6 +714,6 @@ async def _start_nexus_backing_workflow( request_id=temporal_context.nexus_context.request_id, ) - temporal_context._add_outbound_links(wf_handle) + temporal_context._add_outbound_workflow_links(wf_handle) return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index c3e4b2e5e..efb895406 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -138,6 +138,21 @@ class CancelWorkflowRunOptions: """The ID of the workflow to cancel.""" +@dataclass(frozen=True) +class CancelActivityOptions: + """Options for cancelling the activity backing a Nexus operation. + + These options are built by :py:class:`TemporalNexusOperationHandler` and passed to + :py:meth:`TemporalNexusOperationHandler.cancel_activity`. + + .. warning:: + This API is experimental and unstable. + """ + + activity_id: str + """The ID of the activity to cancel.""" + + class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT], ABC): """Operation handler for Nexus operations that interact with Temporal. Implementations override the start_operation method. @@ -183,6 +198,7 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: raise HandlerError( "Unable to decode operation token to cancel", type=HandlerErrorType.INTERNAL, + retryable_override=False, ) from err cancel_ctx = TemporalNexusCancelOperationContext._from_cancel_operation_context( @@ -190,10 +206,26 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: ) match operation_token.type: case OperationTokenType.WORKFLOW: - options = CancelWorkflowRunOptions( + if not operation_token.workflow_id: + raise HandlerError( + "Invalid workflow run operation token: missing workflow ID", + type=HandlerErrorType.NOT_FOUND, + ) + wf_cancel_opts = CancelWorkflowRunOptions( workflow_id=operation_token.workflow_id ) - await self.cancel_workflow_run(cancel_ctx, options) + await self.cancel_workflow_run(cancel_ctx, wf_cancel_opts) + + case OperationTokenType.ACTIVITY: + if not operation_token.activity_id: + raise HandlerError( + "Invalid activity operation token: missing activity ID", + type=HandlerErrorType.NOT_FOUND, + ) + activity_cancel_opts = CancelActivityOptions( + activity_id=operation_token.activity_id + ) + await self.cancel_activity(cancel_ctx, activity_cancel_opts) async def cancel_workflow_run( self, @@ -209,3 +241,14 @@ async def cancel_workflow_run( options.workflow_id ) await workflow_handle.cancel() + + async def cancel_activity( + self, + ctx: TemporalNexusCancelOperationContext, # pyright: ignore[reportUnusedParameter] + options: CancelActivityOptions, + ): + """Requests cancellation of the standalone activity identified by activity_id.""" + activity_handle = temporalio.nexus.client().get_activity_handle( + options.activity_id + ) + await activity_handle.cancel() diff --git a/temporalio/nexus/_temporal_client.py b/temporalio/nexus/_temporal_client.py index 08204d89b..128f1aa6f 100644 --- a/temporalio/nexus/_temporal_client.py +++ b/temporalio/nexus/_temporal_client.py @@ -21,10 +21,16 @@ import temporalio.common from temporalio.nexus._operation_context import ( + _nexus_backing_start_context, _start_nexus_backing_workflow, _TemporalStartOperationContext, ) +from temporalio.nexus._token import OperationToken, OperationTokenType from temporalio.types import ( + CallableAsyncNoParam, + CallableAsyncSingleParam, + CallableSyncNoParam, + CallableSyncSingleParam, MethodAsyncNoParam, MethodAsyncSingleParam, MultiParamSpec, @@ -279,6 +285,199 @@ async def start_workflow( """ ... + # async no-param activity + @overload + async def start_activity( + self, + activity: CallableAsyncNoParam[ReturnType], + *, + id: str, + task_queue: str | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # sync no-param activity + @overload + async def start_activity( + self, + activity: CallableSyncNoParam[ReturnType], + *, + id: str, + task_queue: str | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # async single-param activity + @overload + async def start_activity( + self, + activity: CallableAsyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # sync single-param activity + @overload + async def start_activity( + self, + activity: CallableSyncSingleParam[ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # async multi-param activity + @overload + async def start_activity( + self, + activity: Callable[..., Awaitable[ReturnType]], + *, + args: Sequence[Any], + id: str, + task_queue: str | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # sync multi-param activity + @overload + async def start_activity( + self, + activity: Callable[..., ReturnType], + *, + args: Sequence[Any], + id: str, + task_queue: str | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # string-name activity + @overload + async def start_activity( + self, + activity: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type[ReturnType] | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + @abstractmethod + async def start_activity( + self, + activity: ( + str | Callable[..., Awaitable[ReturnType]] | Callable[..., ReturnType] + ), + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: + """Start a standalone activity that will deliver the Nexus operation result. + + If ``task_queue`` is not specified, the Nexus worker's task queue is used. + See :py:meth:`temporalio.client.Client.start_activity` for all other arguments. + """ + ... + class _TemporalNexusClient(TemporalNexusClient): # pyright: ignore[reportUnusedClass] """Nexus-aware wrapper around a Temporal Client. @@ -377,3 +576,77 @@ async def start_workflow( ) return TemporalOperationResult.async_token(wf_handle.to_token()) + + async def start_activity( + self, + activity: ( + str | Callable[..., Awaitable[ReturnType]] | Callable[..., ReturnType] + ), + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type | None = None, + schedule_to_start_timeout: timedelta | None = None, + schedule_to_close_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.ActivityIDReusePolicy = temporalio.common.ActivityIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.ActivityIDConflictPolicy = temporalio.common.ActivityIDConflictPolicy.FAIL, + retry_policy: temporalio.common.RetryPolicy | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> TemporalOperationResult[ReturnType]: + """Start a standalone activity that will deliver the Nexus operation result. + + If ``task_queue`` is not specified, the Nexus worker's task queue is used. + See :py:meth:`temporalio.client.Client.start_activity` for all other arguments. + """ + with self._reserve_async_start(): + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. + + # Here we are starting a "nexus-backing" standalone activity. The start request + # carries the Nexus completion callback so the activity result is delivered to + # the Nexus caller when the activity reaches a terminal state. + with _nexus_backing_start_context(): + activity_handle: temporalio.client.ActivityHandle[ + ReturnType + ] = await self._temporal_context.client.start_activity( # type: ignore + activity=activity, + arg=arg, + args=args, + id=id, + task_queue=task_queue or self._temporal_context.info().task_queue, + result_type=result_type, + schedule_to_start_timeout=schedule_to_start_timeout, + schedule_to_close_timeout=schedule_to_close_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + search_attributes=search_attributes, + summary=summary, + priority=priority, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + callbacks=self._temporal_context._get_callbacks(), + links=self._temporal_context._get_links(), + request_id=self._temporal_context.nexus_context.request_id, + ) + + self._temporal_context._add_outbound_activity_links(activity_handle) + + activity_token = OperationToken( + type=OperationTokenType.ACTIVITY, + namespace=self._temporal_context.client.namespace, + activity_id=activity_handle.id, + ) + + return TemporalOperationResult.async_token(activity_token.encode()) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index d52b54180..b9de54e82 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -14,6 +14,7 @@ class OperationTokenType(IntEnum): """Type discriminator for Nexus operation tokens.""" WORKFLOW = 1 + ACTIVITY = 2 if TYPE_CHECKING: @@ -27,15 +28,19 @@ class OperationToken: version: int | None = None type: OperationTokenType namespace: str - workflow_id: str + workflow_id: str | None = None + activity_id: str | None = None def encode(self) -> str: """Convert handle to a base64url-encoded token string.""" token_details: dict[str, Any] = { "t": self.type, "ns": self.namespace, - "wid": self.workflow_id, } + if self.workflow_id is not None: + token_details["wid"] = self.workflow_id + if self.activity_id is not None: + token_details["aid"] = self.activity_id if self.version is not None: token_details["v"] = self.version return _base64url_encode_no_padding( @@ -83,7 +88,7 @@ def decode(cls, token: str) -> Self: ) workflow_id = token_details.get("wid") - if not isinstance(workflow_id, str): + if workflow_id is not None and not isinstance(workflow_id, str): raise TypeError( f"invalid token: expected workflow id to be a string, got {type(workflow_id)}" ) @@ -93,6 +98,17 @@ def decode(cls, token: str) -> Self: "invalid token: expected non-empty workflow id for token type `WORKFLOW`" ) + activity_id = token_details.get("aid") + if activity_id is not None and not isinstance(activity_id, str): + raise TypeError( + f"invalid token: expected activity id to be a string, got {type(activity_id)}" + ) + + if token_type == OperationTokenType.ACTIVITY and not activity_id: + raise TypeError( + "invalid token: expected non-empty activity id for token type `ACTIVITY`" + ) + namespace = token_details.get("ns") if not isinstance(namespace, str): # Allow empty string for ns, but it must be present and a string @@ -104,6 +120,7 @@ def decode(cls, token: str) -> Self: type=OperationTokenType(token_type), namespace=namespace, workflow_id=workflow_id, + activity_id=activity_id, version=version, ) @@ -168,6 +185,9 @@ def from_token(cls, token: str) -> WorkflowHandle[OutputT]: f"invalid workflow token type: {op_token.type}, expected: {OperationTokenType.WORKFLOW}" ) + if not op_token.workflow_id: + raise TypeError("invalid workflow token: missing workflow id.") + if op_token.version is not None and op_token.version != 0: raise TypeError( "invalid workflow token: 'v' field, if present, must be 0 or null/absent" diff --git a/tests/nexus/test_link_conversion.py b/tests/nexus/test_link_conversion.py index 345d4f4e3..a875a273d 100644 --- a/tests/nexus/test_link_conversion.py +++ b/tests/nexus/test_link_conversion.py @@ -285,6 +285,62 @@ def test_nexus_operation_link_with_duplicate_run_id_is_ignored(): assert temporalio.nexus._link_conversion.nexus_link_to_temporal_link(link) is None +@pytest.mark.parametrize( + ["link", "expected_link"], + [ + ( + nexusrpc.Link( + type=temporalio.api.common.v1.Link.Activity.DESCRIPTOR.full_name, + url="temporal:///namespaces/ns/activities/act-id?runID=run-id", + ), + temporalio.api.common.v1.Link( + activity=temporalio.api.common.v1.Link.Activity( + namespace="ns", + activity_id="act-id", + run_id="run-id", + ), + ), + ), + ( + nexusrpc.Link( + type=temporalio.api.common.v1.Link.Activity.DESCRIPTOR.full_name, + url="temporal:///namespaces/ns%2F/activities/act-id%2F?runID=run-id%3E", + ), + temporalio.api.common.v1.Link( + activity=temporalio.api.common.v1.Link.Activity( + namespace="ns/", + activity_id="act-id/", + run_id="run-id>", + ), + ), + ), + ], +) +def test_link_conversion_nexus_link_to_activity_link( + link: nexusrpc.Link, + expected_link: temporalio.api.common.v1.Link, +): + from_activity_link = temporalio.nexus._link_conversion.activity_link_to_nexus_link( + expected_link.activity + ) + assert link == from_activity_link + + from_temporal_link = temporalio.nexus._link_conversion.temporal_link_to_nexus_link( + expected_link + ) + assert link == from_temporal_link + + actual_activity = temporalio.nexus._link_conversion.nexus_link_to_activity_link( + link + ) + assert expected_link == actual_activity + + actual_temporal_link = ( + temporalio.nexus._link_conversion.nexus_link_to_temporal_link(link) + ) + assert expected_link == actual_temporal_link + + def test_link_conversion_utilities(): p2c = temporalio.nexus._link_conversion._event_type_pascal_case_to_constant_case c2p = temporalio.nexus._link_conversion._event_type_constant_case_to_pascal_case diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index ffdb60c65..749f1e841 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -11,7 +11,7 @@ import nexusrpc import temporalio.nexus -from temporalio import workflow +from temporalio import activity, workflow from temporalio.client import Client, NexusOperationHandle from temporalio.nexus import TemporalNexusOperationStartHandlerFunc from temporalio.service import ServiceClient @@ -71,6 +71,40 @@ async def run( pass +@activity.defn +async def my_no_arg_activity() -> None: + pass + + +@activity.defn +async def my_one_arg_activity(_input: MyInput) -> None: + pass + + +@activity.defn +async def my_two_arg_activity(_input: MyInput, _arg2: int) -> None: + pass + + +@activity.defn +async def my_three_arg_activity(_input: MyInput, _arg2: int, _arg3: int) -> None: + pass + + +@activity.defn +async def my_four_arg_activity( + _input: MyInput, _arg2: int, _arg3: int, _arg4: int +) -> None: + pass + + +@activity.defn +async def my_five_arg_activity( + _input: MyInput, _arg2: int, _arg3: int, _arg4: int, _arg5: int +) -> None: + pass + + @nexusrpc.service class MyService: my_sync_operation: nexusrpc.Operation[MyInput, MyOutput] @@ -105,8 +139,9 @@ async def my_temporal_operation( input: int, ) -> temporalio.nexus.TemporalOperationResult[None]: """ - Typed proc workflow starts from a generic Temporal Nexus operation handler - infer TemporalOperationResult[None] for 0 to 5 workflow parameters. + Typed proc workflow and activity starts from a generic Temporal Nexus + operation handler infer TemporalOperationResult[None] for 0 to 5 + workflow or activity parameters. """ if input == 0: result_0: temporalio.nexus.TemporalOperationResult[ @@ -154,6 +189,85 @@ async def my_temporal_operation( id="proc-5", ) return result_5 + + # Typed activity starts infer TemporalOperationResult[None] for 0 to 5 + # activity parameters. Activities require a start_to_close_timeout. + if input == 6: + activity_result_0: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_activity( + my_no_arg_activity, + id="activity-0", + start_to_close_timeout=timedelta(seconds=5), + ) + return activity_result_0 + if input == 7: + activity_result_1: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_activity( + my_one_arg_activity, + MyInput(), + id="activity-1", + start_to_close_timeout=timedelta(seconds=5), + ) + return activity_result_1 + if input == 8: + activity_result_2: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_activity( + my_two_arg_activity, + args=[MyInput(), 2], + id="activity-2", + start_to_close_timeout=timedelta(seconds=5), + ) + return activity_result_2 + if input == 9: + activity_result_3: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_activity( + my_three_arg_activity, + args=[MyInput(), 2, 3], + id="activity-3", + start_to_close_timeout=timedelta(seconds=5), + ) + return activity_result_3 + if input == 10: + activity_result_4: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_activity( + my_four_arg_activity, + args=[MyInput(), 2, 3, 4], + id="activity-4", + start_to_close_timeout=timedelta(seconds=5), + ) + return activity_result_4 + if input == 11: + activity_result_5: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_activity( + my_five_arg_activity, + args=[MyInput(), 2, 3, 4, 5], + id="activity-5", + start_to_close_timeout=timedelta(seconds=5), + ) + return activity_result_5 + if input == 12: + # omitting the required start_to_close_timeout is a type error + # assert-type-error-pyright: 'No overloads for "start_activity" match' + return await client.start_activity( # type: ignore + my_no_arg_activity, + id="activity-missing-timeout", + ) + if input == 13: + # assert-type-error-pyright: 'No overloads for "start_activity" match' + return await client.start_activity( # type: ignore + my_one_arg_activity, + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter' + "wrong-input-type", # type: ignore + id="activity-wrong-input", + start_to_close_timeout=timedelta(seconds=5), + ) + # assert-type-error-pyright: 'No overloads for "start_workflow" match' return await client.start_workflow( # type: ignore MyOneArgProcWorkflow.run, diff --git a/tests/nexus/test_operation_token.py b/tests/nexus/test_operation_token.py index 385f4f872..028bc67cb 100644 --- a/tests/nexus/test_operation_token.py +++ b/tests/nexus/test_operation_token.py @@ -19,6 +19,11 @@ def _encode_bytes(value: bytes) -> str: return base64.urlsafe_b64encode(value).decode("utf-8").rstrip("=") +def _decode_bytes(value: str) -> bytes: + padding = "=" * (-len(value) % 4) + return base64.urlsafe_b64decode(value + padding) + + def test_operation_token_encode_decode_round_trip(): token = OperationToken( type=OperationTokenType.WORKFLOW, @@ -36,6 +41,37 @@ def test_operation_token_encode_decode_round_trip(): ) +def test_operation_token_activity_encode_decode_round_trip(): + token = OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="default", + activity_id="activity-id", + version=0, + ).encode() + + assert "=" not in token + assert OperationToken.decode(token) == OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="default", + activity_id="activity-id", + version=0, + ) + + +def test_operation_token_activity_encode_uses_activity_id_and_omits_workflow_id(): + token = OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="default", + activity_id="activity-id", + ).encode() + + assert json.loads(_decode_bytes(token)) == { + "t": 2, + "ns": "default", + "aid": "activity-id", + } + + def test_workflow_handle_to_from_token_round_trip(): handle = WorkflowHandle[str](namespace="default", workflow_id="workflow-id") @@ -80,6 +116,42 @@ def test_workflow_handle_to_from_token_round_trip(): version=0, ), ), + # Activity tokens + ( + _encode_json_token({"t": 2, "ns": "default", "aid": "activity-id"}), + OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="default", + activity_id="activity-id", + ), + ), + ( + _encode_json_token({"t": 2, "ns": "", "aid": "activity-id"}), + OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="", + activity_id="activity-id", + ), + ), + ( + _encode_json_token( + {"t": 2, "ns": "default", "aid": "activity-id", "v": None} + ), + OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="default", + activity_id="activity-id", + ), + ), + ( + _encode_json_token({"t": 2, "ns": "default", "aid": "activity-id", "v": 0}), + OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="default", + activity_id="activity-id", + version=0, + ), + ), ], ) def test_operation_token_decode_accepts_valid_tokens( @@ -110,7 +182,7 @@ def test_operation_token_decode_accepts_valid_tokens( ), ( _encode_json_token({"t": 1, "ns": "default"}), - "expected workflow id to be a string", + "expected non-empty workflow id for token type `WORKFLOW`", ), ( _encode_json_token({"t": 1, "ns": "default", "wid": 123}), @@ -118,7 +190,7 @@ def test_operation_token_decode_accepts_valid_tokens( ), ( _encode_json_token({"t": 1, "ns": "default", "wid": ""}), - "expected non-empty workflow id", + "expected non-empty workflow id for token type `WORKFLOW`", ), ( _encode_json_token({"t": 1, "wid": "workflow-id"}), @@ -134,6 +206,29 @@ def test_operation_token_decode_accepts_valid_tokens( ), "expected version to be an int or null", ), + # Activity tokens + ( + _encode_json_token({"t": 2, "ns": "default"}), + "expected non-empty activity id for token type `ACTIVITY`", + ), + ( + _encode_json_token({"t": 2, "ns": "default", "aid": ""}), + "expected non-empty activity id for token type `ACTIVITY`", + ), + ( + _encode_json_token({"t": 2, "ns": "default", "aid": 123}), + "expected activity id to be a string", + ), + ( + _encode_json_token({"t": 2, "aid": "activity-id"}), + "expected namespace to be a string", + ), + ( + _encode_json_token( + {"t": 2, "ns": "default", "aid": "activity-id", "v": "0"} + ), + "expected version to be an int or null", + ), ], ) def test_operation_token_decode_rejects_invalid_tokens(token: str, message: str): diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py index 8bdfa267e..370d4d2c5 100644 --- a/tests/nexus/test_temporal_operation.py +++ b/tests/nexus/test_temporal_operation.py @@ -1,23 +1,57 @@ import asyncio import uuid from dataclasses import dataclass +from datetime import timedelta import nexusrpc import pytest from nexusrpc import HandlerErrorType, Operation, service -from nexusrpc.handler import operation_handler, service_handler +from nexusrpc.handler import ( + CancelOperationContext, + OperationTaskCancellation, + operation_handler, + service_handler, +) from typing_extensions import override import temporalio.exceptions -from temporalio import nexus, workflow -from temporalio.client import Client, WorkflowExecutionStatus, WorkflowFailureError -from temporalio.common import NexusOperationExecutionStatus, WorkflowIDConflictPolicy +from temporalio import activity, nexus, workflow +from temporalio.client import ( + ActivityExecutionStatus, + Client, + NexusOperationFailureError, + WorkflowExecutionStatus, + WorkflowFailureError, +) +from temporalio.common import ( + NexusOperationExecutionStatus, + RetryPolicy, + WorkflowIDConflictPolicy, +) +from temporalio.nexus._token import OperationToken, OperationTokenType from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers import EventType, assert_event_subsequence, assert_eventually from tests.helpers.nexus import make_nexus_endpoint_name +class FakeNexusTaskCancellation(OperationTaskCancellation): + def is_cancelled(self) -> bool: + return False + + def cancellation_reason(self) -> str | None: + return None + + def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool: + return False + + async def wait_until_cancelled(self) -> None: + return None + + def cancel(self, _reason: str) -> bool: + return False + + @dataclass class Input: value: str @@ -53,6 +87,29 @@ async def run(self, input: Input) -> str: return input.value +@activity.defn +async def echo_activity(input: Input) -> str: + return input.value + + +@activity.defn +async def raise_error_activity() -> None: + raise temporalio.exceptions.ApplicationError( + "test-activity-error-message", + type="test-activity-error-type", + non_retryable=True, + ) + + +@activity.defn +async def wait_for_cancel_activity() -> None: + # Heartbeat in a loop so the activity receives cancellation. Letting the + # resulting CancelledError bubble out transitions the activity to CANCELED. + while True: + await asyncio.sleep(0.3) + activity.heartbeat() + + @service class TestService: echo: Operation[Input, str] @@ -62,6 +119,12 @@ class TestService: retry_after_failed_start: Operation[Input, str] sync_result: Operation[Input, str] custom_cancel: Operation[str, None] + echo_activity: Operation[Input, str] + error_activity: Operation[Input, None] + blocking_activity: Operation[str, None] + custom_cancel_activity: Operation[str, None] + double_start_activity: Operation[Input, None] + mixed_start: Operation[Input, None] @service_handler(service=TestService) @@ -71,6 +134,8 @@ class TestServiceHandler: def __init__(self) -> None: self.started_custom_cancel_workflow = asyncio.Event() + self.started_custom_cancel_activity = asyncio.Event() + self.custom_cancel_activity_called = asyncio.Event() @nexus.temporal_operation async def echo( @@ -217,6 +282,130 @@ async def cancel_workflow_run( return CustomCancelNexusOpHandler() + @nexus.temporal_operation + async def echo_activity( + self, + _ctx: nexus.TemporalNexusStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + return await client.start_activity( + echo_activity, + input, + id=f"echo_activity-{uuid.uuid4()}", + start_to_close_timeout=timedelta(seconds=5), + ) + + @nexus.temporal_operation + async def error_activity( + self, + _ctx: nexus.TemporalNexusStartOperationContext, + client: nexus.TemporalNexusClient, + _input: Input, + ) -> nexus.TemporalOperationResult[None]: + # The activity raises immediately. With a single permitted attempt it + # fails the backing activity, which in turn fails the Nexus operation. + return await client.start_activity( + raise_error_activity, + id=f"error_activity-{uuid.uuid4()}", + start_to_close_timeout=timedelta(seconds=5), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + + @nexus.temporal_operation + async def blocking_activity( + self, + _ctx: nexus.TemporalNexusStartOperationContext, + client: nexus.TemporalNexusClient, + input: str, + ) -> nexus.TemporalOperationResult[None]: + return await client.start_activity( + wait_for_cancel_activity, + id=input, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=1), + ) + + @nexus.temporal_operation + async def double_start_activity( + self, + _ctx: nexus.TemporalNexusStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[None]: + await client.start_activity( + echo_activity, + input, + id=f"double-start-activity-{uuid.uuid4()}", + start_to_close_timeout=timedelta(seconds=5), + ) + await client.start_activity( + echo_activity, + input, + id=f"double-start-activity-{uuid.uuid4()}", + start_to_close_timeout=timedelta(seconds=5), + ) + return nexus.TemporalOperationResult.sync(None) + + @nexus.temporal_operation + async def mixed_start( + self, + _ctx: nexus.TemporalNexusStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[None]: + # Starting a workflow reserves the single async start, so the subsequent + # start_activity must hit the same guard and raise a BAD_REQUEST error. + await client.start_workflow( + EchoWorkflow.run, input, id=f"mixed-start-{uuid.uuid4()}" + ) + await client.start_activity( + echo_activity, + input, + id=f"mixed-start-{uuid.uuid4()}", + start_to_close_timeout=timedelta(seconds=5), + ) + return nexus.TemporalOperationResult.sync(None) + + @operation_handler + def custom_cancel_activity(self) -> nexus.TemporalNexusOperationHandler[str, None]: + started = self.started_custom_cancel_activity + cancel_called = self.custom_cancel_activity_called + + class CustomCancelActivityNexusOpHandler( + nexus.TemporalNexusOperationHandler[str, None] + ): + @override + async def start_operation( + self, + ctx: nexus.TemporalNexusStartOperationContext, + client: nexus.TemporalNexusClient, + input: str, + ) -> nexus.TemporalOperationResult[None]: + result = await client.start_activity( + wait_for_cancel_activity, + id=input, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=1), + ) + started.set() + return result + + @override + async def cancel_activity( + self, + ctx: nexus.TemporalNexusCancelOperationContext, + options: nexus.CancelActivityOptions, + ): + # record that the custom override ran + cancel_called.set() + + # get a handle to the activity and cancel it + handle = nexus.client().get_activity_handle(options.activity_id) + await handle.cancel() + + return CustomCancelActivityNexusOpHandler() + @workflow.defn class EchoWorkflowCaller: @@ -259,6 +448,50 @@ async def test_temporal_operation_start_workflow( ) +async def test_temporal_operation_cancel_rejects_unknown_tokens(): + service_handler = TestServiceHandler() + + # Use a factory style operation form the handler to allow calling cancel directly + op_handler = service_handler.custom_cancel() + + cancel_ctx = CancelOperationContext( + service="TestService", + operation="echo", + headers={}, + task_cancellation=FakeNexusTaskCancellation(), + ) + + # Invalid token type + token = OperationToken(type=30, namespace="default") # type: ignore + with pytest.raises(nexusrpc.HandlerError) as err: + await op_handler.cancel(cancel_ctx, token.encode()) + assert err.value.type == HandlerErrorType.INTERNAL + assert not err.value.retryable + underlying = err.value.__cause__ + assert isinstance(underlying, TypeError) + assert "unknown token type, got 30" in str(underlying) + + # Workflow ID missing from workflow type + token = OperationToken(type=OperationTokenType.WORKFLOW, namespace="default") + with pytest.raises(nexusrpc.HandlerError) as err: + await op_handler.cancel(cancel_ctx, token.encode()) + assert err.value.type == HandlerErrorType.INTERNAL + assert not err.value.retryable + underlying = err.value.__cause__ + assert isinstance(underlying, TypeError) + assert "expected non-empty workflow id for token type `WORKFLOW`" in str(underlying) + + # Activity ID missing from activity type + token = OperationToken(type=OperationTokenType.ACTIVITY, namespace="default") + with pytest.raises(nexusrpc.HandlerError) as err: + await op_handler.cancel(cancel_ctx, token.encode()) + assert err.value.type == HandlerErrorType.INTERNAL + assert not err.value.retryable + underlying = err.value.__cause__ + assert isinstance(underlying, TypeError) + assert "expected non-empty activity id for token type `ACTIVITY`" in str(underlying) + + @workflow.defn class BlockingWorkflow: def __init__(self) -> None: @@ -486,6 +719,41 @@ async def test_temporal_operation_failed_start_allows_retry( await conflict_handle.cancel() +async def test_temporal_operation_mixed_start_raises_handler_err( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[TestServiceHandler()], + workflows=[EchoWorkflow], + activities=[echo_activity], + ): + nexus_client = client.create_nexus_client(TestService, endpoint_name) + + with pytest.raises(NexusOperationFailureError) as err: + await nexus_client.execute_operation( + TestService.mixed_start, + Input(value="test", task_queue=task_queue), + id=str(uuid.uuid4()), + ) + + assert isinstance(err.value.cause, nexusrpc.HandlerError) + assert err.value.cause.type == HandlerErrorType.BAD_REQUEST + assert ( + "Only one async operation can be started per operation handler invocation" + in err.value.cause.message + ) + + @workflow.defn class SyncResultCaller: @workflow.run @@ -525,6 +793,184 @@ async def test_temporal_operation_sync_result(client: Client, env: WorkflowEnvir ) +async def test_temporal_operation_start_activity( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[TestServiceHandler()], + activities=[echo_activity], + ): + nexus_client = client.create_nexus_client(TestService, endpoint_name) + + result = await nexus_client.execute_operation( + TestService.echo_activity, + Input(value="test", task_queue=task_queue), + id=str(uuid.uuid4()), + ) + assert result == "test" + + +async def test_temporal_operation_start_activity_raises_error( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[TestServiceHandler()], + activities=[raise_error_activity], + ): + nexus_client = client.create_nexus_client(TestService, endpoint_name) + + with pytest.raises(NexusOperationFailureError) as err: + await nexus_client.execute_operation( + TestService.error_activity, + Input(value="test", task_queue=task_queue), + id=str(uuid.uuid4()), + ) + + operation_err = err.value.__cause__ + assert isinstance(operation_err, temporalio.exceptions.ApplicationError) + assert operation_err.type == "OperationError" + assert "nexus operation completed unsuccessfully" in str(operation_err) + + application_err = operation_err.__cause__ + assert isinstance(application_err, temporalio.exceptions.ApplicationError) + + assert application_err.type == "test-activity-error-type" + assert "test-activity-error-message" in str(application_err) + assert application_err.__cause__ is None + + +async def test_temporal_operation_cancel_activity( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[TestServiceHandler()], + activities=[wait_for_cancel_activity], + ): + nexus_client = client.create_nexus_client(TestService, endpoint_name) + + activity_id = f"blocking-activity-{uuid.uuid4()}" + op_handle = await nexus_client.start_operation( + TestService.blocking_activity, activity_id, id=str(uuid.uuid4()) + ) + + await op_handle.cancel() + + activity_handle = client.get_activity_handle(activity_id) + + async def check_cancelled(): + op_desc = await op_handle.describe() + assert op_desc.status is NexusOperationExecutionStatus.CANCELED + activity_desc = await activity_handle.describe() + assert activity_desc.status is ActivityExecutionStatus.CANCELED + + await assert_eventually(check_cancelled) + + +async def test_customized_temporal_operation_cancel_activity( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + + service_handler = TestServiceHandler() + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + activities=[wait_for_cancel_activity], + ): + nexus_client = client.create_nexus_client(TestService, endpoint_name) + + activity_id = f"custom-cancel-activity-{uuid.uuid4()}" + op_handle = await nexus_client.start_operation( + TestService.custom_cancel_activity, activity_id, id=str(uuid.uuid4()) + ) + await service_handler.started_custom_cancel_activity.wait() + + await op_handle.cancel() + + activity_handle = client.get_activity_handle(activity_id) + + async def check_cancelled(): + assert service_handler.custom_cancel_activity_called.is_set() + op_desc = await op_handle.describe() + assert op_desc.status is NexusOperationExecutionStatus.CANCELED + activity_desc = await activity_handle.describe() + assert activity_desc.status is ActivityExecutionStatus.CANCELED + + await assert_eventually(check_cancelled) + + +async def test_temporal_operation_double_start_activity_raises_handler_err( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[TestServiceHandler()], + activities=[echo_activity], + ): + nexus_client = client.create_nexus_client(TestService, endpoint_name) + + with pytest.raises(NexusOperationFailureError) as err: + await nexus_client.execute_operation( + TestService.double_start_activity, + Input(value="test", task_queue=task_queue), + id=str(uuid.uuid4()), + ) + + assert isinstance(err.value.cause, nexusrpc.HandlerError) + assert err.value.cause.type == HandlerErrorType.BAD_REQUEST + assert ( + "Only one async operation can be started per operation handler invocation" + in err.value.cause.message + ) + + @dataclass class TemporalOperationOverloadTestValue: value: int From 6217763176eb1cb8915630235851387d176e5b2f Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 14:20:23 -0700 Subject: [PATCH 2/8] Update error message when an invalid link is used --- temporalio/nexus/_link_conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/nexus/_link_conversion.py b/temporalio/nexus/_link_conversion.py index 0e4e6fb30..dfdc49273 100644 --- a/temporalio/nexus/_link_conversion.py +++ b/temporalio/nexus/_link_conversion.py @@ -111,7 +111,7 @@ def temporal_link_to_nexus_link( return activity_link_to_nexus_link(temporal_link.activity) case "batch_job": - raise NotImplementedError("only workflow links are supported") + raise NotImplementedError("batch_job links are not supported") case None: logger.warning("Invalid Temporal link: missing variant") From d071fa5497d8757f26e20352a00e089e936f5fa7 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 14:44:14 -0700 Subject: [PATCH 3/8] Update nexus operation and activity link formats to match new expected server shape --- temporalio/nexus/_link_conversion.py | 56 +++++----------------------- tests/nexus/test_link_conversion.py | 29 +++----------- 2 files changed, 15 insertions(+), 70 deletions(-) diff --git a/temporalio/nexus/_link_conversion.py b/temporalio/nexus/_link_conversion.py index dfdc49273..1b5c8329b 100644 --- a/temporalio/nexus/_link_conversion.py +++ b/temporalio/nexus/_link_conversion.py @@ -20,11 +20,11 @@ logger = logging.getLogger(__name__) _NEXUS_OPERATION_LINK_URL_PATH_REGEX = re.compile( - r"^/namespaces/(?P[^/]+)/nexus-operations/(?P[^/]+)$" + r"^/namespaces/(?P[^/]+)/nexus-operations/(?P[^/]+)/(?P[^/]+)/details$" ) _ACTIVITY_LINK_URL_PATH_REGEX = re.compile( - r"^/namespaces/(?P[^/]+)/activities/(?P[^/]+)$" + r"^/namespaces/(?P[^/]+)/activities/(?P[^/]+)/(?P[^/]+)/details$" ) _WORFKLOW_LINK_URL_PATH_REGEX = re.compile( @@ -42,7 +42,6 @@ class _LinkType(str, Enum): LINK_EVENT_TYPE_PARAM_NAME = "eventType" LINK_REQUEST_ID_PARAM_NAME = "requestID" LINK_REFERENCE_TYPE_PARAM_NAME = "referenceType" -LINK_RUN_ID_PARAM_NAME = "runID" EVENT_REFERENCE_TYPE = "EventReference" REQUEST_ID_REFERENCE_TYPE = "RequestIdReference" @@ -160,18 +159,11 @@ def nexus_operation_to_nexus_link( scheme = "temporal" namespace = urllib.parse.quote(op_link.namespace, safe="") operation_id = urllib.parse.quote(op_link.operation_id, safe="") - path = f"/namespaces/{namespace}/nexus-operations/{operation_id}" - - query_params = "" - if op_link.run_id: - query_params = urllib.parse.urlencode( - { - LINK_RUN_ID_PARAM_NAME: op_link.run_id, - }, - ) + run_id = urllib.parse.quote(op_link.run_id, safe="") + path = f"/namespaces/{namespace}/nexus-operations/{operation_id}/{run_id}/details" # urllib will omit '//' from the url if netloc is empty so we add the scheme manually - url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', query_params, ''))}" + url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', '', ''))}" return nexusrpc.Link(url=url, type=_LinkType.NEXUS_OPERATION.value) @@ -183,14 +175,10 @@ def activity_link_to_nexus_link( scheme = "temporal" namespace = urllib.parse.quote(activity.namespace, safe="") activity_id = urllib.parse.quote(activity.activity_id, safe="") - path = f"/namespaces/{namespace}/activities/{activity_id}" + run_id = urllib.parse.quote(activity.run_id, safe="") + path = f"/namespaces/{namespace}/activities/{activity_id}/{run_id}/details" - if activity.run_id: - query_params = urllib.parse.urlencode({LINK_RUN_ID_PARAM_NAME: activity.run_id}) - else: - query_params = "" - - url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', query_params, ''))}" + url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', '', ''))}" return nexusrpc.Link(url=url, type=_LinkType.ACTIVITY.value) @@ -258,24 +246,11 @@ def nexus_link_to_nexus_operation_link( ) return None - query_params = urllib.parse.parse_qs(url.query) - - match query_params.get(LINK_RUN_ID_PARAM_NAME): - case [run_id_param]: - run_id = run_id_param - case [] | None: - run_id = "" - case _: - logger.warning( - f"Invalid Nexus link: {nexus_link}. Expected {LINK_RUN_ID_PARAM_NAME} to have at most 1 value" - ) - return None - groups = match.groupdict() nexus_op_link = temporalio.api.common.v1.Link.NexusOperation( namespace=urllib.parse.unquote(groups["namespace"]), operation_id=urllib.parse.unquote(groups["operation_id"]), - run_id=run_id, + run_id=urllib.parse.unquote(groups["run_id"]), ) return temporalio.api.common.v1.Link(nexus_operation=nexus_op_link) @@ -292,22 +267,11 @@ def nexus_link_to_activity_link( ) return None - query_params = urllib.parse.parse_qs(url.query, keep_blank_values=True) - - match query_params.get(LINK_RUN_ID_PARAM_NAME): - case [run_id_param]: - run_id = run_id_param - case _: - logger.warning( - f"Invalid Nexus link: {nexus_link}. Expected {LINK_RUN_ID_PARAM_NAME} to have exactly 1 value" - ) - return None - groups = match.groupdict() activity_link = temporalio.api.common.v1.Link.Activity( namespace=urllib.parse.unquote(groups["namespace"]), activity_id=urllib.parse.unquote(groups["activity_id"]), - run_id=run_id, + run_id=urllib.parse.unquote(groups["run_id"]), ) return temporalio.api.common.v1.Link(activity=activity_link) diff --git a/tests/nexus/test_link_conversion.py b/tests/nexus/test_link_conversion.py index a875a273d..8bc46b597 100644 --- a/tests/nexus/test_link_conversion.py +++ b/tests/nexus/test_link_conversion.py @@ -222,19 +222,7 @@ def test_link_conversion_workflow_event_to_link_and_back( ), nexusrpc.Link( type=temporalio.api.common.v1.Link.NexusOperation.DESCRIPTOR.full_name, - url="temporal:///namespaces/ns/nexus-operations/op-id?runID=run-id", - ), - ), - ( - temporalio.api.common.v1.Link( - nexus_operation=temporalio.api.common.v1.Link.NexusOperation( - namespace="ns", - operation_id="op-id", - ) - ), - nexusrpc.Link( - type=temporalio.api.common.v1.Link.NexusOperation.DESCRIPTOR.full_name, - url="temporal:///namespaces/ns/nexus-operations/op-id", + url="temporal:///namespaces/ns/nexus-operations/op-id/run-id/details", ), ), ( @@ -242,11 +230,12 @@ def test_link_conversion_workflow_event_to_link_and_back( nexus_operation=temporalio.api.common.v1.Link.NexusOperation( namespace="ns", operation_id="op/id", + run_id="run/id", ) ), nexusrpc.Link( type=temporalio.api.common.v1.Link.NexusOperation.DESCRIPTOR.full_name, - url="temporal:///namespaces/ns/nexus-operations/op%2Fid", + url="temporal:///namespaces/ns/nexus-operations/op%2Fid/run%2Fid/details", ), ), ], @@ -277,21 +266,13 @@ def test_link_conversion_nexus_operation_to_link_and_back( ) -def test_nexus_operation_link_with_duplicate_run_id_is_ignored(): - link = nexusrpc.Link( - type=temporalio.api.common.v1.Link.NexusOperation.DESCRIPTOR.full_name, - url="temporal:///namespaces/ns/nexus-operations/op-id?runID=one&runID=two", - ) - assert temporalio.nexus._link_conversion.nexus_link_to_temporal_link(link) is None - - @pytest.mark.parametrize( ["link", "expected_link"], [ ( nexusrpc.Link( type=temporalio.api.common.v1.Link.Activity.DESCRIPTOR.full_name, - url="temporal:///namespaces/ns/activities/act-id?runID=run-id", + url="temporal:///namespaces/ns/activities/act-id/run-id/details", ), temporalio.api.common.v1.Link( activity=temporalio.api.common.v1.Link.Activity( @@ -304,7 +285,7 @@ def test_nexus_operation_link_with_duplicate_run_id_is_ignored(): ( nexusrpc.Link( type=temporalio.api.common.v1.Link.Activity.DESCRIPTOR.full_name, - url="temporal:///namespaces/ns%2F/activities/act-id%2F?runID=run-id%3E", + url="temporal:///namespaces/ns%2F/activities/act-id%2F/run-id%3E/details", ), temporalio.api.common.v1.Link( activity=temporalio.api.common.v1.Link.Activity( From c8217393a5f7924ce9e7c0f11f809c2063ec88cd Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 16:03:53 -0700 Subject: [PATCH 4/8] simplify logic in add outbound links methods --- temporalio/nexus/_operation_context.py | 99 ++++++++++++-------------- 1 file changed, 46 insertions(+), 53 deletions(-) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 542df7b5d..e8d9c1676 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -43,9 +43,8 @@ ) from ._link_conversion import ( - activity_link_to_nexus_link, nexus_link_to_temporal_link, - workflow_event_to_nexus_link, + temporal_link_to_nexus_link, workflow_execution_started_event_link_from_workflow_handle, ) from ._token import WorkflowHandle @@ -253,71 +252,65 @@ def _get_links( def _add_outbound_workflow_links( self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any] ): - # If links were not sent in StartWorkflowExecutionResponse then construct them. - wf_event_links: list[temporalio.api.common.v1.Link.WorkflowEvent] = [] - try: - if isinstance( - workflow_handle._start_workflow_response, - temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse, - ): - if workflow_handle._start_workflow_response.HasField("link"): - if link := workflow_handle._start_workflow_response.link: - if link.HasField("workflow_event"): - wf_event_links.append(link.workflow_event) - if not wf_event_links: - wf_event_links = [ - workflow_execution_started_event_link_from_workflow_handle( - workflow_handle, - self.nexus_context.request_id, - ) - ] - self.nexus_context.outbound_links.extend( - workflow_event_to_nexus_link(link) for link in wf_event_links + response = workflow_handle._start_workflow_response + if isinstance( + response, temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse + ) and response.HasField("link"): + link = response.link + elif isinstance( + response, + temporalio.api.workflowservice.v1.SignalWithStartWorkflowExecutionResponse, + ) and response.HasField("signal_link"): + link = response.signal_link + else: + # If a link was not sent in response then construct it. + link = temporalio.api.common.v1.Link( + workflow_event=workflow_execution_started_event_link_from_workflow_handle( + workflow_handle, + self.nexus_context.request_id, + ) ) + + try: + if nexus_link := temporal_link_to_nexus_link(link): + self.nexus_context.outbound_links.append(nexus_link) except Exception as e: logger.warning( - f"Failed to create WorkflowExecutionStarted event links for workflow {workflow_handle}: {e}" + f"Failed to create event links for workflow {workflow_handle}: {e}" ) - return workflow_handle def _add_outbound_activity_links( self, activity_handle: temporalio.client.ActivityHandle[Any] ): - activity_links: list[temporalio.api.common.v1.Link.Activity] = [] - try: - if isinstance( - activity_handle._start_activity_response, - temporalio.api.workflowservice.v1.StartActivityExecutionResponse, - ): - if activity_handle._start_activity_response.HasField("link"): - if activity_handle._start_activity_response.link.HasField( - "activity" - ): - activity_links.append( - activity_handle._start_activity_response.link.activity - ) - if not activity_links: - activity_run_id = activity_handle.run_id - if activity_run_id is None: - raise ValueError( - f"Activity handle {activity_handle} has no run ID. " - "Cannot create Activity link." - ) - activity_links.append( - temporalio.api.common.v1.Link.Activity( - namespace=activity_handle._client.namespace, - activity_id=activity_handle.id, - run_id=activity_run_id, - ) + + if ( + activity_handle._start_activity_response + and activity_handle._start_activity_response.HasField("link") + ): + link = activity_handle._start_activity_response.link + else: + activity_run_id = activity_handle.run_id + if activity_run_id is None: + logger.warning( + "Failed to create Activity link. " + f"Activity handle {activity_handle} has no run ID. " ) - self.nexus_context.outbound_links.extend( - activity_link_to_nexus_link(link) for link in activity_links + return + link = temporalio.api.common.v1.Link( + activity=temporalio.api.common.v1.Link.Activity( + namespace=activity_handle._client.namespace, + activity_id=activity_handle.id, + run_id=activity_run_id, + ), ) + + try: + if nexus_link := temporal_link_to_nexus_link(link): + self.nexus_context.outbound_links.append(nexus_link) except Exception as e: logger.warning( f"Failed to create Activity link for activity {activity_handle}: {e}" ) - return activity_handle class WorkflowRunOperationContext(StartOperationContext): From f792a2baca603b540af1b6150b11c061a1fdf6fe Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 16:12:36 -0700 Subject: [PATCH 5/8] update docstrings --- temporalio/nexus/_operation_context.py | 2 +- temporalio/nexus/_temporal_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index e8d9c1676..b2e3b11b6 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -666,7 +666,7 @@ async def _start_nexus_backing_workflow( priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: temporalio.common.VersioningOverride | None = None, ) -> WorkflowHandle[ReturnType]: - # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # We must pass nexus_completion_callbacks, links, and request_id, # but these are deliberately not exposed in overloads, hence the type-check # violation. diff --git a/temporalio/nexus/_temporal_client.py b/temporalio/nexus/_temporal_client.py index 128f1aa6f..b423326bb 100644 --- a/temporalio/nexus/_temporal_client.py +++ b/temporalio/nexus/_temporal_client.py @@ -607,7 +607,7 @@ async def start_activity( See :py:meth:`temporalio.client.Client.start_activity` for all other arguments. """ with self._reserve_async_start(): - # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # We must pass nexus_completion_callbacks, links, and request_id, # but these are deliberately not exposed in overloads, hence the type-check # violation. From 611c8ec12c44970308534ae142eec764e202c3aa Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 16:26:42 -0700 Subject: [PATCH 6/8] minor test cleanup --- tests/nexus/test_operation_token.py | 23 +++++++-------- tests/nexus/test_temporal_operation.py | 39 +++++++++++++------------- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/tests/nexus/test_operation_token.py b/tests/nexus/test_operation_token.py index 028bc67cb..d443e3479 100644 --- a/tests/nexus/test_operation_token.py +++ b/tests/nexus/test_operation_token.py @@ -1,4 +1,3 @@ -import base64 import json from typing import Any @@ -8,20 +7,15 @@ OperationToken, OperationTokenType, WorkflowHandle, + _base64url_decode_no_padding, + _base64url_encode_no_padding, ) def _encode_json_token(value: Any) -> str: - return _encode_bytes(json.dumps(value, separators=(",", ":")).encode("utf-8")) - - -def _encode_bytes(value: bytes) -> str: - return base64.urlsafe_b64encode(value).decode("utf-8").rstrip("=") - - -def _decode_bytes(value: str) -> bytes: - padding = "=" * (-len(value) % 4) - return base64.urlsafe_b64decode(value + padding) + return _base64url_encode_no_padding( + json.dumps(value, separators=(",", ":")).encode("utf-8") + ) def test_operation_token_encode_decode_round_trip(): @@ -65,7 +59,7 @@ def test_operation_token_activity_encode_uses_activity_id_and_omits_workflow_id( activity_id="activity-id", ).encode() - assert json.loads(_decode_bytes(token)) == { + assert json.loads(_base64url_decode_no_padding(token)) == { "t": 2, "ns": "default", "aid": "activity-id", @@ -166,7 +160,10 @@ def test_operation_token_decode_accepts_valid_tokens( [ ("", "invalid token: token is empty"), ("not+a-base64url-token", "failed to decode token as base64url"), - (_encode_bytes(b"not json"), "failed to unmarshal operation token"), + ( + _base64url_encode_no_padding(b"not json"), + "failed to unmarshal operation token", + ), (_encode_json_token(["not", "a", "dict"]), "expected dict"), ( _encode_json_token({"ns": "default", "wid": "workflow-id"}), diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py index 370d4d2c5..f16da76de 100644 --- a/tests/nexus/test_temporal_operation.py +++ b/tests/nexus/test_temporal_operation.py @@ -35,23 +35,6 @@ from tests.helpers.nexus import make_nexus_endpoint_name -class FakeNexusTaskCancellation(OperationTaskCancellation): - def is_cancelled(self) -> bool: - return False - - def cancellation_reason(self) -> str | None: - return None - - def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool: - return False - - async def wait_until_cancelled(self) -> None: - return None - - def cancel(self, _reason: str) -> bool: - return False - - @dataclass class Input: value: str @@ -449,10 +432,21 @@ async def test_temporal_operation_start_workflow( async def test_temporal_operation_cancel_rejects_unknown_tokens(): - service_handler = TestServiceHandler() + class FakeNexusTaskCancellation(OperationTaskCancellation): + def is_cancelled(self) -> bool: + return False - # Use a factory style operation form the handler to allow calling cancel directly - op_handler = service_handler.custom_cancel() + def cancellation_reason(self) -> str | None: + return None + + def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool: + return False + + async def wait_until_cancelled(self) -> None: + return None + + def cancel(self, _reason: str) -> bool: + return False cancel_ctx = CancelOperationContext( service="TestService", @@ -461,6 +455,11 @@ async def test_temporal_operation_cancel_rejects_unknown_tokens(): task_cancellation=FakeNexusTaskCancellation(), ) + service_handler = TestServiceHandler() + + # Use a factory style operation form the handler to allow calling cancel directly + op_handler = service_handler.custom_cancel() + # Invalid token type token = OperationToken(type=30, namespace="default") # type: ignore with pytest.raises(nexusrpc.HandlerError) as err: From 3f66527c400187a29ecd15a08149ae0c0d4eda0c Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 16:31:15 -0700 Subject: [PATCH 7/8] remove type error that was added as a result of an incorrect required param --- tests/nexus/test_nexus_type_errors.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index 749f1e841..ea3d2c9b7 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -252,13 +252,6 @@ async def my_temporal_operation( ) return activity_result_5 if input == 12: - # omitting the required start_to_close_timeout is a type error - # assert-type-error-pyright: 'No overloads for "start_activity" match' - return await client.start_activity( # type: ignore - my_no_arg_activity, - id="activity-missing-timeout", - ) - if input == 13: # assert-type-error-pyright: 'No overloads for "start_activity" match' return await client.start_activity( # type: ignore my_one_arg_activity, From 5004be7b7026da898d1c72e6482cd5dbbcd4b133 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 29 May 2026 17:16:01 -0700 Subject: [PATCH 8/8] Add run_id to operation token, include it in activity tokens and use it in cancel_activity --- temporalio/nexus/_operation_context.py | 3 +- temporalio/nexus/_operation_handlers.py | 18 +++++++--- temporalio/nexus/_temporal_client.py | 1 + temporalio/nexus/_token.py | 10 ++++++ tests/nexus/test_operation_token.py | 46 ++++++++++++++++++++++--- 5 files changed, 66 insertions(+), 12 deletions(-) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index b2e3b11b6..b11e353f9 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -64,7 +64,7 @@ ContextVar("temporal-cancel-operation-context") ) -# A Nexus start handler might start zero or async Temporal actions as usual using a Temporal client. In +# A Nexus start handler might start zero or more async Temporal actions as usual using a Temporal client. In # addition, it may start one "nexus-backing" async Temporal action, using # WorkflowRunOperationContext.start_workflow or methods from TemporalNexusClient. This context is active while the latter is being done. # It is thus a narrower context than _temporal_start_operation_context. @@ -282,7 +282,6 @@ def _add_outbound_workflow_links( def _add_outbound_activity_links( self, activity_handle: temporalio.client.ActivityHandle[Any] ): - if ( activity_handle._start_activity_response and activity_handle._start_activity_response.HasField("link") diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index efb895406..239e2996d 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -150,7 +150,10 @@ class CancelActivityOptions: """ activity_id: str - """The ID of the activity to cancel.""" + """The activity ID of the activity to cancel.""" + + run_id: str | None + """The run ID of the activity to cancel.""" class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT], ABC): @@ -223,7 +226,8 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: type=HandlerErrorType.NOT_FOUND, ) activity_cancel_opts = CancelActivityOptions( - activity_id=operation_token.activity_id + activity_id=operation_token.activity_id, + run_id=operation_token.run_id, ) await self.cancel_activity(cancel_ctx, activity_cancel_opts) @@ -246,9 +250,13 @@ async def cancel_activity( self, ctx: TemporalNexusCancelOperationContext, # pyright: ignore[reportUnusedParameter] options: CancelActivityOptions, - ): - """Requests cancellation of the standalone activity identified by activity_id.""" + ) -> None: + """Requests cancellation of the standalone activity identified by activity_id. + + .. warning:: + This API is experimental and unstable. + """ activity_handle = temporalio.nexus.client().get_activity_handle( - options.activity_id + options.activity_id, run_id=options.run_id ) await activity_handle.cancel() diff --git a/temporalio/nexus/_temporal_client.py b/temporalio/nexus/_temporal_client.py index b423326bb..22f063353 100644 --- a/temporalio/nexus/_temporal_client.py +++ b/temporalio/nexus/_temporal_client.py @@ -647,6 +647,7 @@ async def start_activity( type=OperationTokenType.ACTIVITY, namespace=self._temporal_context.client.namespace, activity_id=activity_handle.id, + run_id=activity_handle.run_id, ) return TemporalOperationResult.async_token(activity_token.encode()) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index b9de54e82..e804aafca 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -30,6 +30,7 @@ class OperationToken: namespace: str workflow_id: str | None = None activity_id: str | None = None + run_id: str | None = None def encode(self) -> str: """Convert handle to a base64url-encoded token string.""" @@ -41,6 +42,8 @@ def encode(self) -> str: token_details["wid"] = self.workflow_id if self.activity_id is not None: token_details["aid"] = self.activity_id + if self.run_id is not None: + token_details["rid"] = self.run_id if self.version is not None: token_details["v"] = self.version return _base64url_encode_no_padding( @@ -104,6 +107,12 @@ def decode(cls, token: str) -> Self: f"invalid token: expected activity id to be a string, got {type(activity_id)}" ) + run_id = token_details.get("rid") + if run_id is not None and not isinstance(run_id, str): + raise TypeError( + f"invalid token: expected run id to be a string, got {type(run_id)}" + ) + if token_type == OperationTokenType.ACTIVITY and not activity_id: raise TypeError( "invalid token: expected non-empty activity id for token type `ACTIVITY`" @@ -121,6 +130,7 @@ def decode(cls, token: str) -> Self: namespace=namespace, workflow_id=workflow_id, activity_id=activity_id, + run_id=run_id, version=version, ) diff --git a/tests/nexus/test_operation_token.py b/tests/nexus/test_operation_token.py index d443e3479..4431e7d97 100644 --- a/tests/nexus/test_operation_token.py +++ b/tests/nexus/test_operation_token.py @@ -40,6 +40,7 @@ def test_operation_token_activity_encode_decode_round_trip(): type=OperationTokenType.ACTIVITY, namespace="default", activity_id="activity-id", + run_id="run-id", version=0, ).encode() @@ -48,6 +49,7 @@ def test_operation_token_activity_encode_decode_round_trip(): type=OperationTokenType.ACTIVITY, namespace="default", activity_id="activity-id", + run_id="run-id", version=0, ) @@ -112,11 +114,14 @@ def test_workflow_handle_to_from_token_round_trip(): ), # Activity tokens ( - _encode_json_token({"t": 2, "ns": "default", "aid": "activity-id"}), + _encode_json_token( + {"t": 2, "ns": "default", "aid": "activity-id", "rid": "run-id"} + ), OperationToken( type=OperationTokenType.ACTIVITY, namespace="default", activity_id="activity-id", + run_id="run-id", ), ), ( @@ -129,20 +134,41 @@ def test_workflow_handle_to_from_token_round_trip(): ), ( _encode_json_token( - {"t": 2, "ns": "default", "aid": "activity-id", "v": None} + {"t": 2, "ns": "", "aid": "activity-id", "rid": "run-id"} + ), + OperationToken( + type=OperationTokenType.ACTIVITY, + namespace="", + activity_id="activity-id", + run_id="run-id", + ), + ), + ( + _encode_json_token( + { + "t": 2, + "ns": "default", + "aid": "activity-id", + "rid": "run-id", + "v": None, + } ), OperationToken( type=OperationTokenType.ACTIVITY, namespace="default", activity_id="activity-id", + run_id="run-id", ), ), ( - _encode_json_token({"t": 2, "ns": "default", "aid": "activity-id", "v": 0}), + _encode_json_token( + {"t": 2, "ns": "default", "aid": "activity-id", "rid": "run-id", "v": 0} + ), OperationToken( type=OperationTokenType.ACTIVITY, namespace="default", activity_id="activity-id", + run_id="run-id", version=0, ), ), @@ -217,12 +243,22 @@ def test_operation_token_decode_accepts_valid_tokens( "expected activity id to be a string", ), ( - _encode_json_token({"t": 2, "aid": "activity-id"}), + _encode_json_token({"t": 2, "aid": "activity-id", "rid": 123}), + "expected run id to be a string", + ), + ( + _encode_json_token({"t": 2, "aid": "activity-id", "rid": "run-id"}), "expected namespace to be a string", ), ( _encode_json_token( - {"t": 2, "ns": "default", "aid": "activity-id", "v": "0"} + { + "t": 2, + "ns": "default", + "aid": "activity-id", + "rid": "run-id", + "v": "0", + } ), "expected version to be an int or null", ),