Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 96 additions & 30 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,16 +500,19 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
for citation in citations["citations"]:
filtered_citation: dict[str, Any] = {}
if "location" in citation:
location = citation["location"]
filtered_location = {}
location: dict[str, Any] = cast(dict[str, Any], citation["location"])
filtered_location: dict[str, Any] = {}
# Filter location fields to only include Bedrock-supported ones
if "documentIndex" in location:
filtered_location["documentIndex"] = location["documentIndex"]
if "start" in location:
filtered_location["start"] = location["start"]
if "end" in location:
filtered_location["end"] = location["end"]
filtered_citation["location"] = filtered_location
# Handle web-based citations
if "web" in location:
web_data = location["web"]
filtered_location["web"] = {k: v for k, v in web_data.items() if k in ["url", "domain"]}
# Handle document-based citations
for field in ["documentIndex", "start", "end"]:
if field in location:
filtered_location[field] = location[field]
if filtered_location:
filtered_citation["location"] = filtered_location
if "sourceContent" in citation:
filtered_source_content: list[dict[str, Any]] = []
for source_content in citation["sourceContent"]:
Expand Down Expand Up @@ -681,8 +684,12 @@ def _stream(
logger.debug("got response from model")
if streaming:
response = self.client.converse_stream(**request)
# Track tool use events to fix stopReason for streaming responses
has_tool_use = False
# Track tool use/result events to fix stopReason for streaming responses
# We need to distinguish server-side tools (already executed) from client-side tools
tool_use_info: dict[str, str] = {} # toolUseId -> type (e.g., "server_tool_use")
tool_result_ids: set[str] = set() # IDs of tools with results
has_client_tools = False

for chunk in response["stream"]:
if (
"metadata" in chunk
Expand All @@ -694,22 +701,41 @@ def _stream(
for event in self._generate_redaction_events():
callback(event)

# Track if we see tool use events
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
has_tool_use = True
# Track tool use events with their types
if "contentBlockStart" in chunk:
tool_use_start = chunk["contentBlockStart"].get("start", {}).get("toolUse")
if tool_use_start:
tool_use_id = tool_use_start.get("toolUseId", "")
tool_type = tool_use_start.get("type", "")
tool_use_info[tool_use_id] = tool_type
# Check if it's a client-side tool (not server_tool_use)
if tool_type != "server_tool_use":
has_client_tools = True

# Track tool result events (for server-side tools that were already executed)
if "contentBlockStart" in chunk:
tool_result_start = chunk["contentBlockStart"].get("start", {}).get("toolResult")
if tool_result_start:
tool_result_ids.add(tool_result_start.get("toolUseId", ""))

# Fix stopReason for streaming responses that contain tool use
# BUT: Only override if there are client-side tools without results
if (
has_tool_use
and "messageStop" in chunk
"messageStop" in chunk
and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn"
):
# Create corrected chunk with tool_use stopReason
modified_chunk = chunk.copy()
modified_chunk["messageStop"] = message_stop.copy()
modified_chunk["messageStop"]["stopReason"] = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")
callback(modified_chunk)
# Check if we have client-side tools that need execution
needs_execution = has_client_tools and not set(tool_use_info.keys()).issubset(tool_result_ids)

if needs_execution:
# Create corrected chunk with tool_use stopReason
modified_chunk = chunk.copy()
modified_chunk["messageStop"] = message_stop.copy()
modified_chunk["messageStop"]["stopReason"] = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")
callback(modified_chunk)
else:
callback(chunk)
else:
callback(chunk)

Expand Down Expand Up @@ -771,6 +797,43 @@ def _stream(
callback()
logger.debug("finished streaming response from model")

def _has_client_side_tools_to_execute(self, message_content: list[dict[str, Any]]) -> bool:
"""Check if message contains client-side tools that need execution.
Server-side tools (like nova_grounding) are executed by Bedrock and include
toolResult blocks in the response. We should NOT override stopReason to
"tool_use" for these tools.
Args:
message_content: The content array from Bedrock response
Returns:
True if there are client-side tools without results, False otherwise
"""
tool_use_ids = set()
tool_result_ids = set()
has_client_tools = False

for content in message_content:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_use_ids.add(tool_use["toolUseId"])

# Check if it's a server-side tool (Bedrock executes these)
if tool_use.get("type") != "server_tool_use":
has_client_tools = True

elif "toolResult" in content:
# Track which tools already have results
tool_result_ids.add(content["toolResult"]["toolUseId"])

# Only return True if there are client-side tools without results
if not has_client_tools:
return False

# Check if all tool uses have corresponding results
return not tool_use_ids.issubset(tool_result_ids)

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.
Expand Down Expand Up @@ -838,23 +901,26 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera

for citation in content["citationsContent"]["citations"]:
# Then emit citation metadata (for structure)

citation_metadata: CitationsDelta = {
"title": citation["title"],
"location": citation["location"],
"sourceContent": citation["sourceContent"],
}
yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}}
citation_metadata: dict[str, Any] = {}
if "title" in citation:
citation_metadata["title"] = citation["title"]
if "location" in citation:
citation_metadata["location"] = citation["location"]
if "sourceContent" in citation:
citation_metadata["sourceContent"] = citation["sourceContent"]
yield {"contentBlockDelta": {"delta": {"citation": cast(CitationsDelta, citation_metadata)}}}

# Yield contentBlockStop event
yield {"contentBlockStop": {}}

# Yield messageStop event
# Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side
# BUT: Don't override for server-side tools (like nova_grounding) that are already executed
current_stop_reason = response["stopReason"]
if current_stop_reason == "end_turn":
message_content = response["output"]["message"]["content"]
if any("toolUse" in content for content in message_content):
# Only override if there are client-side tools that need execution
if self._has_client_side_tools_to_execute(message_content):
current_stop_reason = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")

Expand Down
66 changes: 20 additions & 46 deletions src/strands/types/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
These types are modeled after the Bedrock API.
"""

from typing import List, Union
from typing import List

from typing_extensions import TypedDict

Expand All @@ -18,67 +18,41 @@ class CitationsConfig(TypedDict):
enabled: bool


class DocumentCharLocation(TypedDict, total=False):
"""Specifies a character-level location within a document.

Provides precise positioning information for cited content using
start and end character indices.
class WebLocationDetail(TypedDict, total=False):
"""Details of a web-based location.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting character position of the cited content within
the document. Minimum value of 0.
end: The ending character position of the cited content within
the document. Minimum value of 0.
url: The URL of the web page containing the cited content.
domain: The domain of the web page containing the cited content.
"""

documentIndex: int
start: int
end: int
url: str
domain: str


class DocumentChunkLocation(TypedDict, total=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we have to preserve backwards compatibility. So we cannot simply remove these classes.

These are typeddicts, so I see the argument that it would not be possible from someone to leverage
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] to do instanceof checks in the first place. But this still needs to be workshopped to avoid breaking changes

"""Specifies a chunk-level location within a document.
class CitationLocation(TypedDict, total=False):
"""Specifies a location for cited content.

Provides positioning information for cited content using logical
document segments or chunks.
Can represent different types of locations depending on which fields are present:
- Document-based citations: Uses documentIndex, start, and end for character/chunk/page positions
- Web-based citations: Uses web dict with url and domain

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting chunk identifier or index of the cited content
within the document. Minimum value of 0.
end: The ending chunk identifier or index of the cited content
within the document. Minimum value of 0.
"""

documentIndex: int
start: int
end: int


class DocumentPageLocation(TypedDict, total=False):
"""Specifies a page-level location within a document.

Provides positioning information for cited content using page numbers.
All fields are optional as only the relevant subset will be present based on citation type.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting page number of the cited content within
the document. Minimum value of 0.
end: The ending page number of the cited content within
the document. Minimum value of 0.
provided in the request. Used for document-based citations.
start: The starting position (character, chunk, or page) of the cited content.
Used for document-based citations.
end: The ending position (character, chunk, or page) of the cited content.
Used for document-based citations.
web: Web location details containing URL and domain. Used for web-based citations.
"""

documentIndex: int
start: int
end: int


# Union type for citation locations
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]
web: WebLocationDetail


class CitationSourceContent(TypedDict, total=False):
Expand Down
Loading