diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 111ef61e3b..f8a85a261c 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -437,6 +437,13 @@ async def _run_core( yield event elif checkpoint_id is not None: + # Restore the prior workflow state from the checkpoint. Shared + # state (e.g. accumulated conversation history maintained by the + # workflow's executors) survives across turns because Workflow.run + # no longer wipes state per call. Callers who want to deliver a + # new user message after restore should make a second + # `workflow.run(message=...)` call - they are NOT mutually + # exclusive on the same instance, but each must be its own call. if streaming: async for event in self.workflow.run( stream=True, diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index d58d9b99a9..51a3312e2b 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -278,7 +278,12 @@ async def restore_from_checkpoint( "Please rebuild the original workflow before resuming." ) - # Restore state + # Restore state. Clear first so import_state (which merges) does + # not leak stale keys from a prior run on this Workflow instance. + # This matters more now that Workflow.run() no longer wipes state + # per call - the only reset point for shared state on a reused + # instance is at restore time. + self._state.clear() self._state.import_state(checkpoint.state) # Restore executor states using the restored state await self._restore_executor_states() diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index c452f62bc2..adea4bda20 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -299,7 +299,7 @@ def get_executors_list(self) -> list[Executor]: async def _run_workflow_with_tracing( self, initial_executor_fn: Callable[[], Awaitable[None]] | None = None, - reset_context: bool = True, + is_continuation: bool = False, streaming: bool = False, function_invocation_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, client_kwargs: Mapping[str, Mapping[str, Any]] | Mapping[str, Any] | None = None, @@ -310,13 +310,19 @@ async def _run_workflow_with_tracing( of external callers to maintain context across different workflow runs. Args: - initial_executor_fn: Optional function to execute initial executor - reset_context: Whether to reset the context for a new run - streaming: Whether to enable streaming mode for agents + initial_executor_fn: Optional function to execute initial executor. + is_continuation: True when this run is a continuation of prior + work (a checkpoint restore or a responses-only replay) rather + than a fresh new turn delivered via the start executor with + ``message=...``. Continuations preserve per-run accounting + (iteration counter and run kwargs) from the prior turn; + fresh-message runs reset them. Shared workflow state is + preserved in both cases. + streaming: Whether to enable streaming mode for agents. function_invocation_kwargs: Optional kwargs to store in State for function - invocations in subagents + invocations in subagents. client_kwargs: Optional kwargs to store in State for chat client - invocations in subagents + invocations in subagents. Yields: WorkflowEvent: The events generated during the workflow execution. @@ -345,16 +351,26 @@ async def _run_workflow_with_tracing( in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS) yield in_progress # noqa: RUF070 - # Reset context for a new run if supported - if reset_context: + # Per-run reset for fresh-message runs only. We deliberately + # do NOT clear shared workflow state (`_state.clear()`) or the + # runner context's in-flight messages (`reset_for_new_run()`) + # here - state and pending work persist across `run()` calls + # so that a `WorkflowAgent` can deliver multi-turn input on + # the same instance and have prior turns' context survive. + # Iteration counting and per-run kwargs ARE per-run though, + # so they're reset here. + if not is_continuation: self._runner.reset_iteration_count() - self._runner.context.reset_for_new_run() - self._state.clear() # Store run kwargs in State so executors can access them. - # Only overwrite when new kwargs are explicitly provided or state was - # just cleared (fresh run). On continuation (reset_context=False) with - # no new kwargs, preserve the kwargs from the original run. + # Per-run kwargs semantics: + # - On a fresh message run, prior kwargs go away (set to {} + # by default, or to the new kwargs if provided). This + # prevents stale kwargs from a prior turn leaking into the + # current turn. + # - On a continuation (checkpoint restore or responses), the + # prior run's kwargs are preserved unless the caller + # explicitly provides new kwargs. if function_invocation_kwargs is not None or client_kwargs is not None: combined_kwargs: dict[str, Any] = {} if function_invocation_kwargs is not None: @@ -366,11 +382,12 @@ async def _run_workflow_with_tracing( client_kwargs, "client_kwargs" ) self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs) - elif reset_context: + elif not is_continuation: self._state.set(WORKFLOW_RUN_KWARGS_KEY, {}) self._state.commit() # Commit immediately so kwargs are available - # Set streaming mode after reset + # Set streaming mode (always set explicitly per run since + # reset_for_new_run() no longer runs to clear it). self._runner_context.set_streaming(streaming) # Execute initial setup if provided @@ -585,13 +602,33 @@ async def _run_core( if checkpoint_storage is not None: self._runner.context.set_runtime_checkpoint_storage(checkpoint_storage) - initial_executor_fn, reset_context = self._resolve_execution_mode( + # Async validation: a fresh-message run is only allowed when the + # runner context has fully drained from any prior run. If it still + # has in-flight executor messages, the prior run didn't complete - + # the caller must either resume from a checkpoint or wait for the + # prior run to drain. (Pending request_info events are intentionally + # NOT blocked here: a follow-up run with message=... is the normal + # way to deliver a response to those pending requests, e.g. via + # WorkflowAgent._process_pending_requests.) + # NOTE: _validate_run_params already enforces that ``message`` is + # mutually exclusive with both ``checkpoint_id`` and ``responses``, + # so we don't need to re-check those here. + if message is not None and await self._runner.context.has_messages(): + raise RuntimeError( + "Cannot start a new run with 'message' while in-flight executor " + "messages remain from a prior run. Resume from a checkpoint " + "(checkpoint_id=...) or wait for the prior run to complete. " + "Workflows that need to recover from a mid-run failure must use " + "checkpointing; there is no in-process recovery path." + ) + + initial_executor_fn = self._resolve_execution_mode( message, responses, checkpoint_id, checkpoint_storage ) async for event in self._run_workflow_with_tracing( initial_executor_fn=initial_executor_fn, - reset_context=reset_context, + is_continuation=(message is None), streaming=streaming, function_invocation_kwargs=function_invocation_kwargs, client_kwargs=client_kwargs, @@ -674,12 +711,8 @@ def _resolve_execution_mode( responses: Mapping[str, Any] | None, checkpoint_id: str | None, checkpoint_storage: CheckpointStorage | None, - ) -> tuple[Callable[[], Awaitable[None]], bool]: - """Determine the initial executor function and reset_context flag based on parameters. - - Returns: - A tuple of (initial_executor_fn, reset_context). - """ + ) -> Callable[[], Awaitable[None]]: + """Determine the initial executor function based on parameters.""" if responses is not None: if checkpoint_id is not None: # Combined: restore checkpoint then send responses @@ -689,13 +722,11 @@ def _resolve_execution_mode( else: # Send responses only (requires pending requests in workflow state) initial_executor_fn = functools.partial(self._send_responses_internal, responses) - return initial_executor_fn, False + return initial_executor_fn # Regular run or checkpoint restoration - initial_executor_fn = functools.partial( + return functools.partial( self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage ) - reset_context = message is not None and checkpoint_id is None - return initial_executor_fn, reset_context async def _restore_and_send_responses( self, diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index f338ce94f6..30e81d8fe6 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -488,8 +488,13 @@ async def handle_message( await ctx.yield_output(existing_messages.copy()) # type: ignore -async def test_workflow_multiple_runs_no_state_collision(): - """Test that running the same workflow instance multiple times doesn't have state collision.""" +async def test_workflow_multiple_runs_preserve_state(): + """Test that running the same workflow instance multiple times preserves shared state. + + State preservation is the new default - calling ``Workflow.run`` repeatedly + on the same instance behaves like a chat agent maintaining memory across + turns. Callers that want fresh state should rebuild the Workflow. + """ with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) @@ -503,29 +508,45 @@ async def test_workflow_multiple_runs_no_state_collision(): .build() ) - # Run 1: Should only see messages from run 1 + # Run 1: Single record from run 1 result1 = await workflow.run(StateTrackingMessage(data="message1", run_id="run1")) assert result1.get_final_state() == WorkflowRunState.IDLE outputs1 = result1.get_outputs() assert outputs1[0] == ["run1:message1"] - # Run 2: Should only see messages from run 2, not run 1 + # Run 2: State from run 1 persists; run 2's record appends. result2 = await workflow.run(StateTrackingMessage(data="message2", run_id="run2")) assert result2.get_final_state() == WorkflowRunState.IDLE outputs2 = result2.get_outputs() - assert outputs2[0] == ["run2:message2"] # Should NOT contain run1 data + assert outputs2[0] == ["run1:message1", "run2:message2"] - # Run 3: Should only see messages from run 3 + # Run 3: Same - all three accumulate. result3 = await workflow.run(StateTrackingMessage(data="message3", run_id="run3")) assert result3.get_final_state() == WorkflowRunState.IDLE outputs3 = result3.get_outputs() - assert outputs3[0] == ["run3:message3"] # Should NOT contain run1 or run2 data + assert outputs3[0] == ["run1:message1", "run2:message2", "run3:message3"] + + +async def test_workflow_multiple_runs_no_state_collision_after_rebuild(): + """Rebuilding the Workflow gives a fresh shared-state slate.""" + with tempfile.TemporaryDirectory() as temp_dir: + storage = FileCheckpointStorage(temp_dir) + + def _build(): + executor = StateTrackingExecutor(id="state_executor") + return ( + WorkflowBuilder(start_executor=executor, checkpoint_storage=storage) + .add_edge(executor, executor) + .build() + ) - # Verify that each run only processed its own message - # This confirms that the checkpointable context properly resets between runs - assert outputs1[0] != outputs2[0] - assert outputs2[0] != outputs3[0] - assert outputs1[0] != outputs3[0] + wf1 = _build() + result1 = await wf1.run(StateTrackingMessage(data="message1", run_id="run1")) + assert result1.get_outputs()[0] == ["run1:message1"] + + wf2 = _build() + result2 = await wf2.run(StateTrackingMessage(data="message2", run_id="run2")) + assert result2.get_outputs()[0] == ["run2:message2"] async def test_workflow_checkpoint_runtime_only_configuration( @@ -932,6 +953,31 @@ async def test_agent_streaming_vs_non_streaming() -> None: assert accumulated_text == "Hello World", f"Expected 'Hello World', got '{accumulated_text}'" +async def test_workflow_run_inflight_messages_guard(simple_executor: Executor) -> None: + """``run(message=...)`` must reject in-flight executor messages from a prior run. + + Workflows preserve state and pending messages across :meth:`Workflow.run` + calls. If a prior run aborted before the runner drained those pending + messages (e.g. it raised :class:`WorkflowConvergenceException`), the next + fresh-message call should fail loudly instead of silently mixing the + leftover messages with the new turn. The supported recovery path is to + resume from a checkpoint; there is no in-process recovery hatch. + """ + workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() + test_message = WorkflowMessage(data="test", source_id="test", target_id=None) + + # Simulate an aborted prior run by leaving a message in the runner context. + workflow._runner.context._messages["test"] = [test_message] + assert await workflow._runner.context.has_messages() + + with pytest.raises(RuntimeError, match="in-flight executor messages"): + await workflow.run(test_message) + + with pytest.raises(RuntimeError, match="in-flight executor messages"): + async for _ in workflow.run(test_message, stream=True): + pass + + async def test_workflow_run_parameter_validation(simple_executor: Executor) -> None: """Test that stream properly validate parameter combinations.""" workflow = WorkflowBuilder(start_executor=simple_executor).add_edge(simple_executor, simple_executor).build() @@ -942,13 +988,15 @@ async def test_workflow_run_parameter_validation(simple_executor: Executor) -> N result = await workflow.run(test_message) assert result.get_final_state() == WorkflowRunState.IDLE - # Invalid: both message and checkpoint_id + # Invalid: message + checkpoint_id (mutually exclusive). Multi-turn + # state preservation is handled by Workflow.run preserving state across + # calls, so the host pattern is two separate calls (restore-then-run), + # not a single combined call. with pytest.raises(ValueError, match="Cannot provide both 'message' and 'checkpoint_id'"): - await workflow.run(test_message, checkpoint_id="fake_id") + await workflow.run(test_message, checkpoint_id="some-checkpoint") - # Invalid: both message and checkpoint_id (streaming) with pytest.raises(ValueError, match="Cannot provide both 'message' and 'checkpoint_id'"): - async for _ in workflow.run(test_message, checkpoint_id="fake_id", stream=True): + async for _ in workflow.run(test_message, checkpoint_id="some-checkpoint", stream=True): pass # Invalid: none of message or checkpoint_id diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index e7af9fde9a..c05af09ddc 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -32,10 +32,12 @@ from collections.abc import Mapping from dataclasses import dataclass from decimal import Decimal as _Decimal +from enum import Enum from typing import Any, Literal, cast from agent_framework import ( Executor, + Message, WorkflowContext, ) from agent_framework._workflows._state import State @@ -120,7 +122,20 @@ def _make_powerfx_safe(value: Any) -> Any: Returns: A PowerFx-safe representation of the value """ - if value is None or isinstance(value, _POWERFX_SAFE_TYPES): + if value is None: + return value + + # Enum coercion must run BEFORE the primitive type check: many MAF + # enums (e.g. MessageRole) are ``str``-subclass enums, so they pass + # ``isinstance(v, str)`` but pythonnet refuses to convert them to + # ``System.String`` and raises ``'MessageRole' value cannot be + # converted to System.'`` for every PowerFx primitive type. Reduce + # to the underlying value (or its string form) so PowerFx sees a + # plain ``str``/``int``. + if isinstance(value, Enum): + return _make_powerfx_safe(value.value) + + if isinstance(value, _POWERFX_SAFE_TYPES): return value if isinstance(value, dict): @@ -197,6 +212,16 @@ def get_state_data(self) -> DeclarativeStateData: result = self._state.get(DECLARATIVE_STATE_KEY) return cast(DeclarativeStateData, result) + def is_initialized(self) -> bool: + """Return True when declarative state has been initialized. + + Useful for distinguishing a fresh start from a continuation: when + Workflow state preserves data across run() calls (multi-turn + scenarios), the start executor needs to avoid calling initialize() + and clobbering the prior turn's Conversation/Local/System data. + """ + return self._state.get(DECLARATIVE_STATE_KEY) is not None + def set_state_data(self, data: DeclarativeStateData) -> None: """Set the full state data dict in state.""" self._state.set(DECLARATIVE_STATE_KEY, data) @@ -873,6 +898,20 @@ async def _ensure_state_initialized( Follows .NET's DefaultTransform pattern - accepts any input type: - dict/Mapping: Used directly as workflow.inputs - str: Converted to {"input": value} + - list[Message]: Treated as the agent-facing message contract + (e.g. from WorkflowAgent / as_agent()). The prior conversation + history is stored in ``Conversation.messages``/ + ``Conversation.history`` and mirrored to + ``System.conversations.{id}.messages`` so workflows that + reference ``=Conversation.messages`` (e.g. InvokeAzureAgent) see + assistant turns and other earlier messages, including non-text + content. At the start of a turn this history excludes the current + user message; that message's text is instead used as the string + input (``Inputs.input``) and surfaced via ``System.LastMessage*`` + for backward compatibility with simple text-only workflows. Agent + executors are responsible for appending the current user message + to ``Conversation.messages`` immediately before invoking the + inner agent. - DeclarativeMessage: Internal message, no initialization needed - Any other type: Converted via str() to {"input": str(value)} @@ -888,6 +927,104 @@ async def _ensure_state_initialized( if isinstance(trigger, dict): # Structured inputs - use directly state.initialize(trigger) # type: ignore + elif isinstance(trigger, list) and all(isinstance(m, Message) for m in trigger): # pyright: ignore[reportUnknownVariableType] + # list[Message] (e.g. from WorkflowAgent / as_agent()). + messages_list = cast(list[Message], trigger) + + # Detect continuation: if the workflow's shared state already + # carries declarative data from a prior turn (because the host + # restored a checkpoint and dispatched this run with + # reset_context=False), we MUST NOT call state.initialize() - + # that would wipe Conversation.messages, Local.*, System.* etc. + # Instead, treat the trigger as the new turn's user input only: + # update Inputs.input, append the new user message to existing + # Conversation history, and refresh System.LastMessage*. + # + # Continuation = declarative state already exists in the workflow's + # shared state (either left over in-memory from a prior turn on + # the same instance, or restored from a checkpoint just before + # this run). In that case state.initialize() would wipe Local.*, + # System.*, Conversation.* etc., destroying the cross-turn + # context we're trying to preserve. + is_continuation = state.is_initialized() + + # Locate the trailing user message in the trigger. + last_user_index = -1 + for idx in range(len(messages_list) - 1, -1, -1): + if str(messages_list[idx].role).lower() == "user": + last_user_index = idx + break + + if last_user_index >= 0: + last_user_msg = messages_list[last_user_index] + last_user_text = last_user_msg.text or "" + last_user_id = getattr(last_user_msg, "message_id", "") or "" + history_messages = ( + messages_list[:last_user_index] + messages_list[last_user_index + 1:] + ) + else: + history_messages = list(messages_list) + tail = messages_list[-1] if messages_list else None + last_user_text = (tail.text or "") if tail is not None else "" + last_user_id = ( + getattr(tail, "message_id", "") or "" if tail is not None else "" + ) + + if is_continuation: + # Continuation turn: keep prior Conversation.messages intact. + # Refresh inputs and surface the new user message via the + # System.LastMessage* fields. We deliberately do NOT append + # the new user message to Conversation.messages here: agent + # executors append the live user input themselves before + # invoking the inner agent (matching the first-turn + # contract where Conversation.messages holds prior turns + # only). + # + # Note: ``state.set("Inputs.input", ...)`` would route to + # the Custom namespace (Inputs is not a recognized top-level + # writable namespace - see DeclarativeWorkflowState.set). + # PowerFx expressions like ``=Workflow.Inputs.input`` / + # ``=inputs.input`` read state_data["Inputs"] directly, so + # we update that dict in place via get_state_data / + # set_state_data. + state_data = state.get_state_data() + inputs_dict = state_data.get("Inputs") + if not isinstance(inputs_dict, dict): + inputs_dict = {} + state_data["Inputs"] = inputs_dict + inputs_dict["input"] = last_user_text + state.set_state_data(state_data) + # Trailing non-user messages (e.g. tool results) sandwiched + # before the new user message in the trigger are still + # appended so later actions see them. + for msg in history_messages: + state.append("Conversation.messages", msg) + state.append("Conversation.history", msg) + conversation_id = state.get("System.ConversationId") + if conversation_id: + conv_path = f"System.conversations.{conversation_id}.messages" + for msg in history_messages: + state.append(conv_path, msg) + state.set("System.LastMessage", {"Text": last_user_text, "Id": last_user_id}) + state.set("System.LastMessageText", last_user_text) + state.set("System.LastMessageId", last_user_id) + else: + # First turn: full initialization. + state.initialize({"input": last_user_text}) + + for msg in history_messages: + state.append("Conversation.messages", msg) + state.append("Conversation.history", msg) + + conversation_id = state.get("System.ConversationId") + if conversation_id: + conv_path = f"System.conversations.{conversation_id}.messages" + for msg in history_messages: + state.append(conv_path, msg) + + state.set("System.LastMessage", {"Text": last_user_text, "Id": last_user_id}) + state.set("System.LastMessageText", last_user_text) + state.set("System.LastMessageId", last_user_id) elif isinstance(trigger, str): # String input - wrap in dict and populate System.LastMessage.Text # so YAML expressions like =System.LastMessage.Text see the user input @@ -895,10 +1032,11 @@ async def _ensure_state_initialized( state.set("System.LastMessage", {"Text": trigger, "Id": ""}) state.set("System.LastMessageText", trigger) elif not isinstance( - trigger, (ActionTrigger, ActionComplete, ConditionResult, LoopIterationResult, LoopControl) + trigger, + (ActionTrigger, ActionComplete, ConditionResult, LoopIterationResult, LoopControl), # pyright: ignore[reportUnknownArgumentType] ): # Any other type - convert to string like .NET's DefaultTransform - input_str = str(trigger) + input_str = str(cast(Any, trigger)) state.initialize({"input": input_str}) state.set("System.LastMessage", {"Text": input_str, "Id": ""}) state.set("System.LastMessageText", input_str) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py index 0aa660b3ea..f5baf80a9d 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_control_flow.py @@ -17,6 +17,7 @@ from typing import Any, cast from agent_framework import ( + Message, WorkflowContext, handler, ) @@ -492,7 +493,13 @@ class JoinExecutor(DeclarativeActionExecutor): @handler async def handle_action( self, - trigger: dict[str, Any] | str | ActionTrigger | ActionComplete | ConditionResult | LoopIterationResult, + trigger: dict[str, Any] + | str + | list[Message] + | ActionTrigger + | ActionComplete + | ConditionResult + | LoopIterationResult, ctx: WorkflowContext[ActionComplete], ) -> None: """Simply pass through to continue the workflow.""" diff --git a/python/packages/declarative/tests/test_powerfx_safe.py b/python/packages/declarative/tests/test_powerfx_safe.py new file mode 100644 index 0000000000..fccbd72b28 --- /dev/null +++ b/python/packages/declarative/tests/test_powerfx_safe.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Regression tests for ``_make_powerfx_safe``. + +PowerFx (via pythonnet) only accepts plain primitives, dicts, and lists. +``Enum`` instances - especially ``str``- and ``int``-subclass enums like +MAF's ``MessageRole`` - silently pass ``isinstance(v, str)`` / +``isinstance(v, int)`` checks but blow up later inside pythonnet with +``'' value cannot be converted to System.``. These tests +pin down the Enum coercion branch so we don't regress that interop fix. +""" + +from enum import Enum, IntEnum + +from agent_framework_declarative._workflows._declarative_base import _make_powerfx_safe + + +class _StrRole(str, Enum): + USER = "user" + SYSTEM = "system" + + +class _IntCode(IntEnum): + ONE = 1 + TWO = 2 + + +class _PlainEnum(Enum): + X = "x" + Y = 42 + + +def test_str_subclass_enum_reduces_to_str(): + assert _make_powerfx_safe(_StrRole.USER) == "user" + assert type(_make_powerfx_safe(_StrRole.USER)) is str + + +def test_int_subclass_enum_reduces_to_int(): + assert _make_powerfx_safe(_IntCode.ONE) == 1 + assert type(_make_powerfx_safe(_IntCode.ONE)) is int + + +def test_plain_enum_reduces_to_underlying_value(): + assert _make_powerfx_safe(_PlainEnum.X) == "x" + assert _make_powerfx_safe(_PlainEnum.Y) == 42 + + +def test_enum_inside_dict_is_coerced(): + safe = _make_powerfx_safe({"role": _StrRole.USER, "code": _IntCode.TWO}) + assert safe == {"role": "user", "code": 2} + assert type(safe["role"]) is str + assert type(safe["code"]) is int + + +def test_enum_inside_list_is_coerced(): + safe = _make_powerfx_safe([_StrRole.USER, _IntCode.ONE]) + assert safe == ["user", 1] + assert type(safe[0]) is str + assert type(safe[1]) is int diff --git a/python/packages/declarative/tests/test_workflow_factory.py b/python/packages/declarative/tests/test_workflow_factory.py index f08f5993e5..809747a037 100644 --- a/python/packages/declarative/tests/test_workflow_factory.py +++ b/python/packages/declarative/tests/test_workflow_factory.py @@ -228,6 +228,98 @@ async def test_entry_join_executor_initializes_workflow_inputs_string(self): outputs = result.get_outputs() assert any("hello-world" in str(o) for o in outputs), f"Expected 'hello-world' in outputs but got: {outputs}" + async def test_as_agent_round_trip_with_last_message_text(self): + """Regression test: a declarative workflow built via WorkflowFactory must be + consumable as an AIAgent via Workflow.as_agent(). + + Specifically, the declarative start executor must accept list[Message] + (the input passed by WorkflowAgent) and populate System.LastMessageText + so =System.LastMessageText is resolvable in the YAML. + """ + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: as-agent-roundtrip-test +actions: + - kind: SetVariable + variable: Local.echo + value: =System.LastMessageText + - kind: SendActivity + activity: + text: =Local.echo +""") + + agent = workflow.as_agent(name="echo-agent") + response = await agent.run("Hello there") + + assert "Hello there" in response.text, ( + f"Expected 'Hello there' in agent response text but got: {response.text!r}" + ) + + async def test_as_agent_continuation_preserves_prior_state(self): + """Regression test for the ``is_continuation`` branch in + ``DeclarativeWorkflowExecutor._ensure_state_initialized``. + + Verifies, end-to-end via ``Workflow.as_agent()``: + * Turn 1 initializes the declarative state via ``state.initialize``. + * Turn 2 takes the *continuation* branch (skips ``state.initialize``), + so any non-Inputs/non-System state stamped on turn 1 survives. + * Turn 2 still refreshes ``Inputs.input`` and + ``System.LastMessage*`` to the new user message. + + Without state preservation, ``Workflow.run`` would clear shared state + on entry and ``state.initialize`` would re-run on every turn, + wiping the marker we stamped between calls. + """ + from agent_framework_declarative._workflows._declarative_base import DECLARATIVE_STATE_KEY + + factory = WorkflowFactory() + workflow = factory.create_workflow_from_yaml(""" +name: as-agent-continuation-test +actions: + - kind: SendActivity + activity: + text: =System.LastMessageText +""") + + agent = workflow.as_agent(name="continuation-agent") + + first = await agent.run("turn-1-msg") + assert first.text == "turn-1-msg", ( + f"Expected turn-1 echo 'turn-1-msg', got: {first.text!r}" + ) + + # Stamp a marker into the declarative state between turns. The + # continuation branch must preserve it; a state-clearing run would + # wipe ``DECLARATIVE_STATE_KEY`` and force re-initialization. + state_data = workflow._state.get(DECLARATIVE_STATE_KEY) + assert isinstance(state_data, dict), ( + "Expected declarative state to be initialized after turn 1" + ) + state_data["Local"] = {"persisted_marker": "kept-from-turn-1"} + workflow._state.set(DECLARATIVE_STATE_KEY, state_data) + workflow._state.commit() + + second = await agent.run("turn-2-msg") + assert second.text == "turn-2-msg", ( + f"Expected System.LastMessageText to refresh to 'turn-2-msg', got: {second.text!r}" + ) + + # The continuation branch in ``_ensure_state_initialized`` must: + # 1. preserve the cross-turn marker we stamped above + # 2. refresh Inputs.input and System.LastMessage* to the new turn + post_state = workflow._state.get(DECLARATIVE_STATE_KEY) + assert isinstance(post_state, dict), "declarative state vanished between turns" + local = post_state.get("Local", {}) + assert local.get("persisted_marker") == "kept-from-turn-1", ( + f"Cross-turn marker was wiped (state was reset). post_state Local={local!r}" + ) + assert post_state.get("Inputs", {}).get("input") == "turn-2-msg", ( + f"Inputs.input not refreshed on turn 2: {post_state.get('Inputs')!r}" + ) + assert post_state.get("System", {}).get("LastMessageText") == "turn-2-msg", ( + f"System.LastMessageText not refreshed on turn 2: {post_state.get('System')!r}" + ) + class TestWorkflowFactoryAgentRegistration: """Tests for agent registration.""" diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index be04a1f397..7da9e8413a 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -272,50 +272,86 @@ async def _handle_inner_workflow( if not isinstance(self._agent, WorkflowAgent): raise RuntimeError("Agent is not a workflow agent.") - # Restore from the latest checkpoint if available, otherwise start with an empty history + # Determine the latest checkpoint (if any) so we can resume the + # workflow's prior state for this turn. The directory is keyed by + # the inbound context id (conversation_id when set, otherwise + # previous_response_id). Multi-turn declarative workflows need the + # workflow's internal state (e.g. Conversation.messages, + # intermediate Local.* variables) to survive across user turns; + # the only place that state lives is the workflow checkpoint, so + # on every turn we restore the latest checkpoint and feed the new + # input back into the start executor as a continuation rather than + # a fresh run. + latest_checkpoint_id: str | None = None + restore_storage: FileCheckpointStorage | None = None if context_id is not None: - checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id)) - latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name) + restore_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id)) + latest_checkpoint = await restore_storage.get_latest(workflow_name=self._agent.workflow.name) if latest_checkpoint is not None: - if not is_streaming_request: - _ = await self._agent.run( - stream=False, - checkpoint_id=latest_checkpoint.checkpoint_id, - checkpoint_storage=checkpoint_storage, - ) - else: - # Consume the streaming or the invocation will result in a no-op - async for _ in self._agent.run( - stream=True, - checkpoint_id=latest_checkpoint.checkpoint_id, - checkpoint_storage=checkpoint_storage, - ): - pass + latest_checkpoint_id = latest_checkpoint.checkpoint_id + + # Storage that will receive checkpoints written during this turn. + # When the caller chains with previous_response_id, the next turn + # will reference the current response_id as its previous_response_id, + # so new checkpoints must land under the current response_id (or the + # conversation_id when set). When conversation_id is set, this + # matches restore_storage; when only previous_response_id was + # supplied, restore_storage points at the *prior* response's + # directory and write_storage points at the *current* response's. + write_context_id = context.conversation_id or context.response_id + write_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, write_context_id)) + + # Multi-turn pattern: when we have a prior checkpoint, restore it + # first (drive the workflow back to idle with prior state intact), + # then make a separate call that delivers the new user input. This + # depends on Workflow.run preserving shared state across calls. The + # restore-only call may yield events from any pending in-flight + # work in the checkpoint; we consume those internally here so they + # don't surface to the response stream as duplicates. + # + # If the restored checkpoint had pending request_info events, the + # restore-only call replays them through + # ``WorkflowAgent._convert_workflow_event_to_agent_response_updates`` + # and populates ``self._agent.pending_requests``. That is the correct + # state: those requests are genuinely outstanding, and the next + # ``run(input_messages, ...)`` call may contain ``function_call_output`` + # items (carried as FunctionResult/FunctionApprovalResponse content) + # that fulfill them via :meth:`WorkflowAgent._process_pending_requests`. + if latest_checkpoint_id is not None: + if is_streaming_request: + async for _ in self._agent.run( + stream=True, + checkpoint_id=latest_checkpoint_id, + checkpoint_storage=restore_storage, + ): + pass + else: + await self._agent.run( + stream=False, + checkpoint_id=latest_checkpoint_id, + checkpoint_storage=restore_storage, + ) # Now run the agent with the latest input response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model) - # Create a new checkpoint storage for this response based on the following rules: - # - If no previous response ID or conversation ID is provided, - # create a new checkpoint storage for this response - # - If a previous response ID is provided, create a new checkpoint storage for this response - # - If a conversation ID is provided, reuse the existing checkpoint storage for the conversation - context_id = context.conversation_id or context.response_id - checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id)) - yield response_event_stream.emit_created() yield response_event_stream.emit_in_progress() if not is_streaming_request: - # Run the agent in non-streaming mode - response = await self._agent.run(input_messages, stream=False, checkpoint_storage=checkpoint_storage) + # Run the agent in non-streaming mode with the new user input. + response = await self._agent.run( + input_messages, + stream=False, + checkpoint_storage=write_storage, + ) for message in response.messages: for content in message.contents: async for item in _to_outputs(response_event_stream, content): yield item - await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name) + await self._delete_not_latest_checkpoints(write_storage, self._agent.workflow.name) yield response_event_stream.emit_completed() return @@ -323,8 +359,12 @@ async def _handle_inner_workflow( # lazily created on matching content, closed when a different type arrives. tracker = _OutputItemTracker(response_event_stream) - # Run the workflow agent in streaming mode - async for update in self._agent.run(input_messages, stream=True, checkpoint_storage=checkpoint_storage): + # Run the workflow agent in streaming mode with the new user input. + async for update in self._agent.run( + input_messages, + stream=True, + checkpoint_storage=write_storage, + ): for content in update.contents: for event in tracker.handle(content): yield event @@ -337,7 +377,7 @@ async def _handle_inner_workflow( for event in tracker.close(): yield event - await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name) + await self._delete_not_latest_checkpoints(write_storage, self._agent.workflow.name) yield response_event_stream.emit_completed() @staticmethod diff --git a/python/packages/foundry_hosting/pyproject.toml b/python/packages/foundry_hosting/pyproject.toml index f4031f9df3..09c4b6a819 100644 --- a/python/packages/foundry_hosting/pyproject.toml +++ b/python/packages/foundry_hosting/pyproject.toml @@ -24,9 +24,9 @@ classifiers = [ ] dependencies = [ "agent-framework-core>=1.2.1,<2", - "azure-ai-agentserver-core==2.0.0b3", - "azure-ai-agentserver-responses==1.0.0b5", - "azure-ai-agentserver-invocations==1.0.0b3", + "azure-ai-agentserver-core>=2.0.0b3,<3", + "azure-ai-agentserver-responses>=1.0.0b5,<2", + "azure-ai-agentserver-invocations>=1.0.0b3,<2", ] [tool.uv] diff --git a/python/uv.lock b/python/uv.lock index 8fb84afa44..e9b9b4b048 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -536,9 +536,9 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "agent-framework-core", editable = "packages/core" }, - { name = "azure-ai-agentserver-core", specifier = "==2.0.0b3" }, - { name = "azure-ai-agentserver-invocations", specifier = "==1.0.0b3" }, - { name = "azure-ai-agentserver-responses", specifier = "==1.0.0b5" }, + { name = "azure-ai-agentserver-core", specifier = ">=2.0.0b3,<3" }, + { name = "azure-ai-agentserver-invocations", specifier = ">=1.0.0b3,<2" }, + { name = "azure-ai-agentserver-responses", specifier = ">=1.0.0b5,<2" }, ] [[package]]