Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cadence/_internal/workflow/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import contextmanager
from datetime import timedelta
from math import ceil
from typing import Optional, Any, Unpack, Type, cast
from typing import Iterator, Optional, Any, Unpack, Type, cast

from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
from cadence._internal.workflow.decisions_helper import DecisionsHelper
Expand All @@ -15,13 +16,12 @@ class Context(WorkflowContext):
def __init__(
self,
info: WorkflowInfo,
decision_helper: DecisionsHelper,
decision_manager: DecisionManager,
):
self._info = info
self._replay_mode = True
self._replay_current_time_milliseconds: Optional[int] = None
self._decision_helper = decision_helper
self._decision_helper = DecisionsHelper()
self._decision_manager = decision_manager

def info(self) -> WorkflowInfo:
Expand Down Expand Up @@ -110,6 +110,12 @@ def get_replay_current_time_milliseconds(self) -> Optional[int]:
"""Get the current replay time in milliseconds."""
return self._replay_current_time_milliseconds

@contextmanager
def _activate(self) -> Iterator["Context"]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override so I could steal the activated context rather than exposed WorkflowContext. Maybe this is not necessary but I'm trying to avoid direct access to self._context inside WorkflowEngine.

token = WorkflowContext._var.set(self)
yield self
WorkflowContext._var.reset(token)


def _round_to_nearest_second(delta: timedelta) -> timedelta:
return timedelta(seconds=ceil(delta.total_seconds()))
248 changes: 95 additions & 153 deletions cadence/_internal/workflow/workflow_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, Optional

from cadence._internal.workflow.context import Context
from cadence._internal.workflow.decisions_helper import DecisionsHelper
from cadence._internal.workflow.decision_events_iterator import DecisionEventsIterator
from cadence._internal.workflow.deterministic_event_loop import DeterministicEventLoop
from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
from cadence._internal.workflow.workflow_intance import WorkflowInstance
from cadence.api.v1.decision_pb2 import Decision
from cadence.api.v1.common_pb2 import Payload
from cadence.api.v1.decision_pb2 import (
CompleteWorkflowExecutionDecisionAttributes,
Decision,
)
from cadence.api.v1.history_pb2 import WorkflowExecutionStartedEventAttributes
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
from cadence.workflow import WorkflowDefinition, WorkflowInfo

Expand All @@ -23,12 +24,13 @@ class DecisionResult:

class WorkflowEngine:
def __init__(self, info: WorkflowInfo, workflow_definition: WorkflowDefinition):
self._workflow_instance = WorkflowInstance(workflow_definition)
self._decision_manager = DecisionManager()
self._decisions_helper = DecisionsHelper()
self._context = Context(info, self._decisions_helper, self._decision_manager)
self._loop = DeterministicEventLoop()
self._task: Optional[asyncio.Task] = None
self._workflow_instance = WorkflowInstance(
workflow_definition, info.data_converter
)
self._decision_manager = (
DecisionManager()
) # TODO: remove this stateful object and use the context instead
self._context = Context(info, self._decision_manager)

