diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index 9966df76f..979222dea 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -15,6 +15,8 @@ import google.protobuf.empty_pb2 from typing_extensions import Self +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice.v1 import temporalio.api.testservice.v1 import temporalio.bridge.testing import temporalio.client @@ -401,6 +403,48 @@ def supports_time_skipping(self) -> bool: """Whether this environment supports time skipping.""" return False + async def create_nexus_endpoint( + self, endpoint_name: str, task_queue: str + ) -> temporalio.api.nexus.v1.Endpoint: + """Create a Nexus endpoint with the given name and task queue. + + Args: + endpoint_name: The name of the Nexus endpoint to create. + task_queue: The task queue to associate with the endpoint. + + Returns: + The created Nexus endpoint. + """ + response = await self._client.operator_service.create_nexus_endpoint( + temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( + spec=temporalio.api.nexus.v1.EndpointSpec( + name=endpoint_name, + target=temporalio.api.nexus.v1.EndpointTarget( + worker=temporalio.api.nexus.v1.EndpointTarget.Worker( + namespace=self._client.namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) + return response.endpoint + + async def delete_nexus_endpoint( + self, endpoint: temporalio.api.nexus.v1.Endpoint + ) -> None: + """Delete a Nexus endpoint. + + Args: + endpoint: The Nexus endpoint to delete. + """ + await self._client.operator_service.delete_nexus_endpoint( + temporalio.api.operatorservice.v1.DeleteNexusEndpointRequest( + id=endpoint.id, + version=endpoint.version, + ) + ) + @contextmanager def auto_time_skipping_disabled(self) -> Iterator[None]: """Disable any automatic time skipping if this is a time-skipping diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 7a19afcc8..84aa5646f 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -102,7 +102,7 @@ ResearchManager, ) from tests.helpers import assert_eventually, new_worker -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name def hello_mock_model(): @@ -489,7 +489,9 @@ async def test_nexus_tool_workflow( NexusToolsWorkflow, nexus_service_handlers=[WeatherServiceHandler()], ) as worker: - await create_nexus_endpoint(worker.task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(worker.task_queue), worker.task_queue + ) workflow_handle = await client.start_workflow( NexusToolsWorkflow.run, diff --git a/tests/contrib/opentelemetry/test_opentelemetry.py b/tests/contrib/opentelemetry/test_opentelemetry.py index 7e21c8935..2c3293fd6 100644 --- a/tests/contrib/opentelemetry/test_opentelemetry.py +++ b/tests/contrib/opentelemetry/test_opentelemetry.py @@ -38,7 +38,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker from tests.helpers import LogCapturer -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name @dataclass @@ -475,7 +475,7 @@ async def test_opentelemetry_tracing_nexus(client: Client, env: WorkflowEnvironm client = Client(**client_config) task_queue = f"task-queue-{uuid.uuid4()}" - await create_nexus_endpoint(task_queue, client) + await env.create_nexus_endpoint(make_nexus_endpoint_name(task_queue), task_queue) async with Worker( client, task_queue=task_queue, diff --git a/tests/contrib/opentelemetry/test_opentelemetry_plugin.py b/tests/contrib/opentelemetry/test_opentelemetry_plugin.py index 64073cf9c..b06bb8bc7 100644 --- a/tests/contrib/opentelemetry/test_opentelemetry_plugin.py +++ b/tests/contrib/opentelemetry/test_opentelemetry_plugin.py @@ -25,7 +25,7 @@ # Import the dump_spans function from the original opentelemetry test from tests.contrib.opentelemetry.test_opentelemetry import dump_spans from tests.helpers import new_worker -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name logger = logging.getLogger(__name__) @@ -272,7 +272,9 @@ async def test_opentelemetry_comprehensive_tracing( max_cached_workflows=0, ) as worker: # Create Nexus endpoint for this task queue - await create_nexus_endpoint(worker.task_queue, new_client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(worker.task_queue), worker.task_queue + ) with get_tracer(__name__).start_as_current_span("ComprehensiveTest") as span: span.set_attribute("test.type", "comprehensive") diff --git a/tests/helpers/nexus.py b/tests/helpers/nexus.py index 904b4422a..688a40273 100644 --- a/tests/helpers/nexus.py +++ b/tests/helpers/nexus.py @@ -5,10 +5,7 @@ from urllib.parse import urlparse import temporalio.api.failure.v1 -import temporalio.api.nexus.v1 -import temporalio.api.operatorservice.v1 import temporalio.workflow -from temporalio.client import Client from temporalio.converter import FailureConverter, PayloadConverter from temporalio.testing import WorkflowEnvironment @@ -22,27 +19,6 @@ def make_nexus_endpoint_name(task_queue: str) -> str: return f"nexus-endpoint-{task_queue}" -# TODO(nexus-preview): How do we recommend that users create endpoints in their own tests? -# See https://github.com/temporalio/sdk-typescript/pull/1708/files?show-viewed-files=true&file-filters%5B%5D=&w=0#r2082549085 -async def create_nexus_endpoint( - task_queue: str, client: Client -) -> temporalio.api.operatorservice.v1.CreateNexusEndpointResponse: - name = make_nexus_endpoint_name(task_queue) - return await client.operator_service.create_nexus_endpoint( - temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( - spec=temporalio.api.nexus.v1.EndpointSpec( - name=name, - target=temporalio.api.nexus.v1.EndpointTarget( - worker=temporalio.api.nexus.v1.EndpointTarget.Worker( - namespace=client.namespace, - task_queue=task_queue, - ) - ), - ) - ) - ) - - @dataclass class ServiceClient: server_address: str # E.g. http://127.0.0.1:7243 diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 95ff1c986..f0070fc8a 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -10,7 +10,7 @@ from temporalio.nexus._util import get_operation_factory from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import ServiceClient, create_nexus_endpoint +from tests.helpers.nexus import ServiceClient, make_nexus_endpoint_name @workflow.defn @@ -80,7 +80,11 @@ async def test_run_nexus_service_from_programmatically_created_service_handler( service_name = service_handler.service.name - endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id async with Worker( client, task_queue=task_queue, @@ -146,7 +150,11 @@ async def test_dynamic_creation_of_user_handler_classes( assert (service_defn := nexusrpc.get_service_definition(service_cls)) service_name = service_defn.name - endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id async with Worker( client, task_queue=task_queue, diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index 407e61caf..21bff0563 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -56,8 +56,8 @@ from tests.helpers.nexus import ( Failure, ServiceClient, - create_nexus_endpoint, dataclass_as_dict, + make_nexus_endpoint_name, ) @@ -608,7 +608,11 @@ async def _test_start_operation_with_service_definition( if test_case.skip: pytest.skip(test_case.skip) task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), endpoint=endpoint, @@ -642,7 +646,11 @@ async def _test_start_operation_without_service_definition( if test_case.skip: pytest.skip(test_case.skip) task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), endpoint=endpoint, @@ -730,7 +738,11 @@ async def test_start_operation_without_type_annotations( if test_case.skip: pytest.skip(test_case.skip) task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), endpoint=endpoint, @@ -776,8 +788,11 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A task_queue = str(uuid.uuid4()) service_name = MyService.__name__ operation_name = "log" - resp = await create_nexus_endpoint(task_queue, env.client) - endpoint = resp.endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), endpoint=endpoint, @@ -929,7 +944,11 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment): """Verify that canceling an operation with an invalid token fails correctly.""" task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), endpoint=endpoint, @@ -961,7 +980,11 @@ async def test_request_id_is_received_by_sync_operation( pytest.skip("Nexus tests don't work with time-skipping server") task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), endpoint=endpoint, @@ -1035,7 +1058,11 @@ async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnviron # demonstrating that the Nexus Start Operation request ID has become the # StartWorkflow request ID. task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), endpoint=endpoint, diff --git a/tests/nexus/test_handler_async_operation.py b/tests/nexus/test_handler_async_operation.py index d25c8a072..13ffb9842 100644 --- a/tests/nexus/test_handler_async_operation.py +++ b/tests/nexus/test_handler_async_operation.py @@ -24,7 +24,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import ServiceClient, create_nexus_endpoint +from tests.helpers.nexus import ServiceClient, make_nexus_endpoint_name @dataclass @@ -113,7 +113,11 @@ async def test_async_operation_lifecycle( task_executor = await TaskExecutor.connect() task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id service_client = ServiceClient( ServiceClient.default_server_address(env), endpoint, diff --git a/tests/nexus/test_nexus_worker_shutdown.py b/tests/nexus/test_nexus_worker_shutdown.py index 0286b46c4..bd9063237 100644 --- a/tests/nexus/test_nexus_worker_shutdown.py +++ b/tests/nexus/test_nexus_worker_shutdown.py @@ -20,7 +20,6 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ( - create_nexus_endpoint, make_nexus_endpoint_name, ) @@ -127,7 +126,9 @@ async def test_nexus_worker_shutdown(env: WorkflowEnvironment): # Use separate task queues for caller and handler workers handler_task_queue = str(uuid.uuid4()) caller_task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(handler_task_queue, env.client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(handler_task_queue), handler_task_queue + ) operation_started = asyncio.Event() @@ -180,7 +181,9 @@ async def test_nexus_worker_shutdown_graceful(env: WorkflowEnvironment): # Use separate task queues for caller and handler workers handler_task_queue = str(uuid.uuid4()) caller_task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(handler_task_queue, env.client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(handler_task_queue), handler_task_queue + ) operation_started = asyncio.Event() @@ -234,7 +237,9 @@ async def test_sync_nexus_operation_worker_shutdown_graceful(env: WorkflowEnviro # Use separate task queues for caller and handler workers handler_task_queue = str(uuid.uuid4()) caller_task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(handler_task_queue, env.client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(handler_task_queue), handler_task_queue + ) sync_operation_started = threading.Event() @@ -292,7 +297,9 @@ async def test_is_worker_shutdown(env: WorkflowEnvironment): # Use separate task queues for caller and handler workers handler_task_queue = str(uuid.uuid4()) caller_task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(handler_task_queue, env.client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(handler_task_queue), handler_task_queue + ) operation_started = asyncio.Event() handler = ShutdownTestServiceHandler(operation_started) diff --git a/tests/nexus/test_use_existing_conflict_policy.py b/tests/nexus/test_use_existing_conflict_policy.py index ecefa4b05..95925ded0 100644 --- a/tests/nexus/test_use_existing_conflict_policy.py +++ b/tests/nexus/test_use_existing_conflict_policy.py @@ -12,7 +12,7 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name @dataclass @@ -103,7 +103,9 @@ async def test_multiple_operation_invocations_can_connect_to_same_handler_workfl workflows=[CallerWorkflow, HandlerWorkflow], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) caller_handle = await client.start_workflow( CallerWorkflow.run, args=[ diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3eaf8de29..70417625d 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -68,7 +68,7 @@ ) from tests.helpers import find_free_port, new_worker from tests.helpers.metrics import PromMetricMatcher -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name # TODO(nexus-preview): test worker shutdown, wait_all_completed, drain etc @@ -603,7 +603,8 @@ async def test_sync_operation_happy_path(client: Client, env: WorkflowEnvironmen task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_output = await client.execute_workflow( CallerWorkflow.run, args=[ @@ -640,7 +641,8 @@ async def test_workflow_run_operation_happy_path( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_output = await client.execute_workflow( CallerWorkflow.run, args=[ @@ -797,7 +799,8 @@ async def test_start_operation_headers( task_queue=task_queue, interceptors=[HeaderAddingOutboundInterceptor(), inbound_interceptor], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) workflow_headers = {"x-custom-from-workflow": "workflow-value"} result = await client.execute_workflow( @@ -842,7 +845,9 @@ async def test_workflow_run_operation_headers( workflows=[WorkflowRunHeaderTestCallerWorkflow, HeaderEchoWorkflow], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + result = await client.execute_workflow( WorkflowRunHeaderTestCallerWorkflow.run, WorkflowRunHeaderTestCallerWfInput( @@ -876,7 +881,8 @@ async def test_cancel_operation_headers( task_queue=task_queue, interceptors=[inbound_interceptor], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) workflow_headers = {"x-custom-cancel": "cancel-value"} await client.execute_workflow( @@ -930,7 +936,8 @@ async def test_sync_response( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) caller_wf_handle = await client.start_workflow( CallerWorkflow.run, args=[ @@ -1004,6 +1011,7 @@ async def test_async_response( workflow_failure_exception_types=[Exception], ): caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op( + env, client, task_queue, exception_in_operation_start, @@ -1076,6 +1084,7 @@ async def test_async_response( async def _start_wf_and_nexus_op( + env: WorkflowEnvironment, client: Client, task_queue: str, exception_in_operation_start: bool, @@ -1089,7 +1098,8 @@ async def _start_wf_and_nexus_op( """ Start the caller workflow and wait until the Nexus operation has started. """ - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) operation_workflow_id = str(uuid.uuid4()) # Start the caller workflow and wait until it confirms the Nexus operation has started. @@ -1174,7 +1184,8 @@ async def test_untyped_caller( op_definition_type=op_definition_type, exception_in_operation_start=exception_in_operation_start, ) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) caller_wf_handle = await client.start_workflow( UntypedCallerWorkflow.run, args=[ @@ -1335,7 +1346,8 @@ async def test_service_interface_and_implementation_names( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) assert await client.execute_workflow( ServiceInterfaceAndImplCallerWorkflow.run, args=(CallerReference.INTERFACE, NameOverride.YES, task_queue), @@ -1451,7 +1463,8 @@ async def test_workflow_run_operation_can_execute_workflow_before_starting_backi ], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) result = await client.execute_workflow( WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow.run, args=("result-1", task_queue), @@ -1503,7 +1516,8 @@ async def test_nexus_operation_summary( ], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_id = f"wf-{uuid.uuid4()}" handle = await client.start_workflow( ExecuteNexusOperationWithSummaryWorkflow.run, @@ -1799,7 +1813,8 @@ async def test_workflow_run_operation_overloads( ], nexus_service_handlers=[OverloadTestServiceHandler()], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) res = await client.execute_workflow( OverloadTestCallerWorkflow.run, args=[op, OverloadTestValue(value=2)], @@ -1859,7 +1874,8 @@ async def test_workflow_caller_custom_metrics(client: Client, env: WorkflowEnvir pytest.skip("Nexus tests don't work with time-skipping server") task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) # Create new runtime with Prom server prom_addr = f"127.0.0.1:{find_free_port()}" @@ -1952,7 +1968,8 @@ async def test_workflow_caller_buffered_metrics( runtime=runtime, ) task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) async with new_worker( client, CustomMetricsWorkflow, @@ -2129,7 +2146,8 @@ async def test_task_executor_operation_cancel_method( ], nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) caller_wf_handle = await client.start_workflow( CancelTestCallerWorkflow.run, diff --git a/tests/nexus/test_workflow_caller_cancellation_types.py b/tests/nexus/test_workflow_caller_cancellation_types.py index fd5aa84a7..9cbb95e77 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types.py +++ b/tests/nexus/test_workflow_caller_cancellation_types.py @@ -20,7 +20,7 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name @dataclass @@ -256,7 +256,9 @@ async def test_cancellation_type( workflows=[CallerWorkflow, HandlerWorkflow], nexus_service_handlers=[ServiceHandler()], ) as worker: - await create_nexus_endpoint(worker.task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(worker.task_queue), worker.task_queue + ) # Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op # token diff --git a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py index 37aa986ed..2e4ef401c 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py +++ b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py @@ -23,7 +23,7 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name from tests.nexus.test_workflow_caller_cancellation_types import ( assert_event_subsequence, get_event_time, @@ -222,7 +222,9 @@ async def test_cancellation_type( workflows=[CallerWorkflow, HandlerWorkflow], nexus_service_handlers=[ServiceHandler()], ) as worker: - await create_nexus_endpoint(worker.task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(worker.task_queue), worker.task_queue + ) # Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op # token diff --git a/tests/nexus/test_workflow_caller_error_chains.py b/tests/nexus/test_workflow_caller_error_chains.py index 07a575d60..9258834c7 100644 --- a/tests/nexus/test_workflow_caller_error_chains.py +++ b/tests/nexus/test_workflow_caller_error_chains.py @@ -23,7 +23,7 @@ ) from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name error_conversion_test_cases: dict[str, type[ErrorConversionTestCase]] = {} @@ -413,7 +413,9 @@ async def test_errors_raised_by_nexus_operation( workflows=[ErrorTestCallerWorkflow], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) await client.execute_workflow( ErrorTestCallerWorkflow.invoke_nexus_op_and_assert_error, ErrorTestInput( diff --git a/tests/nexus/test_workflow_caller_errors.py b/tests/nexus/test_workflow_caller_errors.py index 31353c5a9..2cfc3a36d 100644 --- a/tests/nexus/test_workflow_caller_errors.py +++ b/tests/nexus/test_workflow_caller_errors.py @@ -34,7 +34,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers import LogCapturer, assert_eq_eventually -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name operation_invocation_counts = Counter[str]() @@ -148,7 +148,9 @@ async def test_nexus_operation_is_retried( workflows=[CallerWorkflow], task_queue=input.task_queue, ): - await create_nexus_endpoint(input.task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(input.task_queue), input.task_queue + ) asyncio.create_task( client.execute_workflow( CallerWorkflow.run, @@ -211,7 +213,9 @@ async def test_nexus_operation_fails_without_retry_as_handler_error( workflows=[CallerWorkflow], task_queue=input.task_queue, ): - await create_nexus_endpoint(input.task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(input.task_queue), input.task_queue + ) try: await client.execute_workflow( CallerWorkflow.run, @@ -283,7 +287,9 @@ async def test_error_raised_by_timeout_of_nexus_start_operation( task_queue=task_queue, nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), ): - await create_nexus_endpoint(task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) try: await client.execute_workflow( StartTimeoutTestCallerWorkflow.run, @@ -377,7 +383,9 @@ async def test_error_raised_by_timeout_of_nexus_cancel_operation( task_queue=task_queue, ): with LogCapturer().logs_captured(logger) as capturer: - await create_nexus_endpoint(task_queue, client) + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) try: await client.execute_workflow( CancellationTimeoutTestCallerWorkflow.run, diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 7d284412c..a5abc582c 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -22,8 +22,8 @@ from tests.helpers.nexus import ( Failure, ServiceClient, - create_nexus_endpoint, dataclass_as_dict, + make_nexus_endpoint_name, ) @@ -94,7 +94,11 @@ async def test_workflow_run_operation( pytest.skip("Nexus tests don't work with time-skipping server") task_queue = str(uuid.uuid4()) - endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id + endpoint = ( + await env.create_nexus_endpoint( + make_nexus_endpoint_name(task_queue), task_queue + ) + ).id assert (service_defn := nexusrpc.get_service_definition(service_handler_cls)) service_client = ServiceClient( server_address=ServiceClient.default_server_address(env), diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py index c346c46bb..d3ce022f5 100644 --- a/tests/test_serialization_context.py +++ b/tests/test_serialization_context.py @@ -51,7 +51,7 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name @dataclass @@ -1751,7 +1751,8 @@ async def test_nexus_payload_codec_operations_lack_context( workflows=[NexusOperationTestWorkflow], nexus_service_handlers=[NexusOperationTestServiceHandler()], ) as worker: - await create_nexus_endpoint(worker.task_queue, client) + endpoint_name = make_nexus_endpoint_name(worker.task_queue) + await env.create_nexus_endpoint(endpoint_name, worker.task_queue) await client.execute_workflow( NexusOperationTestWorkflow.run, "workflow-data", diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index 16605876e..431f8280d 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -35,7 +35,7 @@ WorkflowInterceptorClassInput, WorkflowOutboundInterceptor, ) -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name interceptor_traces: list[tuple[str, Any]] = [] @@ -276,7 +276,7 @@ async def test_worker_interceptor(client: Client, env: WorkflowEnvironment): "Java test server: https://github.com/temporalio/sdk-java/issues/1424" ) task_queue = f"task-queue-{uuid.uuid4()}" - await create_nexus_endpoint(task_queue, client) + await env.create_nexus_endpoint(make_nexus_endpoint_name(task_queue), task_queue) async with Worker( client, diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index b1b65112d..d410c2128 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -63,7 +63,7 @@ new_worker, ) from tests.helpers.fork import _ForkTestResult, _TestFork -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name def test_load_default_worker_binary_id(): @@ -468,7 +468,8 @@ def reserve_asserts(self, ctx: SlotReserveContext) -> None: tuner=tuner, identity="myworker", ) as w: - await create_nexus_endpoint(w.task_queue, client) + endpoint_name = make_nexus_endpoint_name(w.task_queue) + await env.create_nexus_endpoint(endpoint_name, w.task_queue) wf1 = await client.start_workflow( CustomSlotSupplierWorkflow.run, id=f"custom-slot-supplier-{uuid.uuid4()}",