From f25be2b63c6aac282bde56f7d4b3d8a522bc4525 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Fri, 3 Apr 2026 15:46:00 -0400 Subject: [PATCH 1/2] feat: add metadata field to messages for stateful context tracking (#1532) Attach usage and metrics from model responses directly to assistant messages, enabling downstream features like smart truncation and per-message cost analysis. --- src/strands/agent/agent.py | 2 +- src/strands/event_loop/event_loop.py | 7 + src/strands/event_loop/streaming.py | 3 + src/strands/types/content.py | 27 +++- .../strands/agent/hooks/test_agent_events.py | 13 +- tests/strands/agent/test_agent.py | 16 +- .../strands/agent/test_agent_cancellation.py | 3 +- tests/strands/agent/test_agent_hooks.py | 6 +- tests/strands/event_loop/test_event_loop.py | 17 ++- .../event_loop/test_event_loop_metadata.py | 141 ++++++++++++++++++ tests/strands/types/test_message_metadata.py | 37 +++++ 11 files changed, 250 insertions(+), 22 deletions(-) create mode 100644 tests/strands/event_loop/test_event_loop_metadata.py create mode 100644 tests/strands/types/test_message_metadata.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 37fa5fc00..e8ea3c9bc 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -1025,7 +1025,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: # Check if all item in input list are dictionaries elif all(isinstance(item, dict) for item in prompt): # Check if all items are messages - if all(all(key in item for key in Message.__annotations__.keys()) for item in prompt): + if all(all(key in item for key in Message.__required_keys__) for item in prompt): # Messages input - add all messages to conversation messages = cast(Messages, prompt) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b4af16058..bf1cc7a84 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -354,6 +354,13 @@ async def _handle_model_execution( stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) + # Attach metadata to the assistant message immediately so it's + # available to all downstream consumers (hooks, events, state). + message["metadata"] = { + "usage": usage, + "metrics": metrics, + } + after_model_call_event = AfterModelCallEvent( agent=agent, invocation_state=invocation_state, diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 0a1161135..76eda48bf 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -488,6 +488,9 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = _normalize_messages(messages) + # Whitelist only role and content before sending to the model provider. + # This ensures metadata (and any future non-model fields) never leak to providers. + messages = [Message(role=msg["role"], content=msg["content"]) for msg in messages] start_time = time.time() chunks = model.stream( diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 2b0714bee..7ddd4db9e 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,11 +6,12 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal +from typing import Any, Literal from typing_extensions import NotRequired, TypedDict from .citations import CitationsContentBlock +from .event_loop import Metrics, Usage from .media import DocumentContent, ImageContent, VideoContent from .tools import ToolResult, ToolUse @@ -177,17 +178,41 @@ class ContentBlockStop(TypedDict): """ +class MessageMetadata(TypedDict, total=False): + """Optional metadata attached to a message. + + Not sent to model providers — explicitly stripped before model calls. + Persisted alongside the message in session storage. + + Attributes: + usage: Token usage information from the model response. + metrics: Performance metrics from the model response. + custom: Arbitrary user/framework metadata (e.g. compression provenance). + """ + + usage: Usage + metrics: Metrics + custom: dict[str, Any] + + class Message(TypedDict): """A message in a conversation with the agent. Attributes: content: The message content. role: The role of the message sender. + metadata: Optional metadata, stripped before model calls. """ content: list[ContentBlock] role: Role + metadata: NotRequired[MessageMetadata] Messages = list[Message] """A list of messages representing a conversation.""" + + +def get_message_metadata(message: Message) -> MessageMetadata: + """Get metadata for a message, returning empty dict if not present.""" + return message.get("metadata", {}) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 02c367ccc..1f09579b0 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -147,6 +147,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "normal_tool", "toolUseId": "123"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -205,6 +206,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "async_tool", "toolUseId": "1234"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -263,6 +265,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -307,11 +310,11 @@ async def test_stream_e2e_success(alist): }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "end_turn"}}}, - {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant"}}, + {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant", "metadata": ANY}}, { "result": AgentResult( stop_reason="end_turn", - message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"}, + message={"content": [{"text": "I invoked the tools!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, state={}, ), @@ -371,11 +374,11 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, - {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}}, + {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant", "metadata": ANY}}, { "result": AgentResult( stop_reason="guardrail_intervened", - message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}, + message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, state={}, ), @@ -442,6 +445,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist): {"text": "Response with redacted reasoning"}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -453,6 +457,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist): {"text": "Response with redacted reasoning"}, ], "role": "assistant", + "metadata": ANY, }, metrics=ANY, state={}, diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0057c50a3..1e27274a1 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -336,7 +336,7 @@ def test_agent__call__( "stop_reason": result.stop_reason, } exp_result = { - "message": {"content": [{"text": "test text"}], "role": "assistant"}, + "message": {"content": [{"text": "test text"}], "role": "assistant", "metadata": unittest.mock.ANY}, "state": {}, "stop_reason": "end_turn", } @@ -781,6 +781,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, {"text": "value"}, ], + "metadata": unittest.mock.ANY, }, ), unittest.mock.call( @@ -793,6 +794,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, {"text": "value"}, ], + "metadata": unittest.mock.ANY, }, metrics=unittest.mock.ANY, state={}, @@ -817,7 +819,7 @@ async def test_agent__call__in_async_context(mock_model, agent, agenerator): result = agent("test") tru_message = result.message - exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + exp_message = {"content": [{"text": "abc"}], "role": "assistant", "metadata": unittest.mock.ANY} assert tru_message == exp_message @@ -837,7 +839,7 @@ async def test_agent_invoke_async(mock_model, agent, agenerator): result = await agent.invoke_async("test") tru_message = result.message - exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + exp_message = {"content": [{"text": "abc"}], "role": "assistant", "metadata": unittest.mock.ANY} assert tru_message == exp_message @@ -1128,7 +1130,7 @@ async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, ali tru_message = agent.messages exp_message = [ {"content": prompt, "role": "user"}, - {"content": [{"text": "I see text and an image"}], "role": "assistant"}, + {"content": [{"text": "I see text and an image"}], "role": "assistant", "metadata": unittest.mock.ANY}, ] assert tru_message == exp_message @@ -1966,7 +1968,11 @@ def shell(command: str): } # And that it continued to the LLM call - assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"} + assert agent.messages[-1] == { + "content": [{"text": "I invoked a tool!"}], + "role": "assistant", + "metadata": unittest.mock.ANY, + } def test_agent_string_system_prompt(): diff --git a/tests/strands/agent/test_agent_cancellation.py b/tests/strands/agent/test_agent_cancellation.py index 6af153f4a..756e96485 100644 --- a/tests/strands/agent/test_agent_cancellation.py +++ b/tests/strands/agent/test_agent_cancellation.py @@ -2,6 +2,7 @@ import asyncio import threading +from unittest.mock import ANY import pytest @@ -31,7 +32,7 @@ async def test_agent_cancel_before_invocation(): result = await agent.invoke_async("Hello") assert result.stop_reason == "cancelled" - assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}]} + assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}], "metadata": ANY} @pytest.mark.asyncio diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 3a40d69a8..2c61ee966 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -173,6 +173,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u message={ "content": [{"toolUse": tool_use}], "role": "assistant", + "metadata": ANY, }, stop_reason="tool_use", ), @@ -199,7 +200,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u agent=agent, invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message=mock_model.agent_responses[1], + message={"role": "assistant", "content": [{"text": "I invoked a tool!"}], "metadata": ANY}, stop_reason="end_turn", ), exception=None, @@ -246,6 +247,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m message={ "content": [{"toolUse": tool_use}], "role": "assistant", + "metadata": ANY, }, stop_reason="tool_use", ), @@ -272,7 +274,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m agent=agent, invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message=mock_model.agent_responses[1], + message={"role": "assistant", "content": [{"text": "I invoked a tool!"}], "metadata": ANY}, stop_reason="end_turn", ), exception=None, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index f91f7c2af..871371f5f 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -193,7 +193,7 @@ async def test_event_loop_cycle_text_response( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -225,7 +225,7 @@ async def test_event_loop_cycle_text_response_throttling( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -264,7 +264,7 @@ async def test_event_loop_cycle_exponential_backoff( # Verify the final response assert tru_stop_reason == "end_turn" - assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} assert tru_request_state == {} # Verify that sleep was called with increasing delays @@ -354,7 +354,7 @@ async def test_event_loop_cycle_tool_result( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -389,7 +389,6 @@ async def test_event_loop_cycle_tool_result( }, ], }, - {"role": "assistant", "content": [{"text": "test text"}]}, ], tool_registry.get_all_tool_specs(), "p1", @@ -484,6 +483,7 @@ async def test_event_loop_cycle_stop( } } ], + "metadata": ANY, } exp_request_state = {"stop_event_loop": True} @@ -946,14 +946,14 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, agent=agent, invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn" + message={"content": [{"text": "test text"}], "role": "assistant", "metadata": ANY}, stop_reason="end_turn" ), exception=None, ) # Final message assert next(events) == MessageAddedEvent( - agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} + agent=agent, message={"content": [{"text": "test text"}], "role": "assistant", "metadata": ANY} ) @@ -997,6 +997,7 @@ def interrupt_callback(event): }, ], "role": "assistant", + "metadata": ANY, }, }, "interrupts": { @@ -1131,7 +1132,7 @@ async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): # ensure that we got end_turn and not tool_use assert events[-1] == EventLoopStopEvent( stop_reason="end_turn", - message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"}, + message={"content": [{"text": "I invoked a tool!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, request_state={}, ) diff --git a/tests/strands/event_loop/test_event_loop_metadata.py b/tests/strands/event_loop/test_event_loop_metadata.py new file mode 100644 index 000000000..e6fe97f39 --- /dev/null +++ b/tests/strands/event_loop/test_event_loop_metadata.py @@ -0,0 +1,141 @@ +"""Tests for metadata population on assistant messages in the event loop.""" + +import threading +import unittest.mock + +import pytest + +import strands +import strands.event_loop.event_loop +from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy +from strands.hooks import HookRegistry +from strands.interrupt import _InterruptState +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.executors import SequentialToolExecutor +from strands.tools.registry import ToolRegistry + + +@pytest.fixture +def model(): + return unittest.mock.Mock() + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +@pytest.fixture +def hook_registry(): + registry = HookRegistry() + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry + + +@pytest.fixture +def tool_registry(): + return ToolRegistry() + + +@pytest.fixture +def agent(model, messages, tool_registry, hook_registry): + mock = unittest.mock.Mock(name="agent") + mock.__class__ = Agent + mock.config.cache_points = [] + mock.model = model + mock.system_prompt = "test" + mock.messages = messages + mock.tool_registry = tool_registry + mock.thread_pool = None + mock.event_loop_metrics = EventLoopMetrics() + mock.event_loop_metrics.reset_usage_metrics() + mock.hooks = hook_registry + mock.tool_executor = SequentialToolExecutor() + mock._interrupt_state = _InterruptState() + mock._cancel_signal = threading.Event() + mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() + return mock + + +@pytest.mark.asyncio +async def test_metadata_populated_on_assistant_message(agent, model, agenerator, alist): + """After a model response, the assistant message should have metadata with usage and metrics.""" + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + { + "metadata": { + "usage": {"inputTokens": 42, "outputTokens": 10, "totalTokens": 52}, + "metrics": {"latencyMs": 200}, + } + }, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + # The assistant message should be in agent.messages + assistant_msg = agent.messages[-1] + assert assistant_msg["role"] == "assistant" + assert "metadata" in assistant_msg + + meta = assistant_msg["metadata"] + assert meta["usage"]["inputTokens"] == 42 + assert meta["usage"]["outputTokens"] == 10 + assert meta["usage"]["totalTokens"] == 52 + assert meta["metrics"]["latencyMs"] == 200 + + +@pytest.mark.asyncio +async def test_metadata_has_default_usage_when_no_metadata_event(agent, model, agenerator, alist): + """When no metadata event is in the stream, metadata should still be set with defaults.""" + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + assistant_msg = agent.messages[-1] + assert "metadata" in assistant_msg + assert assistant_msg["metadata"]["usage"]["inputTokens"] == 0 + assert assistant_msg["metadata"]["usage"]["outputTokens"] == 0 + assert assistant_msg["metadata"]["metrics"]["latencyMs"] == 0 + + +@pytest.mark.asyncio +async def test_metadata_stripped_before_model_call(agent, model, agenerator, alist): + """Metadata from previous messages should be stripped before sending to the model.""" + # Pre-populate a message with metadata (simulating a previous turn) + agent.messages.append( + { + "role": "assistant", + "content": [{"text": "previous response"}], + "metadata": {"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}}, + } + ) + agent.messages.append({"role": "user", "content": [{"text": "follow up"}]}) + + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + # Verify that messages passed to model.stream() have no metadata key + call_args = model.stream.call_args + messages_sent = call_args[0][0] + for msg in messages_sent: + assert "metadata" not in msg, f"metadata leaked to model: {msg}" diff --git a/tests/strands/types/test_message_metadata.py b/tests/strands/types/test_message_metadata.py new file mode 100644 index 000000000..a7f93f690 --- /dev/null +++ b/tests/strands/types/test_message_metadata.py @@ -0,0 +1,37 @@ +"""Tests for MessageMetadata and get_message_metadata.""" + +from strands.types.content import Message, MessageMetadata, get_message_metadata + + +def test_message_without_metadata(): + msg: Message = {"role": "assistant", "content": [{"text": "hello"}]} + assert get_message_metadata(msg) == {} + + +def test_message_with_metadata(): + meta: MessageMetadata = { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + msg: Message = {"role": "assistant", "content": [{"text": "hello"}], "metadata": meta} + assert get_message_metadata(msg) == meta + assert get_message_metadata(msg)["usage"]["inputTokens"] == 10 + + +def test_message_with_custom_metadata(): + meta: MessageMetadata = { + "custom": {"source": "summarization", "original_turns": [5, 6, 7]}, + } + msg: Message = {"role": "assistant", "content": [{"text": "summary"}], "metadata": meta} + result = get_message_metadata(msg) + assert result["custom"]["source"] == "summarization" + + +def test_metadata_does_not_affect_role_and_content(): + msg: Message = { + "role": "assistant", + "content": [{"text": "hello"}], + "metadata": {"usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 2}}, + } + assert msg["role"] == "assistant" + assert msg["content"] == [{"text": "hello"}] From 68f5400920862007849147975ff2263579ebc61e Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:50:12 -0400 Subject: [PATCH 2/2] fix: update doc string + add metadata to recover_message_on_max_tokens_reached dict --- .../_recover_message_on_max_tokens_reached.py | 5 +++- src/strands/telemetry/tracer.py | 4 +-- src/strands/types/content.py | 5 +++- ...t_recover_message_on_max_tokens_reached.py | 28 +++++++++++++++++++ .../tools/mcp/test_mcp_client_tasks.py | 8 ++---- 5 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index ab6fb4abe..dc073ba07 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -68,4 +68,7 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: } ) - return {"content": valid_content, "role": message["role"]} + recovered: Message = {"content": valid_content, "role": message["role"]} + if "metadata" in message: + recovered["metadata"] = message["metadata"] + return recovered diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index d5d399f95..37c16d3ae 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -527,9 +527,7 @@ def start_event_loop_cycle_span( event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id")) parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span") - attributes: dict[str, AttributeValue] = self._get_common_attributes( - operation_name="execute_event_loop_cycle" - ) + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_event_loop_cycle") attributes["event_loop.cycle_id"] = event_loop_cycle_id if custom_trace_attributes: diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 7ddd4db9e..8db1d1d98 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -214,5 +214,8 @@ class Message(TypedDict): def get_message_metadata(message: Message) -> MessageMetadata: - """Get metadata for a message, returning empty dict if not present.""" + """Get metadata for a message, returning empty dict if not present. + + Individual fields (usage, metrics, custom) may not be present. Use .get() to safely access them. + """ return message.get("metadata", {}) diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py index 402e90966..6dff0fc29 100644 --- a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -224,6 +224,34 @@ def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools(): assert "incomplete due to maximum token limits" in result["content"][2]["text"] +def test_recover_message_on_max_tokens_reached_preserves_metadata(): + """Test that metadata is preserved through recovery.""" + message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": "123"}}, + ], + "metadata": {"usage": {"inputTokens": 42, "outputTokens": 10, "totalTokens": 52}}, + } + + result = recover_message_on_max_tokens_reached(message) + + assert "metadata" in result + assert result["metadata"]["usage"]["inputTokens"] == 42 + + +def test_recover_message_on_max_tokens_reached_without_metadata(): + """Test that recovery works fine when no metadata is present.""" + message: Message = { + "role": "assistant", + "content": [{"text": "some text"}], + } + + result = recover_message_on_max_tokens_reached(message) + + assert "metadata" not in result + + def test_recover_message_on_max_tokens_reached_preserves_user_role(): """Test that the function preserves the original message role.""" incomplete_message: Message = { diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index c21db9e28..d566ac6f5 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -251,9 +251,7 @@ def test_call_tool_sync_forwards_meta_to_task(self, mock_transport, mock_session with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client.list_tools_sync() - client.call_tool_sync( - tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta - ) + client.call_tool_sync(tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta) experimental.call_tool_as_task.assert_called_once() call_kwargs = experimental.call_tool_as_task.call_args @@ -281,9 +279,7 @@ def test_call_tool_sync_forwards_none_meta_to_task(self, mock_transport, mock_se with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client.list_tools_sync() - client.call_tool_sync( - tool_use_id="test-id", name="no_meta_tool", arguments={"param": "value"} - ) + client.call_tool_sync(tool_use_id="test-id", name="no_meta_tool", arguments={"param": "value"}) experimental.call_tool_as_task.assert_called_once() call_kwargs = experimental.call_tool_as_task.call_args