Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cadence/_internal/workflow/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import contextmanager
from datetime import timedelta
from math import ceil
from typing import Optional, Any, Unpack, Type, cast
from typing import Iterator, Optional, Any, Unpack, Type, cast

from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
from cadence._internal.workflow.decisions_helper import DecisionsHelper
Expand All @@ -15,13 +16,12 @@ class Context(WorkflowContext):
def __init__(
self,
info: WorkflowInfo,
decision_helper: DecisionsHelper,
decision_manager: DecisionManager,
):
self._info = info
self._replay_mode = True
self._replay_current_time_milliseconds: Optional[int] = None
self._decision_helper = decision_helper
self._decision_helper = DecisionsHelper()
self._decision_manager = decision_manager

def info(self) -> WorkflowInfo:
Expand Down Expand Up @@ -110,6 +110,12 @@ def get_replay_current_time_milliseconds(self) -> Optional[int]:
"""Get the current replay time in milliseconds."""
return self._replay_current_time_milliseconds

@contextmanager
def _activate(self) -> Iterator["Context"]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override so I could steal the activated context rather than exposed WorkflowContext. Maybe this is not necessary but I'm trying to avoid direct access to self._context inside WorkflowEngine.

token = WorkflowContext._var.set(self)
yield self
WorkflowContext._var.reset(token)


def _round_to_nearest_second(delta: timedelta) -> timedelta:
return timedelta(seconds=ceil(delta.total_seconds()))
3 changes: 0 additions & 3 deletions cadence/_internal/workflow/decision_events_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from cadence._internal.workflow.history_event_iterator import HistoryEventsIterator
from cadence.api.v1.history_pb2 import HistoryEvent
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse


@dataclass
Expand Down Expand Up @@ -44,10 +43,8 @@ class DecisionEventsIterator(Iterator[DecisionEvents]):

def __init__(
self,
decision_task: PollForDecisionTaskResponse,
events: List[HistoryEvent],
):
self._decision_task = decision_task
self._events: HistoryEventsIterator = HistoryEventsIterator(events)
self._next_decision_event_id: Optional[int] = None
self._replay_current_time_milliseconds: Optional[int] = None
Expand Down
269 changes: 107 additions & 162 deletions cadence/_internal/workflow/workflow_engine.py

Large diffs are not rendered by default.

37 changes: 32 additions & 5 deletions cadence/_internal/workflow/workflow_intance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
from asyncio import Task
from typing import Any, Optional
from cadence._internal.workflow.deterministic_event_loop import DeterministicEventLoop
from cadence.api.v1.common_pb2 import Payload
from cadence.data_converter import DataConverter
from cadence.workflow import WorkflowDefinition


class WorkflowInstance:
def __init__(self, workflow_definition: WorkflowDefinition):
def __init__(
self, workflow_definition: WorkflowDefinition, data_converter: DataConverter
):
self._definition = workflow_definition
self._instance = workflow_definition.cls().__init__()
self._data_converter = data_converter
self._instance = workflow_definition.cls() # construct a new workflow object
self._loop = DeterministicEventLoop()
self._task: Optional[Task] = None

async def run(self, *args):
run_method = self._definition.get_run_method(self._instance)
return run_method(*args)
def start(self, input: Payload):
if self._task is None:
run_method = self._definition.get_run_method(self._instance)
# TODO handle multiple inputs
workflow_input = self._data_converter.from_data(input, [Any])
self._task = self._loop.create_task(run_method(*workflow_input))

def run_once(self):
self._loop.run_until_yield()

def is_done(self) -> bool:
return self._task is not None and self._task.done()

# TODO: consider cache result to avoid multiple data conversions
def get_result(self) -> Payload:
if self._task is None:
raise RuntimeError("Workflow is not started yet")
result = self._task.result()
# TODO: handle result with multiple outputs
return self._data_converter.to_data([result])
3 changes: 1 addition & 2 deletions cadence/worker/_decision_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ async def _handle_task_implementation(
workflow_run_id=run_id,
workflow_task_list=self.task_list,
data_converter=self._client.data_converter,
workflow_events=workflow_events,
)

