From 75b6d31ea04052edc6885b26f1b6d072b3036b29 Mon Sep 17 00:00:00 2001 From: "wuqingfu.528" Date: Fri, 6 Mar 2026 16:37:50 +0800 Subject: [PATCH] feat: integrate vanna training --- veadk/tools/vanna_tools/agent_memory.py | 127 +++ .../example_with_vikingdb_training.py | 213 +++++ veadk/tools/vanna_tools/vanna_toolset.py | 21 +- veadk/tools/vanna_tools/vanna_trainer.py | 284 +++++++ .../vanna_tools/vikingdb_agent_memory.py | 759 ++++++++++++++++++ 5 files changed, 1402 insertions(+), 2 deletions(-) create mode 100644 veadk/tools/vanna_tools/examples/example_with_vikingdb_training.py create mode 100644 veadk/tools/vanna_tools/vanna_trainer.py create mode 100644 veadk/tools/vanna_tools/vikingdb_agent_memory.py diff --git a/veadk/tools/vanna_tools/agent_memory.py b/veadk/tools/vanna_tools/agent_memory.py index 88d64ff7..8943d079 100644 --- a/veadk/tools/vanna_tools/agent_memory.py +++ b/veadk/tools/vanna_tools/agent_memory.py @@ -337,3 +337,130 @@ async def run_async( return str(result.result_for_llm) except Exception as e: return f"Error saving text memory: {str(e)}" + + +class SearchTextMemoriesTool(BaseTool): + """Search stored documentation and DDL schemas based on a query.""" + + def __init__( + self, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the search text memories tool with custom agent_memory. + + Args: + agent_memory: A Vanna agent memory instance (e.g., VikingDBAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + """ + self.agent_memory = agent_memory + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="search_text_memories", + description="Search for relevant documentation and DDL schemas based on a query. This retrieves stored documentation, contextual information, and table schemas that are relevant to the question.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="The query to search for in documentation and DDL schemas", + ), + "limit": types.Schema( + type=types.Type.INTEGER, + description="Maximum number of results to return (default: 10)", + ), + "include_ddl": types.Schema( + type=types.Type.BOOLEAN, + description="Whether to include DDL schema results (default: true)", + ), + }, + required=["query"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + session_id = tool_context.session.id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User( + id=user_id + "_" + session_id, + email=user_email, + group_memberships=user_groups, + ) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=session_id, + request_id=session_id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Search for text memories including documentation and DDL.""" + query = args.get("query", "").strip() + limit = args.get("limit", 10) + include_ddl = args.get("include_ddl", True) + + if not query: + return "Error: No query provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + # Call the agent_memory's search_text_memories method + results = await self.agent_memory.search_text_memories( + query=query, + context=vanna_context, + limit=limit, + similarity_threshold=0.7, + include_ddl=include_ddl, + ) + + if not results: + return f"No relevant documentation or DDL found for query: {query}" + + # Format results for LLM + formatted_results = [] + for idx, result in enumerate(results, 1): + formatted_results.append( + f"{idx}. [Score: {result.similarity_score:.2f}]\n{result.memory.content}\n" + ) + + response = f"Found {len(results)} relevant results:\n\n" + "\n".join( + formatted_results + ) + + return response + + except Exception as e: + return f"Error searching text memories: {str(e)}" diff --git a/veadk/tools/vanna_tools/examples/example_with_vikingdb_training.py b/veadk/tools/vanna_tools/examples/example_with_vikingdb_training.py new file mode 100644 index 00000000..3bc41e66 --- /dev/null +++ b/veadk/tools/vanna_tools/examples/example_with_vikingdb_training.py @@ -0,0 +1,213 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: Training Vanna with VikingDB and using it with VeADK Agent + +This example demonstrates: +1. Training a Vanna model using VikingDB as the backend (Vanna 1.0 style) +2. Using the trained model with VeADK Agent and VannaToolSet +""" + +from veadk import Agent, Runner +from veadk.tools.vanna_tools.vanna_toolset import VannaToolSet +from veadk.tools.vanna_tools.vanna_trainer import VannaTrainer + + +# Step 1: Train the Vanna model with VikingDB +print("=" * 80) +print("STEP 1: Training Vanna with VikingDB") +print("=" * 80) + +# Initialize the trainer with VikingDB backend +trainer = VannaTrainer( + collection_prefix="chinook_vanna", # Unique name for your project + region="cn-beijing", +) + +# Train with DDL (table schemas) +print("\nTraining with DDL...") +trainer.train( + ddl=""" +CREATE TABLE [Album] +( + [AlbumId] INTEGER NOT NULL, + [Title] NVARCHAR(160) NOT NULL, + [ArtistId] INTEGER NOT NULL, + CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]), + FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) +) +""" +) + +trainer.train( + ddl=""" +CREATE TABLE [Artist] +( + [ArtistId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_Artist] PRIMARY KEY ([ArtistId]) +) +""" +) + +trainer.train( + ddl=""" +CREATE TABLE [Track] +( + [TrackId] INTEGER NOT NULL, + [Name] NVARCHAR(200) NOT NULL, + [AlbumId] INTEGER, + [MediaTypeId] INTEGER NOT NULL, + [GenreId] INTEGER, + [Composer] NVARCHAR(220), + [Milliseconds] INTEGER NOT NULL, + [Bytes] INTEGER, + [UnitPrice] NUMERIC(10,2) NOT NULL, + CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]) +) +""" +) + +# Train with documentation +print("\nTraining with documentation...") +trainer.train( + documentation="The Chinook database represents a digital media store, including tables for artists, albums, media tracks, invoices and customers." +) + +trainer.train( + documentation="The Album table contains album information and has a foreign key relationship with Artist table through ArtistId." +) + +# Train with question-SQL pairs +print("\nTraining with question-SQL pairs...") +trainer.train( + question="Get all the tracks in the album 'Balls to the Wall'", + sql="SELECT * FROM Track WHERE AlbumId = (SELECT AlbumId FROM Album WHERE Title = 'Balls to the Wall')", +) + +trainer.train( + question="How many tracks are there in each album?", + sql="SELECT a.Title, COUNT(t.TrackId) as TrackCount FROM Album a JOIN Track t ON a.AlbumId = t.AlbumId GROUP BY a.AlbumId", +) + +trainer.train( + question="List all artists with their album count", + sql="SELECT ar.Name, COUNT(al.AlbumId) as AlbumCount FROM Artist ar LEFT JOIN Album al ON ar.ArtistId = al.ArtistId GROUP BY ar.ArtistId", +) + +# Bulk training example +print("\nBulk training...") +trainer.train_bulk( + question_sql_pairs=[ + ( + "What are the top 5 longest tracks?", + "SELECT Name, Milliseconds FROM Track ORDER BY Milliseconds DESC LIMIT 5", + ), + ("How many artists are there?", "SELECT COUNT(*) as TotalArtists FROM Artist"), + ] +) + +print("\nāœ… Training completed! Data stored in VikingDB.") + + +# Step 2: Use the trained model with VeADK Agent +print("\n" + "=" * 80) +print("STEP 2: Using trained model with VeADK Agent") +print("=" * 80) + +# Get the trained agent memory +agent_memory = trainer.get_agent_memory() + +# Create VannaToolSet with the trained memory +vanna_toolset = VannaToolSet( + connection_string="sqlite:///tmp/Chinook.sqlite", + file_storage="/tmp/vanna_files", + agent_memory=agent_memory, # Use the trained VikingDB memory +) + +# Define the VeADK Agent with the trained toolset +agent = Agent( + name="vanna_sql_agent_with_vikingdb", + description="An intelligent agent that can query databases using trained VikingDB knowledge.", + instruction=""" + You are a helpful assistant that can answer questions about data in the Chinook database. + You have been trained with: + - Database schemas (DDL) + - Documentation about the database + - Example question-SQL pairs + + When answering questions: + 1. First search for similar questions using search_saved_correct_tool_uses + 2. Use the retrieved DDL and documentation to understand the schema + 3. Generate and execute appropriate SQL queries + 4. Present results in a clear format + """, + tools=[vanna_toolset], + model_extra_config={"extra_body": {"thinking": {"type": "disabled"}}}, +) + +print("\nāœ… Agent initialized with trained VikingDB knowledge.") +print("\nYou can now run the agent with queries like:") +print(" - How many albums are there in total?") +print(" - Show me the top 10 longest tracks") +print(" - Which artist has the most albums?") + + +# Step 3: Test the agent +print("\n" + "=" * 80) +print("STEP 3: Testing the agent") +print("=" * 80) + + +async def test_agent(): + """Test the agent with a sample query.""" + + # Create runner + runner = Runner( + agent=agent, + ) + + # Test query + test_question = "How many albums are there in total?" + print(f"\nQuery: {test_question}") + print("-" * 80) + + response = await runner.run( + new_message=test_question, + ) + + print(f"\nResponse: {response}") + print("-" * 80) + + +# Run the test +if __name__ == "__main__": + import asyncio + + print("\nRunning test query...") + asyncio.run(test_agent()) + + print("\n" + "=" * 80) + print("āœ… Example completed successfully!") + print("=" * 80) + print("\nNext steps:") + print("1. The trained knowledge is stored in VikingDB") + print("2. You can add more training data anytime using trainer.train()") + print( + "3. The agent will automatically use the trained knowledge to answer questions" + ) + print( + "4. Similar questions will be retrieved from VikingDB for better SQL generation" + ) diff --git a/veadk/tools/vanna_tools/vanna_toolset.py b/veadk/tools/vanna_tools/vanna_toolset.py index ab19171f..b10ecb9c 100644 --- a/veadk/tools/vanna_tools/vanna_toolset.py +++ b/veadk/tools/vanna_tools/vanna_toolset.py @@ -30,6 +30,7 @@ SaveQuestionToolArgsTool, SearchSavedCorrectToolUsesTool, SaveTextMemoryTool, + SearchTextMemoriesTool, ) from veadk.tools.vanna_tools.run_sql import RunSqlTool from veadk.tools.vanna_tools.visualize_data import VisualizeDataTool @@ -41,10 +42,16 @@ class VannaToolSet(BaseToolset): - def __init__(self, connection_string: str, file_storage: str = "/tmp/data"): + def __init__( + self, + connection_string: str, + file_storage: str = "/tmp/data", + agent_memory=None, + ): super().__init__() self.connection_string = connection_string self.file_storage = file_storage + self.custom_agent_memory = agent_memory self._post_init() def _post_init(self): @@ -58,6 +65,8 @@ def _post_init(self): - postgresql://user:password@host:port/database - mysql://user:password@host:port/database file_storage (str, optional): The directory to store files. Defaults to "/tmp/data". + agent_memory (optional): Custom agent memory instance (e.g., VikingDBAgentMemory). + If not provided, defaults to DemoAgentMemory. """ from vanna.integrations.sqlite import SqliteRunner @@ -137,7 +146,12 @@ def _post_init(self): os.makedirs(self.file_storage, exist_ok=True) self.file_system = LocalFileSystem(working_directory=self.file_storage) - self.agent_memory = DemoAgentMemory(max_items=1000) + + # Use custom agent_memory if provided, otherwise use default DemoAgentMemory + if self.custom_agent_memory is not None: + self.agent_memory = self.custom_agent_memory + else: + self.agent_memory = DemoAgentMemory(max_items=1000) self._tools = { "SaveQuestionToolArgsTool": SaveQuestionToolArgsTool( @@ -149,6 +163,9 @@ def _post_init(self): "SaveTextMemoryTool": SaveTextMemoryTool( agent_memory=self.agent_memory, ), + "SearchTextMemoriesTool": SearchTextMemoriesTool( + agent_memory=self.agent_memory, + ), "WriteFileTool": WriteFileTool( file_system=self.file_system, agent_memory=self.agent_memory, diff --git a/veadk/tools/vanna_tools/vanna_trainer.py b/veadk/tools/vanna_tools/vanna_trainer.py new file mode 100644 index 00000000..99009d57 --- /dev/null +++ b/veadk/tools/vanna_tools/vanna_trainer.py @@ -0,0 +1,284 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List +from veadk.tools.vanna_tools.vikingdb_agent_memory import VikingDBAgentMemory +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +class VannaTrainer: + """ + Vanna 1.0 style trainer for VeADK. + + This class provides a simple API for training Vanna models using VikingDB as the backend. + It mimics the Vanna 1.0 `train()` method interface. + + Example: + ```python + from veadk.tools.vanna_tools.vanna_trainer import VannaTrainer + + # Initialize trainer + trainer = VannaTrainer( + collection_prefix="my_vanna_project", + region="cn-beijing" + ) + + # Train with DDL + trainer.train(ddl="CREATE TABLE customers (id INT, name VARCHAR(100), email VARCHAR(100))") + + # Train with documentation + trainer.train(documentation="The customers table contains all customer information") + + # Train with question-SQL pairs + trainer.train( + question="Who are the top 10 customers by sales?", + sql="SELECT name, SUM(sales) as total FROM customers GROUP BY name ORDER BY total DESC LIMIT 10" + ) + + # Get the agent memory for use in VannaToolSet + agent_memory = trainer.get_agent_memory() + ``` + """ + + def __init__( + self, + volcengine_access_key: Optional[str] = None, + volcengine_secret_key: Optional[str] = None, + session_token: str = "", + region: str = "cn-beijing", + host: Optional[str] = None, + collection_prefix: str = "vanna_train", + embedding_model: str = "doubao-embedding", + cloud_provider: str = "volces", + ): + """ + Initialize VannaTrainer with VikingDB backend. + + Args: + volcengine_access_key: Volcengine access key (defaults to env var) + volcengine_secret_key: Volcengine secret key (defaults to env var) + session_token: Optional session token for temporary credentials + region: VikingDB region (defaults to cn-beijing) + host: VikingDB host (auto-generated from region if not provided) + collection_prefix: Prefix for collection names (default: "vanna_train") + embedding_model: Embedding model to use (default: "bge-large-zh") + cloud_provider: Cloud provider (volces or byteplus) + """ + self.agent_memory = VikingDBAgentMemory( + volcengine_access_key=volcengine_access_key, + volcengine_secret_key=volcengine_secret_key, + session_token=session_token, + region=region, + host=host, + collection_prefix=collection_prefix, + embedding_model=embedding_model, + cloud_provider=cloud_provider, + ) + + logger.info( + f"VannaTrainer initialized with collection_prefix='{collection_prefix}'" + ) + + def train( + self, + question: Optional[str] = None, + sql: Optional[str] = None, + ddl: Optional[str] = None, + documentation: Optional[str] = None, + ) -> str: + """ + Train Vanna with different types of data (Vanna 1.0 style API). + + This method mimics the Vanna 1.0 `train()` method interface. You can call it with: + - `ddl`: Train with table schema + - `documentation`: Train with contextual information + - `question` + `sql`: Train with question-SQL pairs + + Args: + question: User question (must be provided with sql) + sql: SQL query (must be provided with question) + ddl: DDL statement for table schema + documentation: Documentation or contextual information + + Returns: + ID of the saved training data + + Raises: + ValueError: If invalid arguments are provided + + Examples: + ```python + # Train with DDL + trainer.train(ddl="CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100))") + + # Train with documentation + trainer.train(documentation="The users table stores user profile information") + + # Train with question-SQL pair + trainer.train( + question="How many users do we have?", + sql="SELECT COUNT(*) FROM users" + ) + ``` + """ + # Validate arguments + if question and not sql: + raise ValueError( + "Please also provide a SQL query when training with a question" + ) + + if sql and not question: + logger.warning( + "SQL provided without question - generating default question" + ) + question = f"Query: {sql[:50]}..." + + # Train based on provided data type + if documentation: + logger.info("Training with documentation...") + return self.agent_memory.train_documentation(documentation) + + if sql and question: + logger.info(f"Training with question-SQL pair: {question[:50]}...") + return self.agent_memory.train_question_sql(question, sql) + + if ddl: + logger.info(f"Training with DDL: {ddl[:50]}...") + return self.agent_memory.train_ddl(ddl) + + raise ValueError( + "You must provide one of the following:\n" + "- ddl: for table schemas\n" + "- documentation: for contextual information\n" + "- question + sql: for training examples" + ) + + def train_bulk( + self, + ddls: Optional[List[str]] = None, + documentations: Optional[List[str]] = None, + question_sql_pairs: Optional[List[tuple[str, str]]] = None, + ) -> dict: + """ + Train with multiple items in bulk. + + Args: + ddls: List of DDL statements + documentations: List of documentation strings + question_sql_pairs: List of (question, sql) tuples + + Returns: + Dictionary with counts of trained items + + Example: + ```python + trainer.train_bulk( + ddls=[ + "CREATE TABLE customers (...)", + "CREATE TABLE orders (...)", + ], + documentations=[ + "The customers table contains...", + "The orders table contains...", + ], + question_sql_pairs=[ + ("Who are the top customers?", "SELECT * FROM customers..."), + ("What are recent orders?", "SELECT * FROM orders..."), + ] + ) + ``` + """ + results = { + "ddl_count": 0, + "documentation_count": 0, + "question_sql_count": 0, + } + + if ddls: + for ddl in ddls: + try: + self.agent_memory.train_ddl(ddl) + results["ddl_count"] += 1 + except Exception as e: + logger.error(f"Failed to train DDL: {e}") + + if documentations: + for doc in documentations: + try: + self.agent_memory.train_documentation(doc) + results["documentation_count"] += 1 + except Exception as e: + logger.error(f"Failed to train documentation: {e}") + + if question_sql_pairs: + for question, sql in question_sql_pairs: + try: + self.agent_memory.train_question_sql(question, sql) + results["question_sql_count"] += 1 + except Exception as e: + logger.error(f"Failed to train question-SQL pair: {e}") + + logger.info(f"Bulk training completed: {results}") + return results + + def get_agent_memory(self) -> VikingDBAgentMemory: + """ + Get the underlying VikingDB agent memory instance. + + This can be used to initialize VannaToolSet with the trained data. + + Returns: + VikingDBAgentMemory instance + + Example: + ```python + trainer = VannaTrainer(collection_prefix="my_project") + trainer.train(ddl="...") + + # Use the trained memory in VannaToolSet + vanna_toolset = VannaToolSet( + connection_string="sqlite:///db.sqlite", + agent_memory=trainer.get_agent_memory() + ) + ``` + """ + return self.agent_memory + + def get_related_ddl(self, question: str, limit: int = 5) -> List[str]: + """ + Get related DDL for a question (for debugging/testing). + + Args: + question: User question + limit: Maximum results + + Returns: + List of related DDL statements + """ + return self.agent_memory.get_related_ddl(question, limit) + + def get_related_documentation(self, question: str, limit: int = 5) -> List[str]: + """ + Get related documentation for a question (for debugging/testing). + + Args: + question: User question + limit: Maximum results + + Returns: + List of related documentation + """ + return self.agent_memory.get_related_documentation(question, limit) diff --git a/veadk/tools/vanna_tools/vikingdb_agent_memory.py b/veadk/tools/vanna_tools/vikingdb_agent_memory.py new file mode 100644 index 00000000..900b1daa --- /dev/null +++ b/veadk/tools/vanna_tools/vikingdb_agent_memory.py @@ -0,0 +1,759 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import uuid +from typing import Any, Dict, List, Optional +from datetime import datetime + +from volcengine.viking_db import ( + VikingDBService, + Field, + FieldType, + VectorIndexParams, + DistanceType, + IndexType, + QuantType, + RawData, + EmbModel, + Data, +) + +from vanna.capabilities.agent_memory import ( + AgentMemory, + TextMemory, + TextMemorySearchResult, + ToolMemory, + ToolMemorySearchResult, +) +from vanna.core.tool import ToolContext + +from veadk.utils.logger import get_logger +from veadk.auth.veauth.utils import get_credential_from_vefaas_iam + +logger = get_logger(__name__) + + +class VikingDBAgentMemory(AgentMemory): + """ + VikingDB-based implementation of AgentMemory for Vanna training data. + + This stores three types of training data: + 1. DDL (table schemas) + 2. Documentation (contextual information) + 3. Question-SQL pairs (training examples) + + Each type is stored in a separate VikingDB collection for efficient retrieval. + + Args: + volcengine_access_key: Volcengine access key (defaults to env var) + volcengine_secret_key: Volcengine secret key (defaults to env var) + session_token: Optional session token for temporary credentials + region: VikingDB region (defaults to cn-beijing) + host: VikingDB host (auto-generated from region if not provided) + collection_prefix: Prefix for collection names (default: "vanna_train") + embedding_model: Embedding model to use (default: "bge-large-zh") + """ + + def __init__( + self, + volcengine_access_key: Optional[str] = None, + volcengine_secret_key: Optional[str] = None, + session_token: str = "", + region: str = "cn-beijing", + host: Optional[str] = None, + collection_prefix: str = "vanna_train", + embedding_model: str = "doubao-embedding", + cloud_provider: str = "volces", + ): + self.volcengine_access_key = volcengine_access_key or os.getenv( + "VOLCENGINE_ACCESS_KEY" + ) + self.volcengine_secret_key = volcengine_secret_key or os.getenv( + "VOLCENGINE_SECRET_KEY" + ) + self.session_token = session_token + self.region = region + self.cloud_provider = cloud_provider.lower() + + # Auto-generate host based on cloud provider + if not host: + if self.cloud_provider == "byteplus": + self.host = "api-vikingdb.mlp.ap-mya.byteplus.com" + else: + if region == "cn-beijing": + self.host = "api-vikingdb.volces.com" + else: + self.host = "api-vikingdb.mlp.cn-shanghai.volces.com" + else: + self.host = host + + self.collection_prefix = collection_prefix + self.embedding_model = embedding_model + + # Collection names for different training data types + self.ddl_collection = f"{collection_prefix}_ddl" + self.doc_collection = f"{collection_prefix}_doc" + self.sql_collection = f"{collection_prefix}_sql" + + self._client = None + self._initialize_client() + + def _initialize_client(self): + """Initialize VikingDB client with authentication.""" + ak = self.volcengine_access_key + sk = self.volcengine_secret_key + sts_token = self.session_token + + # Try to get credentials from VeFaaS IAM if not provided + if not (ak and sk): + try: + cred = get_credential_from_vefaas_iam() + ak = cred.access_key_id + sk = cred.secret_access_key + sts_token = cred.session_token + logger.info("Using VeFaaS IAM credentials for VikingDB") + except Exception as e: + logger.warning(f"Failed to get VeFaaS credentials: {e}") + + if not (ak and sk): + raise ValueError( + "Volcengine credentials not found. Please set VOLCENGINE_ACCESS_KEY " + "and VOLCENGINE_SECRET_KEY environment variables." + ) + + self._client = VikingDBService( + host=self.host, + region=self.region, + ak=ak, + sk=sk, + scheme="https", + ) + self._client.set_session_token(session_token=sts_token) + + # Ensure collections exist + self._ensure_collections() + + def _ensure_collections(self): + """Create collections if they don't exist.""" + collections_to_create = [ + (self.ddl_collection, "DDL and table schemas"), + (self.doc_collection, "Documentation and context"), + (self.sql_collection, "Question-SQL training pairs"), + ] + + for collection_name, description in collections_to_create: + try: + # Check if collection exists + self._client.get_collection(collection_name) + logger.info(f"Collection {collection_name} already exists") + except Exception: + # Create collection + logger.info(f"Creating collection: {collection_name}") + try: + self._client.create_collection( + collection_name=collection_name, + fields=[ + Field( + field_name="id", + field_type=FieldType.String, + default_val="", + ), + Field( + field_name="content", + field_type=FieldType.Text, + ), + Field( + field_name="vector", + field_type=FieldType.Vector, + dim=2048, + ), + Field( + field_name="metadata", + field_type=FieldType.String, + default_val="{}", + ), + ], + description=description, + ) + logger.info(f"Successfully created collection: {collection_name}") + + vector_index = VectorIndexParams( + distance=DistanceType.COSINE, + index_type=IndexType.HNSW, + quant=QuantType.Float, + ) + + try: + self._client.create_index( + collection_name, + f"{collection_name}_index", + vector_index, + cpu_quota=2, + description=description, + ) + except Exception as e: + logger.error( + f"Failed to create index for collection {collection_name}: {e}" + ) + raise + except Exception as e: + logger.error(f"Failed to create collection {collection_name}: {e}") + raise + + def _generate_embedding(self, text: str) -> List[float]: + """Generate embedding for text using VikingDB embedding service.""" + try: + raw_data = [RawData("text", text)] + response = self._client.embedding_v2( + emb_model=EmbModel(self.embedding_model), raw_data=raw_data + ) + if response and response.get("sentence_dense_embedding"): + return response["sentence_dense_embedding"][0] + else: + raise ValueError(f"Invalid embedding response: {response}") + except Exception as e: + logger.error(f"Failed to generate embedding: {e}") + raise + + async def save_tool_usage( + self, + question: str, + tool_name: str, + args: Dict[str, Any], + context: ToolContext, + success: bool = True, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Save a tool usage pattern. + + - If tool is run_sql: saves to sql_collection (question-SQL pairs) + - If tool is other: saves to doc_collection (general tool usage documentation) + + Args: + question: The user question + tool_name: Name of the tool + args: Tool arguments + context: Tool execution context + success: Whether execution was successful + metadata: Additional metadata + """ + # Generate ID + doc_id = str(uuid.uuid4()) + + # Prepare metadata + meta = metadata or {} + meta.update( + { + "question": question, + "tool_name": tool_name, + "success": success, + "timestamp": datetime.now().isoformat(), + } + ) + + # Handle run_sql tool separately + if tool_name == "run_sql" and "sql" in args: + sql = args["sql"] + + # Create content for embedding (combine question and SQL) + content = json.dumps({"question": question, "sql": sql}, ensure_ascii=False) + + meta["sql"] = sql + meta["type"] = "question_sql" + + # Generate embedding + vector = self._generate_embedding(content) + + # Insert into SQL collection + field = { + "id": doc_id, + "content": content, + "vector": vector, + "metadata": json.dumps(meta, ensure_ascii=False), + } + + data = Data(field) + + try: + collection = self._client.get_collection(self.sql_collection) + collection.upsert_data(data=[data]) + logger.info( + f"Saved question-SQL pair to sql_collection: {question[:50]}..." + ) + except Exception as e: + logger.error(f"Failed to save SQL tool usage: {e}") + raise + else: + # For other tools, save to doc_collection as documentation + # Create a descriptive content + content = json.dumps( + { + "question": question, + "tool_name": tool_name, + "args": args, + "description": f"User asked '{question}' and used tool '{tool_name}' with args: {args}", + }, + ensure_ascii=False, + ) + + meta["args"] = args + meta["type"] = "tool_usage" + + # Generate embedding + vector = self._generate_embedding(content) + + # Insert into documentation collection + field = { + "id": doc_id, + "content": content, + "vector": vector, + "metadata": json.dumps(meta, ensure_ascii=False), + } + + data = Data(field) + + try: + collection = self._client.get_collection(self.doc_collection) + collection.upsert_data(data=[data]) + logger.info( + f"Saved tool usage to doc_collection: {tool_name} for question '{question[:50]}...'" + ) + except Exception as e: + logger.error(f"Failed to save tool usage to doc_collection: {e}") + raise + + async def search_similar_usage( + self, + question: str, + context: ToolContext, + *, + limit: int = 10, + similarity_threshold: float = 0.7, + tool_name_filter: Optional[str] = None, + ) -> List[ToolMemorySearchResult]: + """ + Search for similar question-SQL pairs. + + Args: + question: The question to search for + context: Tool execution context + limit: Maximum number of results + similarity_threshold: Minimum similarity score + tool_name_filter: Filter by tool name + + Returns: + List of similar tool usage patterns + """ + # Generate embedding for query + vector = self._generate_embedding(question) + + # Search in VikingDB + try: + index = self._client.get_index( + self.sql_collection, f"{self.sql_collection}_index" + ) + response = index.search_by_vector( + vector=vector, + limit=limit, + ) + + results = [] + for idx, item in enumerate(response): + score = item.score + + # Apply similarity threshold + if score < similarity_threshold: + continue + + # Parse metadata + metadata = json.loads(item.fields.get("metadata", "{}")) + + # Apply tool name filter + if tool_name_filter and metadata.get("tool_name") != tool_name_filter: + continue + + # Create ToolMemory object + tool_memory = ToolMemory( + memory_id=item.fields.get("id"), + question=metadata.get("question", ""), + tool_name=metadata.get("tool_name", "run_sql"), + args={"sql": metadata.get("sql", "")}, + success=metadata.get("success", True), + ) + + results.append( + ToolMemorySearchResult( + memory=tool_memory, + similarity_score=score, + rank=idx + 1, + ) + ) + + return results + except Exception as e: + logger.error(f"Failed to search similar usage: {e}") + return [] + + async def save_text_memory(self, content: str, context: ToolContext) -> TextMemory: + """Save a text memory.""" + + # Generate ID + doc_id = str(uuid.uuid4()) + + # Generate embedding + vector = self._generate_embedding(content) + + # Prepare metadata + metadata = { + "timestamp": datetime.now().isoformat(), + "type": "documentation", + } + + # Insert into VikingDB + field = { + "id": doc_id, + "content": content, + "vector": vector, + "metadata": json.dumps(metadata, ensure_ascii=False), + } + + data = Data(field) + + try: + collection = self._client.get_collection(self.doc_collection) + collection.upsert_data(data=[data]) + logger.info(f"Saved documentation: {content[:50]}...") + + return TextMemory( + memory_id=doc_id, + content=content, + timestamp=datetime.now().isoformat(), + ) + except Exception as e: + logger.error(f"Failed to save text memory: {e}") + raise + + async def search_text_memories( + self, + query: str, + context: ToolContext, + *, + limit: int = 10, + similarity_threshold: float = 0.7, + include_ddl: bool = True, + ) -> List[TextMemorySearchResult]: + """ + Search documentation and DDL memories. + + This method searches both doc_collection and ddl_collection to provide + comprehensive context including table schemas and documentation. + + Args: + query: Query string + context: Tool execution context + limit: Maximum results per collection + similarity_threshold: Minimum similarity score + include_ddl: Whether to include DDL results (default: True) + + Returns: + List of matching text memories (documentation + DDL) + """ + results = [] + vector = self._generate_embedding(query) + + # Search documentation collection + try: + doc_index = self._client.get_index( + self.doc_collection, f"{self.doc_collection}_index" + ) + doc_response = doc_index.search_by_vector( + vector=vector, + limit=limit, + ) + + for idx, item in enumerate(doc_response): + score = item.score + + if score < similarity_threshold: + continue + + content = item.fields.get("content", "") + memory_id = item.fields.get("id") + + text_memory = TextMemory( + memory_id=memory_id, + content=content, + timestamp=None, + ) + + results.append( + TextMemorySearchResult( + memory=text_memory, + similarity_score=score, + rank=len(results) + 1, + ) + ) + except Exception as e: + logger.error(f"Failed to search documentation: {e}") + + # Search DDL collection if requested + if include_ddl: + try: + ddl_index = self._client.get_index( + self.ddl_collection, f"{self.ddl_collection}_index" + ) + ddl_response = ddl_index.search_by_vector( + vector=vector, + limit=limit, + ) + + for idx, item in enumerate(ddl_response): + score = item.score + + if score < similarity_threshold: + continue + + content = item.fields.get("content", "") + memory_id = item.fields.get("id") + + # Mark DDL content with a prefix for clarity + text_memory = TextMemory( + memory_id=memory_id, + content=f"[DDL Schema]\n{content}", + timestamp=None, + ) + + results.append( + TextMemorySearchResult( + memory=text_memory, + similarity_score=score, + rank=len(results) + 1, + ) + ) + except Exception as e: + logger.error(f"Failed to search DDL: {e}") + + # Sort by similarity score (descending) + results.sort(key=lambda x: x.similarity_score, reverse=True) + + # Update ranks after sorting + for idx, result in enumerate(results): + result.rank = idx + 1 + + return results[ + : limit * 2 + ] # Return up to limit*2 results (from both collections) + + def train_ddl(self, ddl: str) -> str: + """ + Train with DDL (table schema). + + Args: + ddl: DDL statement + + Returns: + ID of the saved DDL + """ + doc_id = str(uuid.uuid4()) + vector = self._generate_embedding(ddl) + + metadata = { + "timestamp": datetime.now().isoformat(), + "type": "ddl", + } + + field = { + "id": doc_id, + "content": ddl, + "vector": vector, + "metadata": json.dumps(metadata, ensure_ascii=False), + } + + data = Data(field) + + try: + collection = self._client.get_collection(self.ddl_collection) + collection.upsert_data(data=[data]) + logger.info(f"Trained DDL: {ddl[:50]}...") + return doc_id + except Exception as e: + logger.error(f"Failed to train DDL: {e}") + raise + + def train_documentation(self, documentation: str) -> str: + """ + Train with documentation. + + Args: + documentation: Documentation text + + Returns: + ID of the saved documentation + """ + doc_id = str(uuid.uuid4()) + vector = self._generate_embedding(documentation) + + metadata = { + "timestamp": datetime.now().isoformat(), + "type": "documentation", + } + + field = { + "id": doc_id, + "content": documentation, + "vector": vector, + "metadata": json.dumps(metadata, ensure_ascii=False), + } + + data = Data(field) + + try: + collection = self._client.get_collection(self.doc_collection) + collection.upsert_data(data=[data]) + logger.info(f"Trained documentation: {documentation[:50]}...") + return doc_id + except Exception as e: + logger.error(f"Failed to train documentation: {e}") + raise + + def train_question_sql(self, question: str, sql: str) -> str: + """ + Train with question-SQL pair. + + Args: + question: User question + sql: SQL query + + Returns: + ID of the saved pair + """ + content = json.dumps({"question": question, "sql": sql}, ensure_ascii=False) + + doc_id = str(uuid.uuid4()) + vector = self._generate_embedding(content) + + metadata = { + "question": question, + "sql": sql, + "timestamp": datetime.now().isoformat(), + "type": "question_sql", + } + + field = { + "id": doc_id, + "content": content, + "vector": vector, + "metadata": json.dumps(metadata, ensure_ascii=False), + } + + data = Data(field) + + try: + collection = self._client.get_collection(self.sql_collection) + collection.upsert_data(data=[data]) + logger.info(f"Trained question-SQL: {question[:50]}...") + return doc_id + except Exception as e: + logger.error(f"Failed to train question-SQL: {e}") + raise + + def get_related_ddl(self, question: str, limit: int = 5) -> List[str]: + """ + Get related DDL for a question. + + Args: + question: User question + limit: Maximum results + + Returns: + List of related DDL statements + """ + vector = self._generate_embedding(question) + + try: + index = self._client.get_index( + self.ddl_collection, f"{self.ddl_collection}_index" + ) + response = index.search_by_vector( + vector=vector, + limit=limit, + ) + return [ + item.get("fields", {}).get("content", "") + for item in response.get("items", []) + ] + except Exception as e: + logger.error(f"Failed to get related DDL: {e}") + return [] + + def get_related_documentation(self, question: str, limit: int = 5) -> List[str]: + """ + Get related documentation for a question. + + Args: + question: User question + limit: Maximum results + + Returns: + List of related documentation + """ + vector = self._generate_embedding(question) + + try: + index = self._client.get_index( + self.doc_collection, f"{self.doc_collection}_index" + ) + response = index.search_by_vector( + vector=vector, + limit=limit, + ) + return [ + item.get("fields", {}).get("content", "") + for item in response.get("items", []) + ] + except Exception as e: + logger.error(f"Failed to get related documentation: {e}") + return [] + + async def get_recent_memories( + self, context: ToolContext, limit: int = 10 + ) -> List[ToolMemory]: + """Get recent tool memories (not implemented for VikingDB).""" + return [] + + async def get_recent_text_memories( + self, context: ToolContext, limit: int = 10 + ) -> List[TextMemory]: + """Get recent text memories (not implemented for VikingDB).""" + return [] + + async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool: + """Delete memory by ID (not implemented for VikingDB).""" + return False + + async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool: + """Delete text memory by ID (not implemented for VikingDB).""" + return False + + async def clear_memories( + self, + context: ToolContext, + tool_name: Optional[str] = None, + before_date: Optional[str] = None, + ) -> int: + """Clear memories (not implemented for VikingDB).""" + return 0