diff --git a/.gitignore b/.gitignore index b01d78f..fbc1b1f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ venv/ env/ backend/backend/ +.venv/ # Node modules node_modules/ diff --git a/backend/banking_app.py b/backend/banking_app.py index 514d7e2..6af0264 100644 --- a/backend/banking_app.py +++ b/backend/banking_app.py @@ -1,5 +1,6 @@ # import urllib.parse import uuid +import asyncio from datetime import datetime import json import time @@ -11,19 +12,18 @@ from flask_cors import CORS from flask_sqlalchemy import SQLAlchemy from dotenv import load_dotenv -from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI -from langchain_community.vectorstores.utils import DistanceStrategy -from langchain_sqlserver import SQLServer_VectorStore -from langchain_core.messages import HumanMessage, AIMessage, ToolMessage -from langgraph.store.memory import InMemoryStore +from agent_framework import ChatAgent +from agent_framework.azure import AzureOpenAIChatClient +from openai import AsyncOpenAI from shared.connection_manager import sqlalchemy_connection_creator, connection_manager from shared.utils import get_user_id import requests # For calling analytics service -from langgraph.prebuilt import create_react_agent from shared.utils import _serialize_messages from init_data import check_and_ingest_data # Load Environment variables and initialize app import os +from azure.identity import AzureCliCredential + load_dotenv(override=True) app = Flask(__name__) @@ -31,32 +31,24 @@ global fixed_user_id fixed_user_id = get_user_id() # For simplicity, using a fixed user ID -# --- Azure OpenAI Configuration --- -AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY") -AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") -AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") -AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") - # Analytics service URL ANALYTICS_SERVICE_URL = "http://127.0.0.1:5002" -if not all([AZURE_OPENAI_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_DEPLOYMENT, AZURE_OPENAI_EMBEDDING_DEPLOYMENT]): - print("⚠️ Warning: One or more Azure OpenAI environment variables are not set.") - ai_client = None - embeddings_client = None -else: - ai_client = AzureChatOpenAI( - azure_endpoint=AZURE_OPENAI_ENDPOINT, - api_version="2024-10-21", - api_key=AZURE_OPENAI_KEY, - azure_deployment=AZURE_OPENAI_DEPLOYMENT - ) - embeddings_client = AzureOpenAIEmbeddings( - azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT, - openai_api_version="2024-10-21", - azure_endpoint=AZURE_OPENAI_ENDPOINT, - api_key=AZURE_OPENAI_KEY, +# Initialize OpenAI client for Agent Framework +chat_client = None + +try: + + # Create chat client for Agent Framework + chat_client = AzureOpenAIChatClient( + deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT"], + endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], + credential=AzureCliCredential() ) + print("✅ Agent Framework OpenAI client initialized successfully") +except Exception as e: + print(f"❌ Failed to initialize OpenAI client: {e}") + chat_client = None # Database configuration for Azure SQL (banking data) app.config['SQLALCHEMY_DATABASE_URI'] = "mssql+pyodbc://" @@ -77,15 +69,8 @@ connection_url = f"mssql+pyodbc:///?odbc_connect={connection_string}" -vector_store = None -if embeddings_client: - vector_store = SQLServer_VectorStore( - connection_string=connection_url, - table_name="DocsChunks_Embeddings", - embedding_function=embeddings_client, - embedding_length=1536, - distance_strategy=DistanceStrategy.COSINE, - ) +# Vector search will be implemented as Agent Framework tools +# Agent Framework doesn't use direct vector store integration def to_dict_helper(instance): d = {} @@ -97,63 +82,6 @@ def to_dict_helper(instance): d[column.name] = value return d -from langgraph.checkpoint.memory import MemorySaver -from langchain_core.messages import HumanMessage, AIMessage, ToolMessage -from collections import defaultdict - -def reconstruct_messages_from_history(history_data): - """Converts DB history into LangChain message objects, sorted by trace_id and message order.""" - messages = [] - print("Reconstructing messages from history data:", history_data) - - if not history_data: - return MemorySaver(), [] - - # Group messages by trace_id - traces = defaultdict(list) - for msg_data in history_data: - trace_id = msg_data.get('trace_id') - if trace_id: - traces[trace_id].append(msg_data) - - # Sort trace_ids chronologically - sorted_trace_ids = sorted(traces.keys()) - - # Process each trace in chronological order - for trace_id in sorted_trace_ids: - trace_messages = traces[trace_id] - - # Sort messages within each trace by message type priority - message_priority = { - 'human': 1, - 'ai': 2 - } - - trace_messages.sort(key=lambda x: ( - message_priority.get(x.get('message_type'), 5), - x.get('trace_end', ''), - )) - - # Convert to LangChain message objects - for msg_data in trace_messages: - try: - message_type = msg_data.get('message_type') - content = msg_data.get('content', '') - - if message_type == 'human': - messages.append(HumanMessage(content=content)) - elif message_type == 'ai': - messages.append(AIMessage(content=content)) - - except Exception as e: - print(f"Error processing message in trace {trace_id}: {e}") - continue - - print(f"Reconstructed {len(messages)} messages from {len(sorted_trace_ids)} traces") - - # Return both the memory saver and the historical messages - return MemorySaver(), messages - # Banking Database Models class User(db.Model): __tablename__ = 'users' @@ -268,21 +196,9 @@ def get_transactions_summary(user_id: str = fixed_user_id, time_period: str = 't def search_support_documents(user_question: str) -> str: """Searches the knowledge base for answers to customer support questions using vector search.""" - if not vector_store: - return "The vector store is not configured." - try: - results = vector_store.similarity_search_with_score(user_question, k=3) - relevant_docs = [doc.page_content for doc, score in results if score < 0.5] - print("-------------> ", relevant_docs) - if not relevant_docs: - return "No relevant support documents found to answer this question." - - context = "\n\n---\n\n".join(relevant_docs) - return context - - except Exception as e: - print(f"ERROR in search_support_documents: {e}") - return "An error occurred while searching for support documents." + # TODO: Implement vector search using Agent Framework tools or Azure AI Search + # For now, return a placeholder response + return f"Document search functionality is being updated. Your question about '{user_question}' has been noted. Please contact customer support for immediate assistance." def create_new_account(user_id: str = fixed_user_id, account_type: str = 'checking', name: str = None, balance: float = 0.0) -> str: """Creates a new bank account for the user.""" @@ -363,83 +279,77 @@ def handle_transactions(): return jsonify(result), status_code @app.route('/api/chatbot', methods=['POST']) def chatbot(): - if not ai_client: - return jsonify({"error": "Azure OpenAI client is not configured."}), 503 + if not chat_client: + return jsonify({"error": "Agent Framework chat client is not configured."}), 503 data = request.json messages = data.get("messages", []) session_id = data.get("session_id") user_id = fixed_user_id - # session_id_temp = "session_74a4b39c-72d9-4b30-b8b4-f317e4366e1e" - # Fetch chat history from the analytics service - history_data = call_analytics_service(f"chat/history/{session_id}", method='GET') - - # Reconstruct messages and session memory - session_memory, historical_messages = reconstruct_messages_from_history(history_data) - - - # Print debugging info - print("\n--- Context being passed to the agent ---") - print(f"History data received: {len(history_data) if history_data else 0} messages") - print(f"Historical messages reconstructed: {len(historical_messages)}") - for i, msg in enumerate(historical_messages): - print(f" {i+1}. [{msg.__class__.__name__}] {msg.content[:50]}...") - print("-----------------------------------------\n") - # Extract current user message user_message = messages[-1].get("content", "") - tools = [get_user_accounts, get_transactions_summary, - search_support_documents, create_new_account, - transfer_money] - - # Initialize banking agent - banking_agent = create_react_agent( - model=ai_client, - tools=tools, - checkpointer=session_memory, - prompt=""" - - You are a customer support agent. - - You can use the provided tools to answer user questions and perform tasks. - - If you were unable to find an answer, inform the user. - - Do not use your general knowledge to answer questions.""", - name = "banking_agent_v1" - ) - # Thread config for session management - thread_config = {"configurable": {"thread_id": session_id}} - all_messages = historical_messages + [HumanMessage(content=user_message)] - - trace_start_time = time.time() - response = banking_agent.invoke( - {"messages": all_messages}, - config=thread_config - ) - end_time = time.time() - trace_duration = int((end_time - trace_start_time) * 1000) - - print("################### TRACE RESPONSE ######################") - all_messages = response['messages'] - historical_count = len(historical_messages) - final_messages = all_messages[historical_count:] - - for msg in final_messages: - print(f"[{msg.__class__.__name__}] {msg.content}") - - analytics_data = { - "session_id": session_id, - "user_id": user_id, - "messages": _serialize_messages(final_messages), - "trace_duration": trace_duration, - } - - # calling analytics service to capture this trace - call_analytics_service("chat/log-trace", data=analytics_data) - return jsonify({ - "response": final_messages[-1].content, - "session_id": session_id, - "tools_used": [] - }) + # Run the agent asynchronously + try: + response_text = asyncio.run(run_banking_agent(user_message, session_id)) + + # Log to analytics service + analytics_data = { + "session_id": session_id, + "user_id": user_id, + "messages": [ + {"type": "human", "content": user_message}, + {"type": "ai", "content": response_text} + ], + "trace_duration": 0, # Agent Framework handles timing internally + } + + call_analytics_service("chat/log-trace", data=analytics_data) + + return jsonify({ + "response": response_text, + "session_id": session_id, + "tools_used": [] + }) + + except Exception as e: + print(f"Error in chatbot: {e}") + return jsonify({"error": "An error occurred processing your request."}), 500 + +async def run_banking_agent(user_message: str, session_id: str) -> str: + """Run the banking agent using Agent Framework""" + if not chat_client: + return "Chat client is not available." + + try: + # Create agent with tools + agent = chat_client.create_agent( + name="BankingAgent", + instructions="""You are a helpful banking customer support agent. + - Help users with their banking questions and tasks + - Use the provided tools to access account information and perform operations + - Be helpful and professional + - If you cannot find information, inform the user politely""", + tools=[ + get_user_accounts, + get_transactions_summary, + search_support_documents, + create_new_account, + transfer_money + ], + ) + + # Get or create thread for session persistence + thread = agent.get_new_thread() # You could implement session-based thread retrieval here + + # Run the agent + result = await agent.run(user_message, thread=thread) + return result.text + + except Exception as e: + print(f"Error running banking agent: {e}") + return "I apologize, but I'm having trouble processing your request right now. Please try again later." def initialize_banking_app(): """Initialize banking app when called from combined launcher.""" diff --git a/backend/chat_data_model.py b/backend/chat_data_model.py index 4ad3106..efdbec2 100644 --- a/backend/chat_data_model.py +++ b/backend/chat_data_model.py @@ -7,16 +7,14 @@ # Global variables that will be set by the main app db = None ChatHistory = None -Threads = None -Runs = None -Users = None -ToolCalls = None +ChatSession = None +ToolUsage = None ToolDefinition = None ChatHistoryManager = None def init_chat_db(database): """Initialize the database reference and create models""" - global db, ChatHistory, Threads, Runs, Users, ToolCalls, ToolDefinition, ChatHistoryManager, AgentDefinition + global db, ChatHistory, ChatSession, ToolUsage, ToolDefinition, ChatHistoryManager, AgentDefinition db = database # Helper function to convert model instances to dictionaries @@ -32,9 +30,8 @@ def to_dict_helper(instance): class AgentDefinition(db.Model): __tablename__ = 'agent_definitions' - agent_id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + agent_id = db.Column(db.String(255), primary_key=True, default=lambda: f"agent_{uuid.uuid4()}") name = db.Column(db.String(255), unique=True, nullable=False) - version = db.Column(db.NCHAR(10)) description = db.Column(db.Text) llm_config = db.Column(db.JSON, nullable=False) prompt_template = db.Column(db.Text, nullable=False) @@ -42,68 +39,46 @@ class AgentDefinition(db.Model): def to_dict(self): return to_dict_helper(self) - class Threads(db.Model): - __tablename__ = 'threads' - thread_id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) - user_id = db.Column(db.Integer, nullable=False) - agent_id = db.Column(db.BigInteger, nullable=True) + class ChatSession(db.Model): + __tablename__ = 'chat_sessions' + session_id = db.Column(db.String(255), primary_key=True, default=lambda: f"session_{uuid.uuid4()}") + user_id = db.Column(db.String(255), nullable=False) + title = db.Column(db.String(500)) created_at = db.Column(db.DateTime, default=datetime.now()) updated_at = db.Column(db.DateTime, default=datetime.now(), onupdate=datetime.now()) def to_dict(self): return to_dict_helper(self) - - class Runs(db.Model): - __tablename__ = 'runs' - run_id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) - thread_id = db.Column(db.BigInteger, db.ForeignKey('threads.thread_id'), nullable=False) - input = db.Column(db.Text, nullable=False) - output = db.Column(db.Text, nullable=False) - start_time = db.Column(db.DateTime, nullable=False) - end_time = db.Column(db.DateTime, nullable=True) - status = db.Column(db.String(10), nullable=True) - total_tokens = db.Column(db.Integer, nullable=True) - input_tokens = db.Column(db.Integer, nullable=True) - output_tokens = db.Column(db.Integer, nullable=True) - - def to_dict(self): - return to_dict_helper(self) class ToolDefinition(db.Model): __tablename__ = 'tool_definitions' - tool_id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) - name = db.Column(db.String(255), nullable=False) - description = db.Column(db.Text, nullable=True) - input_schema = db.Column(db.Text, nullable=False) - version = db.Column(db.String(50), nullable=True) - is_active = db.Column(db.Boolean, nullable=True) - created_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=True) - - def to_dict(self): - return to_dict_helper(self) - - class Users(db.Model): - __tablename__ = 'users' - user_id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) - user_guid = db.Column(db.String(255), nullable=False) - description = db.Column(db.String(500), nullable=True) - user_name = db.Column(db.String(500), nullable=False) + tool_id = db.Column(db.String(255), primary_key=True, default=lambda: f"tooldef_{uuid.uuid4()}") + name = db.Column(db.String(255), unique=True, nullable=False) + description = db.Column(db.Text) + input_schema = db.Column(db.JSON, nullable=False) + version = db.Column(db.String(50), default='1.0.0') + is_active = db.Column(db.Boolean, default=True) + cost_per_call_cents = db.Column(db.Integer, default=0) + created_at = db.Column(db.DateTime, default=datetime.now()) + updated_at = db.Column(db.DateTime, default=datetime.now(), onupdate=datetime.now()) def to_dict(self): return to_dict_helper(self) - class ToolCalls(db.Model): - __tablename__ = 'tool_calls' - tool_call_id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) - tool_id = db.Column(db.BigInteger, nullable=False) - run_id = db.Column(db.BigInteger, db.ForeignKey('runs.run_id'), nullable=False) - start_time = db.Column(db.DateTime, nullable=True) - end_time = db.Column(db.DateTime, nullable=True) - status = db.Column(db.String(50), nullable=True) - attributes = db.Column(db.Text, nullable=True) - input = db.Column(db.Text, nullable=True) - output = db.Column(db.Text, nullable=True) + class ToolUsage(db.Model): + __tablename__ = 'tool_usage' + tool_call_id = db.Column(db.String(255), primary_key=True, default=lambda: f"tool_{uuid.uuid4()}") + session_id = db.Column(db.String(255), nullable=False) + trace_id = db.Column(db.String(255)) + tool_id = db.Column(db.String(255), db.ForeignKey('tool_definitions.tool_id'), nullable=False) + tool_name = db.Column(db.String(255), nullable=False) + tool_input = db.Column(db.JSON, nullable=False) + tool_output = db.Column(db.JSON) + tool_message = db.Column(db.Text) + status = db.Column(db.String(50)) + + # Additional tracking fields + tokens_used = db.Column(db.Integer) def to_dict(self): return to_dict_helper(self) @@ -111,10 +86,10 @@ def to_dict(self): class ChatHistory(db.Model): __tablename__ = 'chat_history' message_id = db.Column(db.String(255), primary_key=True, default=lambda: f"msg_{uuid.uuid4()}") - session_id = db.Column(db.BigInteger, db.ForeignKey('threads.thread_id')) + session_id = db.Column(db.String(255), db.ForeignKey('chat_sessions.session_id')) trace_id = db.Column(db.String(255), nullable=False) user_id = db.Column(db.String(255), nullable=False) - agent_id = db.Column(db.BigInteger, nullable=True) + agent_id = db.Column(db.String(255), nullable=True) message_type = db.Column(db.String(50), nullable=False) # 'human', 'ai', 'system', 'tool_call', 'tool_result' content = db.Column(db.Text) @@ -146,14 +121,14 @@ def __init__(self, session_id: str, user_id: str = 'user_1'): def _ensure_session_exists(self): """Ensure the chat session exists in the database""" - session = Threads.query.filter_by(thread_id=self.session_id).first() + session = ChatSession.query.filter_by(session_id=self.session_id).first() if not session: - session = Threads( - thread_id=self.session_id, + session = ChatSession( + session_id=self.session_id, title= "New Session", user_id=self.user_id, ) - print("-----------------> New chat session created: ", session.thread_id) + print("-----------------> New chat session created: ", session.session_id) db.session.add(session) db.session.commit() def add_trace_messages(self, serialized_messages: str, @@ -277,7 +252,7 @@ def add_tool_result_message(self, message: dict, trace_id: str): def update_session_timestamp(self): """Update the session's updated_at timestamp""" - session = Threads.query.filter_by(thread_id=self.session_id).first() + session = ChatSession.query.filter_by(session_id=self.session_id).first() if session: session.updated_at = datetime.now() db.session.commit() @@ -285,7 +260,7 @@ def update_session_timestamp(self): def log_tool_usage(self, tool_info: dict, trace_id: str): """Log detailed tool usage metrics""" - existing = ToolCalls.query.filter_by(tool_call_id=tool_info.get("tool_call_id")).first() + existing = ToolUsage.query.filter_by(tool_call_id=tool_info.get("tool_call_id")).first() tool_msg = '' if(type(tool_info.get("tool_output")) is dict): tool_msg = tool_info.get("tool_output").get('message', '') @@ -309,7 +284,7 @@ def log_tool_usage(self, tool_info: dict, trace_id: str): existing.tokens_used = tool_info.get("total_tokens") db.session.commit() else: - tool_usage = ToolCalls( + tool_usage = ToolUsage( session_id=self.session_id, trace_id=trace_id, tool_call_id=tool_info.get("tool_call_id"), @@ -336,10 +311,8 @@ def get_conversation_history(self, limit: int = 50): # Make classes available globally in this module globals()['ChatHistory'] = ChatHistory - globals()['Threads'] = Threads - globals()['Runs'] = Runs - globals()['Users'] = Users - globals()['ToolCalls'] = ToolCalls + globals()['ChatSession'] = ChatSession + globals()['ToolUsage'] = ToolUsage globals()['ToolDefinition'] = ToolDefinition globals()['ChatHistoryManager'] = ChatHistoryManager @@ -349,13 +322,13 @@ def handle_chat_sessions(request): user_id = get_user_id() # In production, get from auth if request.method == 'GET': - sessions = Threads.query.filter_by(user_id=user_id).order_by(Threads.updated_at.desc()).all() + sessions = ChatSession.query.filter_by(user_id=user_id).order_by(ChatSession.updated_at.desc()).all() return jsonify([session.to_dict() for session in sessions]) if request.method == 'POST': data = request.json - session = Threads( - thread_id = data.get('thread_id'), + session = ChatSession( + session_id = data.get('session_id'), user_id=user_id, title=data.get('title', 'New Chat Session'), ) @@ -369,9 +342,9 @@ def clear_chat_history(): """Clear all chat history data - USE WITH CAUTION""" try: # Delete in order to respect foreign key constraints - ToolCalls.query.delete() + ToolUsage.query.delete() ChatHistory.query.delete() - Threads.query.delete() + ChatSession.query.delete() db.session.commit() return jsonify({"message": "All chat history cleared successfully"}), 200 @@ -384,9 +357,9 @@ def clear_session_data(session_id): """Clear chat history for a specific session""" try: # Delete in order to respect foreign key constraints - ToolCalls.query.filter_by(session_id=session_id).delete() + ToolUsage.query.filter_by(session_id=session_id).delete() ChatHistory.query.filter_by(session_id=session_id).delete() - Threads.query.filter_by(thread_id=session_id).delete() + ChatSession.query.filter_by(session_id=session_id).delete() db.session.commit() return jsonify({"message": f"Session {session_id} data cleared successfully"}), 200 @@ -486,4 +459,4 @@ def initialize_agent_definitions(): agent_def = AgentDefinition(**agent) db.session.add(agent_def) - db.session.commit() \ No newline at end of file + db.session.commit()