diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..bf920afd1 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -500,16 +500,19 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An for citation in citations["citations"]: filtered_citation: dict[str, Any] = {} if "location" in citation: - location = citation["location"] - filtered_location = {} + location: dict[str, Any] = cast(dict[str, Any], citation["location"]) + filtered_location: dict[str, Any] = {} # Filter location fields to only include Bedrock-supported ones - if "documentIndex" in location: - filtered_location["documentIndex"] = location["documentIndex"] - if "start" in location: - filtered_location["start"] = location["start"] - if "end" in location: - filtered_location["end"] = location["end"] - filtered_citation["location"] = filtered_location + # Handle web-based citations + if "web" in location: + web_data = location["web"] + filtered_location["web"] = {k: v for k, v in web_data.items() if k in ["url", "domain"]} + # Handle document-based citations + for field in ["documentIndex", "start", "end"]: + if field in location: + filtered_location[field] = location[field] + if filtered_location: + filtered_citation["location"] = filtered_location if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: @@ -681,8 +684,12 @@ def _stream( logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) - # Track tool use events to fix stopReason for streaming responses - has_tool_use = False + # Track tool use/result events to fix stopReason for streaming responses + # We need to distinguish server-side tools (already executed) from client-side tools + tool_use_info: dict[str, str] = {} # toolUseId -> type (e.g., "server_tool_use") + tool_result_ids: set[str] = set() # IDs of tools with results + has_client_tools = False + for chunk in response["stream"]: if ( "metadata" in chunk @@ -694,22 +701,41 @@ def _stream( for event in self._generate_redaction_events(): callback(event) - # Track if we see tool use events - if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"): - has_tool_use = True + # Track tool use events with their types + if "contentBlockStart" in chunk: + tool_use_start = chunk["contentBlockStart"].get("start", {}).get("toolUse") + if tool_use_start: + tool_use_id = tool_use_start.get("toolUseId", "") + tool_type = tool_use_start.get("type", "") + tool_use_info[tool_use_id] = tool_type + # Check if it's a client-side tool (not server_tool_use) + if tool_type != "server_tool_use": + has_client_tools = True + + # Track tool result events (for server-side tools that were already executed) + if "contentBlockStart" in chunk: + tool_result_start = chunk["contentBlockStart"].get("start", {}).get("toolResult") + if tool_result_start: + tool_result_ids.add(tool_result_start.get("toolUseId", "")) # Fix stopReason for streaming responses that contain tool use + # BUT: Only override if there are client-side tools without results if ( - has_tool_use - and "messageStop" in chunk + "messageStop" in chunk and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn" ): - # Create corrected chunk with tool_use stopReason - modified_chunk = chunk.copy() - modified_chunk["messageStop"] = message_stop.copy() - modified_chunk["messageStop"]["stopReason"] = "tool_use" - logger.warning("Override stop reason from end_turn to tool_use") - callback(modified_chunk) + # Check if we have client-side tools that need execution + needs_execution = has_client_tools and not set(tool_use_info.keys()).issubset(tool_result_ids) + + if needs_execution: + # Create corrected chunk with tool_use stopReason + modified_chunk = chunk.copy() + modified_chunk["messageStop"] = message_stop.copy() + modified_chunk["messageStop"]["stopReason"] = "tool_use" + logger.warning("Override stop reason from end_turn to tool_use") + callback(modified_chunk) + else: + callback(chunk) else: callback(chunk) @@ -771,6 +797,43 @@ def _stream( callback() logger.debug("finished streaming response from model") + def _has_client_side_tools_to_execute(self, message_content: list[dict[str, Any]]) -> bool: + """Check if message contains client-side tools that need execution. + + Server-side tools (like nova_grounding) are executed by Bedrock and include + toolResult blocks in the response. We should NOT override stopReason to + "tool_use" for these tools. + + Args: + message_content: The content array from Bedrock response + + Returns: + True if there are client-side tools without results, False otherwise + """ + tool_use_ids = set() + tool_result_ids = set() + has_client_tools = False + + for content in message_content: + if "toolUse" in content: + tool_use = content["toolUse"] + tool_use_ids.add(tool_use["toolUseId"]) + + # Check if it's a server-side tool (Bedrock executes these) + if tool_use.get("type") != "server_tool_use": + has_client_tools = True + + elif "toolResult" in content: + # Track which tools already have results + tool_result_ids.add(content["toolResult"]["toolUseId"]) + + # Only return True if there are client-side tools without results + if not has_client_tools: + return False + + # Check if all tool uses have corresponding results + return not tool_use_ids.issubset(tool_result_ids) + def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: """Convert a non-streaming response to the streaming format. @@ -838,23 +901,26 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera for citation in content["citationsContent"]["citations"]: # Then emit citation metadata (for structure) - - citation_metadata: CitationsDelta = { - "title": citation["title"], - "location": citation["location"], - "sourceContent": citation["sourceContent"], - } - yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}} + citation_metadata: dict[str, Any] = {} + if "title" in citation: + citation_metadata["title"] = citation["title"] + if "location" in citation: + citation_metadata["location"] = citation["location"] + if "sourceContent" in citation: + citation_metadata["sourceContent"] = citation["sourceContent"] + yield {"contentBlockDelta": {"delta": {"citation": cast(CitationsDelta, citation_metadata)}}} # Yield contentBlockStop event yield {"contentBlockStop": {}} # Yield messageStop event # Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side + # BUT: Don't override for server-side tools (like nova_grounding) that are already executed current_stop_reason = response["stopReason"] if current_stop_reason == "end_turn": message_content = response["output"]["message"]["content"] - if any("toolUse" in content for content in message_content): + # Only override if there are client-side tools that need execution + if self._has_client_side_tools_to_execute(message_content): current_stop_reason = "tool_use" logger.warning("Override stop reason from end_turn to tool_use") diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..b351229e0 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,7 +3,7 @@ These types are modeled after the Bedrock API. """ -from typing import List, Union +from typing import List from typing_extensions import TypedDict @@ -18,67 +18,41 @@ class CitationsConfig(TypedDict): enabled: bool -class DocumentCharLocation(TypedDict, total=False): - """Specifies a character-level location within a document. - - Provides precise positioning information for cited content using - start and end character indices. +class WebLocationDetail(TypedDict, total=False): + """Details of a web-based location. Attributes: - documentIndex: The index of the document within the array of documents - provided in the request. Minimum value of 0. - start: The starting character position of the cited content within - the document. Minimum value of 0. - end: The ending character position of the cited content within - the document. Minimum value of 0. + url: The URL of the web page containing the cited content. + domain: The domain of the web page containing the cited content. """ - documentIndex: int - start: int - end: int + url: str + domain: str -class DocumentChunkLocation(TypedDict, total=False): - """Specifies a chunk-level location within a document. +class CitationLocation(TypedDict, total=False): + """Specifies a location for cited content. - Provides positioning information for cited content using logical - document segments or chunks. + Can represent different types of locations depending on which fields are present: + - Document-based citations: Uses documentIndex, start, and end for character/chunk/page positions + - Web-based citations: Uses web dict with url and domain - Attributes: - documentIndex: The index of the document within the array of documents - provided in the request. Minimum value of 0. - start: The starting chunk identifier or index of the cited content - within the document. Minimum value of 0. - end: The ending chunk identifier or index of the cited content - within the document. Minimum value of 0. - """ - - documentIndex: int - start: int - end: int - - -class DocumentPageLocation(TypedDict, total=False): - """Specifies a page-level location within a document. - - Provides positioning information for cited content using page numbers. + All fields are optional as only the relevant subset will be present based on citation type. Attributes: documentIndex: The index of the document within the array of documents - provided in the request. Minimum value of 0. - start: The starting page number of the cited content within - the document. Minimum value of 0. - end: The ending page number of the cited content within - the document. Minimum value of 0. + provided in the request. Used for document-based citations. + start: The starting position (character, chunk, or page) of the cited content. + Used for document-based citations. + end: The ending position (character, chunk, or page) of the cited content. + Used for document-based citations. + web: Web location details containing URL and domain. Used for web-based citations. """ documentIndex: int start: int end: int - - -# Union type for citation locations -CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] + web: WebLocationDetail class CitationSourceContent(TypedDict, total=False): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..124458e6c 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2070,3 +2070,171 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +def test_format_request_message_content_web_citation(model): + """Test that web citations are correctly filtered to include only url and domain.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Web Citation Example", + "location": { + "web": { + "url": "https://example.com/article", + "domain": "example.com", + "extraField": "should be filtered out", + } + }, + "sourceContent": [{"text": "Example content"}], + } + ], + "content": [{"text": "Generated text with citation"}], + } + } + + result = model._format_request_message_content(content) + + assert "citationsContent" in result + assert "citations" in result["citationsContent"] + assert len(result["citationsContent"]["citations"]) == 1 + + citation = result["citationsContent"]["citations"][0] + assert citation["title"] == "Web Citation Example" + assert "location" in citation + assert "web" in citation["location"] + # Verify only url and domain are included + assert citation["location"]["web"] == { + "url": "https://example.com/article", + "domain": "example.com", + } + assert "extraField" not in citation["location"]["web"] + + +def test_format_request_message_content_document_citation(model): + """Test that document citations preserve documentIndex, start, and end fields.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Document Citation Example", + "location": { + "documentIndex": 0, + "start": 100, + "end": 200, + "extraField": "should be filtered out", + }, + "sourceContent": [{"text": "Document excerpt"}], + } + ], + "content": [{"text": "Generated text with document citation"}], + } + } + + result = model._format_request_message_content(content) + + citation = result["citationsContent"]["citations"][0] + assert citation["title"] == "Document Citation Example" + assert citation["location"] == { + "documentIndex": 0, + "start": 100, + "end": 200, + } + + +def test_format_request_message_content_mixed_citations(model): + """Test handling of both web and document citations in the same response.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Web Source", + "location": { + "web": { + "url": "https://example.com", + "domain": "example.com", + } + }, + "sourceContent": [{"text": "Web content"}], + }, + { + "title": "Document Source", + "location": { + "documentIndex": 1, + "start": 50, + "end": 150, + }, + "sourceContent": [{"text": "Document content"}], + }, + ], + "content": [{"text": "Generated text with multiple citations"}], + } + } + + result = model._format_request_message_content(content) + + citations = result["citationsContent"]["citations"] + assert len(citations) == 2 + + # Verify web citation + assert citations[0]["title"] == "Web Source" + assert "web" in citations[0]["location"] + assert citations[0]["location"]["web"]["url"] == "https://example.com" + + # Verify document citation + assert citations[1]["title"] == "Document Source" + assert citations[1]["location"]["documentIndex"] == 1 + assert citations[1]["location"]["start"] == 50 + + +def test_format_request_message_content_citation_partial_fields(model): + """Test that citations with only some fields present are handled correctly.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Minimal Citation", + "location": { + "web": { + "url": "https://example.com", + # domain is optional + } + }, + # sourceContent is optional + } + ], + "content": [{"text": "Generated text"}], + } + } + + result = model._format_request_message_content(content) + + citation = result["citationsContent"]["citations"][0] + assert citation["title"] == "Minimal Citation" + assert citation["location"]["web"]["url"] == "https://example.com" + assert "domain" not in citation["location"]["web"] + assert "sourceContent" not in citation + + +def test_format_request_message_content_citation_empty_location(model): + """Test that citations with empty or invalid locations are filtered out.""" + content = { + "citationsContent": { + "citations": [ + { + "title": "Citation without valid location", + "location": {"unknownField": "value"}, + "sourceContent": [{"text": "Some content"}], + } + ], + "content": [{"text": "Generated text"}], + } + } + + result = model._format_request_message_content(content) + + citation = result["citationsContent"]["citations"][0] + # Location should not be included if it has no valid fields + assert "location" not in citation + assert citation["title"] == "Citation without valid location" + assert "sourceContent" in citation