1212import tiktoken
1313from botocore .config import Config
1414from fastapi import HTTPException
15+ from starlette .concurrency import run_in_threadpool
1516
1617from api .models .base import BaseChatModel , BaseEmbeddingsModel
1718from api .schema import (
@@ -145,7 +146,7 @@ def validate(self, chat_request: ChatRequest):
145146 detail = error ,
146147 )
147148
148- def _invoke_bedrock (self , chat_request : ChatRequest , stream = False ):
149+ async def _invoke_bedrock (self , chat_request : ChatRequest , stream = False ):
149150 """Common logic for invoke bedrock models"""
150151 if DEBUG :
151152 logger .info ("Raw request: " + chat_request .model_dump_json ())
@@ -157,9 +158,11 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
157158
158159 try :
159160 if stream :
160- response = bedrock_runtime .converse_stream (** args )
161+ # Run the blocking boto3 call in a thread pool
162+ response = await run_in_threadpool (bedrock_runtime .converse_stream , ** args )
161163 else :
162- response = bedrock_runtime .converse (** args )
164+ # Run the blocking boto3 call in a thread pool
165+ response = await run_in_threadpool (bedrock_runtime .converse , ** args )
163166 except bedrock_runtime .exceptions .ValidationException as e :
164167 logger .error ("Validation Error: " + str (e ))
165168 raise HTTPException (status_code = 400 , detail = str (e ))
@@ -171,11 +174,11 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
171174 raise HTTPException (status_code = 500 , detail = str (e ))
172175 return response
173176
174- def chat (self , chat_request : ChatRequest ) -> ChatResponse :
177+ async def chat (self , chat_request : ChatRequest ) -> ChatResponse :
175178 """Default implementation for Chat API."""
176179
177180 message_id = self .generate_message_id ()
178- response = self ._invoke_bedrock (chat_request )
181+ response = await self ._invoke_bedrock (chat_request )
179182
180183 output_message = response ["output" ]["message" ]
181184 input_tokens = response ["usage" ]["inputTokens" ]
@@ -194,9 +197,9 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse:
194197 logger .info ("Proxy response :" + chat_response .model_dump_json ())
195198 return chat_response
196199
197- def chat_stream (self , chat_request : ChatRequest ) -> AsyncIterable [bytes ]:
200+ async def chat_stream (self , chat_request : ChatRequest ) -> AsyncIterable [bytes ]:
198201 """Default implementation for Chat Stream API"""
199- response = self ._invoke_bedrock (chat_request , stream = True )
202+ response = await self ._invoke_bedrock (chat_request , stream = True )
200203 message_id = self .generate_message_id ()
201204 stream = response .get ("stream" )
202205 for chunk in stream :
0 commit comments