Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions temporalio/testing/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import google.protobuf.empty_pb2
from typing_extensions import Self

import temporalio.api.nexus.v1
import temporalio.api.operatorservice.v1
import temporalio.api.testservice.v1
import temporalio.bridge.testing
import temporalio.client
Expand Down Expand Up @@ -401,6 +403,48 @@ def supports_time_skipping(self) -> bool:
"""Whether this environment supports time skipping."""
return False

async def create_nexus_endpoint(
self, endpoint_name: str, task_queue: str
) -> temporalio.api.nexus.v1.Endpoint:
"""Create a Nexus endpoint with the given name and task queue.

Args:
endpoint_name: The name of the Nexus endpoint to create.
task_queue: The task queue to associate with the endpoint.

Returns:
The created Nexus endpoint.
"""
response = await self._client.operator_service.create_nexus_endpoint(
temporalio.api.operatorservice.v1.CreateNexusEndpointRequest(
spec=temporalio.api.nexus.v1.EndpointSpec(
name=endpoint_name,
target=temporalio.api.nexus.v1.EndpointTarget(
worker=temporalio.api.nexus.v1.EndpointTarget.Worker(
namespace=self._client.namespace,
task_queue=task_queue,
)
),
)
)
)
return response.endpoint

async def delete_nexus_endpoint(
self, endpoint: temporalio.api.nexus.v1.Endpoint
) -> None:
"""Delete a Nexus endpoint.

Args:
endpoint: The Nexus endpoint to delete.
"""
await self._client.operator_service.delete_nexus_endpoint(
temporalio.api.operatorservice.v1.DeleteNexusEndpointRequest(
id=endpoint.id,
version=endpoint.version,
)
)