def process_decision(
self, decision_task: PollForDecisionTaskResponse
Expand All @@ -46,54 +48,58 @@ def process_decision(
DecisionResult containing the list of decisions
"""
try:
# Log decision task processing start with full context (matches Java ReplayDecisionTaskHandler)
logger.info(
"Processing decision task for workflow",
extra={
"workflow_type": self._context.info().workflow_type,
"workflow_id": self._context.info().workflow_id,
"run_id": self._context.info().workflow_run_id,
"started_event_id": decision_task.started_event_id,
"attempt": decision_task.attempt,
},
)

# Activate workflow context for the entire decision processing
with self._context._activate():
with self._context._activate() as ctx:
# Log decision task processing start with full context (matches Java ReplayDecisionTaskHandler)
logger.info(
"Processing decision task for workflow",
extra={
"workflow_type": ctx.info().workflow_type,
"workflow_id": ctx.info().workflow_id,
"run_id": ctx.info().workflow_run_id,
"started_event_id": decision_task.started_event_id,
"attempt": decision_task.attempt,
},
)

# Create DecisionEventsIterator for structured event processing
events_iterator = DecisionEventsIterator(
decision_task, self._context.info().workflow_events
decision_task, ctx.info().workflow_events
)

# Process decision events using iterator-driven approach
self._process_decision_events(events_iterator, decision_task)
self._process_decision_events(ctx, events_iterator, decision_task)

# Collect all pending decisions from state machines
decisions = self._decision_manager.collect_pending_decisions()

# Log decision task completion with metrics (matches Java ReplayDecisionTaskHandler)
logger.debug(
"Decision task completed",
extra={
"workflow_type": self._context.info().workflow_type,
"workflow_id": self._context.info().workflow_id,
"run_id": self._context.info().workflow_run_id,
"started_event_id": decision_task.started_event_id,
"decisions_count": len(decisions),
"replay_mode": self._context.is_replay_mode(),
},
)
# complete workflow if it is done
try:
if self._workflow_instance.is_done():
result = self._workflow_instance.get_result()
decisions.append(
Decision(
complete_workflow_execution_decision_attributes=CompleteWorkflowExecutionDecisionAttributes(
result=result
)
)
)
return DecisionResult(decisions=decisions)

return DecisionResult(decisions=decisions)
except Exception:
# TODO: handle CancellationError
# TODO: handle WorkflowError
# TODO: handle unknown error, fail decision task and try again instead of breaking the engine
raise

except Exception as e:
# Log decision task failure with full context (matches Java ReplayDecisionTaskHandler)
logger.error(
"Decision task processing failed",
extra={
"workflow_type": self._context.info().workflow_type,
"workflow_id": self._context.info().workflow_id,
"run_id": self._context.info().workflow_run_id,
"workflow_type": ctx.info().workflow_type,
"workflow_id": ctx.info().workflow_id,
"run_id": ctx.info().workflow_run_id,
"started_event_id": decision_task.started_event_id,
"attempt": decision_task.attempt,
"error_type": type(e).__name__,
Expand All @@ -104,10 +110,11 @@ def process_decision(
raise

def is_done(self) -> bool:
return self._task is not None and self._task.done()
return self._workflow_instance.is_done()

def _process_decision_events(
self,
ctx: Context,
events_iterator: DecisionEventsIterator,
decision_task: PollForDecisionTaskResponse,
) -> None:
Expand All @@ -131,7 +138,7 @@ def _process_decision_events(
logger.debug(
"Processing decision events batch",
extra={
"workflow_id": self._context.info().workflow_id,
"workflow_id": ctx.info().workflow_id,
"events_count": len(decision_events.get_events()),
"markers_count": len(decision_events.get_markers()),
"replay_mode": decision_events.is_replay(),
Expand All @@ -140,109 +147,55 @@ def _process_decision_events(
)

# Update context with replay information
self._context.set_replay_mode(decision_events.is_replay())
ctx.set_replay_mode(decision_events.is_replay())
if decision_events.replay_current_time_milliseconds:
self._context.set_replay_current_time_milliseconds(
ctx.set_replay_current_time_milliseconds(
decision_events.replay_current_time_milliseconds
)

# Phase 1: Process markers first for deterministic replay
for marker_event in decision_events.get_markers():
try:
logger.debug(
"Processing marker event",
extra={
"workflow_id": self._context.info().workflow_id,
"marker_name": getattr(
marker_event, "marker_name", "unknown"
),
"event_id": getattr(marker_event, "event_id", None),
"replay_mode": self._context.is_replay_mode(),
},
)
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
self._decision_manager.handle_history_event(marker_event)
except Exception as e:
# Warning for unexpected markers (matches Java ClockDecisionContext)
logger.warning(
"Unexpected marker event encountered",
extra={
"workflow_id": self._context.info().workflow_id,
"marker_name": getattr(
marker_event, "marker_name", "unknown"
),
"event_id": getattr(marker_event, "event_id", None),
"error_type": type(e).__name__,
},
exc_info=True,
)

# Phase 2: Process regular events to update workflow state
for event in decision_events.get_events():
try:
logger.debug(
"Processing history event",
extra={
"workflow_id": self._context.info().workflow_id,
"event_type": getattr(event, "event_type", "unknown"),
"event_id": getattr(event, "event_id", None),
"replay_mode": self._context.is_replay_mode(),
},
)
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
self._decision_manager.handle_history_event(event)
except Exception as e:
logger.warning(
"Error processing history event",
extra={
"workflow_id": self._context.info().workflow_id,
"event_type": getattr(event, "event_type", "unknown"),
"event_id": getattr(event, "event_id", None),
"error_type": type(e).__name__,
},
exc_info=True,
)

# Phase 3: Execute workflow logic
self._execute_workflow_once(decision_task)

def _execute_workflow_once(
self, decision_task: PollForDecisionTaskResponse
) -> None:
"""
Execute the workflow function to generate new decisions.
# Phase 1: Process markers first
for marker_event in decision_events.markers:
logger.debug(
"Processing marker event",
extra={
"workflow_id": ctx.info().workflow_id,
"marker_name": getattr(marker_event, "marker_name", "unknown"),
"event_id": getattr(marker_event, "event_id", None),
"replay_mode": ctx.is_replay_mode(),
},
)
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
self._decision_manager.handle_history_event(marker_event)

This blocks until the workflow schedules an activity or completes.
# Phase 2: Process regular input events
for event in decision_events.input:
logger.debug(
"Processing history event",
extra={
"workflow_id": ctx.info().workflow_id,
"event_type": getattr(event, "event_type", "unknown"),
"event_id": getattr(event, "event_id", None),
"replay_mode": ctx.is_replay_mode(),
},
)
# Process through state machines (DecisionsHelper now delegates to DecisionManager)
self._decision_manager.handle_history_event(event)

Args:
decision_task: The decision task containing workflow context
"""
try:
# Extract workflow input from history
if self._task is None:
workflow_input = self._extract_workflow_input(decision_task)
self._task = self._loop.create_task(
self._workflow_instance.run(workflow_input)
# Phase 3: Execute workflow logic
if not self._workflow_instance.is_started():
self._workflow_instance.start(
self._extract_workflow_input(decision_task)
)

self._loop.run_until_yield()
self._workflow_instance.run_once()

except Exception as e:
logger.error(
"Error executing workflow function",
extra={
"workflow_type": self._context.info().workflow_type,
"workflow_id": self._context.info().workflow_id,
"run_id": self._context.info().workflow_run_id,
"error_type": type(e).__name__,
},
exc_info=True,
)
raise
# Phase 4: update state machine with output events
for event in decision_events.output:
self._decision_manager.handle_history_event(event)

def _extract_workflow_input(
self, decision_task: PollForDecisionTaskResponse
) -> Any:
) -> Payload:
"""
Extract workflow input from the decision task history.

Expand All @@ -253,26 +206,15 @@ def _extract_workflow_input(
The workflow input data, or None if not found
"""
if not decision_task.history or not hasattr(decision_task.history, "events"):
logger.warning("No history events found in decision task")
return None
raise ValueError("No history events found in decision task")

# Look for WorkflowExecutionStarted event
for event in decision_task.history.events:
if hasattr(event, "workflow_execution_started_event_attributes"):
started_attrs = event.workflow_execution_started_event_attributes
started_attrs: WorkflowExecutionStartedEventAttributes = (
event.workflow_execution_started_event_attributes
)
if started_attrs and hasattr(started_attrs, "input"):
# Deserialize the input using the client's data converter
try:
# Use from_data method with a single type hint of None (no type conversion)
input_data_list = self._context.data_converter().from_data(
started_attrs.input, [None]
)
input_data = input_data_list[0] if input_data_list else None
logger.debug(f"Extracted workflow input: {input_data}")
return input_data
except Exception as e:
logger.warning(f"Failed to deserialize workflow input: {e}")
return None

logger.warning("No WorkflowExecutionStarted event found in history")
return None
return started_attrs.input

raise ValueError("No WorkflowExecutionStarted event found in history")
Loading
Loading