diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 374fcd15..068e160a 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -16,6 +16,7 @@ from starlette.concurrency import run_in_threadpool from api.models.base import BaseChatModel, BaseEmbeddingsModel +from api.models.model_manager import ModelManager from api.schema import ( AssistantMessage, ChatRequest, @@ -173,22 +174,21 @@ def list_bedrock_models() -> dict: return model_list -# Initialize the model list. -bedrock_model_list = list_bedrock_models() - - class BedrockModel(BaseChatModel): + def __init__(self): + """Instantiate with initial model grouping.""" + self._model_manager = ModelManager() + [self._model_manager.add_model({k: v}) for k,v in list_bedrock_models().items()] + def list_models(self) -> list[str]: """Always refresh the latest model list""" - global bedrock_model_list - bedrock_model_list = list_bedrock_models() - return list(bedrock_model_list.keys()) + return self._model_manager.model_keys def validate(self, chat_request: ChatRequest): """Perform basic validation on requests""" error = "" # check if model is supported - if chat_request.model not in bedrock_model_list.keys(): + if chat_request.model not in self._model_manager.model_keys: error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" logger.error("Unsupported model: %s", chat_request.model) @@ -575,6 +575,10 @@ def _parse_request(self, chat_request: ChatRequest) -> dict: if chat_request.extra_body: # reasoning_config will not be used args["additionalModelRequestFields"] = chat_request.extra_body + + # Add user + args["session_id"] = chat_request.session_id + return args def _create_response( diff --git a/src/api/models/bedrock_agents.py b/src/api/models/bedrock_agents.py new file mode 100644 index 00000000..270b5ef1 --- /dev/null +++ b/src/api/models/bedrock_agents.py @@ -0,0 +1,365 @@ +# Original Credit: GitHub user dhapola +import base64 +import uuid +import json +import logging +import re +import time +from abc import ABC +from typing import AsyncIterable + +import boto3 +from botocore.exceptions import EventStreamError +from botocore.config import Config +import numpy as np +import requests +import tiktoken +from fastapi import HTTPException +from api.models.model_manager import ModelManager + +from api.models.bedrock import ( + BedrockModel, + bedrock_client, + bedrock_runtime) + +from api.schema import ( + ChatResponse, + ChatRequest, + ChatResponseMessage, + ChatStreamResponse, + ChoiceDelta +) + +from api.setting import (DEBUG, AWS_REGION, AGENT_PREFIX) + +from api.models.md import MetaData + +logger = logging.getLogger(__name__) +config = Config( + connect_timeout=60, # Connection timeout: 60 seconds + read_timeout=900, # Read timeout: 15 minutes (suitable for long streaming responses) + retries={ + 'max_attempts': 8, # Maximum retry attempts + 'mode': 'adaptive' # Adaptive retry mode + }, + max_pool_connections=50 # Maximum connection pool size + ) + +bedrock_agent = boto3.client( + service_name="bedrock-agent", + region_name=AWS_REGION, + config=config, + ) + +def get_agent_runtime(): + return boto3.client( + service_name="bedrock-agent-runtime", + region_name=AWS_REGION, + config=config, + ) + +bedrock_agent_runtime = get_agent_runtime() + + +class BedrockAgents(BedrockModel): + + def __init__(self): + """Append agents to model list.""" + super().__init__() + self.get_agents() + + def get_latest_agent_aliases(self, client, agent_id):#, limit=2): + + # List all aliases for the agent + response = client.list_agent_aliases( + agentId=agent_id, + maxResults=100 # Adjust based on your needs + ) + + if not response.get('agentAliasSummaries'): + return None + + # Sort aliases by createdAt descending + aliases = response.get('agentAliasSummaries', []) + + sorted_aliases = sorted( + [a for a in aliases if a.get('agentAliasName')], + key=lambda a: a['createdAt'], + reverse=True + ) + + # Init + result = {} + seen_statuses = set() + + for alias in sorted_aliases: + if "PREPARED" in alias.get('agentAliasStatus'): + name = alias.get('agentAliasName').replace('AgentTestAlias', 'DRAFT') + result[name]=alias + + #if len(result) >= limit: + # break + + return result + + def get_agents(self): + bedrock_ag = boto3.client( + service_name="bedrock-agent", + region_name=AWS_REGION, + config=config, + ) + # List Agents + response = bedrock_agent.list_agents(maxResults=100) + + # Prepare agent for display + for agent in response['agentSummaries']: + + if (agent['agentStatus'] != 'PREPARED'): + continue + + agentId = agent['agentId'] + + all_latest_aliases = self.get_latest_agent_aliases(bedrock_ag, agentId) + if not all_latest_aliases: + continue + + for alias_name, latest_alias in all_latest_aliases.items(): + key_alias_id = 'agentAliasId' + + name = f"{AGENT_PREFIX}{agent['agentName']}-{alias_name}" + + val = { + "system": False, # Supports system prompts for context setting. These are already set in Bedrock Agent configuration + "multimodal": True, # Capable of processing both text and images + "tool_call": False, # Tool Use not required for Agents + "stream_tool_call": True, + "agent_id": agentId, + "alias_id": latest_alias[key_alias_id] + } + + model = {} + model[name]=val + self._model_manager.add_model(model) + + + async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke bedrock models""" + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + del args["session_id"] # Not used for foundation models + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + try: + + if stream: + response = bedrock_runtime.converse_stream(**args) + else: + response = bedrock_runtime.converse(**args) + + + except bedrock_client.exceptions.ValidationException as e: + logger.error("Validation Error: " + str(e)) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + return response + + async def chat(self, chat_request: ChatRequest) -> ChatResponse: + """Default implementation for Chat API.""" + #chat: {chat_request}") + + message_id = self.generate_message_id() + + if (chat_request.model.startswith(AGENT_PREFIX)): + response = self._invoke_agent(chat_request) + output = "" + + for event in response["completion"]: + output += event["chunk"]["bytes"].decode("utf-8") + + # Minimal response (stop reason, token I/O counts not returned) + chat_response = self._create_response( + model=chat_request.model, + message_id=message_id, + content=[{"text": output}], + finish_reason="", + input_tokens=0, + output_tokens=0 + ) + else: + # Just use what we know works + chat_response = await super().chat(chat_request) + + return chat_response + + async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + + """Default implementation for Chat Stream API""" + + response = '' + message_id = self.generate_message_id() + + if (chat_request.model.startswith(AGENT_PREFIX)): + response = self._invoke_agent(chat_request, stream=True) + + _event_stream = response["completion"] + + chunk_count = 1 + message = ChatResponseMessage( + role="assistant", + content="", + ) + stream_response = ChatStreamResponse( + id=message_id, + model=chat_request.model, + choices=[ + ChoiceDelta( + index=0, + delta=message, + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + ) + yield self.stream_response_to_bytes(stream_response) + + for event in _event_stream: + chunk_count += 1 + if "chunk" in event: + _data = event["chunk"]["bytes"].decode("utf8") + message = ChatResponseMessage(content=_data) + + stream_response = ChatStreamResponse( + id=message_id, + model=chat_request.model, + choices=[ + ChoiceDelta( + index=0, + delta=message, + logprobs=None, + finish_reason=None, + ) + ], + usage=None, + ) + yield self.stream_response_to_bytes(stream_response) + + #message = self._make_fully_cited_answer(_data, event, False, 0) + + # return an [DONE] message at the end. + yield self.stream_response_to_bytes() + return + else: + response = await self._invoke_bedrock(chat_request, stream=True) + + stream = response.get("stream") + for chunk in stream: + stream_response = self._create_response_stream( + model_id=chat_request.model, message_id=message_id, chunk=chunk + ) + if not stream_response: + continue + if DEBUG: + logger.info("Proxy response :" + stream_response.model_dump_json()) + if stream_response.choices: + yield self.stream_response_to_bytes(stream_response) + elif ( + chat_request.stream_options + and chat_request.stream_options.include_usage + ): + # An empty choices for Usage as per OpenAI doc below: + # if you set stream_options: {"include_usage": true}. + # an additional chunk will be streamed before the data: [DONE] message. + # The usage field on this chunk shows the token usage statistics for the entire request, + # and the choices field will always be an empty array. + # All other chunks will also include a usage field, but with a null value. + yield self.stream_response_to_bytes(stream_response) + + # return an [DONE] message at the end. + yield self.stream_response_to_bytes() + + def _invoke_agent(self, chat_request: ChatRequest, stream=False, retry=False): + """Common logic for invoke agent """ + if DEBUG: + logger.info("BedrockAgents._invoke_agent: Raw request: " + chat_request.model_dump_json()) + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + + + if DEBUG: + logger.info("Bedrock request: " + json.dumps(str(args))) + + model = self._model_manager.get_all_models()[chat_request.model] + + ################ + global bedrock_agent_runtime + + try: + query = args['messages'][0]['content'][0]['text'] + messages = args['messages'] + query = messages[len(messages)-1]['content'][0]['text'] + + # Sanitize variants of double quotes + query = query.translate(str.maketrans({'“':'"', '”':'"', '„':'"', '‟':'"'})) + + md = MetaData(query) + md_args = {} + session_state = {} + session_id = args["session_id"].replace(" ", "_") # [0-9a-zA-Z._:-]+ + + if md.has_metadata: + md_args = md.get_metadata_args() + session_id = str(uuid.uuid4()) + logger.info(md_args) + query = md.get_clean_query() + kb_id = "D3Q2K57HXU" # TODO: Don't hard-wire + + session_state['knowledgeBaseConfigurations'] = [{ + 'knowledgeBaseId': kb_id, + 'retrievalConfiguration': { + 'vectorSearchConfiguration': { + 'filter': md_args + } + } + }] + + # Step 1 - Retrieve Context + # TODO: Session state + request_params = { + 'agentId': model['agent_id'], + 'agentAliasId': model['alias_id'], + 'sessionId': session_id, + 'inputText': query, + } + + # Append KB config if present + if session_state: + request_params['sessionState'] = session_state + + # Apply streaming if desired + if stream: + request_params['streamingConfigurations'] = { + 'streamFinalResponse': True, + 'applyGuardrailInterval': 123 + } + + # Make the retrieve request + # Invoke the agent + response = bedrock_agent_runtime.invoke_agent(**request_params) + return response + except EventStreamError as ese: + if retry: + # Reinstantiate client to hopefully refresh credentials + logger.info("Refreshing client to get current credentials") + bedrock_agent_runtime = get_agent_runtime() + return self._invoke_agent(chat_request, stream, True) + else: + raise ese + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/src/api/models/md.py b/src/api/models/md.py new file mode 100644 index 00000000..59dd37c0 --- /dev/null +++ b/src/api/models/md.py @@ -0,0 +1,37 @@ +import re + +class MetaData(object): + + @property + def has_metadata(self): + return '"="' in self._prompt + + def __init__(self, prompt: str): + self._prompt = prompt + + def get_metadata_args(self): + outer_key = "orAll" + md_args = {outer_key: []} + + pattern = r'"([^"]*)"\s*=\s*"([^"]*)"' # TODO: DRY on pattern + matches = re.findall(pattern, self._prompt) + + for k,v in dict(matches).items(): + sub_map = {"equals": {"key": k, "value": v}} + md_args[outer_key].append(sub_map) + + # Can't have andAll with just one filter :( + if len(matches) == 1: + md_args = md_args[outer_key][0] + + return md_args + + def get_clean_query(self): + return re.sub(r'"[^"]*"\s*=\s*"[^"]*"', '', self._prompt).strip() + +if __name__ == "__main__": + md = MetaData('"OE_Number"="111" Tell me about the event.') + prompt = md.get_clean_query() + filters = md.get_metadata_args() + + print(f"Prompt: {prompt}\nFilters: {filters}") \ No newline at end of file diff --git a/src/api/models/model_manager.py b/src/api/models/model_manager.py new file mode 100644 index 00000000..2331250b --- /dev/null +++ b/src/api/models/model_manager.py @@ -0,0 +1,26 @@ +# Original Credit: GitHub user dhapola + + +class ModelManager: + + @property + def model_keys(self): + return list(self.get_all_models().keys()) + + def __init__(self, *args, **kwargs): + self._models = {} + + def get_all_models(self): + return self._models + + def add_model(self, model): + """Add a model to the list.""" + self._models.update(model) + + def clear_models(self): + """Clear the list of models.""" + self._models.clear() + self._models = {} + + def __repr__(self): + return f"ModelManager(models={self._models})" \ No newline at end of file diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 530f75d6..42e26cad 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -5,6 +5,7 @@ from api.auth import api_key_auth from api.models.bedrock import BedrockModel +from api.models.bedrock_agents import BedrockAgents from api.schema import ChatRequest, ChatResponse, ChatStreamResponse, Error from api.setting import DEFAULT_MODEL @@ -36,10 +37,11 @@ async def chat_completions( ): if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL + + model = BedrockAgents() + model.validate(chat_request) # Exception will be raised if model not supported. - model = BedrockModel() - model.validate(chat_request) if chat_request.stream: return StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream") return await model.chat(chat_request) diff --git a/src/api/routers/model.py b/src/api/routers/model.py index e1de1553..0f8d5b33 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, Path from api.auth import api_key_auth -from api.models.bedrock import BedrockModel +from api.models.bedrock_agents import BedrockAgents from api.schema import Model, Models router = APIRouter( @@ -12,7 +12,7 @@ # responses={404: {"description": "Not found"}}, ) -chat_model = BedrockModel() +chat_model = BedrockAgents() async def validate_model_id(model_id: str): diff --git a/src/api/schema.py b/src/api/schema.py index 233e1139..2f137c28 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -99,7 +99,8 @@ class ChatRequest(BaseModel): stream_options: StreamOptions | None = None temperature: float | None = Field(default=1.0, le=2.0, ge=0.0) top_p: float | None = Field(default=1.0, le=1.0, ge=0.0) - user: str | None = None # Not used + user: str | None = None + session_id: str | None = None # Feed user name for unique session ID max_tokens: int | None = 2048 max_completion_tokens: int | None = None reasoning_effort: Literal["low", "medium", "high"] | None = None diff --git a/src/api/setting.py b/src/api/setting.py index 4e0a7bbd..8db65a5b 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -17,3 +17,6 @@ DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false" + +# Added for agent ingestion +AGENT_PREFIX = 'agent-' \ No newline at end of file