@contextmanager
def auto_time_skipping_disabled(self) -> Iterator[None]:
"""Disable any automatic time skipping if this is a time-skipping
Expand Down
6 changes: 4 additions & 2 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
ResearchManager,
)
from tests.helpers import assert_eventually, new_worker
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
from tests.helpers.nexus import make_nexus_endpoint_name


def hello_mock_model():
Expand Down Expand Up @@ -489,7 +489,9 @@ async def test_nexus_tool_workflow(
NexusToolsWorkflow,
nexus_service_handlers=[WeatherServiceHandler()],
) as worker:
await create_nexus_endpoint(worker.task_queue, client)
await env.create_nexus_endpoint(
make_nexus_endpoint_name(worker.task_queue), worker.task_queue
)

workflow_handle = await client.start_workflow(
NexusToolsWorkflow.run,
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/opentelemetry/test_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
from tests.helpers import LogCapturer
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
from tests.helpers.nexus import make_nexus_endpoint_name


@dataclass
Expand Down Expand Up @@ -475,7 +475,7 @@ async def test_opentelemetry_tracing_nexus(client: Client, env: WorkflowEnvironm
client = Client(**client_config)

task_queue = f"task-queue-{uuid.uuid4()}"
await create_nexus_endpoint(task_queue, client)
await env.create_nexus_endpoint(make_nexus_endpoint_name(task_queue), task_queue)
async with Worker(
client,
task_queue=task_queue,
Expand Down
6 changes: 4 additions & 2 deletions tests/contrib/opentelemetry/test_opentelemetry_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Import the dump_spans function from the original opentelemetry test
from tests.contrib.opentelemetry.test_opentelemetry import dump_spans
from tests.helpers import new_worker
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
from tests.helpers.nexus import make_nexus_endpoint_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -272,7 +272,9 @@ async def test_opentelemetry_comprehensive_tracing(
max_cached_workflows=0,
) as worker:
# Create Nexus endpoint for this task queue
await create_nexus_endpoint(worker.task_queue, new_client)
await env.create_nexus_endpoint(
make_nexus_endpoint_name(worker.task_queue), worker.task_queue
)

with get_tracer(__name__).start_as_current_span("ComprehensiveTest") as span:
span.set_attribute("test.type", "comprehensive")
Expand Down
24 changes: 0 additions & 24 deletions tests/helpers/nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
from urllib.parse import urlparse

import temporalio.api.failure.v1
import temporalio.api.nexus.v1
import temporalio.api.operatorservice.v1
import temporalio.workflow
from temporalio.client import Client
from temporalio.converter import FailureConverter, PayloadConverter
from temporalio.testing import WorkflowEnvironment

Expand All @@ -22,27 +19,6 @@ def make_nexus_endpoint_name(task_queue: str) -> str:
return f"nexus-endpoint-{task_queue}"


# TODO(nexus-preview): How do we recommend that users create endpoints in their own tests?
# See https://github.com/temporalio/sdk-typescript/pull/1708/files?show-viewed-files=true&file-filters%5B%5D=&w=0#r2082549085
async def create_nexus_endpoint(
task_queue: str, client: Client
) -> temporalio.api.operatorservice.v1.CreateNexusEndpointResponse:
name = make_nexus_endpoint_name(task_queue)
return await client.operator_service.create_nexus_endpoint(
temporalio.api.operatorservice.v1.CreateNexusEndpointRequest(
spec=temporalio.api.nexus.v1.EndpointSpec(
name=name,
target=temporalio.api.nexus.v1.EndpointTarget(
worker=temporalio.api.nexus.v1.EndpointTarget.Worker(
namespace=client.namespace,
task_queue=task_queue,
)
),
)
)
)


@dataclass
class ServiceClient:
server_address: str # E.g. http://127.0.0.1:7243
Expand Down
14 changes: 11 additions & 3 deletions tests/nexus/test_dynamic_creation_of_user_handler_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from temporalio.nexus._util import get_operation_factory
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import ServiceClient, create_nexus_endpoint
from tests.helpers.nexus import ServiceClient, make_nexus_endpoint_name


@workflow.defn
Expand Down Expand Up @@ -80,7 +80,11 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(

service_name = service_handler.service.name

endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
async with Worker(
client,
task_queue=task_queue,
Expand Down Expand Up @@ -146,7 +150,11 @@ async def test_dynamic_creation_of_user_handler_classes(
assert (service_defn := nexusrpc.get_service_definition(service_cls))
service_name = service_defn.name

endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
async with Worker(
client,
task_queue=task_queue,
Expand Down
45 changes: 36 additions & 9 deletions tests/nexus/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
from tests.helpers.nexus import (
Failure,
ServiceClient,
create_nexus_endpoint,
dataclass_as_dict,
make_nexus_endpoint_name,
)


Expand Down Expand Up @@ -608,7 +608,11 @@ async def _test_start_operation_with_service_definition(
if test_case.skip:
pytest.skip(test_case.skip)
task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
Expand Down Expand Up @@ -642,7 +646,11 @@ async def _test_start_operation_without_service_definition(
if test_case.skip:
pytest.skip(test_case.skip)
task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
Expand Down Expand Up @@ -730,7 +738,11 @@ async def test_start_operation_without_type_annotations(
if test_case.skip:
pytest.skip(test_case.skip)
task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
Expand Down Expand Up @@ -776,8 +788,11 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A
task_queue = str(uuid.uuid4())
service_name = MyService.__name__
operation_name = "log"
resp = await create_nexus_endpoint(task_queue, env.client)
endpoint = resp.endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
Expand Down Expand Up @@ -929,7 +944,11 @@ async def test_cancel_operation_with_invalid_token(env: WorkflowEnvironment):

"""Verify that canceling an operation with an invalid token fails correctly."""
task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
Expand Down Expand Up @@ -961,7 +980,11 @@ async def test_request_id_is_received_by_sync_operation(
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
Expand Down Expand Up @@ -1035,7 +1058,11 @@ async def test_request_id_becomes_start_workflow_request_id(env: WorkflowEnviron
# demonstrating that the Nexus Start Operation request ID has become the
# StartWorkflow request ID.
task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
Expand Down
8 changes: 6 additions & 2 deletions tests/nexus/test_handler_async_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import ServiceClient, create_nexus_endpoint
from tests.helpers.nexus import ServiceClient, make_nexus_endpoint_name


@dataclass
Expand Down Expand Up @@ -113,7 +113,11 @@ async def test_async_operation_lifecycle(

task_executor = await TaskExecutor.connect()
task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
endpoint = (
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
).id
service_client = ServiceClient(
ServiceClient.default_server_address(env),
endpoint,
Expand Down
17 changes: 12 additions & 5 deletions tests/nexus/test_nexus_worker_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import (
create_nexus_endpoint,
make_nexus_endpoint_name,
)

Expand Down Expand Up @@ -127,7 +126,9 @@ async def test_nexus_worker_shutdown(env: WorkflowEnvironment):
# Use separate task queues for caller and handler workers
handler_task_queue = str(uuid.uuid4())
caller_task_queue = str(uuid.uuid4())
await create_nexus_endpoint(handler_task_queue, env.client)
await env.create_nexus_endpoint(
make_nexus_endpoint_name(handler_task_queue), handler_task_queue
)

operation_started = asyncio.Event()

Expand Down Expand Up @@ -180,7 +181,9 @@ async def test_nexus_worker_shutdown_graceful(env: WorkflowEnvironment):
# Use separate task queues for caller and handler workers
handler_task_queue = str(uuid.uuid4())
caller_task_queue = str(uuid.uuid4())
await create_nexus_endpoint(handler_task_queue, env.client)
await env.create_nexus_endpoint(
make_nexus_endpoint_name(handler_task_queue), handler_task_queue
)

operation_started = asyncio.Event()

Expand Down Expand Up @@ -234,7 +237,9 @@ async def test_sync_nexus_operation_worker_shutdown_graceful(env: WorkflowEnviro
# Use separate task queues for caller and handler workers
handler_task_queue = str(uuid.uuid4())
caller_task_queue = str(uuid.uuid4())
await create_nexus_endpoint(handler_task_queue, env.client)
await env.create_nexus_endpoint(
make_nexus_endpoint_name(handler_task_queue), handler_task_queue
)

sync_operation_started = threading.Event()

Expand Down Expand Up @@ -292,7 +297,9 @@ async def test_is_worker_shutdown(env: WorkflowEnvironment):
# Use separate task queues for caller and handler workers
handler_task_queue = str(uuid.uuid4())
caller_task_queue = str(uuid.uuid4())
await create_nexus_endpoint(handler_task_queue, env.client)
await env.create_nexus_endpoint(
make_nexus_endpoint_name(handler_task_queue), handler_task_queue
)

operation_started = asyncio.Event()
handler = ShutdownTestServiceHandler(operation_started)
Expand Down
6 changes: 4 additions & 2 deletions tests/nexus/test_use_existing_conflict_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from temporalio.common import WorkflowIDConflictPolicy
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
from tests.helpers.nexus import make_nexus_endpoint_name


@dataclass
Expand Down Expand Up @@ -103,7 +103,9 @@ async def test_multiple_operation_invocations_can_connect_to_same_handler_workfl
workflows=[CallerWorkflow, HandlerWorkflow],
task_queue=task_queue,
):
await create_nexus_endpoint(task_queue, client)
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
caller_handle = await client.start_workflow(
CallerWorkflow.run,
args=[
Expand Down
Loading