diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/aio/client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/aio/client.py index 8a2790638..223a71b04 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/aio/client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/aio/client.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging +import time import uuid from datetime import datetime from typing import Any, Optional, Sequence, Union @@ -33,6 +35,7 @@ TOutput, WorkflowIdReusePolicy, WorkflowState, + _TransientTimeout, new_orchestration_state, ) from google.protobuf import wrappers_pb2 @@ -120,34 +123,33 @@ async def get_orchestration_state( return new_orchestration_state(req.instanceId, res) async def wait_for_orchestration_start( - self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0 + self, instance_id: str, *, fetch_payloads: bool = False, timeout: Optional[int] = 0 ) -> Optional[WorkflowState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - try: - grpc_timeout = None if timeout == 0 else timeout - self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start." - ) + self._logger.info( + f"Waiting {'indefinitely' if timeout in (0, None) else f'up to {timeout}s'} for instance '{instance_id}' to start." + ) + + async def _call(grpc_timeout): res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart( req, timeout=grpc_timeout ) return new_orchestration_state(req.instanceId, res) - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore - # Replace gRPC error with the built-in TimeoutError - raise TimeoutError('Timed-out waiting for the orchestration to start') - else: - raise + + try: + return await self._call_with_transient_retry(instance_id, timeout, _call) + except _TransientTimeout: + raise TimeoutError('Timed-out waiting for the orchestration to start') async def wait_for_orchestration_completion( - self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0 + self, instance_id: str, *, fetch_payloads: bool = True, timeout: Optional[int] = 0 ) -> Optional[WorkflowState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - try: - grpc_timeout = None if timeout == 0 else timeout - self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete." - ) + self._logger.info( + f"Waiting {'indefinitely' if timeout in (0, None) else f'up to {timeout}s'} for instance '{instance_id}' to complete." + ) + + async def _call(grpc_timeout): res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion( req, timeout=grpc_timeout ) @@ -167,14 +169,87 @@ async def wait_for_orchestration_completion( self._logger.info(f"Instance '{instance_id}' was terminated.") elif state.runtime_status == OrchestrationStatus.COMPLETED: self._logger.info(f"Instance '{instance_id}' completed.") - return state - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore - # Replace gRPC error with the built-in TimeoutError - raise TimeoutError('Timed-out waiting for the orchestration to complete') - else: - raise + + try: + return await self._call_with_transient_retry(instance_id, timeout, _call) + except _TransientTimeout: + raise TimeoutError('Timed-out waiting for the orchestration to complete') + + # Transient gRPC codes that indicate the workflow runtime is temporarily + # unable to locate the workflow actor — typically immediately after a Dapr + # sidecar restart (e.g. recovery from chaos). The placement service has the + # actor registration, but local daprd hasn't received the dissemination yet. + # Without retry, every poll fails permanently with FAILED_PRECONDITION even + # though the workflow runtime state is intact. + _TRANSIENT_RPC_CODES = ( + grpc.StatusCode.FAILED_PRECONDITION, + grpc.StatusCode.UNAVAILABLE, + ) + + # See TaskHubGrpcClient._MAX_TRANSIENT_RETRY_SECONDS — same grace window for + # unbounded (timeout=0) callers so a down sidecar surfaces the original + # error instead of retrying forever. + _MAX_TRANSIENT_RETRY_SECONDS = 30.0 + + async def _call_with_transient_retry(self, instance_id, timeout, call_fn): + """Async mirror of TaskHubGrpcClient._call_with_transient_retry. + Retries FAILED_PRECONDITION/UNAVAILABLE with capped exponential + backoff while clamping sleep and per-call gRPC timeout to the + remaining budget. The first call uses the caller's timeout unchanged + (``None`` when unbounded) so callers observe identical behavior on a + healthy runtime. In unbounded + mode, continuous transient retries are capped at + ``_MAX_TRANSIENT_RETRY_SECONDS`` before the original error propagates. + """ + unbounded = timeout in (0, None) + deadline = None if unbounded else time.monotonic() + timeout + grpc_timeout = None if unbounded else timeout + backoff = 0.5 + transient_deadline = None # unbounded mode only; anchored on first transient + while True: + try: + return await call_fn(grpc_timeout) + except grpc.RpcError as rpc_error: + code = rpc_error.code() # type: ignore + if code == grpc.StatusCode.DEADLINE_EXCEEDED: + raise _TransientTimeout() + if code not in self._TRANSIENT_RPC_CODES: + raise + + now = time.monotonic() + + if unbounded: + if transient_deadline is None: + transient_deadline = now + self._MAX_TRANSIENT_RETRY_SECONDS + elif now >= transient_deadline: + raise + + if deadline is None: + remaining = None + else: + remaining = deadline - now + if remaining <= 0: + raise _TransientTimeout() + + sleep_for = min(backoff, 5.0) + if remaining is not None: + sleep_for = min(sleep_for, remaining) + if transient_deadline is not None: + sleep_for = min(sleep_for, transient_deadline - now) + self._logger.warning( + f"Transient gRPC error {code.name} waiting on instance '{instance_id}'; " + f'retrying in {sleep_for:.2f}s' + ) + await asyncio.sleep(sleep_for) + backoff = min(backoff * 2, 5.0) + + if deadline is None: + grpc_timeout = None + else: + grpc_timeout = deadline - time.monotonic() + if grpc_timeout <= 0: + raise _TransientTimeout() async def raise_orchestration_event( self, instance_id: str, event_name: str, *, data: Optional[Any] = None diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/client.py index 72fcb2f89..e80634f2c 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/client.py @@ -10,6 +10,7 @@ # limitations under the License. import logging +import time import uuid from dataclasses import dataclass from datetime import datetime @@ -25,6 +26,12 @@ from dapr.ext.workflow._durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl from google.protobuf import wrappers_pb2 + +class _TransientTimeout(Exception): + """Internal sentinel: the retry loop exhausted the user-provided timeout + budget. Callers convert this to a public ``TimeoutError``.""" + + TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -217,32 +224,31 @@ def get_orchestration_state( return new_orchestration_state(req.instanceId, res) def wait_for_orchestration_start( - self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0 + self, instance_id: str, *, fetch_payloads: bool = False, timeout: Optional[int] = 0 ) -> Optional[WorkflowState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - try: - grpc_timeout = None if timeout == 0 else timeout - self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start." - ) + self._logger.info( + f"Waiting {'indefinitely' if timeout in (0, None) else f'up to {timeout}s'} for instance '{instance_id}' to start." + ) + + def _call(grpc_timeout): res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=grpc_timeout) return new_orchestration_state(req.instanceId, res) - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore - # Replace gRPC error with the built-in TimeoutError - raise TimeoutError('Timed-out waiting for the orchestration to start') - else: - raise + + try: + return self._call_with_transient_retry(instance_id, timeout, _call) + except _TransientTimeout: + raise TimeoutError('Timed-out waiting for the orchestration to start') def wait_for_orchestration_completion( - self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0 + self, instance_id: str, *, fetch_payloads: bool = True, timeout: Optional[int] = 0 ) -> Optional[WorkflowState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - try: - grpc_timeout = None if timeout == 0 else timeout - self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete." - ) + self._logger.info( + f"Waiting {'indefinitely' if timeout in (0, None) else f'up to {timeout}s'} for instance '{instance_id}' to complete." + ) + + def _call(grpc_timeout): res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion( req, timeout=grpc_timeout ) @@ -262,14 +268,100 @@ def wait_for_orchestration_completion( self._logger.info(f"Instance '{instance_id}' was terminated.") elif state.runtime_status == OrchestrationStatus.COMPLETED: self._logger.info(f"Instance '{instance_id}' completed.") - return state - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore - # Replace gRPC error with the built-in TimeoutError - raise TimeoutError('Timed-out waiting for the orchestration to complete') - else: - raise + + try: + return self._call_with_transient_retry(instance_id, timeout, _call) + except _TransientTimeout: + raise TimeoutError('Timed-out waiting for the orchestration to complete') + + # Transient gRPC codes that indicate the workflow runtime is temporarily + # unable to locate the workflow actor — typically immediately after a Dapr + # sidecar restart (e.g. recovery from chaos). The placement service has the + # actor registration, but local daprd hasn't received the dissemination yet. + # Without retry, every poll fails permanently with FAILED_PRECONDITION even + # though the workflow runtime state is intact. + _TRANSIENT_RPC_CODES = ( + grpc.StatusCode.FAILED_PRECONDITION, + grpc.StatusCode.UNAVAILABLE, + ) + + # When the caller sets no timeout (timeout=0), bound how long we keep + # retrying *consecutive* transient errors so a permanently-unavailable + # sidecar surfaces the original error instead of hanging forever. This + # window comfortably covers placement re-dissemination after a restart; + # a slow-but-healthy workflow never enters this path (it just blocks in + # the long-poll), so its indefinite wait is preserved. + _MAX_TRANSIENT_RETRY_SECONDS = 30.0 + + def _call_with_transient_retry(self, instance_id, timeout, call_fn): + """Run a gRPC wait call, retrying transient errors until the user + timeout deadline. Re-raises non-transient errors immediately. + timeout in (0, None) means unbounded; transients are still retried, + but only for up to ``_MAX_TRANSIENT_RETRY_SECONDS`` of continuous + failures, after which the original transient error propagates. + + The first call passes the caller's ``grpc_timeout`` (``None`` when + unbounded) to ``call_fn`` so callers observe identical behavior to a + non-retrying client when no transient occurs (preserves prior public + behavior). On a retry, both the sleep + and the per-call gRPC deadline are clamped to the remaining budget so + the helper never sleeps past ``timeout`` or starts a gRPC call with + no time left. + """ + unbounded = timeout in (0, None) + deadline = None if unbounded else time.monotonic() + timeout + grpc_timeout = None if unbounded else timeout + backoff = 0.5 + transient_deadline = None # unbounded mode only; anchored on first transient + while True: + try: + return call_fn(grpc_timeout) + except grpc.RpcError as rpc_error: + code = rpc_error.code() # type: ignore + if code == grpc.StatusCode.DEADLINE_EXCEEDED: + raise _TransientTimeout() + if code not in self._TRANSIENT_RPC_CODES: + raise + + now = time.monotonic() + + # In unbounded mode the user budget can't end the loop, so cap + # continuous transient retries and re-raise the original error + # (matching pre-retry behavior) once the grace window elapses. + if unbounded: + if transient_deadline is None: + transient_deadline = now + self._MAX_TRANSIENT_RETRY_SECONDS + elif now >= transient_deadline: + raise + + # Compute remaining budget once and reuse so the sleep and the + # next per-call grpc_timeout agree on "how much time is left". + if deadline is None: + remaining = None + else: + remaining = deadline - now + if remaining <= 0: + raise _TransientTimeout() + + sleep_for = min(backoff, 5.0) + if remaining is not None: + sleep_for = min(sleep_for, remaining) + if transient_deadline is not None: + sleep_for = min(sleep_for, transient_deadline - now) + self._logger.warning( + f"Transient gRPC error {code.name} waiting on instance '{instance_id}'; " + f'retrying in {sleep_for:.2f}s' + ) + time.sleep(sleep_for) + backoff = min(backoff * 2, 5.0) + + if deadline is None: + grpc_timeout = None + else: + grpc_timeout = deadline - time.monotonic() + if grpc_timeout <= 0: + raise _TransientTimeout() def raise_orchestration_event( self, instance_id: str, event_name: str, *, data: Optional[Any] = None diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_wait.py b/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_wait.py index f550cf4f9..66c768b88 100644 --- a/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_wait.py +++ b/ext/dapr-ext-workflow/tests/durabletask/test_orchestration_wait.py @@ -1,5 +1,7 @@ +import time from unittest.mock import Mock +import grpc import pytest from dapr.ext.workflow._durabletask.client import TaskHubGrpcClient @@ -66,3 +68,121 @@ def test_wait_for_orchestration_completion_timeout(timeout): assert kwargs.get('timeout') is None else: assert kwargs.get('timeout') == timeout + + +def _make_rpc_error(code: grpc.StatusCode) -> grpc.RpcError: + err = grpc.RpcError() + err.code = lambda: code # type: ignore[method-assign] + err.details = lambda: f'simulated {code.name}' # type: ignore[method-assign] + return err + + +@pytest.mark.parametrize( + 'transient_code', [grpc.StatusCode.FAILED_PRECONDITION, grpc.StatusCode.UNAVAILABLE] +) +def test_wait_for_orchestration_start_retries_transient_then_succeeds(transient_code, monkeypatch): + """Transient gRPC error on the first call → backoff → next call succeeds.""" + instance_id = 'test-instance' + + from dapr.ext.workflow._durabletask.internal.protos import ( + ORCHESTRATION_STATUS_RUNNING, + GetInstanceResponse, + WorkflowState, + ) + + response = GetInstanceResponse() + state = WorkflowState() + state.instanceId = instance_id + state.workflowStatus = ORCHESTRATION_STATUS_RUNNING + response.workflowState.CopyFrom(state) + + sleeps = [] + monkeypatch.setattr( + 'dapr.ext.workflow._durabletask.client.time.sleep', lambda s: sleeps.append(s) + ) + + calls = {'n': 0} + + def fake_call(*args, **kwargs): + calls['n'] += 1 + if calls['n'] == 1: + raise _make_rpc_error(transient_code) + return response + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceStart.side_effect = fake_call + + # The point of this test is the retry behavior, not the response payload — + # the second call returns successfully (no exception), the first transient + # is absorbed, and exactly one backoff sleep happens between them. + c.wait_for_orchestration_start(instance_id, timeout=10) + assert calls['n'] == 2 + assert len(sleeps) == 1 and sleeps[0] > 0 + + +def test_wait_for_orchestration_start_transient_exhaustion_raises_timeout(monkeypatch): + """Transient gRPC errors keep returning until the user budget runs out + → public TimeoutError, not the raw RpcError.""" + instance_id = 'test-instance' + + # Advance monotonic time on every call so the deadline is reached quickly. + fake_time = [0.0] + + def fake_monotonic(): + fake_time[0] += 0.6 # 0.0, 0.6, 1.2, ... + return fake_time[0] + + monkeypatch.setattr('dapr.ext.workflow._durabletask.client.time.monotonic', fake_monotonic) + monkeypatch.setattr('dapr.ext.workflow._durabletask.client.time.sleep', lambda s: None) + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceStart.side_effect = _make_rpc_error(grpc.StatusCode.UNAVAILABLE) + + with pytest.raises(TimeoutError): + c.wait_for_orchestration_start(instance_id, timeout=1) + + +def test_wait_for_orchestration_start_non_transient_propagates(monkeypatch): + """Non-transient gRPC errors must NOT be retried — propagate directly.""" + instance_id = 'test-instance' + monkeypatch.setattr(time, 'sleep', lambda s: None) + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceStart.side_effect = _make_rpc_error(grpc.StatusCode.PERMISSION_DENIED) + + with pytest.raises(grpc.RpcError): + c.wait_for_orchestration_start(instance_id, timeout=10) + assert c._stub.WaitForInstanceStart.call_count == 1 + + +def test_wait_for_orchestration_start_unbounded_transient_gives_up_with_rpc_error(monkeypatch): + """With timeout=0 (unbounded), persistent transient errors are retried only + for the grace window, then the original RpcError propagates — NOT a hang and + NOT a TimeoutError, preserving the pre-retry contract that timeout=0 surfaces + the gRPC error rather than TimeoutError.""" + instance_id = 'test-instance' + + # Advance well past _MAX_TRANSIENT_RETRY_SECONDS on each transient so the + # grace window is exhausted within a couple of retries. + fake_time = [0.0] + + def fake_monotonic(): + fake_time[0] += 20.0 # 20, 40, 60, ... — anchors at 20, deadline 50 + return fake_time[0] + + monkeypatch.setattr('dapr.ext.workflow._durabletask.client.time.monotonic', fake_monotonic) + monkeypatch.setattr('dapr.ext.workflow._durabletask.client.time.sleep', lambda s: None) + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceStart.side_effect = _make_rpc_error(grpc.StatusCode.UNAVAILABLE) + + with pytest.raises(grpc.RpcError) as exc_info: + c.wait_for_orchestration_start(instance_id, timeout=0) + assert not isinstance(exc_info.value, TimeoutError) + # Retried at least once before giving up (proves it didn't fail-fast like the + # non-transient path, and didn't loop forever). + assert c._stub.WaitForInstanceStart.call_count >= 2