-
Notifications
You must be signed in to change notification settings - Fork 860
feat: make BedrockModel._format_request and _convert_non_streaming_to… #2315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,6 +234,28 @@ def get_config(self) -> BedrockConfig: | |
| """ | ||
| return resolve_config_metadata(self.config, self.config.get("model_id", "")) | ||
|
|
||
| def format_request( | ||
| self, | ||
| messages: Messages, | ||
| tool_specs: list[ToolSpec] | None = None, | ||
| system_prompt_content: list[SystemContentBlock] | None = None, | ||
| tool_choice: ToolChoice | None = None, | ||
| ) -> dict[str, Any]: | ||
| """Format a Bedrock converse stream request. | ||
|
|
||
| Args: | ||
| messages: List of message objects to be processed by the model. | ||
| tool_specs: List of tool specifications to make available to the model. | ||
| tool_choice: Selection strategy for tool invocation. | ||
| system_prompt_content: System prompt content blocks to provide context to the model. | ||
|
|
||
| Returns: | ||
| A Bedrock converse stream request. | ||
| """ | ||
| with warnings.catch_warnings(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The flow is a little odd right now; could we switch to:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we're also okay breaking users of _ as well given that it's marked as private. That's what we went with for #2370 |
||
| warnings.simplefilter("ignore", DeprecationWarning) | ||
| return self._format_request(messages, tool_specs, system_prompt_content, tool_choice) | ||
|
|
||
| def _format_request( | ||
| self, | ||
| messages: Messages, | ||
|
|
@@ -243,6 +265,9 @@ def _format_request( | |
| ) -> dict[str, Any]: | ||
| """Format a Bedrock converse stream request. | ||
|
|
||
| .. deprecated:: | ||
| Use :meth:`format_request` instead. This will be removed in September 2026. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't call out a specific date; it's very subject to change. |
||
|
|
||
| Args: | ||
| messages: List of message objects to be processed by the model. | ||
| tool_specs: List of tool specifications to make available to the model. | ||
|
|
@@ -252,6 +277,12 @@ def _format_request( | |
| Returns: | ||
| A Bedrock converse stream request. | ||
| """ | ||
| warnings.warn( | ||
| "_format_request is on the deprecation path, use format_request instead. " | ||
| "This will be removed in September 2026.", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| if not tool_specs: | ||
| has_tool_content = any( | ||
| any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages | ||
|
|
@@ -830,7 +861,9 @@ async def count_tokens( | |
| if system_prompt and system_prompt_content is None: | ||
| system_prompt_content = [{"text": system_prompt}] | ||
|
|
||
| request = self._format_request(messages, tool_specs, system_prompt_content) | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore", DeprecationWarning) | ||
| request = self._format_request(messages, tool_specs, system_prompt_content) | ||
| converse_input: dict[str, Any] = {} | ||
| if "messages" in request: | ||
| converse_input["messages"] = request["messages"] | ||
|
|
@@ -852,13 +885,9 @@ async def count_tokens( | |
| logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) | ||
| return total_tokens | ||
| except Exception as e: | ||
| if ( | ||
| isinstance(e, ClientError) | ||
| and e.response.get("Error", {}).get("Code") == "AccessDeniedException" | ||
| ): | ||
| if isinstance(e, ClientError) and e.response.get("Error", {}).get("Code") == "AccessDeniedException": | ||
| logger.warning( | ||
| "model_id=<%s> | bedrock:CountTokens permission denied," | ||
| " falling back to heuristic estimation: %s", | ||
| "model_id=<%s> | bedrock:CountTokens permission denied, falling back to heuristic estimation: %s", | ||
| model_id, | ||
| e, | ||
| ) | ||
|
|
@@ -964,7 +993,9 @@ def _stream( | |
| """ | ||
| try: | ||
| logger.debug("formatting request") | ||
| request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore", DeprecationWarning) | ||
| request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice) | ||
| logger.debug("request=<%s>", request) | ||
|
|
||
| logger.debug("invoking model") | ||
|
|
@@ -988,8 +1019,10 @@ def _stream( | |
|
|
||
| else: | ||
| response = self.client.converse(**request) | ||
| for event in self._convert_non_streaming_to_streaming(response): | ||
| callback(event) | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore", DeprecationWarning) | ||
| for event in self._convert_non_streaming_to_streaming(response): | ||
| callback(event) | ||
|
|
||
| if ( | ||
| "trace" in response | ||
|
|
@@ -1044,15 +1077,38 @@ def _stream( | |
| callback() | ||
| logger.debug("finished streaming response from model") | ||
|
|
||
| def convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: | ||
| """Convert a non-streaming response to the streaming format. | ||
|
|
||
| Args: | ||
| response: The non-streaming response from the Bedrock model. | ||
|
|
||
| Returns: | ||
| An iterable of response events in the streaming format. | ||
| """ | ||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore", DeprecationWarning) | ||
| yield from self._convert_non_streaming_to_streaming(response) | ||
|
|
||
| def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: | ||
| """Convert a non-streaming response to the streaming format. | ||
|
|
||
| .. deprecated:: | ||
| Use :meth:`convert_non_streaming_to_streaming` instead. This will be removed in September 2026. | ||
|
|
||
| Args: | ||
| response: The non-streaming response from the Bedrock model. | ||
|
|
||
| Returns: | ||
| An iterable of response events in the streaming format. | ||
| """ | ||
| warnings.warn( | ||
| "_convert_non_streaming_to_streaming is on the deprecation path, " | ||
| "use convert_non_streaming_to_streaming instead. " | ||
| "This will be removed in September 2026.", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| # Yield messageStart event | ||
| yield {"messageStart": {"role": response["output"]["message"]["role"]}} | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add kwargs here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in https://github.com/strands-agents/sdk-python/pull/2093/changes#r3306308887 is an example of an additive parameter FWIW