diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index c647b19a8..f88ec22c2 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -22,6 +22,7 @@ wait_for_worker_shutdown_sync, ) from ._token import WorkflowHandle +from ._util import is_async_callable __all__ = ( "workflow_run_operation", @@ -32,6 +33,7 @@ "client", "in_operation", "info", + "is_async_callable", "is_worker_shutdown", "logger", "metric_meter", diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 48d3ad644..4fc4b352b 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable from typing import ( Any, + TypeGuard, TypeVar, ) @@ -153,8 +154,10 @@ def set_operation_factory( # # Copyright (c) 2024 Anthropic, PBC. # +# Modified to use TypeGuard. +# # This file is licensed under the MIT License. -def is_async_callable(obj: Any) -> bool: +def is_async_callable(obj: Any) -> TypeGuard[Callable[..., Awaitable[Any]]]: """Return True if ``obj`` is an async callable. Supports partials of async callable class instances. diff --git a/temporalio/worker/_nexus.py b/temporalio/worker/_nexus.py index eb54dde30..698ae53d4 100644 --- a/temporalio/worker/_nexus.py +++ b/temporalio/worker/_nexus.py @@ -8,20 +8,16 @@ import threading from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass -from functools import reduce from typing import ( Any, NoReturn, ParamSpec, TypeVar, - cast, ) import nexusrpc.handler -from nexusrpc import LazyValue -from nexusrpc.handler import CancelOperationContext, Handler, StartOperationContext +from nexusrpc.handler import CancelOperationContext, StartOperationContext -import temporalio.api.common.v1 import temporalio.api.nexus.v1 import temporalio.bridge.proto.nexus import temporalio.bridge.worker @@ -40,11 +36,9 @@ from temporalio.service import RPCError, RPCStatusCode from ._interceptor import ( - ExecuteNexusOperationCancelInput, - ExecuteNexusOperationStartInput, Interceptor, - NexusOperationInboundInterceptor, ) +from ._nexus_handler import _TemporalNexusHandler _TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure" @@ -77,13 +71,12 @@ def __init__( self._task_queue = task_queue self._metric_meter = metric_meter - middleware = _NexusMiddlewareForInterceptors(interceptors) # If an executor is provided, we wrap the executor with one that will # copy the contextvars.Context to the thread on submit handler_executor = _ContextPropagatingExecutor(executor) if executor else None - self._handler = Handler( - service_handlers, handler_executor, middleware=[middleware] + self._handler = _TemporalNexusHandler( + service_handlers, interceptors, data_converter, handler_executor ) self._data_converter = data_converter @@ -360,16 +353,8 @@ async def _start_operation( _runtime_metric_meter=self._metric_meter, _worker_shutdown_event=self._worker_shutdown_event, ).set() - input = LazyValue( - serializer=_DummyPayloadSerializer( - data_converter=self._data_converter, - payload=start_request.payload, - ), - headers={}, - stream=None, - ) try: - result = await self._handler.start_operation(ctx, input) + result = await self._handler.start_operation(ctx, start_request.payload) links = [ temporalio.api.nexus.v1.Link(url=link.url, type=link.type) for link in ctx.outbound_links @@ -415,45 +400,6 @@ async def _start_operation( return response -@dataclass -class _DummyPayloadSerializer: - data_converter: temporalio.converter.DataConverter - payload: temporalio.api.common.v1.Payload - - async def serialize(self, value: Any) -> nexusrpc.Content: # type:ignore[reportUnusedParameter] - raise NotImplementedError( - "The serialize method of the Serializer is not used by handlers" - ) - - async def deserialize( - self, - content: nexusrpc.Content, # type:ignore[reportUnusedParameter] - as_type: type[Any] | None = None, - ) -> Any: - payload = self.payload - if self.data_converter.payload_codec: - try: - [payload] = await self.data_converter.payload_codec.decode([payload]) - except Exception as err: - raise nexusrpc.HandlerError( - "Payload codec failed to decode Nexus operation input", - type=nexusrpc.HandlerErrorType.INTERNAL, - ) from err - - try: - [input] = self.data_converter.payload_converter.from_payloads( - [payload], - type_hints=[as_type] if as_type else None, - ) - return input - except Exception as err: - raise nexusrpc.HandlerError( - "Payload converter failed to decode Nexus operation input", - type=nexusrpc.HandlerErrorType.BAD_REQUEST, - retryable_override=False, - ) from err - - def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError: # Based on sdk-typescript's convertKnownErrors: # https://github.com/temporalio/sdk-typescript/blob/nexus/packages/worker/src/nexus.ts @@ -569,69 +515,6 @@ def cancel(self, reason: str) -> bool: return True -class _NexusOperationHandlerForInterceptor( - nexusrpc.handler.MiddlewareSafeOperationHandler -): - def __init__(self, next_interceptor: NexusOperationInboundInterceptor): - self._next_interceptor = next_interceptor - - async def start( - self, ctx: nexusrpc.handler.StartOperationContext, input: Any - ) -> ( - nexusrpc.handler.StartOperationResultSync[Any] - | nexusrpc.handler.StartOperationResultAsync - ): - return await self._next_interceptor.execute_nexus_operation_start( - ExecuteNexusOperationStartInput(ctx, input) - ) - - async def cancel( - self, ctx: nexusrpc.handler.CancelOperationContext, token: str - ) -> None: - return await self._next_interceptor.execute_nexus_operation_cancel( - ExecuteNexusOperationCancelInput(ctx, token) - ) - - -class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): - def __init__(self, handler: nexusrpc.handler.MiddlewareSafeOperationHandler): # pyright: ignore[reportMissingSuperCall] - self._handler = handler - - async def execute_nexus_operation_start( - self, input: ExecuteNexusOperationStartInput - ) -> ( - nexusrpc.handler.StartOperationResultSync[Any] - | nexusrpc.handler.StartOperationResultAsync - ): - return await self._handler.start(input.ctx, input.input) - - async def execute_nexus_operation_cancel( - self, input: ExecuteNexusOperationCancelInput - ) -> None: - return await self._handler.cancel(input.ctx, input.token) - - -class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware): - def __init__(self, interceptors: Sequence[Interceptor]) -> None: - self._interceptors = interceptors - - def intercept( - self, - ctx: nexusrpc.handler.OperationContext, - next: nexusrpc.handler.MiddlewareSafeOperationHandler, - ) -> nexusrpc.handler.MiddlewareSafeOperationHandler: - inbound = reduce( - lambda impl, _next: _next.intercept_nexus_operation(impl), - reversed(self._interceptors), - cast( - NexusOperationInboundInterceptor, - _NexusOperationInboundInterceptorImpl(next), - ), - ) - - return _NexusOperationHandlerForInterceptor(inbound) - - _P = ParamSpec("_P") _T = TypeVar("_T") diff --git a/temporalio/worker/_nexus_handler.py b/temporalio/worker/_nexus_handler.py new file mode 100644 index 000000000..db004a380 --- /dev/null +++ b/temporalio/worker/_nexus_handler.py @@ -0,0 +1,204 @@ +"""Temporal Nexus handler. + +Replaces nexusrpc.handler.Handler with a Temporal-specific implementation that +uses Temporal interceptors instead of nexusrpc middleware, and deserializes +Nexus operation input using the Temporal data converter directly (without the +nexusrpc Serializer/LazyValue/Content abstractions). +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +from collections.abc import Awaitable, Mapping, Sequence +from functools import reduce +from typing import Any, cast + +import nexusrpc +from nexusrpc.handler import ( + CancelOperationContext, + OperationHandler, + StartOperationContext, + StartOperationResultAsync, + StartOperationResultSync, +) +from nexusrpc.handler._core import ServiceHandler + +import temporalio.api.common.v1 +import temporalio.converter + +from temporalio.worker import ( + ExecuteNexusOperationCancelInput, + ExecuteNexusOperationStartInput, + Interceptor, + NexusOperationInboundInterceptor, +) +from temporalio.nexus import is_async_callable + +OperationHandlerResult = StartOperationResultSync[Any] | StartOperationResultAsync + + +class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor): + """Terminal interceptor that delegates to the actual OperationHandler.""" + + def __init__( # pyright: ignore[reportMissingSuperCall] + self, + handler: OperationHandler[Any, Any], + executor: concurrent.futures.Executor | None, + ) -> None: + self._handler = handler + self._executor = executor + + async def execute_nexus_operation_start( + self, input: ExecuteNexusOperationStartInput + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + if is_async_callable(self._handler.start): + return await self._handler.start(input.ctx, input.input) + else: + assert self._executor + return await cast( + Awaitable[OperationHandlerResult], + asyncio.get_event_loop().run_in_executor( + self._executor, self._handler.start, input.ctx, input.input + ), + ) + + async def execute_nexus_operation_cancel( + self, input: ExecuteNexusOperationCancelInput + ) -> None: + if is_async_callable(self._handler.cancel): + await self._handler.cancel(input.ctx, input.token) + else: + assert self._executor + self._executor.submit(self._handler.cancel, input.ctx, input.token).result() + + +class _TemporalNexusHandler: # type:ignore[reportUnusedClass] + """Temporal-specific Nexus handler. + + Replaces nexusrpc.handler.Handler. Uses Temporal interceptors instead of + nexusrpc middleware, and deserializes input using the Temporal data + converter directly. + """ + + def __init__( + self, + user_service_handlers: Sequence[Any], + interceptors: Sequence[Interceptor], + data_converter: temporalio.converter.DataConverter, + executor: concurrent.futures.Executor | None, + ) -> None: + self._interceptors = interceptors + self._data_converter = data_converter + self._executor = executor + self._service_handlers = self._register_service_handlers(user_service_handlers) + if not self._executor: + self._validate_all_operation_handlers_are_async() + + def _register_service_handlers( + self, user_service_handlers: Sequence[Any] + ) -> Mapping[str, ServiceHandler]: + service_handlers: dict[str, ServiceHandler] = {} + for sh in user_service_handlers: + if isinstance(sh, type): + raise TypeError( + f"Expected a service instance, but got a class: {type(sh)}. " + "Nexus service handlers must be supplied as instances, not classes." + ) + if not isinstance(sh, ServiceHandler): + sh = ServiceHandler.from_user_instance(sh) + if sh.service.name in service_handlers: + raise RuntimeError( + f"Service '{sh.service.name}' has already been registered." + ) + service_handlers[sh.service.name] = sh + return service_handlers + + def _get_service_handler(self, service_name: str) -> ServiceHandler: + service = self._service_handlers.get(service_name) + if service is None: + raise nexusrpc.HandlerError( + f"No handler for service '{service_name}'.", + type=nexusrpc.HandlerErrorType.NOT_FOUND, + ) + return service + + def _validate_all_operation_handlers_are_async(self) -> None: + for service_handler in self._service_handlers.values(): + for op_handler in service_handler.operation_handlers.values(): + for method in [op_handler.start, op_handler.cancel]: + if not is_async_callable(method): + raise RuntimeError( + f"Operation handler method {method} is not an `async def` method, " + f"but you have not supplied an executor." + ) + + async def start_operation( + self, + ctx: StartOperationContext, + payload: temporalio.api.common.v1.Payload, + ) -> StartOperationResultSync[Any] | StartOperationResultAsync: + service_handler = self._get_service_handler(ctx.service) + op_handler = service_handler.get_operation_handler(ctx.operation) + op_defn = service_handler.service.operation_definitions[ctx.operation] + + deserialized_input = await self._deserialize_nexus_input( + payload, op_defn.input_type + ) + + inbound = self._build_interceptor_chain(op_handler) + return await inbound.execute_nexus_operation_start( + ExecuteNexusOperationStartInput(ctx, deserialized_input) + ) + + async def cancel_operation(self, ctx: CancelOperationContext, token: str) -> None: + service_handler = self._get_service_handler(ctx.service) + op_handler = service_handler.get_operation_handler(ctx.operation) + + inbound = self._build_interceptor_chain(op_handler) + return await inbound.execute_nexus_operation_cancel( + ExecuteNexusOperationCancelInput(ctx, token) + ) + + def _build_interceptor_chain( + self, op_handler: OperationHandler[Any, Any] + ) -> NexusOperationInboundInterceptor: + return reduce( + lambda impl, interceptor: interceptor.intercept_nexus_operation(impl), + reversed(self._interceptors), + cast( + NexusOperationInboundInterceptor, + _NexusOperationInboundInterceptorImpl(op_handler, self._executor), + ), + ) + + async def _deserialize_nexus_input( + self, + payload: temporalio.api.common.v1.Payload, + input_type: type[Any] | None, + ) -> Any: + """Deserialize a Nexus operation input payload using the Temporal data converter. + + Applies the payload codec (if configured) and then the payload converter. + """ + if self._data_converter.payload_codec: + try: + [payload] = await self._data_converter.payload_codec.decode([payload]) + except Exception as err: + raise nexusrpc.HandlerError( + "Payload codec failed to decode Nexus operation input", + type=nexusrpc.HandlerErrorType.INTERNAL, + ) from err + + try: + [result] = self._data_converter.payload_converter.from_payloads( + [payload], + type_hints=[input_type] if input_type else None, + ) + return result + except Exception as err: + raise nexusrpc.HandlerError( + "Payload converter failed to decode Nexus operation input", + type=nexusrpc.HandlerErrorType.BAD_REQUEST, + retryable_override=False, + ) from err