# Use thread-safe cache to get or create workflow engine
Expand All @@ -136,7 +135,7 @@ async def _handle_task_implementation(
self._workflow_engines[cache_key] = workflow_engine

decision_result = await asyncio.get_running_loop().run_in_executor(
self._executor, workflow_engine.process_decision, task
self._executor, workflow_engine.process_decision, workflow_events
)

# Clean up completed workflows from cache to prevent memory leaks
Expand Down
21 changes: 10 additions & 11 deletions cadence/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from datetime import timedelta
from typing import (
Callable,
List,
cast,
Optional,
Union,
Expand All @@ -15,10 +14,10 @@
Type,
Unpack,
Any,
Generic,
)
import inspect

from cadence.api.v1.history_pb2 import HistoryEvent
from cadence.data_converter import DataConverter

ResultType = TypeVar("ResultType")
Expand All @@ -44,6 +43,7 @@ async def execute_activity(


T = TypeVar("T", bound=Callable[..., Any])
C = TypeVar("C")


class WorkflowDefinitionOptions(TypedDict, total=False):
Expand All @@ -52,16 +52,16 @@ class WorkflowDefinitionOptions(TypedDict, total=False):
name: str


class WorkflowDefinition:
class WorkflowDefinition(Generic[C]):
"""
Definition of a workflow class with metadata.
Similar to ActivityDefinition but for workflow classes.
Provides type safety and metadata for workflow classes.
"""

def __init__(self, cls: Type, name: str, run_method_name: str):
self._cls = cls
def __init__(self, cls: Type[C], name: str, run_method_name: str):
self._cls: Type[C] = cls
self._name = name
self._run_method_name = run_method_name

Expand All @@ -71,7 +71,7 @@ def name(self) -> str:
return self._name

@property
def cls(self) -> Type:
def cls(self) -> Type[C]:
"""Get the workflow class."""
return self._cls

Expand Down Expand Up @@ -151,7 +151,7 @@ def decorator(f: T) -> T:
raise ValueError(f"Workflow run method '{f.__name__}' must be async")

# Attach metadata to the function
f._workflow_run = True # type: ignore
setattr(f, "_workflow_run", None)
return f

# Support both @workflow.run and @workflow.run()
Expand All @@ -163,14 +163,13 @@ def decorator(f: T) -> T:
return decorator(func)


@dataclass
@dataclass(frozen=True)
class WorkflowInfo:
workflow_type: str
workflow_domain: str
workflow_id: str
workflow_run_id: str
workflow_task_list: str
workflow_events: List[HistoryEvent]
data_converter: DataConverter


Expand All @@ -193,9 +192,9 @@ async def execute_activity(
) -> ResultType: ...

@contextmanager
def _activate(self) -> Iterator[None]:
def _activate(self) -> Iterator["WorkflowContext"]:
token = WorkflowContext._var.set(self)
yield None
yield self
WorkflowContext._var.reset(token)

@staticmethod
Expand Down
32 changes: 3 additions & 29 deletions tests/cadence/_internal/workflow/test_decision_events_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
Tests for Decision Events Iterator.
"""

import pytest
from typing import List
import pytest

from cadence.api.v1.history_pb2 import HistoryEvent, History
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
from cadence.api.v1.common_pb2 import WorkflowExecution

from cadence._internal.workflow.decision_events_iterator import (
DecisionEventsIterator,
)
from cadence.api.v1.history_pb2 import HistoryEvent


class TestDecisionEventsIterator:
Expand Down Expand Up @@ -95,8 +93,7 @@ class TestDecisionEventsIterator:
)
def test_successful_cases(self, name, event_types, expected):
events = create_mock_history_event(event_types)
decision_task = create_mock_decision_task(events)
iterator = DecisionEventsIterator(decision_task, events)
iterator = DecisionEventsIterator(events)

batches = [decision_events for decision_events in iterator]
assert len(expected) == len(batches)
Expand Down Expand Up @@ -166,26 +163,3 @@ def create_mock_history_event(event_types: List[str]) -> List[HistoryEvent]:
events.append(event)

return events


def create_mock_decision_task(
events: List[HistoryEvent], next_page_token: bytes = None
) -> PollForDecisionTaskResponse:
"""Create a mock decision task for testing."""
task = PollForDecisionTaskResponse()

# Mock history
history = History()
history.events.extend(events)
task.history.CopyFrom(history)

# Mock workflow execution
workflow_execution = WorkflowExecution()
workflow_execution.workflow_id = "test-workflow"
workflow_execution.run_id = "test-run"
task.workflow_execution.CopyFrom(workflow_execution)

if next_page_token:
task.next_page_token = next_page_token

return task
94 changes: 94 additions & 0 deletions tests/cadence/_internal/workflow/test_workflow_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3
from typing import List
import pytest
from cadence.api.v1.common_pb2 import Payload
from cadence.api.v1.history_pb2 import (
DecisionTaskCompletedEventAttributes,
DecisionTaskScheduledEventAttributes,
DecisionTaskStartedEventAttributes,
HistoryEvent,
WorkflowExecutionCompletedEventAttributes,
WorkflowExecutionStartedEventAttributes,
)
from cadence._internal.workflow.workflow_engine import WorkflowEngine
from cadence import workflow
from cadence.data_converter import DefaultDataConverter
from cadence.workflow import WorkflowInfo, WorkflowDefinition, WorkflowDefinitionOptions


class TestWorkflow:
@workflow.run
async def echo(self, input_data):
return f"echo: {input_data}"


class TestWorkflowEngine:
"""Unit tests for WorkflowEngine."""

@pytest.fixture
def echo_workflow_definition(self) -> WorkflowDefinition:
"""Create a mock workflow definition."""
workflow_opts = WorkflowDefinitionOptions(name="test_workflow")
return WorkflowDefinition.wrap(TestWorkflow, workflow_opts)

@pytest.fixture
def simple_workflow_events(self) -> List[HistoryEvent]:
return [
HistoryEvent(
event_id=1,
workflow_execution_started_event_attributes=WorkflowExecutionStartedEventAttributes(
input=Payload(data=b'"test-input"')
),
),
HistoryEvent(
event_id=2,
decision_task_scheduled_event_attributes=DecisionTaskScheduledEventAttributes(),
),
HistoryEvent(
event_id=3,
decision_task_started_event_attributes=DecisionTaskStartedEventAttributes(
scheduled_event_id=2
),
),
HistoryEvent(
event_id=4,
decision_task_completed_event_attributes=DecisionTaskCompletedEventAttributes(
scheduled_event_id=2,
),
),
HistoryEvent(
event_id=5,
workflow_execution_completed_event_attributes=WorkflowExecutionCompletedEventAttributes(
result=Payload(data=b'"echo: test-input"')
),
),
]

def test_process_simple_workflow(
self,
echo_workflow_definition: WorkflowDefinition,
simple_workflow_events: List[HistoryEvent],
):
workflow_engine = create_workflow_engine(echo_workflow_definition)
decision_result = workflow_engine.process_decision(simple_workflow_events[:3])
assert len(decision_result.decisions) == 1
assert decision_result.decisions[
0
].complete_workflow_execution_decision_attributes.result == Payload(
data=b'"echo: test-input"'
)


def create_workflow_engine(workflow_definition: WorkflowDefinition) -> WorkflowEngine:
"""Create workflow engine."""
return WorkflowEngine(
info=WorkflowInfo(
workflow_type="test_workflow",
workflow_domain="test-domain",
workflow_id="test-workflow-id",
workflow_run_id="test-run-id",
workflow_task_list="test-task-list",
data_converter=DefaultDataConverter(),
),
workflow_definition=workflow_definition,
)
Loading