Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions temporalio/worker/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
operation: nexusrpc.Operation[InputT, OutputT] | str | Callable[..., Any]
input: InputT
schedule_to_close_timeout: timedelta | None
schedule_to_start_timeout: timedelta | None
start_to_close_timeout: timedelta | None
Comment on lines +306 to +307

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep new interceptor timeout fields optional

StartNexusOperationInput is part of the public interceptor API, but these newly added fields are required constructor args, which is a backwards-incompatible break for existing interceptors that instantiate this dataclass (or pass positional args) using the prior signature. After this change, older code will either raise TypeError for missing args or mis-bind positional values, so adding defaults (e.g., None) is needed to preserve compatibility while introducing the new timeouts.

Useful? React with 👍 / 👎.

cancellation_type: temporalio.workflow.NexusOperationCancellationType
headers: Mapping[str, str] | None
summary: str | None
Expand Down
10 changes: 10 additions & 0 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,8 @@ async def workflow_start_nexus_operation(
input: Any,
output_type: type[OutputT] | None,
schedule_to_close_timeout: timedelta | None,
schedule_to_start_timeout: timedelta | None,
start_to_close_timeout: timedelta | None,
cancellation_type: temporalio.workflow.NexusOperationCancellationType,
headers: Mapping[str, str] | None,
summary: str | None,
Expand All @@ -1609,6 +1611,8 @@ async def workflow_start_nexus_operation(
input=input,
output_type=output_type,
schedule_to_close_timeout=schedule_to_close_timeout,
schedule_to_start_timeout=schedule_to_start_timeout,
start_to_close_timeout=start_to_close_timeout,
cancellation_type=cancellation_type,
headers=headers,
summary=summary,
Expand Down Expand Up @@ -3340,6 +3344,12 @@ def _apply_schedule_command(self) -> None:
v.schedule_to_close_timeout.FromTimedelta(
self._input.schedule_to_close_timeout
)
if self._input.schedule_to_start_timeout is not None:
v.schedule_to_start_timeout.FromTimedelta(
self._input.schedule_to_start_timeout
)
if self._input.start_to_close_timeout is not None:
v.start_to_close_timeout.FromTimedelta(self._input.start_to_close_timeout)
v.cancellation_type = cast(
temporalio.bridge.proto.nexus.NexusOperationCancellationType.ValueType,
int(self._input.cancellation_type),
Expand Down
42 changes: 42 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,8 @@ async def workflow_start_nexus_operation(
input: Any,
output_type: type[OutputT] | None,
schedule_to_close_timeout: timedelta | None,
schedule_to_start_timeout: timedelta | None,
start_to_close_timeout: timedelta | None,
cancellation_type: temporalio.workflow.NexusOperationCancellationType,
headers: Mapping[str, str] | None,
summary: str | None,
Expand Down Expand Up @@ -5418,6 +5420,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5433,6 +5437,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5451,6 +5457,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5469,6 +5477,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5487,6 +5497,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5504,6 +5516,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5517,6 +5531,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5528,6 +5544,8 @@ async def start_operation(
input: The Nexus operation input.
output_type: The Nexus operation output type.
schedule_to_close_timeout: Timeout for the entire operation attempt.
schedule_to_start_timeout: Timeout for the operation to be started.
start_to_close_timeout: Timeout for async operations to complete after starting.
headers: Headers to send with the Nexus HTTP request.

Returns:
Expand All @@ -5548,6 +5566,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5563,6 +5583,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5581,6 +5603,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5599,6 +5623,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5617,6 +5643,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5635,6 +5663,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5648,6 +5678,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5659,6 +5691,8 @@ async def execute_operation(
input: The Nexus operation input.
output_type: The Nexus operation output type.
schedule_to_close_timeout: Timeout for the entire operation attempt.
schedule_to_start_timeout: Timeout for the operation to be started.
start_to_close_timeout: Timeout for async operations to complete after starting.
headers: Headers to send with the Nexus HTTP request.

Returns:
Expand Down Expand Up @@ -5701,6 +5735,8 @@ async def start_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5713,6 +5749,8 @@ async def start_operation(
input=input,
output_type=output_type,
schedule_to_close_timeout=schedule_to_close_timeout,
schedule_to_start_timeout=schedule_to_start_timeout,
start_to_close_timeout=start_to_close_timeout,
cancellation_type=cancellation_type,
headers=headers,
summary=summary,
Expand All @@ -5726,6 +5764,8 @@ async def execute_operation(
*,
output_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Mapping[str, str] | None = None,
summary: str | None = None,
Expand All @@ -5735,6 +5775,8 @@ async def execute_operation(
input,
output_type=output_type,
schedule_to_close_timeout=schedule_to_close_timeout,
schedule_to_start_timeout=schedule_to_start_timeout,
start_to_close_timeout=start_to_close_timeout,
cancellation_type=cancellation_type,
headers=headers,
summary=summary,
Expand Down
144 changes: 144 additions & 0 deletions tests/nexus/test_workflow_caller_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ApplicationError,
NexusOperationError,
TimeoutError,
TimeoutType,
)
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
Expand Down Expand Up @@ -323,6 +324,149 @@ async def test_error_raised_by_timeout_of_nexus_start_operation(
assert capturer.find_log("unexpected cancellation reason") is None


# Schedule to start timeout test
@service_handler
class ScheduleToStartTimeoutTestService:
@sync_operation
async def expect_schedule_to_start_timeout(
self, ctx: StartOperationContext, _input: None
) -> None:
try:
await asyncio.wait_for(ctx.task_cancellation.wait_until_cancelled(), 1)
except asyncio.TimeoutError:
raise ApplicationError("expected cancel", non_retryable=True)


@workflow.defn
class ScheduleToStartTimeoutTestCallerWorkflow:
@workflow.init
def __init__(self):
self.nexus_client = workflow.create_nexus_client(
service=ScheduleToStartTimeoutTestService,
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
)

@workflow.run
async def run(self) -> None:
await self.nexus_client.execute_operation(
ScheduleToStartTimeoutTestService.expect_schedule_to_start_timeout,
None,
output_type=None,
schedule_to_start_timeout=timedelta(seconds=0.1),
)


async def test_error_raised_by_schedule_to_start_timeout_of_nexus_operation(
client: Client, env: WorkflowEnvironment
):
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
async with Worker(
client,
nexus_service_handlers=[ScheduleToStartTimeoutTestService()],
workflows=[ScheduleToStartTimeoutTestCallerWorkflow],
task_queue=task_queue,
nexus_task_executor=concurrent.futures.ThreadPoolExecutor(),
):
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
try:
await client.execute_workflow(
ScheduleToStartTimeoutTestCallerWorkflow.run,
id=str(uuid.uuid4()),
task_queue=task_queue,
)
except Exception as err:
assert isinstance(err, WorkflowFailureError)
assert isinstance(err.__cause__, NexusOperationError)
assert isinstance(err.__cause__.__cause__, TimeoutError)
timeout_err = err.__cause__.__cause__
assert timeout_err.type == TimeoutType.SCHEDULE_TO_START
else:
pytest.fail(
"Expected exception due to schedule to start timeout of nexus operation"
)


# Start to close timeout test


class OperationThatExpectsStartToCloseTimeoutAsync(OperationHandler[None, None]):
async def start(
self, ctx: StartOperationContext, input: None
) -> StartOperationResultAsync:
return StartOperationResultAsync("fake-token")

async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
pass


@service_handler
class StartToCloseTimeoutTestService:
@operation_handler
def expect_start_to_close_timeout(self) -> OperationHandler[None, None]:
return OperationThatExpectsStartToCloseTimeoutAsync()


@workflow.defn
class StartToCloseTimeoutTestCallerWorkflow:
@workflow.init
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to create the nexus client as a member? Not that it matters too much in the test, but if we're viewing all code as AI consumed, is this a good pattern to suggest?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually what all the tests in this file do so i just copied them for consistency.

self,
):
self.nexus_client = workflow.create_nexus_client(
service=StartToCloseTimeoutTestService,
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
)

@workflow.run
async def run(self) -> None:
op_handle = await self.nexus_client.start_operation(
StartToCloseTimeoutTestService.expect_start_to_close_timeout,
None,
start_to_close_timeout=timedelta(seconds=0.1),
)
await op_handle


async def test_error_raised_by_start_to_close_timeout_of_nexus_operation(
client: Client, env: WorkflowEnvironment
):
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
async with Worker(
client,
nexus_service_handlers=[StartToCloseTimeoutTestService()],
workflows=[StartToCloseTimeoutTestCallerWorkflow],
task_queue=task_queue,
nexus_task_executor=concurrent.futures.ThreadPoolExecutor(),
):
await env.create_nexus_endpoint(
make_nexus_endpoint_name(task_queue), task_queue
)
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer pytest.raises, but no big deal.

await client.execute_workflow(
StartToCloseTimeoutTestCallerWorkflow.run,
id=str(uuid.uuid4()),
task_queue=task_queue,
)
except Exception as err:
assert isinstance(err, WorkflowFailureError)
assert isinstance(err.__cause__, NexusOperationError)
timeout_err = err.__cause__.__cause__
assert isinstance(timeout_err, TimeoutError)
assert timeout_err.type == TimeoutType.START_TO_CLOSE
else:
pytest.fail(
"Expected exception due to start to close timeout of nexus operation"
)


# Cancellation timeout test


Expand Down
Loading