From 2271a73342accac7a2d6a6112ed4c46635196787 Mon Sep 17 00:00:00 2001 From: CodePothunter Date: Fri, 13 Mar 2026 18:17:02 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E2=9C=A8=E3=80=90=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E3=80=91=EF=BC=9A=E6=96=B0=E5=A2=9E=E6=B7=B7=E5=90=88=E6=A3=80?= =?UTF-8?q?=E7=B4=A2=E7=B3=BB=E7=BB=9F=EF=BC=8C=E6=94=AF=E6=8C=81=20Vector?= =?UTF-8?q?=20+=20BM25=20+=20Substring=20=E5=A4=9A=E8=B7=AF=E5=8F=AC?= =?UTF-8?q?=E5=9B=9E=E4=B8=8E=E8=9E=8D=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **tinysearch/base.py** - 新增 Retriever ABC,统一文本级检索接口 - 新增 FusionStrategy ABC,定义多路结果融合策略 - 新增 Reranker ABC,定义重排序接口 **新增文件 tinysearch/retrievers/__init__.py** **新增文件 tinysearch/retrievers/vector_retriever.py** - 新增 VectorRetriever,包装 Embedder + VectorIndexer 为 Retriever 接口 **新增文件 tinysearch/retrievers/bm25_retriever.py** - 新增 BM25Retriever,基于 bm25s 库实现关键词检索 - 支持 jieba 中文分词,可配置 tokenizer **新增文件 tinysearch/retrievers/substring_retriever.py** - 新增 SubstringRetriever,支持正则和纯文本子串匹配 **新增文件 tinysearch/fusion/__init__.py** **新增文件 tinysearch/fusion/rrf.py** - 新增 ReciprocalRankFusion,基于排名的融合策略 **新增文件 tinysearch/fusion/weighted.py** - 新增 WeightedFusion,min-max 归一化 + 加权求和融合 **新增文件 tinysearch/rerankers/__init__.py** **新增文件 tinysearch/rerankers/cross_encoder.py** - 新增 CrossEncoderReranker,基于 FlagEmbedding BGE 模型的交叉编码器重排序 **新增文件 tinysearch/query/hybrid.py** - 新增 HybridQueryEngine,多路召回 → 融合 → 可选重排序 pipeline **tinysearch/cli.py** - 新增 load_retriever / load_retrievers / load_fusion / load_reranker 工厂函数 - 扩展 load_query_engine 支持 method: "hybrid" **tinysearch/config.py** - 新增 retrievers / fusion / reranker 默认配置项 **tinysearch/flow/controller.py** - 新增多路检索器的 build / save / load 支持 **tinysearch/__init__.py** - 新增 retrievers / fusion / rerankers 模块导出 - 版本号升级 0.1.0 → 0.2.0 **setup.py** - 更新项目描述为混合检索系统 - 新增 hybrid / reranker 可选依赖组 **README.md** - 更新架构图,展示 Retriever → Fusion → Reranker 流程 - 新增 Hybrid Search 完整章节(配置、用法、结果格式) - 新增混合检索依赖安装说明 **影响说明** - 向后兼容:默认 method: "template" 行为不变,无需额外依赖 - 新增可选依赖:bm25s / jieba / FlagEmbedding 均为可选导入 + 优雅降级 Co-Authored-By: Claude Opus 4.6 --- README.md | 184 ++++++++++++++++--- setup.py | 9 +- tinysearch/__init__.py | 11 +- tinysearch/base.py | 84 +++++++++ tinysearch/cli.py | 127 ++++++++++++- tinysearch/config.py | 13 +- tinysearch/flow/controller.py | 69 ++++++- tinysearch/fusion/__init__.py | 11 ++ tinysearch/fusion/rrf.py | 70 +++++++ tinysearch/fusion/weighted.py | 104 +++++++++++ tinysearch/query/__init__.py | 6 +- tinysearch/query/hybrid.py | 76 ++++++++ tinysearch/rerankers/__init__.py | 9 + tinysearch/rerankers/cross_encoder.py | 126 +++++++++++++ tinysearch/retrievers/__init__.py | 13 ++ tinysearch/retrievers/bm25_retriever.py | 162 ++++++++++++++++ tinysearch/retrievers/substring_retriever.py | 148 +++++++++++++++ tinysearch/retrievers/vector_retriever.py | 84 +++++++++ 18 files changed, 1256 insertions(+), 50 deletions(-) create mode 100644 tinysearch/fusion/__init__.py create mode 100644 tinysearch/fusion/rrf.py create mode 100644 tinysearch/fusion/weighted.py create mode 100644 tinysearch/query/hybrid.py create mode 100644 tinysearch/rerankers/__init__.py create mode 100644 tinysearch/rerankers/cross_encoder.py create mode 100644 tinysearch/retrievers/__init__.py create mode 100644 tinysearch/retrievers/bm25_retriever.py create mode 100644 tinysearch/retrievers/substring_retriever.py create mode 100644 tinysearch/retrievers/vector_retriever.py diff --git a/README.md b/README.md index 9641b9a..6764c85 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,17 @@ # TinySearch -**A lightweight vector search and retrieval system for document understanding, semantic search, and text information retrieval.** +**A lightweight hybrid search and retrieval system for document understanding, semantic search, and text information retrieval.** -TinySearch provides an end-to-end solution for converting documents into searchable vector embeddings, with a focus on simplicity, flexibility, and efficiency. Perfect for building RAG (Retrieval-Augmented Generation) systems, semantic document search, knowledge bases, and more. +TinySearch provides an end-to-end solution for converting documents into searchable vector embeddings, with support for hybrid multi-retriever search (Vector + BM25 + Substring). Focused on simplicity, flexibility, and efficiency. Perfect for building RAG (Retrieval-Augmented Generation) systems, semantic document search, knowledge bases, and more. ## Key Features - 🧩 **Modular Design**: Plug-and-play components for data processing, text splitting, embedding generation, and vector retrieval - 🔍 **Semantic Search**: Find contextually relevant information beyond keyword matching +- 🔀 **Hybrid Retrieval**: Combine Vector, BM25, and Substring retrievers with configurable fusion strategies (RRF, Weighted) +- 📝 **BM25 Keyword Search**: Fast keyword-based retrieval with jieba Chinese tokenization support +- 🔤 **Substring/Regex Search**: Ctrl+F style exact match and regex pattern search +- 🏆 **Cross-Encoder Reranking**: Optional reranking with BGE Reranker for improved relevance - ⚙️ **Highly Configurable**: Simple YAML configuration to control all aspects of the system - 🔌 **Multiple Input Formats**: Support for TXT, PDF, CSV, Markdown, JSON, and custom adapters - 🤖 **Embedding Models**: Integration with HuggingFace models like Qwen-Embedding and more @@ -29,55 +33,63 @@ flowchart TB subgraph Input Documents["Documents
(PDF, Text, CSV, JSON, etc.)"] end - + subgraph DataProcessing["Data Processing"] DataAdapter["DataAdapter
Extract text from data source"] TextSplitter["TextSplitter
Chunk text into segments"] - Embedder["Embedder
Generate vector embeddings"] end - - subgraph IndexLayer["Index Layer"] - VectorIndexer["VectorIndexer
Build and maintain FAISS index"] - IndexStorage["Index Storage
(FAISS Index + Original Text)"] + + subgraph RetrieverLayer["Retriever Layer"] + VectorRetriever["VectorRetriever
Embedder + FAISS"] + BM25Retriever["BM25Retriever
bm25s + jieba"] + SubstringRetriever["SubstringRetriever
Regex / exact match"] + end + + subgraph FusionLayer["Fusion & Reranking"] + FusionStrategy["FusionStrategy
(RRF / Weighted)"] + Reranker["Reranker (optional)
Cross-Encoder BGE"] end - + subgraph QueryLayer["Query Layer"] UserQuery["User Query"] - QueryEngine["QueryEngine
Process and reformat query"] + QueryEngine["QueryEngine
Template / Hybrid"] SearchResults["Search Results
Ranked by relevance"] end - + subgraph FlowControl["Flow Control"] Config["Configuration"] FlowController["FlowController
Orchestrate data flow"] end - + subgraph API["API Layer"] CLI["Command Line Interface"] FastAPI["FastAPI Web Service"] end - + %% Data Flow - Indexing Documents --> DataAdapter DataAdapter --> TextSplitter - TextSplitter --> Embedder - Embedder --> VectorIndexer - VectorIndexer --> IndexStorage - + TextSplitter --> VectorRetriever + TextSplitter --> BM25Retriever + TextSplitter --> SubstringRetriever + %% Data Flow - Querying UserQuery --> QueryEngine - QueryEngine --> Embedder - Embedder --> VectorIndexer - VectorIndexer --> SearchResults - + QueryEngine --> VectorRetriever + QueryEngine --> BM25Retriever + QueryEngine --> SubstringRetriever + VectorRetriever --> FusionStrategy + BM25Retriever --> FusionStrategy + SubstringRetriever --> FusionStrategy + FusionStrategy --> Reranker + Reranker --> SearchResults + %% Control Flow Config --> FlowController FlowController --> DataAdapter FlowController --> TextSplitter - FlowController --> Embedder - FlowController --> VectorIndexer FlowController --> QueryEngine - + %% API Flow CLI --> FlowController FastAPI --> FlowController @@ -152,9 +164,15 @@ For API documentation, see [API Guide](docs/api.md) and [API Authentication Guid ## Installation ```bash -# Basic installation +# Basic installation (vector search only) pip install tinysearch +# With hybrid search (BM25 + Chinese tokenization) +pip install tinysearch bm25s jieba + +# With cross-encoder reranking +pip install tinysearch FlagEmbedding + # With API support pip install tinysearch[api] @@ -195,7 +213,7 @@ indexer: type: faiss index_path: .cache/index.faiss metric: cosine - + query_engine: method: template template: "Please find information about: {query}" @@ -226,6 +244,118 @@ Then visit http://localhost:8000 in your browser to use the web interface, or se curl -X POST http://localhost:8000/query -H "Content-Type: application/json" -d '{"query": "Your search query", "top_k": 5}' ``` +## Hybrid Search + +TinySearch supports hybrid retrieval that combines multiple search strategies for better recall and precision. You can mix Vector (semantic), BM25 (keyword), and Substring (exact match) retrievers, then fuse their results using RRF or Weighted fusion. + +### Hybrid Search Configuration + +To enable hybrid search, set `query_engine.method` to `"hybrid"` and configure the `retrievers` list: + +```yaml +# Embedding + Vector index (required for vector retriever) +embedder: + model: Qwen/Qwen3-Embedding-0.6B + device: cuda +indexer: + type: faiss + index_path: .cache/index.faiss + metric: cosine + +# Multi-retriever configuration +retrievers: + - type: vector # Semantic search (uses embedder + indexer above) + - type: bm25 # Keyword search + tokenizer: jieba # Chinese tokenization (optional, fallback: whitespace) + - type: substring # Exact match / regex + is_regex: false + +# Fusion strategy +fusion: + strategy: weighted # "weighted" or "rrf" + weights: [0.5, 0.4, 0.1] # Weights for each retriever (vector, bm25, substring) + min_score: 0.1 # Drop results below this fused score + +# Optional: Cross-encoder reranking +reranker: + enabled: false + model: BAAI/bge-reranker-v2-m3 + +# Query engine +query_engine: + method: hybrid # "template" (vector only) or "hybrid" + top_k: 20 + recall_multiplier: 2 # Each retriever recalls top_k * 2 candidates before fusion +``` + +### Fusion Strategies + +| Strategy | Description | Best For | +|----------|-------------|----------| +| **weighted** | Min-max normalize scores, then weighted sum | When you know relative retriever importance | +| **rrf** | Reciprocal Rank Fusion: `score = Σ 1/(rank + k)` | When score distributions differ (more robust) | + +### Programmatic Usage + +```python +from tinysearch.base import TextChunk +from tinysearch.retrievers import VectorRetriever, BM25Retriever, SubstringRetriever +from tinysearch.fusion import WeightedFusion, ReciprocalRankFusion +from tinysearch.query import HybridQueryEngine + +# Build retrievers +bm25 = BM25Retriever() +bm25.build(chunks) # chunks: List[TextChunk] + +substr = SubstringRetriever(is_regex=False) +substr.build(chunks) + +# Create hybrid engine +engine = HybridQueryEngine( + retrievers=[bm25, substr], + fusion_strategy=WeightedFusion(weights=[0.7, 0.3]), + recall_multiplier=2, +) + +results = engine.retrieve("搜索关键词", top_k=10) +# Each result: {"text", "metadata", "score", "retrieval_method": "hybrid", "scores": {...}} +``` + +### Result Format + +Hybrid search results include per-retriever scores for transparency: + +```python +{ + "text": "...", # Chunk text + "metadata": {...}, # Source metadata + "score": 0.85, # Final fused score [0, 1] + "retrieval_method": "hybrid", + "scores": { # Per-retriever original scores + "vector": 0.92, + "bm25": 0.75, + } +} +``` + +### Optional Dependencies for Hybrid Search + +| Package | Purpose | Required? | +|---------|---------|-----------| +| `bm25s` | BM25 retrieval engine | Only for BM25Retriever | +| `jieba` | Chinese tokenization | Optional (fallback: whitespace split) | +| `FlagEmbedding` | Cross-encoder reranking | Optional (only for reranker) | + +Install with: +```bash +pip install bm25s jieba # For BM25 + Chinese support +pip install FlagEmbedding # For cross-encoder reranking +``` + +### Backward Compatibility + +If you don't configure `retrievers` or keep `query_engine.method: "template"`, TinySearch behaves exactly as before — pure vector search with no additional dependencies. + ## Examples TinySearch includes various example scripts in the `examples/` directory to demonstrate different features: @@ -426,4 +556,4 @@ This project is licensed under the MIT License - see the LICENSE file for detail ## Keywords -vector search, semantic search, document retrieval, embeddings, FAISS, RAG, information retrieval, text search, vector database, document understanding, NLP, natural language processing, AI search \ No newline at end of file +vector search, semantic search, hybrid search, BM25, document retrieval, embeddings, FAISS, RAG, information retrieval, text search, vector database, document understanding, NLP, natural language processing, AI search, reranking, fusion \ No newline at end of file diff --git a/setup.py b/setup.py index e0cbadc..eda9364 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,13 @@ "indexers": [ "faiss-cpu>=1.7.0", ], + "hybrid": [ + "bm25s>=0.1.0", + "jieba>=0.42.0", + ], + "reranker": [ + "FlagEmbedding>=1.0.0", + ], "dev": [ "pytest>=6.0.0", "black>=21.5b2", @@ -64,7 +71,7 @@ version=version.get("__version__", "0.1.0"), author="TinySearch Team", author_email="tinysearch@example.com", - description="A lightweight vector retrieval system", + description="A lightweight hybrid search and retrieval system (Vector + BM25 + Substring)", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/yourusername/tinysearch", diff --git a/tinysearch/__init__.py b/tinysearch/__init__.py index d2a4c08..9a61b24 100644 --- a/tinysearch/__init__.py +++ b/tinysearch/__init__.py @@ -1,13 +1,16 @@ """ -TinySearch: A lightweight vector retrieval system +TinySearch: A lightweight hybrid retrieval system """ -__version__ = "0.1.0" +__version__ = "0.2.0" # Make submodules available for import from . import adapters from . import splitters -from . import embedders +from . import embedders from . import indexers from . import query -from . import flow \ No newline at end of file +from . import flow +from . import retrievers +from . import fusion +from . import rerankers \ No newline at end of file diff --git a/tinysearch/base.py b/tinysearch/base.py index cd91871..fb3ca0c 100644 --- a/tinysearch/base.py +++ b/tinysearch/base.py @@ -166,6 +166,90 @@ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: pass +class Retriever(ABC): + """ + Interface for text-level retrievers that search by query string. + Unlike VectorIndexer (which takes pre-embedded vectors), Retriever + operates on raw text queries, making it suitable for BM25, substring, + and wrapped vector search. + """ + + @abstractmethod + def build(self, chunks: List[TextChunk]) -> None: + """ + Build the retriever index from text chunks + + Args: + chunks: List of TextChunk objects to index + """ + pass + + @abstractmethod + def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Retrieve relevant chunks for a text query + + Args: + query: Raw query string + top_k: Number of results to return + + Returns: + List of dicts with keys: text, metadata, score, retrieval_method + """ + pass + + @abstractmethod + def save(self, path: Union[str, pathlib.Path]) -> None: + """Save the retriever index to disk""" + pass + + @abstractmethod + def load(self, path: Union[str, pathlib.Path]) -> None: + """Load the retriever index from disk""" + pass + + +class FusionStrategy(ABC): + """ + Interface for fusing results from multiple retrievers into a single ranked list. + """ + + @abstractmethod + def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[str, Any]]: + """ + Fuse multiple result lists into a single ranked list + + Args: + results_list: List of result lists from different retrievers + **kwargs: Strategy-specific parameters + + Returns: + Fused and ranked list of results + """ + pass + + +class Reranker(ABC): + """ + Interface for rerankers that re-score candidates using a cross-encoder or similar model. + """ + + @abstractmethod + def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: + """ + Re-rank candidates for a query + + Args: + query: Query string + candidates: List of candidate results to re-rank + top_k: Number of results to return + + Returns: + Re-ranked list of results + """ + pass + + class FlowController(ABC): """ Interface for flow controllers that orchestrate the data pipeline diff --git a/tinysearch/cli.py b/tinysearch/cli.py index 2c126bb..a3c019c 100644 --- a/tinysearch/cli.py +++ b/tinysearch/cli.py @@ -10,13 +10,23 @@ from typing import Dict, Any, List, Optional, Union, Type, cast from .config import Config -from .base import DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine +from .base import ( + DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine, + Retriever, FusionStrategy, Reranker, +) from .adapters import TextAdapter, PDFAdapter, CSVAdapter, MarkdownAdapter, JSONAdapter from .splitters import CharacterTextSplitter from .embedders import HuggingFaceEmbedder # 直接从模块导入 from .indexers.faiss_indexer import FAISSIndexer from .query.template import TemplateQueryEngine +from .query.hybrid import HybridQueryEngine +from .retrievers.vector_retriever import VectorRetriever +from .retrievers.bm25_retriever import BM25Retriever +from .retrievers.substring_retriever import SubstringRetriever +from .fusion.rrf import ReciprocalRankFusion +from .fusion.weighted import WeightedFusion +from .rerankers.cross_encoder import CrossEncoderReranker from .logger import get_logger, configure_logger, log_step, log_progress, log_success, log_error @@ -137,26 +147,133 @@ def load_indexer(config: Config) -> FAISSIndexer: raise ValueError(f"Unsupported indexer type: {indexer_type}") +def load_retriever(config: Config, retriever_config: Dict[str, Any], embedder: Embedder, indexer: FAISSIndexer) -> Retriever: + """ + Load a single retriever based on its config dict. + + Args: + config: Global configuration object + retriever_config: Dict like {"type": "vector"} or {"type": "bm25", "tokenizer": "jieba"} + embedder: Embedder instance (used by vector retriever) + indexer: FAISSIndexer instance (used by vector retriever) + + Returns: + Retriever instance + """ + rtype = retriever_config.get("type", "vector") + + if rtype == "vector": + return VectorRetriever( + embedder=embedder, + indexer=indexer, + query_template=config.get("query_engine.template"), + ) + elif rtype == "bm25": + tokenizer_name = retriever_config.get("tokenizer", "default") + tokenizer = None # use default + if tokenizer_name == "jieba": + try: + import jieba + tokenizer = lambda text: list(jieba.cut(text.lower())) + except ImportError: + pass # fallback to default + return BM25Retriever(tokenizer=tokenizer) + elif rtype == "substring": + return SubstringRetriever( + is_regex=retriever_config.get("is_regex", False), + ) + else: + raise ValueError(f"Unsupported retriever type: {rtype}") + + +def load_retrievers(config: Config, embedder: Embedder, indexer: FAISSIndexer) -> List[Retriever]: + """ + Load all configured retrievers. + + Returns: + List of Retriever instances + """ + retriever_configs = config.get("retrievers", [{"type": "vector"}]) + return [ + load_retriever(config, rc, embedder, indexer) + for rc in retriever_configs + ] + + +def load_fusion(config: Config) -> FusionStrategy: + """ + Load a fusion strategy based on configuration. + + Returns: + FusionStrategy instance + """ + strategy = config.get("fusion.strategy", "weighted") + if strategy == "rrf": + return ReciprocalRankFusion( + k=config.get("fusion.k", 60), + ) + elif strategy == "weighted": + return WeightedFusion( + weights=config.get("fusion.weights"), + min_score=config.get("fusion.min_score", 0.0), + ) + else: + raise ValueError(f"Unsupported fusion strategy: {strategy}") + + +def load_reranker(config: Config) -> Optional[Reranker]: + """ + Load a reranker if enabled in configuration. + + Returns: + Reranker instance or None + """ + if not config.get("reranker.enabled", False): + return None + + return CrossEncoderReranker( + model_name=config.get("reranker.model", "BAAI/bge-reranker-v2-m3"), + device=config.get("reranker.device"), + batch_size=config.get("reranker.batch_size", 64), + max_length=config.get("reranker.max_length", 512), + use_fp16=config.get("reranker.use_fp16", True), + ) + + def load_query_engine(config: Config, embedder: Embedder, indexer: FAISSIndexer) -> QueryEngine: """ - Load a query engine based on configuration - + Load a query engine based on configuration. + + Supports: + - "template": Original TemplateQueryEngine (backward compatible) + - "hybrid": HybridQueryEngine with multi-retriever fusion + Args: config: Configuration object embedder: Embedder instance indexer: FAISSIndexer instance - + Returns: QueryEngine instance """ query_engine_type = config.get("query_engine.method", "template") - + if query_engine_type == "template": return TemplateQueryEngine( embedder=embedder, indexer=indexer, template=config.get("query_engine.template", "请帮我查找:{query}") ) + elif query_engine_type == "hybrid": + retrievers = load_retrievers(config, embedder, indexer) + fusion = load_fusion(config) + reranker = load_reranker(config) + return HybridQueryEngine( + retrievers=retrievers, + fusion_strategy=fusion, + reranker=reranker, + recall_multiplier=config.get("query_engine.recall_multiplier", 2), + ) else: raise ValueError(f"Unsupported query engine type: {query_engine_type}") diff --git a/tinysearch/config.py b/tinysearch/config.py index b50fe68..9bc357f 100644 --- a/tinysearch/config.py +++ b/tinysearch/config.py @@ -42,10 +42,21 @@ def __init__(self, config_path: Optional[Union[str, Path]] = None): "index_path": ".cache/index.faiss", "metric": "cosine" }, + "retrievers": [ + {"type": "vector"} + ], + "fusion": { + "strategy": "weighted", + "weights": [1.0], + }, + "reranker": { + "enabled": False, + }, "query_engine": { "method": "template", "template": "请帮我查找:{query}", - "top_k": 5 + "top_k": 5, + "recall_multiplier": 2, }, "flow": { "use_cache": True, diff --git a/tinysearch/flow/controller.py b/tinysearch/flow/controller.py index bd67071..e466d7f 100644 --- a/tinysearch/flow/controller.py +++ b/tinysearch/flow/controller.py @@ -7,9 +7,11 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union, Set, Tuple, cast, Callable -from tinysearch.base import DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine +from tinysearch.base import DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine, Retriever from tinysearch.base import TextChunk, FlowController as FlowControllerBase from tinysearch.flow.hot_update import HotUpdateManager +from tinysearch.query.hybrid import HybridQueryEngine +from tinysearch.retrievers.vector_retriever import VectorRetriever class FlowController(FlowControllerBase): @@ -141,6 +143,9 @@ def process_file(self, file_path: Union[str, Path], force_reprocess: bool = Fals # Add to index self.indexer.build(vectors, chunks) + + # Build retriever indexes for HybridQueryEngine + self._build_retriever_indexes(chunks) def process_directory(self, dir_path: Union[str, Path], extensions: Optional[List[str]] = None, recursive: bool = True, force_reprocess: bool = False) -> None: @@ -200,31 +205,75 @@ def build_index(self, data_path: Union[str, Path], **kwargs) -> None: else: self.process_file(data_path, force_reprocess=force_reprocess) + def _get_hybrid_retrievers(self) -> List[Retriever]: + """Get retrievers from HybridQueryEngine, if applicable""" + if isinstance(self.query_engine, HybridQueryEngine): + return self.query_engine.retrievers + return [] + + def _build_retriever_indexes(self, chunks: List[TextChunk]) -> None: + """Build indexes for non-vector retrievers in HybridQueryEngine""" + for retriever in self._get_hybrid_retrievers(): + # Skip VectorRetriever - it's already handled by self.indexer.build() + if isinstance(retriever, VectorRetriever): + continue + retriever.build(chunks) + def save_index(self, path: Optional[Union[str, Path]] = None) -> None: """ - Save the built index to disk - + Save the built index to disk. + Also saves retriever indexes for HybridQueryEngine. + Args: path: Path to save the index to, if None use the config path """ if path is None: path = self.config.get("indexer", {}).get("index_path", "index.faiss") - - # Use cast to ensure type safety for optional path + + # Save the main vector index self.indexer.save(cast(Union[str, Path], path)) - + + # Save non-vector retriever indexes + self._save_retriever_indexes(Path(str(path))) + + def _save_retriever_indexes(self, base_path: Path) -> None: + """Save indexes for non-vector retrievers alongside the main index""" + index_dir = base_path.parent if base_path.suffix else base_path + for retriever in self._get_hybrid_retrievers(): + if isinstance(retriever, VectorRetriever): + continue + # Derive subdirectory name from retriever class + retriever_name = type(retriever).__name__.lower().replace("retriever", "") + retriever_path = index_dir / f"{retriever_name}_index" + retriever.save(retriever_path) + def load_index(self, path: Optional[Union[str, Path]] = None) -> None: """ - Load an index from disk - + Load an index from disk. + Also loads retriever indexes for HybridQueryEngine. + Args: path: Path to load the index from, if None use the config path """ if path is None: path = self.config.get("indexer", {}).get("index_path", "index.faiss") - - # Use cast to ensure type safety for optional path + + # Load the main vector index self.indexer.load(cast(Union[str, Path], path)) + + # Load non-vector retriever indexes + self._load_retriever_indexes(Path(str(path))) + + def _load_retriever_indexes(self, base_path: Path) -> None: + """Load indexes for non-vector retrievers""" + index_dir = base_path.parent if base_path.suffix else base_path + for retriever in self._get_hybrid_retrievers(): + if isinstance(retriever, VectorRetriever): + continue + retriever_name = type(retriever).__name__.lower().replace("retriever", "") + retriever_path = index_dir / f"{retriever_name}_index" + if retriever_path.exists(): + retriever.load(retriever_path) def query(self, query_text: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ diff --git a/tinysearch/fusion/__init__.py b/tinysearch/fusion/__init__.py new file mode 100644 index 0000000..cfa509e --- /dev/null +++ b/tinysearch/fusion/__init__.py @@ -0,0 +1,11 @@ +""" +Fusion strategies for combining multi-retriever results +""" + +from .rrf import ReciprocalRankFusion +from .weighted import WeightedFusion + +__all__ = [ + "ReciprocalRankFusion", + "WeightedFusion", +] diff --git a/tinysearch/fusion/rrf.py b/tinysearch/fusion/rrf.py new file mode 100644 index 0000000..7241773 --- /dev/null +++ b/tinysearch/fusion/rrf.py @@ -0,0 +1,70 @@ +""" +Reciprocal Rank Fusion (RRF) strategy +""" +from collections import defaultdict +from typing import Any, Dict, List + +from tinysearch.base import FusionStrategy + + +class ReciprocalRankFusion(FusionStrategy): + """ + Reciprocal Rank Fusion combines multiple ranked lists using: + score(doc) = sum( 1 / (rank_i + k) ) + + where k is a constant (default 60) that reduces the impact of high-ranked items. + This is a robust, parameter-free fusion method widely used in information retrieval. + """ + + def __init__(self, k: int = 60): + """ + Args: + k: RRF constant. Higher values reduce the gap between ranks. + Default 60 is the standard value from the original RRF paper. + """ + self.k = k + + def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[str, Any]]: + """ + Fuse multiple result lists using RRF. + + Args: + results_list: List of result lists from different retrievers. + Each result must have 'text' key for deduplication. + + Returns: + Fused list sorted by RRF score descending. + """ + # Track RRF scores and best result per document (keyed by text) + rrf_scores: Dict[str, float] = defaultdict(float) + best_result: Dict[str, Dict[str, Any]] = {} + per_method_scores: Dict[str, Dict[str, float]] = defaultdict(dict) + + for result_list in results_list: + for rank, result in enumerate(result_list): + doc_key = result["text"] + rrf_score = 1.0 / (rank + self.k) + rrf_scores[doc_key] += rrf_score + + method = result.get("retrieval_method", "unknown") + per_method_scores[doc_key][method] = result.get("score", 0.0) + + # Keep the result with the highest original score + if doc_key not in best_result or result.get("score", 0) > best_result[doc_key].get("score", 0): + best_result[doc_key] = result + + # Build fused results + fused = [] + for doc_key, rrf_score in rrf_scores.items(): + base = best_result[doc_key] + fused.append({ + "text": base["text"], + "metadata": base.get("metadata", {}), + "score": rrf_score, + "retrieval_method": "hybrid", + "scores": per_method_scores[doc_key], + }) + + # Sort by fused score descending + fused.sort(key=lambda x: x["score"], reverse=True) + return fused diff --git a/tinysearch/fusion/weighted.py b/tinysearch/fusion/weighted.py new file mode 100644 index 0000000..0d3f464 --- /dev/null +++ b/tinysearch/fusion/weighted.py @@ -0,0 +1,104 @@ +""" +Weighted score fusion strategy +""" +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from tinysearch.base import FusionStrategy + + +class WeightedFusion(FusionStrategy): + """ + Weighted fusion normalizes each retriever's scores to [0, 1] via min-max, + then computes a weighted sum: + fusion_score = sum( normalized_score_i * weight_i ) + """ + + def __init__( + self, + weights: Optional[List[float]] = None, + min_score: float = 0.0, + ): + """ + Args: + weights: Weight for each retriever. If None, equal weights are used. + min_score: Minimum fusion score threshold. Results below this are dropped. + """ + self.weights = weights + self.min_score = min_score + + def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[str, Any]]: + """ + Fuse multiple result lists using weighted score fusion. + + Args: + results_list: List of result lists from different retrievers. + **kwargs: + weights: Override weights (takes precedence over self.weights) + + Returns: + Fused list sorted by fusion score descending. + """ + if not results_list: + return [] + + n_retrievers = len(results_list) + weights = kwargs.get("weights", self.weights) + if weights is None: + weights = [1.0 / n_retrievers] * n_retrievers + + if len(weights) != n_retrievers: + raise ValueError( + f"Number of weights ({len(weights)}) must match " + f"number of result lists ({n_retrievers})" + ) + + # Normalize scores per retriever (min-max to [0, 1]) + normalized_lists = [] + for result_list in results_list: + if not result_list: + normalized_lists.append([]) + continue + scores = [r.get("score", 0.0) for r in result_list] + min_s = min(scores) + max_s = max(scores) + range_s = max_s - min_s if max_s != min_s else 1.0 + + normalized = [] + for r in result_list: + norm_score = (r.get("score", 0.0) - min_s) / range_s + normalized.append({**r, "_norm_score": norm_score}) + normalized_lists.append(normalized) + + # Aggregate by document text + doc_scores: Dict[str, float] = defaultdict(float) + best_result: Dict[str, Dict[str, Any]] = {} + per_method_scores: Dict[str, Dict[str, float]] = defaultdict(dict) + + for i, (result_list, weight) in enumerate(zip(normalized_lists, weights)): + for result in result_list: + doc_key = result["text"] + doc_scores[doc_key] += result["_norm_score"] * weight + + method = result.get("retrieval_method", "unknown") + per_method_scores[doc_key][method] = result.get("score", 0.0) + + if doc_key not in best_result: + best_result[doc_key] = result + + # Build fused results + fused = [] + for doc_key, fusion_score in doc_scores.items(): + if fusion_score < self.min_score: + continue + base = best_result[doc_key] + fused.append({ + "text": base["text"], + "metadata": base.get("metadata", {}), + "score": fusion_score, + "retrieval_method": "hybrid", + "scores": per_method_scores[doc_key], + }) + + fused.sort(key=lambda x: x["score"], reverse=True) + return fused diff --git a/tinysearch/query/__init__.py b/tinysearch/query/__init__.py index 6338169..4f1d3a2 100644 --- a/tinysearch/query/__init__.py +++ b/tinysearch/query/__init__.py @@ -3,7 +3,9 @@ """ from .template import TemplateQueryEngine - +from .hybrid import HybridQueryEngine + __all__ = [ - "TemplateQueryEngine" + "TemplateQueryEngine", + "HybridQueryEngine", ] \ No newline at end of file diff --git a/tinysearch/query/hybrid.py b/tinysearch/query/hybrid.py new file mode 100644 index 0000000..5b74f35 --- /dev/null +++ b/tinysearch/query/hybrid.py @@ -0,0 +1,76 @@ +""" +Hybrid query engine - multi-retriever fusion with optional reranking +""" +from typing import Any, Dict, List, Optional + +from tinysearch.base import QueryEngine, Retriever, FusionStrategy, Reranker + + +class HybridQueryEngine(QueryEngine): + """ + Query engine that combines multiple retrievers via a fusion strategy, + with optional reranking. + + Pipeline: + 1. Each retriever recalls top_k * recall_multiplier candidates + 2. FusionStrategy merges and deduplicates results + 3. Optional Reranker re-scores the fused candidates + 4. Return top_k final results + """ + + def __init__( + self, + retrievers: List[Retriever], + fusion_strategy: FusionStrategy, + reranker: Optional[Reranker] = None, + recall_multiplier: int = 2, + ): + """ + Args: + retrievers: List of Retriever instances for multi-path retrieval + fusion_strategy: Strategy to fuse results from multiple retrievers + reranker: Optional reranker for final re-scoring + recall_multiplier: Multiply top_k by this for each retriever's recall + """ + if not retrievers: + raise ValueError("At least one retriever is required") + self.retrievers = retrievers + self.fusion_strategy = fusion_strategy + self.reranker = reranker + self.recall_multiplier = recall_multiplier + + def format_query(self, query: str) -> str: + """Pass-through: hybrid engine doesn't transform queries""" + return query + + def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Multi-path retrieval with fusion and optional reranking. + + Args: + query: Query string + top_k: Number of final results to return + + Returns: + Fused (and optionally reranked) list of results + """ + recall_k = top_k * self.recall_multiplier + + # Step 1: Recall from each retriever + all_results = [] + for retriever in self.retrievers: + try: + results = retriever.retrieve(query, top_k=recall_k) + all_results.append(results) + except Exception: + # If a retriever fails, skip it rather than failing entirely + all_results.append([]) + + # Step 2: Fuse results + fused = self.fusion_strategy.fuse(all_results) + + # Step 3: Optional reranking + if self.reranker is not None and fused: + fused = self.reranker.rerank(query, fused, top_k=top_k) + + return fused[:top_k] diff --git a/tinysearch/rerankers/__init__.py b/tinysearch/rerankers/__init__.py new file mode 100644 index 0000000..296350e --- /dev/null +++ b/tinysearch/rerankers/__init__.py @@ -0,0 +1,9 @@ +""" +Rerankers for re-scoring retrieval results +""" + +from .cross_encoder import CrossEncoderReranker + +__all__ = [ + "CrossEncoderReranker", +] diff --git a/tinysearch/rerankers/cross_encoder.py b/tinysearch/rerankers/cross_encoder.py new file mode 100644 index 0000000..592bc8b --- /dev/null +++ b/tinysearch/rerankers/cross_encoder.py @@ -0,0 +1,126 @@ +""" +Cross-encoder reranker using FlagEmbedding BGE Reranker +""" +from typing import Any, Dict, List, Optional + +from tinysearch.base import Reranker +from tinysearch.logger import get_logger + +logger = get_logger("CrossEncoderReranker") + +try: + from FlagEmbedding import FlagReranker + FLAGEMBEDDING_AVAILABLE = True +except ImportError: + FLAGEMBEDDING_AVAILABLE = False + + +class CrossEncoderReranker(Reranker): + """ + Cross-encoder reranker using BAAI/bge-reranker-v2-m3 (or compatible model). + + Uses FlagEmbedding for GPU-accelerated cross-encoder inference. + The model is lazily loaded on first use. + """ + + def __init__( + self, + model_name: str = "BAAI/bge-reranker-v2-m3", + device: Optional[str] = None, + batch_size: int = 64, + max_length: int = 512, + use_fp16: bool = True, + ): + """ + Args: + model_name: HuggingFace model name or local path + device: Device to use ("cuda", "cpu", or None for auto) + batch_size: Batch size for inference + max_length: Maximum sequence length + use_fp16: Whether to use FP16 for faster inference + """ + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.max_length = max_length + self.use_fp16 = use_fp16 + self._model = None + + def _ensure_model(self) -> None: + """Lazily load the reranker model""" + if self._model is not None: + return + + if not FLAGEMBEDDING_AVAILABLE: + raise ImportError( + "FlagEmbedding is required for CrossEncoderReranker. " + "Install with: pip install FlagEmbedding" + ) + + # Normalize device for FlagReranker + device = self.device + if device is None: + try: + import torch + device = 0 if torch.cuda.is_available() else "cpu" + except ImportError: + device = "cpu" + elif device == "cuda": + device = 0 + elif device.startswith("cuda:") and device[5:].isdigit(): + device = int(device[5:]) + + logger.info(f"Loading reranker model: {self.model_name} on device={device}") + + self._model = FlagReranker( + self.model_name, + devices=device, + batch_size=self.batch_size, + max_length=self.max_length, + use_fp16=self.use_fp16, + ) + + # Trigger full weight loading with dummy inference + try: + self._model.compute_score([["warmup", "warmup"]]) + except Exception: + pass + + logger.info("Reranker model loaded successfully") + + def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: + """Re-rank candidates using cross-encoder scoring""" + if not candidates: + return [] + + self._ensure_model() + + # Build query-document pairs + pairs = [] + for item in candidates: + doc_text = item.get("text", "") + pairs.append([query, doc_text]) + + # Compute scores + scores = self._model.compute_score(pairs) + + # Handle single result (compute_score returns a float instead of list) + if isinstance(scores, (int, float)): + scores = [scores] + + # Attach rerank scores + reranked = [] + for item, score in zip(candidates, scores): + reranked.append({ + **item, + "rerank_score": float(score), + }) + + # Sort by rerank score descending + reranked.sort(key=lambda x: x["rerank_score"], reverse=True) + return reranked[:top_k] + + @classmethod + def is_available(cls) -> bool: + """Check if FlagEmbedding is installed""" + return FLAGEMBEDDING_AVAILABLE diff --git a/tinysearch/retrievers/__init__.py b/tinysearch/retrievers/__init__.py new file mode 100644 index 0000000..bf71115 --- /dev/null +++ b/tinysearch/retrievers/__init__.py @@ -0,0 +1,13 @@ +""" +Retrievers for text-level search +""" + +from .vector_retriever import VectorRetriever +from .bm25_retriever import BM25Retriever +from .substring_retriever import SubstringRetriever + +__all__ = [ + "VectorRetriever", + "BM25Retriever", + "SubstringRetriever", +] diff --git a/tinysearch/retrievers/bm25_retriever.py b/tinysearch/retrievers/bm25_retriever.py new file mode 100644 index 0000000..92b65e9 --- /dev/null +++ b/tinysearch/retrievers/bm25_retriever.py @@ -0,0 +1,162 @@ +""" +BM25 Retriever - Keyword-based search using bm25s + +Fast BM25 implementation with optional jieba Chinese tokenization. +""" +import json +import pickle +import re +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +from tinysearch.base import Retriever, TextChunk +from tinysearch.logger import get_logger + +logger = get_logger("BM25Retriever") + +try: + import bm25s + BM25S_AVAILABLE = True +except ImportError: + BM25S_AVAILABLE = False + +try: + import jieba + JIEBA_AVAILABLE = True +except ImportError: + JIEBA_AVAILABLE = False + + +def _default_tokenizer(text: str) -> List[str]: + """Default tokenizer: jieba if available, else whitespace split""" + text_lower = text.lower() + if JIEBA_AVAILABLE: + return list(jieba.cut(text_lower)) + return re.findall(r'[\w]+', text_lower) + + +class BM25Retriever(Retriever): + """ + BM25 keyword-based retriever using the bm25s library. + + Features: + - Fast indexing and retrieval via bm25s + - Configurable tokenizer (default: jieba for Chinese, fallback to whitespace) + - Persistent index storage + """ + + def __init__( + self, + tokenizer: Optional[Callable[[str], List[str]]] = None, + ): + """ + Args: + tokenizer: Custom tokenizer function. Takes a string, returns list of tokens. + Defaults to jieba (if available) or whitespace splitting. + """ + if not BM25S_AVAILABLE: + logger.warning( + "bm25s not installed. BM25 retrieval will not work. " + "Install with: pip install bm25s" + ) + self.tokenizer = tokenizer or _default_tokenizer + + # Internal state + self._index: Optional[Any] = None # bm25s.BM25 + self._chunks: List[TextChunk] = [] + self._corpus_tokens: List[List[str]] = [] + + def build(self, chunks: List[TextChunk]) -> None: + """Build BM25 index from text chunks""" + if not BM25S_AVAILABLE: + raise ImportError( + "bm25s is required for BM25Retriever. Install with: pip install bm25s" + ) + if not chunks: + return + + self._chunks = list(chunks) + + # Tokenize all documents + self._corpus_tokens = [self.tokenizer(chunk.text) for chunk in chunks] + + # Build bm25s index + self._index = bm25s.BM25() + self._index.index(self._corpus_tokens) + + def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """Retrieve documents using BM25 keyword matching""" + if self._index is None: + return [] + + # Tokenize query + query_tokens = self.tokenizer(query) + if not query_tokens: + return [] + + # Clamp top_k to number of indexed documents + effective_k = min(top_k, len(self._chunks)) + if effective_k == 0: + return [] + + # Retrieve (no corpus → returns indices instead of documents) + result = self._index.retrieve( + [query_tokens], k=effective_k, return_as="tuple" + ) + indices = result.documents[0] + scores = result.scores[0] + + # Build result list + results = [] + for idx, score in zip(indices, scores): + idx = int(idx) + if idx < 0 or idx >= len(self._chunks): + continue + chunk = self._chunks[idx] + results.append({ + "text": chunk.text, + "metadata": chunk.metadata, + "score": float(score), + "retrieval_method": "bm25", + }) + + return results + + def save(self, path: Union[str, Path]) -> None: + """Save BM25 index to disk""" + if self._index is None: + raise ValueError("No index to save. Call build() first.") + + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + # Save bm25s index + self._index.save(str(path), corpus=self._corpus_tokens) + + # Save chunks metadata + chunks_data = [(chunk.text, chunk.metadata) for chunk in self._chunks] + with open(path / "chunks.pkl", "wb") as f: + pickle.dump(chunks_data, f) + + def load(self, path: Union[str, Path]) -> None: + """Load BM25 index from disk""" + if not BM25S_AVAILABLE: + raise ImportError( + "bm25s is required for BM25Retriever. Install with: pip install bm25s" + ) + + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"BM25 index directory not found: {path}") + + # Load bm25s index (don't load corpus - we manage chunks separately) + self._index = bm25s.BM25.load(str(path), load_corpus=False) + + # Load chunks metadata + chunks_file = path / "chunks.pkl" + if chunks_file.exists(): + with open(chunks_file, "rb") as f: + chunks_data = pickle.load(f) + self._chunks = [TextChunk(text, metadata) for text, metadata in chunks_data] + else: + self._chunks = [] diff --git a/tinysearch/retrievers/substring_retriever.py b/tinysearch/retrievers/substring_retriever.py new file mode 100644 index 0000000..0038d4d --- /dev/null +++ b/tinysearch/retrievers/substring_retriever.py @@ -0,0 +1,148 @@ +""" +Substring Retriever - Ctrl+F style exact match search + +Fast regex-based substring search. No external dependencies. +""" +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +import json +import pickle + +from tinysearch.base import Retriever, TextChunk + + +class SubstringRetriever(Retriever): + """ + Regex/substring retriever for exact match queries. + + Operates on raw text with regex or plain substring matching. + Useful for finding exact phrases, codes, or patterns. + """ + + def __init__(self, is_regex: bool = False): + """ + Args: + is_regex: If True, treat query as regex pattern. + If False, escape query for literal substring matching. + """ + self.is_regex = is_regex + self._chunks: List[TextChunk] = [] + + def build(self, chunks: List[TextChunk]) -> None: + """Store chunks in memory for substring search""" + self._chunks = list(chunks) + + def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """Search chunks using regex/substring matching""" + if not self._chunks or not query: + return [] + + results = [] + + try: + if self.is_regex: + pattern = re.compile(query, re.IGNORECASE) + else: + pattern = re.compile(re.escape(query), re.IGNORECASE) + + for chunk in self._chunks: + match = pattern.search(chunk.text) + if match: + score = self._calculate_match_score(match, chunk.text) + results.append({ + "text": chunk.text, + "metadata": chunk.metadata, + "score": score, + "retrieval_method": "substring", + "match_text": match.group(0), + }) + + # Collect extra then sort + if len(results) >= top_k * 2: + break + + except re.error: + # Fallback to plain substring if regex fails + results = self._plain_substring_search(query, top_k * 2) + + # Sort by score descending and take top_k + results.sort(key=lambda x: x["score"], reverse=True) + return results[:top_k] + + def _calculate_match_score(self, match: re.Match, text: str) -> float: + """Calculate match quality score""" + score = 1.0 + + # Bonus for match at start + if match.start() == 0: + score += 2.0 + + # Bonus for longer matches (max +3) + match_len = len(match.group(0)) + score += min(match_len / 10.0, 3.0) + + # Small penalty for matching deep in long text + if len(text) > 500 and match.start() > 100: + score -= 0.5 + + return score + + def _plain_substring_search(self, query: str, limit: int) -> List[Dict[str, Any]]: + """Fallback plain substring search when regex fails""" + results = [] + query_lower = query.lower() + + for chunk in self._chunks: + text_lower = chunk.text.lower() + if query_lower in text_lower: + pos = text_lower.find(query_lower) + score = 1.0 + if pos == 0: + score += 2.0 + score += min(len(query) / 10.0, 3.0) + + results.append({ + "text": chunk.text, + "metadata": chunk.metadata, + "score": score, + "retrieval_method": "substring", + "match_text": query[:50], + }) + + if len(results) >= limit: + break + + return results + + def save(self, path: Union[str, Path]) -> None: + """Save chunks to disk""" + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + chunks_data = [(chunk.text, chunk.metadata) for chunk in self._chunks] + with open(path / "chunks.pkl", "wb") as f: + pickle.dump(chunks_data, f) + + # Save config + config = {"is_regex": self.is_regex} + with open(path / "config.json", "w") as f: + json.dump(config, f) + + def load(self, path: Union[str, Path]) -> None: + """Load chunks from disk""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Substring index directory not found: {path}") + + chunks_file = path / "chunks.pkl" + if chunks_file.exists(): + with open(chunks_file, "rb") as f: + chunks_data = pickle.load(f) + self._chunks = [TextChunk(text, metadata) for text, metadata in chunks_data] + + config_file = path / "config.json" + if config_file.exists(): + with open(config_file, "r") as f: + config = json.load(f) + self.is_regex = config.get("is_regex", False) diff --git a/tinysearch/retrievers/vector_retriever.py b/tinysearch/retrievers/vector_retriever.py new file mode 100644 index 0000000..1583083 --- /dev/null +++ b/tinysearch/retrievers/vector_retriever.py @@ -0,0 +1,84 @@ +""" +Vector retriever - wraps Embedder + VectorIndexer as a Retriever +""" +from typing import Any, Dict, List, Optional, Union +import pathlib +import numpy as np + +from tinysearch.base import Retriever, Embedder, VectorIndexer, TextChunk + + +class VectorRetriever(Retriever): + """ + Wraps an Embedder and VectorIndexer into the Retriever interface. + + This is the bridge between the existing vector search pipeline and + the new Retriever abstraction, ensuring backward compatibility. + """ + + def __init__( + self, + embedder: Embedder, + indexer: VectorIndexer, + query_template: Optional[str] = None, + ): + """ + Args: + embedder: Embedder to convert text to vectors + indexer: VectorIndexer for similarity search + query_template: Optional template for formatting queries (e.g. "请帮我查找:{query}") + """ + self.embedder = embedder + self.indexer = indexer + self.query_template = query_template + + def build(self, chunks: List[TextChunk]) -> None: + """Embed chunks and build the vector index""" + if not chunks: + return + texts = [chunk.text for chunk in chunks] + vectors = self.embedder.embed(texts) + self.indexer.build(vectors, chunks) + + def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """Embed query and search the vector index""" + # Apply query template if configured + formatted_query = query + if self.query_template: + try: + formatted_query = self.query_template.format(query=query) + except (KeyError, ValueError): + formatted_query = query + + # Embed the query + query_vectors = self.embedder.embed([formatted_query]) + if not query_vectors: + return [] + query_vector = query_vectors[0] + + # Search + raw_results = self.indexer.search(query_vector, top_k) + + # Normalize scores to [0, 1] and add retrieval_method + results = [] + for r in raw_results: + result = { + "text": r["text"], + "metadata": r.get("metadata", {}), + "score": float(r.get("score", 0.0)), + "retrieval_method": "vector", + } + # Preserve embedding if present + if "embedding" in r: + result["embedding"] = r["embedding"] + results.append(result) + + return results + + def save(self, path: Union[str, pathlib.Path]) -> None: + """Save the vector index to disk""" + self.indexer.save(path) + + def load(self, path: Union[str, pathlib.Path]) -> None: + """Load the vector index from disk""" + self.indexer.load(path) From 8d57a5cfc6e59f709e696d312b934001273434ed Mon Sep 17 00:00:00 2001 From: CodePothunter Date: Fri, 20 Mar 2026 17:01:21 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E2=9C=A8=E3=80=90=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E3=80=91=EF=BC=9AHybridQueryEngine=20=E6=96=B0=E5=A2=9E=20Meta?= =?UTF-8?q?data=20Filtering=E3=80=81=E5=8A=A8=E6=80=81=E6=9D=83=E9=87=8D?= =?UTF-8?q?=E3=80=81min=5Fscore=E3=80=81=E8=AF=A6=E7=BB=86=E7=BB=93?= =?UTF-8?q?=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增功能: - Metadata Filtering:支持精确匹配、列表 OR、callable 谓词的后过滤 - 动态权重:retrieve() 支持 weights kwarg 透传到 FusionStrategy - Per-retriever min_score:各路召回后按阈值过滤低分结果 - retrieve_with_details():返回各阶段中间结果用于调试 Bug 修复: - (Critical) CLI/API 入口现在正确构建、保存、加载 BM25/Substring 索引 - (High) FlowController.query() 透传 **kwargs 到 query_engine.retrieve() - (High) 非向量 retriever 索引路径改为 FAISS 索引目录内,避免路径冲突 - (Medium) 融合去重键改为 text+source+chunk_index,防止不同 chunk 误合并 - (Medium) CLI 目录建索引时按文件逐个提取,确保 source 元数据精确到文件 - (Medium) HybridQueryEngine 召回异常改为 logger.warning,不再静默吞掉 - (Medium) QueryEngine ABC 签名统一添加 **kwargs - (Medium) API /query 端点透传 params 到检索引擎 测试:新增 tests/test_hybrid.py(31 个测试覆盖全部新功能和修复) Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_hybrid.py | 383 ++++++++++++++++++++++++++++++++++ tinysearch/api.py | 14 +- tinysearch/base.py | 7 +- tinysearch/cli.py | 104 ++++++++- tinysearch/config.py | 2 + tinysearch/flow/controller.py | 11 +- tinysearch/fusion/_utils.py | 22 ++ tinysearch/fusion/rrf.py | 3 +- tinysearch/fusion/weighted.py | 3 +- tinysearch/query/hybrid.py | 143 +++++++++++-- tinysearch/query/template.py | 2 +- 11 files changed, 659 insertions(+), 35 deletions(-) create mode 100644 tests/test_hybrid.py create mode 100644 tinysearch/fusion/_utils.py diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py new file mode 100644 index 0000000..43192cc --- /dev/null +++ b/tests/test_hybrid.py @@ -0,0 +1,383 @@ +""" +Tests for hybrid search: HybridQueryEngine, FusionStrategies, and CLI/FlowController integration. +""" +import pytest +from unittest.mock import MagicMock, patch +from pathlib import Path + +from tinysearch.base import TextChunk, QueryEngine, Retriever +from tinysearch.retrievers.bm25_retriever import BM25Retriever +from tinysearch.retrievers.substring_retriever import SubstringRetriever +from tinysearch.fusion.weighted import WeightedFusion +from tinysearch.fusion.rrf import ReciprocalRankFusion +from tinysearch.fusion._utils import make_doc_key +from tinysearch.query.hybrid import HybridQueryEngine + + +# ── Fixtures ────────────────────────────────────────────── + +@pytest.fixture +def sample_chunks(): + return [ + TextChunk("Python编程语言", {"source": "a.txt", "grade": "六年级上", "chunk_index": 0}), + TextChunk("Java编程语言", {"source": "b.txt", "grade": "六年级下", "chunk_index": 0}), + TextChunk("TinySearch检索系统", {"source": "c.txt", "grade": "七年级", "chunk_index": 0}), + ] + + +@pytest.fixture +def hybrid_engine(sample_chunks): + bm25 = BM25Retriever() + bm25.build(sample_chunks) + substr = SubstringRetriever() + substr.build(sample_chunks) + return HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + min_scores=[0.0, 0.0], + ) + + +# ── HybridQueryEngine basic ────────────────────────────── + +class TestHybridQueryEngine: + def test_retrieve_returns_results(self, hybrid_engine): + results = hybrid_engine.retrieve("编程", top_k=5) + assert len(results) > 0 + assert all("text" in r and "score" in r for r in results) + + def test_backward_compat_no_kwargs(self, hybrid_engine): + """retrieve() works with no extra kwargs (backward compatible).""" + results = hybrid_engine.retrieve("编程", top_k=5) + assert isinstance(results, list) + + def test_empty_retrievers_raises(self): + with pytest.raises(ValueError, match="At least one retriever"): + HybridQueryEngine([], WeightedFusion()) + + def test_min_scores_length_mismatch(self, sample_chunks): + bm25 = BM25Retriever() + bm25.build(sample_chunks) + with pytest.raises(ValueError, match="min_scores length"): + HybridQueryEngine([bm25], WeightedFusion(), min_scores=[0.0, 0.0]) + + +# ── Metadata filtering ─────────────────────────────────── + +class TestMetadataFiltering: + def test_exact_match_filter(self, hybrid_engine): + results = hybrid_engine.retrieve("编程", top_k=5, filters={"source": "a.txt"}) + assert all(r["metadata"]["source"] == "a.txt" for r in results) + + def test_list_or_filter(self, hybrid_engine): + results = hybrid_engine.retrieve( + "编程", top_k=5, + filters={"grade": ["六年级上", "六年级下"]}, + ) + assert len(results) > 0 + assert all(r["metadata"]["grade"].startswith("六年级") for r in results) + + def test_callable_filter(self, hybrid_engine): + results = hybrid_engine.retrieve( + "编程", top_k=5, + filters={"source": lambda v: v in ["a.txt", "b.txt"]}, + ) + assert all(r["metadata"]["source"] in ["a.txt", "b.txt"] for r in results) + + def test_missing_key_excluded(self, hybrid_engine): + """Chunks missing a filter key should be excluded.""" + results = hybrid_engine.retrieve( + "编程", top_k=5, + filters={"nonexistent_key": "value"}, + ) + assert len(results) == 0 + + def test_none_metadata_excluded(self): + assert HybridQueryEngine._match_filters(None, {"key": "val"}) is False + + def test_filter_over_recall(self, hybrid_engine): + """With filters, recall_k should be multiplied by filter_multiplier.""" + # Default filter_multiplier=3, recall_multiplier=2 + # So recall_k = 5 * 2 * 3 = 30 + # This is tested implicitly: we still get results despite filters + results = hybrid_engine.retrieve( + "编程", top_k=5, + filters={"grade": ["六年级上"]}, + ) + assert len(results) > 0 + + +# ── Dynamic weights ─────────────────────────────────────── + +class TestDynamicWeights: + def test_dynamic_weights_override(self, hybrid_engine): + """Passing weights= at query time should override constructor weights.""" + r1 = hybrid_engine.retrieve("编程", top_k=5, weights=[1.0, 0.0]) + r2 = hybrid_engine.retrieve("编程", top_k=5, weights=[0.0, 1.0]) + # Results should differ because weights favor different retrievers + # At minimum, both should return results + assert len(r1) > 0 + assert len(r2) > 0 + + +# ── min_scores ──────────────────────────────────────────── + +class TestMinScores: + def test_high_min_scores_filters_everything(self, sample_chunks): + bm25 = BM25Retriever() + bm25.build(sample_chunks) + substr = SubstringRetriever() + substr.build(sample_chunks) + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.5, 0.5]), + min_scores=[999.0, 999.0], + ) + results = engine.retrieve("编程", top_k=5) + assert len(results) == 0 + + def test_zero_min_scores_passes_all(self, hybrid_engine): + results = hybrid_engine.retrieve("编程", top_k=5) + assert len(results) > 0 + + +# ── retrieve_with_details ───────────────────────────────── + +class TestRetrieveWithDetails: + def test_returns_all_keys(self, hybrid_engine): + d = hybrid_engine.retrieve_with_details("编程", top_k=5) + assert "results" in d + assert "per_retriever" in d + assert "fused_before_rerank" in d + + def test_per_retriever_count_matches(self, hybrid_engine): + d = hybrid_engine.retrieve_with_details("编程", top_k=5) + assert len(d["per_retriever"]) == len(hybrid_engine.retrievers) + + def test_results_match_retrieve(self, hybrid_engine): + """retrieve_with_details().results should equal retrieve().""" + r1 = hybrid_engine.retrieve("编程", top_k=5) + r2 = hybrid_engine.retrieve_with_details("编程", top_k=5)["results"] + assert len(r1) == len(r2) + for a, b in zip(r1, r2): + assert a["text"] == b["text"] + + +# ── Fusion dedup key ────────────────────────────────────── + +class TestFusionDedupKey: + def test_same_text_different_source_not_merged(self): + """Two chunks with same text but different source should stay separate.""" + chunks = [ + TextChunk("Same text", {"source": "fileA.txt", "chunk_index": 0}), + TextChunk("Same text", {"source": "fileB.txt", "chunk_index": 0}), + ] + bm25 = BM25Retriever() + bm25.build(chunks) + substr = SubstringRetriever() + substr.build(chunks) + engine = HybridQueryEngine( + [bm25, substr], WeightedFusion([0.5, 0.5]) + ) + results = engine.retrieve("Same text", top_k=5) + sources = {r["metadata"]["source"] for r in results} + assert sources == {"fileA.txt", "fileB.txt"} + + def test_same_text_different_chunk_index_not_merged(self): + """Same text at different positions in same file should stay separate.""" + chunks = [ + TextChunk("Repeated", {"source": "f.txt", "chunk_index": 0}), + TextChunk("Repeated", {"source": "f.txt", "chunk_index": 5}), + ] + bm25 = BM25Retriever() + bm25.build(chunks) + substr = SubstringRetriever() + substr.build(chunks) + engine = HybridQueryEngine( + [bm25, substr], WeightedFusion([0.5, 0.5]) + ) + results = engine.retrieve("Repeated", top_k=5) + indices = {r["metadata"]["chunk_index"] for r in results} + assert indices == {0, 5} + + def test_make_doc_key_different_source(self): + r1 = {"text": "hello", "metadata": {"source": "a"}} + r2 = {"text": "hello", "metadata": {"source": "b"}} + assert make_doc_key(r1) != make_doc_key(r2) + + def test_make_doc_key_same_chunk(self): + """Same chunk from different retrievers should produce the same key.""" + r1 = {"text": "hello", "metadata": {"source": "a", "chunk_index": 0}} + r2 = {"text": "hello", "metadata": {"source": "a", "chunk_index": 0}} + assert make_doc_key(r1) == make_doc_key(r2) + + def test_rrf_dedup_also_fixed(self): + """RRF fusion should also keep different-source same-text chunks separate.""" + chunks = [ + TextChunk("Same text", {"source": "x.txt", "chunk_index": 0}), + TextChunk("Same text", {"source": "y.txt", "chunk_index": 0}), + ] + bm25 = BM25Retriever() + bm25.build(chunks) + substr = SubstringRetriever() + substr.build(chunks) + engine = HybridQueryEngine( + [bm25, substr], ReciprocalRankFusion() + ) + results = engine.retrieve("Same text", top_k=5) + assert len(results) == 2 + + +# ── Silent exception logging ───────────────────────────── + +class TestRetrieverFailureLogging: + def test_failing_retriever_logged_not_crash(self, caplog): + """A failing retriever should log a warning and return empty, not crash.""" + class FailRetriever(Retriever): + def build(self, chunks): pass + def retrieve(self, query, top_k=5): + raise RuntimeError("boom") + def save(self, path): pass + def load(self, path): pass + + engine = HybridQueryEngine([FailRetriever()], WeightedFusion()) + import logging + with caplog.at_level(logging.WARNING, logger="tinysearch.query.hybrid"): + results = engine.retrieve("test", top_k=5) + assert results == [] + assert "FailRetriever failed" in caplog.text + + +# ── ABC signature ───────────────────────────────────────── + +class TestABCSignature: + def test_query_engine_retrieve_accepts_kwargs(self): + import inspect + sig = inspect.signature(QueryEngine.retrieve) + params = list(sig.parameters.keys()) + assert "kwargs" in params + + def test_hybrid_engine_is_query_engine(self, hybrid_engine): + assert isinstance(hybrid_engine, QueryEngine) + + +# ── FlowController kwargs forwarding ────────────────────── + +class TestFlowControllerKwargs: + def test_query_forwards_kwargs(self): + """FlowController.query() should forward **kwargs to query_engine.retrieve().""" + mock_engine = MagicMock(spec=QueryEngine) + mock_engine.retrieve.return_value = [] + + from tinysearch.flow.controller import FlowController + # Minimal construction — we only test query(), not build + fc = FlowController.__new__(FlowController) + fc.query_engine = mock_engine + fc.config = {"query_engine": {"top_k": 5}} + + fc.query("test", top_k=5, filters={"source": "a"}, weights=[1.0]) + mock_engine.retrieve.assert_called_once_with( + "test", 5, filters={"source": "a"}, weights=[1.0] + ) + + +# ── CLI helpers ─────────────────────────────────────────── + +class TestCLIHelpers: + def test_get_retriever_index_dir(self): + from tinysearch.cli import _get_retriever_index_dir + assert _get_retriever_index_dir(Path("index.faiss")) == Path("index") + assert _get_retriever_index_dir(Path("data/my.faiss")) == Path("data/my") + assert _get_retriever_index_dir(Path("mydir")) == Path("mydir") + + def test_build_hybrid_noop_for_template_engine(self): + """_build_hybrid_retriever_indexes should be a no-op for non-hybrid engines.""" + from tinysearch.cli import _build_hybrid_retriever_indexes + mock_engine = MagicMock() # not a HybridQueryEngine + # Should not raise + _build_hybrid_retriever_indexes(mock_engine, []) + + def test_save_load_hybrid_noop_for_template_engine(self): + from tinysearch.cli import _save_hybrid_retriever_indexes, _load_hybrid_retriever_indexes + mock_engine = MagicMock() + _save_hybrid_retriever_indexes(mock_engine, Path("test.faiss")) + _load_hybrid_retriever_indexes(mock_engine, Path("test.faiss")) + + def test_build_index_directory_per_file_source(self, tmp_path): + """CLI build_index on a directory should assign per-file source metadata.""" + import argparse + from tinysearch.cli import build_index + from tinysearch.config import Config + + # Create two files with the same content + (tmp_path / "file1.txt").write_text("Same content here") + (tmp_path / "file2.txt").write_text("Same content here") + + config = Config() + config.set("query_engine.method", "hybrid") + config.set("retrievers", [{"type": "bm25"}, {"type": "substring"}]) + config.set("fusion.strategy", "weighted") + config.set("fusion.weights", [0.5, 0.5]) + + # We just need to verify the metadata, not actually embed/index. + # Patch embedder and indexer to avoid needing a real model. + captured_chunks = [] + original_split = None + + from tinysearch.splitters import CharacterTextSplitter + orig_split = CharacterTextSplitter.split + + def spy_split(self, texts, metadata=None): + result = orig_split(self, texts, metadata) + captured_chunks.extend(result) + return result + + args = argparse.Namespace(data=str(tmp_path)) + + with patch.object(CharacterTextSplitter, "split", spy_split), \ + patch("tinysearch.cli.load_embedder") as mock_emb, \ + patch("tinysearch.cli.load_indexer") as mock_idx: + mock_emb_inst = MagicMock() + mock_emb_inst.embed.return_value = [[0.0] * 10] * 100 + mock_emb.return_value = mock_emb_inst + mock_idx_inst = MagicMock() + mock_idx.return_value = mock_idx_inst + + build_index(args, config) + + # Each chunk should have a per-file source, not the directory + sources = {c.metadata["source"] for c in captured_chunks} + assert len(sources) == 2, f"Expected 2 distinct sources, got {sources}" + assert all(str(tmp_path) != s for s in sources), \ + f"source should be per-file, not the directory: {sources}" + + +# ── Retriever index save path ───────────────────────────── + +class TestRetrieverIndexPath: + def test_flow_controller_uses_faiss_dir(self): + """FlowController should save retriever indexes inside the FAISS index directory.""" + from tinysearch.flow.controller import FlowController + fc = FlowController.__new__(FlowController) + fc.query_engine = MagicMock(spec=HybridQueryEngine) + + mock_retriever = MagicMock() + mock_retriever.__class__.__name__ = "BM25Retriever" + fc.query_engine.retrievers = [mock_retriever] + + fc._save_retriever_indexes(Path("data/index.faiss")) + mock_retriever.save.assert_called_once_with(Path("data/index") / "bm25_index") + + def test_flow_controller_load_uses_faiss_dir(self): + from tinysearch.flow.controller import FlowController + fc = FlowController.__new__(FlowController) + fc.query_engine = MagicMock(spec=HybridQueryEngine) + + mock_retriever = MagicMock() + mock_retriever.__class__.__name__ = "BM25Retriever" + fc.query_engine.retrievers = [mock_retriever] + + # Simulate that the path exists + with patch.object(Path, "exists", return_value=True): + fc._load_retriever_indexes(Path("data/index.faiss")) + mock_retriever.load.assert_called_once_with(Path("data/index") / "bm25_index") diff --git a/tinysearch/api.py b/tinysearch/api.py index f1cd75b..276583a 100644 --- a/tinysearch/api.py +++ b/tinysearch/api.py @@ -19,7 +19,10 @@ import tempfile from .config import Config -from .cli import load_embedder, load_indexer, load_query_engine, load_adapter, load_splitter +from .cli import ( + load_embedder, load_indexer, load_query_engine, load_adapter, load_splitter, + _load_hybrid_retriever_indexes, +) from .flow.controller import FlowController from .logger import get_logger, configure_logger, log_success, log_warning, log_error @@ -177,6 +180,8 @@ def initialize_components(): logger = get_logger("api_init") if Path(index_path).exists(): indexer.load(index_path) + # Load non-vector retriever indexes (BM25, Substring) for hybrid mode + _load_hybrid_retriever_indexes(query_engine, Path(index_path), logger) log_success(f"TinySearch initialized with index: {index_path}") else: log_warning(f"Index not found: {index_path}") @@ -409,8 +414,11 @@ async def query( raise HTTPException(status_code=500, detail="TinySearch not initialized") try: - # Execute query - results = query_engine.retrieve(query_request.query, top_k=query_request.top_k) + # Execute query — forward params (filters, weights, etc.) to the engine + extra_params = query_request.params or {} + results = query_engine.retrieve( + query_request.query, top_k=query_request.top_k, **extra_params + ) # Format results matches = [] diff --git a/tinysearch/base.py b/tinysearch/base.py index fb3ca0c..4463a0e 100644 --- a/tinysearch/base.py +++ b/tinysearch/base.py @@ -152,14 +152,15 @@ def format_query(self, query: str) -> str: pass @abstractmethod - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ Retrieve relevant chunks for a query - + Args: query: Query string top_k: Number of results to return - + **kwargs: Engine-specific parameters (e.g. filters, weights) + Returns: List of dictionaries containing text chunks and similarity scores """ diff --git a/tinysearch/cli.py b/tinysearch/cli.py index a3c019c..d2c607e 100644 --- a/tinysearch/cli.py +++ b/tinysearch/cli.py @@ -273,11 +273,73 @@ def load_query_engine(config: Config, embedder: Embedder, indexer: FAISSIndexer) fusion_strategy=fusion, reranker=reranker, recall_multiplier=config.get("query_engine.recall_multiplier", 2), + min_scores=config.get("query_engine.min_scores", None), + filter_multiplier=config.get("query_engine.filter_multiplier", 3), ) else: raise ValueError(f"Unsupported query engine type: {query_engine_type}") +def _get_retriever_index_dir(index_path: Path) -> Path: + """Derive the retriever index directory from the FAISS index path. + + FAISS saves into ``index_path.with_suffix('')`` (e.g. ``index.faiss`` → + ``index/``). Non-vector retriever indexes are stored as subdirectories + inside the same directory so they stay together with the FAISS index. + """ + return index_path.with_suffix('') if index_path.suffix else index_path + + +def _build_hybrid_retriever_indexes( + query_engine: QueryEngine, chunks, logger=None, +) -> None: + """Build BM25 / Substring indexes when the engine is a HybridQueryEngine.""" + if not isinstance(query_engine, HybridQueryEngine): + return + for retriever in query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue # already covered by FAISS indexer.build() + rname = type(retriever).__name__ + if logger: + log_step(f"Building {rname} index") + retriever.build(chunks) + + +def _save_hybrid_retriever_indexes( + query_engine: QueryEngine, index_path: Path, logger=None, +) -> None: + """Save non-vector retriever indexes alongside the FAISS index.""" + if not isinstance(query_engine, HybridQueryEngine): + return + index_dir = _get_retriever_index_dir(index_path) + for retriever in query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue + rname = type(retriever).__name__.lower().replace("retriever", "") + rpath = index_dir / f"{rname}_index" + if logger: + log_step(f"Saving {type(retriever).__name__} index to {rpath}") + retriever.save(rpath) + + +def _load_hybrid_retriever_indexes( + query_engine: QueryEngine, index_path: Path, logger=None, +) -> None: + """Load non-vector retriever indexes from alongside the FAISS index.""" + if not isinstance(query_engine, HybridQueryEngine): + return + index_dir = _get_retriever_index_dir(index_path) + for retriever in query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue + rname = type(retriever).__name__.lower().replace("retriever", "") + rpath = index_dir / f"{rname}_index" + if rpath.exists(): + if logger: + log_step(f"Loading {type(retriever).__name__} index from {rpath}") + retriever.load(rpath) + + def build_index(args: argparse.Namespace, config: Config) -> None: """ Build a search index @@ -294,15 +356,36 @@ def build_index(args: argparse.Namespace, config: Config) -> None: splitter = load_splitter(config) embedder = load_embedder(config) indexer = load_indexer(config) + query_engine = load_query_engine(config, embedder, indexer) - # Extract text from data source + # Extract text from data source with per-file source metadata. + # When args.data is a directory, process each file individually so that + # every text gets its actual file path as "source" — not just the + # directory name. This prevents fusion dedup key collisions between + # chunks from different files that happen to share the same text. log_step("Extracting text") - texts = adapter.extract(args.data) + data_path = Path(args.data) + texts: list = [] + metadata: list = [] + if data_path.is_dir(): + for child in sorted(data_path.rglob("*")): + if not child.is_file(): + continue + try: + file_texts = adapter.extract(child) + except Exception: + continue # adapter will skip unsupported extensions + for t in file_texts: + texts.append(t) + metadata.append({"source": str(child)}) + else: + texts = adapter.extract(args.data) + metadata = [{"source": str(data_path)} for _ in range(len(texts))] logger.info(f"📄 Extracted {len(texts)} documents") # Split text into chunks log_step("Splitting text into chunks") - chunks = splitter.split(texts) + chunks = splitter.split(texts, metadata) logger.info(f"✂️ Created {len(chunks)} text chunks") # Generate embeddings @@ -310,15 +393,21 @@ def build_index(args: argparse.Namespace, config: Config) -> None: vectors = embedder.embed([chunk.text for chunk in chunks]) logger.info(f"🧠 Generated {len(vectors)} embedding vectors") - # Build index + # Build FAISS index log_step("Building index") indexer.build(vectors, chunks) - # Save index + # Build non-vector retriever indexes (BM25, Substring) for hybrid mode + _build_hybrid_retriever_indexes(query_engine, chunks, logger) + + # Save FAISS index index_path = Path(config.get("indexer.index_path", "index.faiss")) log_step(f"Saving index to {index_path}") indexer.save(index_path) + # Save non-vector retriever indexes + _save_hybrid_retriever_indexes(query_engine, index_path, logger) + log_success("Index built successfully") @@ -337,11 +426,14 @@ def query_index(args: argparse.Namespace, config: Config) -> None: indexer = load_indexer(config) query_engine = load_query_engine(config, embedder, indexer) - # Load index + # Load FAISS index index_path = Path(config.get("indexer.index_path", "index.faiss")) log_step(f"Loading index from {index_path}") indexer.load(index_path) + # Load non-vector retriever indexes (BM25, Substring) for hybrid mode + _load_hybrid_retriever_indexes(query_engine, index_path, logger) + # Process query query = args.q top_k = args.top_k or config.get("query_engine.top_k", 5) diff --git a/tinysearch/config.py b/tinysearch/config.py index 9bc357f..5a2019a 100644 --- a/tinysearch/config.py +++ b/tinysearch/config.py @@ -57,6 +57,8 @@ def __init__(self, config_path: Optional[Union[str, Path]] = None): "template": "请帮我查找:{query}", "top_k": 5, "recall_multiplier": 2, + "min_scores": None, + "filter_multiplier": 3, }, "flow": { "use_cache": True, diff --git a/tinysearch/flow/controller.py b/tinysearch/flow/controller.py index e466d7f..998f940 100644 --- a/tinysearch/flow/controller.py +++ b/tinysearch/flow/controller.py @@ -237,8 +237,9 @@ def save_index(self, path: Optional[Union[str, Path]] = None) -> None: self._save_retriever_indexes(Path(str(path))) def _save_retriever_indexes(self, base_path: Path) -> None: - """Save indexes for non-vector retrievers alongside the main index""" - index_dir = base_path.parent if base_path.suffix else base_path + """Save indexes for non-vector retrievers inside the FAISS index directory""" + # FAISS saves into base_path.with_suffix('') (e.g. "index.faiss" → "index/") + index_dir = base_path.with_suffix('') if base_path.suffix else base_path for retriever in self._get_hybrid_retrievers(): if isinstance(retriever, VectorRetriever): continue @@ -265,8 +266,8 @@ def load_index(self, path: Optional[Union[str, Path]] = None) -> None: self._load_retriever_indexes(Path(str(path))) def _load_retriever_indexes(self, base_path: Path) -> None: - """Load indexes for non-vector retrievers""" - index_dir = base_path.parent if base_path.suffix else base_path + """Load indexes for non-vector retrievers from the FAISS index directory""" + index_dir = base_path.with_suffix('') if base_path.suffix else base_path for retriever in self._get_hybrid_retrievers(): if isinstance(retriever, VectorRetriever): continue @@ -291,7 +292,7 @@ def query(self, query_text: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any top_k = self.config.get("query_engine", {}).get("top_k", 5) # Use cast to ensure type safety for optional top_k - return self.query_engine.retrieve(query_text, cast(int, top_k)) + return self.query_engine.retrieve(query_text, cast(int, top_k), **kwargs) def clear_cache(self) -> None: """Clear all cached data""" diff --git a/tinysearch/fusion/_utils.py b/tinysearch/fusion/_utils.py new file mode 100644 index 0000000..aa4edd5 --- /dev/null +++ b/tinysearch/fusion/_utils.py @@ -0,0 +1,22 @@ +""" +Shared utilities for fusion strategies. +""" +from typing import Any, Dict + + +def make_doc_key(result: Dict[str, Any]) -> str: + """ + Create a dedup key from a retrieval result. + + Uses text + metadata (source, chunk_index) to distinguish genuinely + different chunks that happen to share the same text content, while still + merging the same chunk found by different retrievers. + + Disambiguation levels: + - source: separates same text from different files + - chunk_index: separates same text in different positions within one file + """ + meta = result.get("metadata", {}) + source = meta.get("source", "") + chunk_index = meta.get("chunk_index", "") + return f"{result['text']}\x00{source}\x00{chunk_index}" diff --git a/tinysearch/fusion/rrf.py b/tinysearch/fusion/rrf.py index 7241773..e8cf012 100644 --- a/tinysearch/fusion/rrf.py +++ b/tinysearch/fusion/rrf.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List from tinysearch.base import FusionStrategy +from tinysearch.fusion._utils import make_doc_key class ReciprocalRankFusion(FusionStrategy): @@ -42,7 +43,7 @@ def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[ for result_list in results_list: for rank, result in enumerate(result_list): - doc_key = result["text"] + doc_key = make_doc_key(result) rrf_score = 1.0 / (rank + self.k) rrf_scores[doc_key] += rrf_score diff --git a/tinysearch/fusion/weighted.py b/tinysearch/fusion/weighted.py index 0d3f464..2bad2d4 100644 --- a/tinysearch/fusion/weighted.py +++ b/tinysearch/fusion/weighted.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional from tinysearch.base import FusionStrategy +from tinysearch.fusion._utils import make_doc_key class WeightedFusion(FusionStrategy): @@ -77,7 +78,7 @@ def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[ for i, (result_list, weight) in enumerate(zip(normalized_lists, weights)): for result in result_list: - doc_key = result["text"] + doc_key = make_doc_key(result) doc_scores[doc_key] += result["_norm_score"] * weight method = result.get("retrieval_method", "unknown") diff --git a/tinysearch/query/hybrid.py b/tinysearch/query/hybrid.py index 5b74f35..dfdea11 100644 --- a/tinysearch/query/hybrid.py +++ b/tinysearch/query/hybrid.py @@ -1,10 +1,16 @@ """ Hybrid query engine - multi-retriever fusion with optional reranking """ -from typing import Any, Dict, List, Optional +import logging +from typing import Any, Callable, Dict, List, Optional, Union from tinysearch.base import QueryEngine, Retriever, FusionStrategy, Reranker +logger = logging.getLogger(__name__) + +# Type alias for filter values +FilterValue = Union[str, int, float, bool, List, Callable] + class HybridQueryEngine(QueryEngine): """ @@ -13,9 +19,12 @@ class HybridQueryEngine(QueryEngine): Pipeline: 1. Each retriever recalls top_k * recall_multiplier candidates - 2. FusionStrategy merges and deduplicates results - 3. Optional Reranker re-scores the fused candidates - 4. Return top_k final results + (× filter_multiplier when metadata filters are active) + 2. Per-retriever min_score filtering + 3. Metadata post-filtering (if filters provided) + 4. FusionStrategy merges and deduplicates results + 5. Optional Reranker re-scores the fused candidates + 6. Return top_k final results """ def __init__( @@ -24,6 +33,8 @@ def __init__( fusion_strategy: FusionStrategy, reranker: Optional[Reranker] = None, recall_multiplier: int = 2, + min_scores: Optional[List[float]] = None, + filter_multiplier: int = 3, ): """ Args: @@ -31,46 +42,148 @@ def __init__( fusion_strategy: Strategy to fuse results from multiple retrievers reranker: Optional reranker for final re-scoring recall_multiplier: Multiply top_k by this for each retriever's recall + min_scores: Per-retriever minimum score thresholds (length must match retrievers) + filter_multiplier: Extra recall multiplier when filters are active """ if not retrievers: raise ValueError("At least one retriever is required") + if min_scores is not None and len(min_scores) != len(retrievers): + raise ValueError( + f"min_scores length ({len(min_scores)}) must match " + f"retrievers length ({len(retrievers)})" + ) self.retrievers = retrievers self.fusion_strategy = fusion_strategy self.reranker = reranker self.recall_multiplier = recall_multiplier + self.min_scores = min_scores + self.filter_multiplier = filter_multiplier def format_query(self, query: str) -> str: """Pass-through: hybrid engine doesn't transform queries""" return query - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ Multi-path retrieval with fusion and optional reranking. Args: query: Query string top_k: Number of final results to return + **kwargs: + filters: Dict of metadata filters (see _match_filters) + weights: List of floats to override fusion weights dynamically Returns: Fused (and optionally reranked) list of results """ + return self._retrieve_pipeline(query, top_k, **kwargs)["results"] + + def retrieve_with_details( + self, query: str, top_k: int = 5, **kwargs + ) -> Dict[str, Any]: + """ + Like retrieve(), but returns structured details of each pipeline stage. + + Returns: + Dict with keys: + results: Final top_k results + per_retriever: List of per-retriever raw results (after min_score) + fused_before_rerank: Fused results before reranking + """ + return self._retrieve_pipeline(query, top_k, **kwargs) + + def _retrieve_pipeline( + self, query: str, top_k: int, **kwargs + ) -> Dict[str, Any]: + """Core pipeline shared by retrieve() and retrieve_with_details().""" + filters = kwargs.pop("filters", None) + weights = kwargs.pop("weights", None) + + # Compute recall amount — over-recall when filters are active recall_k = top_k * self.recall_multiplier + if filters: + recall_k *= self.filter_multiplier # Step 1: Recall from each retriever - all_results = [] - for retriever in self.retrievers: + per_retriever: List[List[Dict[str, Any]]] = [] + for i, retriever in enumerate(self.retrievers): try: results = retriever.retrieve(query, top_k=recall_k) - all_results.append(results) - except Exception: - # If a retriever fails, skip it rather than failing entirely - all_results.append([]) + except Exception as e: + retriever_name = type(retriever).__name__ + logger.warning( + "Retriever %s failed, skipping: %s", retriever_name, e + ) + results = [] + + # Per-retriever min_score filtering + if self.min_scores is not None: + threshold = self.min_scores[i] + results = [r for r in results if r.get("score", 0) >= threshold] - # Step 2: Fuse results - fused = self.fusion_strategy.fuse(all_results) + per_retriever.append(results) - # Step 3: Optional reranking + # Step 2: Metadata post-filtering + if filters: + per_retriever = [ + self._apply_filters(results, filters) for results in per_retriever + ] + + # Step 3: Fuse results (pass dynamic weights if provided) + fuse_kwargs: Dict[str, Any] = {} + if weights is not None: + fuse_kwargs["weights"] = weights + fused = self.fusion_strategy.fuse(per_retriever, **fuse_kwargs) + + fused_before_rerank = list(fused) + + # Step 4: Optional reranking if self.reranker is not None and fused: fused = self.reranker.rerank(query, fused, top_k=top_k) - return fused[:top_k] + final = fused[:top_k] + + return { + "results": final, + "per_retriever": per_retriever, + "fused_before_rerank": fused_before_rerank, + } + + @staticmethod + def _match_filters( + metadata: Optional[Dict[str, Any]], filters: Dict[str, FilterValue] + ) -> bool: + """ + Check whether a metadata dict matches all filter criteria. + + Filter syntax (all keys are AND-ed): + - str/int/float/bool → exact match + - list → match any value in the list (OR) + - callable → predicate function returning bool + + A result with missing metadata or missing a required key does NOT pass. + """ + if metadata is None: + return False + for key, condition in filters.items(): + if key not in metadata: + return False + value = metadata[key] + if callable(condition): + if not condition(value): + return False + elif isinstance(condition, list): + if value not in condition: + return False + else: + if value != condition: + return False + return True + + @classmethod + def _apply_filters( + cls, results: List[Dict[str, Any]], filters: Dict[str, FilterValue] + ) -> List[Dict[str, Any]]: + """Filter a list of results by metadata criteria.""" + return [r for r in results if cls._match_filters(r.get("metadata"), filters)] diff --git a/tinysearch/query/template.py b/tinysearch/query/template.py index 24ec1b7..a4faf63 100644 --- a/tinysearch/query/template.py +++ b/tinysearch/query/template.py @@ -54,7 +54,7 @@ def format_query(self, query: str) -> str: # If all else fails, just concatenate return f"{self.template} {query}" - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ Retrieve relevant chunks for a query From 5f23664b6f98f58ebf14efa2616cffbeea74774b Mon Sep 17 00:00:00 2001 From: CodePothunter Date: Fri, 20 Mar 2026 18:30:05 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E2=9C=A8=E3=80=90=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E3=80=91=EF=BC=9A=E6=96=B0=E5=A2=9E=20MetadataIndex=20?= =?UTF-8?q?=E9=A2=84=E8=BF=87=E6=BB=A4=E3=80=81=E9=9B=86=E4=B8=AD=E5=BC=8F?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=91=E7=8E=B0=E3=80=81Adapter=20?= =?UTF-8?q?=E5=8D=95=E6=96=87=E4=BB=B6=E8=81=8C=E8=B4=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **新增文件 tinysearch/utils/__init__.py** **新增文件 tinysearch/utils/file_discovery.py** - 新增 iter_input_files() 集中式目录扫描函数 - 支持 adapter_type 自动匹配默认扩展名、自定义扩展名覆盖、递归/非递归模式 **新增文件 tinysearch/indexers/metadata_index.py** - 新增 MetadataIndex 倒排索引,支持 O(1) 元数据预过滤 - 提供 build/lookup/classify_filters/save/load 全生命周期管理 **新增文件 tests/test_metadata_index.py** - 新增 MetadataIndex 单元测试 **tinysearch/adapters/text.py, csv.py, json_adapter.py, markdown.py, pdf.py** - 移除各 Adapter 内嵌的目录遍历逻辑 - 传入目录时抛出 ValueError,引导使用 iter_input_files() **tinysearch/query/hybrid.py** - 新增 metadata_index 和 filter_mode 参数("pre"/"post"/"auto") - 重构 retrieve() 流程:自动拆分 indexable/callable 过滤器 - 预过滤空结果时短路返回 **tinysearch/retrievers/bm25_retriever.py, substring_retriever.py, vector_retriever.py** - 统一新增 candidate_ids 参数支持预过滤 - 过召回再截断策略确保预过滤不影响 top_k 准确性 **tinysearch/indexers/faiss_indexer.py** - search() 新增 candidate_ids 参数 - 过召回 3x 再过滤策略 **tinysearch/cli.py** - 统一使用 iter_input_files() 替代手动 rglob 遍历 - build/save/load 流程集成 MetadataIndex **tinysearch/flow/controller.py** - process_directory() 统一使用 iter_input_files() - _build/_save/_load_retriever_indexes() 集成 MetadataIndex **tinysearch/indexers/__init__.py** - 导出 MetadataIndex **tests/test_hybrid.py** - 新增 iter_input_files、Adapter 目录拒绝、source 一致性、预过滤等测试 **README.md** - 新增 Directory Indexing 文档章节 - 更新 Custom Data Adapters 说明 **影响说明** - 质量提升:预过滤大幅减少无效检索计算量 - 架构改进:Adapter 职责单一化,文件发现逻辑集中管理 - 兼容性:向后兼容,metadata_index 默认 None、filter_mode 默认 "auto" Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 38 ++- tests/test_hybrid.py | 229 +++++++++++++++++++ tests/test_metadata_index.py | 153 +++++++++++++ tinysearch/adapters/csv.py | 21 +- tinysearch/adapters/json_adapter.py | 21 +- tinysearch/adapters/markdown.py | 23 +- tinysearch/adapters/pdf.py | 21 +- tinysearch/adapters/text.py | 30 +-- tinysearch/cli.py | 56 +++-- tinysearch/flow/controller.py | 42 ++-- tinysearch/indexers/__init__.py | 4 +- tinysearch/indexers/faiss_indexer.py | 24 +- tinysearch/indexers/metadata_index.py | 201 ++++++++++++++++ tinysearch/query/hybrid.py | 78 +++++-- tinysearch/retrievers/bm25_retriever.py | 25 +- tinysearch/retrievers/substring_retriever.py | 17 +- tinysearch/retrievers/vector_retriever.py | 18 +- tinysearch/utils/__init__.py | 6 + tinysearch/utils/file_discovery.py | 55 +++++ 19 files changed, 917 insertions(+), 145 deletions(-) create mode 100644 tests/test_metadata_index.py create mode 100644 tinysearch/indexers/metadata_index.py create mode 100644 tinysearch/utils/__init__.py create mode 100644 tinysearch/utils/file_discovery.py diff --git a/README.md b/README.md index 6764c85..84c284f 100644 --- a/README.md +++ b/README.md @@ -433,14 +433,16 @@ The TinySearch Web UI provides an intuitive interface for interacting with the s ## Custom Data Adapters -You can create custom data adapters by implementing the `DataAdapter` interface: +You can create custom data adapters by implementing the `DataAdapter` interface. +Each adapter handles **single file** extraction — directory traversal is handled +automatically by the framework: ```python from tinysearch.base import DataAdapter class MyAdapter(DataAdapter): def extract(self, filepath): - # Your code to extract text from the file + # Extract text from a single file return [text1, text2, ...] ``` @@ -456,6 +458,38 @@ adapter: param1: value1 ``` +## Directory Indexing + +When building an index from a directory, TinySearch uses a centralized file +discovery mechanism (`iter_input_files()`). Each adapter type has default +file extensions: + +| Adapter | Extensions | +|---------|-----------| +| text | `.txt`, `.text`, `.md`, `.py`, `.js`, `.html`, `.css`, `.json` | +| pdf | `.pdf` | +| csv | `.csv` | +| markdown | `.md`, `.markdown`, `.mdown`, `.mkdn` | +| json | `.json` | + +You can override extensions in `config.yaml`: + +```yaml +adapter: + type: text + params: + extensions: [".txt", ".log"] +``` + +Or use `iter_input_files()` programmatically: + +```python +from tinysearch.utils.file_discovery import iter_input_files + +for file_path in iter_input_files("./data", adapter_type="text"): + texts = adapter.extract(file_path) +``` + ## Modern Logging System TinySearch features a modern logging system powered by [loguru](https://github.com/Delgan/loguru) with beautiful, colorful output and flexible configuration options. diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 43192cc..2a22dbd 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -6,12 +6,14 @@ from pathlib import Path from tinysearch.base import TextChunk, QueryEngine, Retriever +from tinysearch.utils.file_discovery import iter_input_files from tinysearch.retrievers.bm25_retriever import BM25Retriever from tinysearch.retrievers.substring_retriever import SubstringRetriever from tinysearch.fusion.weighted import WeightedFusion from tinysearch.fusion.rrf import ReciprocalRankFusion from tinysearch.fusion._utils import make_doc_key from tinysearch.query.hybrid import HybridQueryEngine +from tinysearch.indexers.metadata_index import MetadataIndex # ── Fixtures ────────────────────────────────────────────── @@ -360,6 +362,7 @@ def test_flow_controller_uses_faiss_dir(self): from tinysearch.flow.controller import FlowController fc = FlowController.__new__(FlowController) fc.query_engine = MagicMock(spec=HybridQueryEngine) + fc.query_engine.metadata_index = None mock_retriever = MagicMock() mock_retriever.__class__.__name__ = "BM25Retriever" @@ -372,6 +375,7 @@ def test_flow_controller_load_uses_faiss_dir(self): from tinysearch.flow.controller import FlowController fc = FlowController.__new__(FlowController) fc.query_engine = MagicMock(spec=HybridQueryEngine) + fc.query_engine.metadata_index = None mock_retriever = MagicMock() mock_retriever.__class__.__name__ = "BM25Retriever" @@ -381,3 +385,228 @@ def test_flow_controller_load_uses_faiss_dir(self): with patch.object(Path, "exists", return_value=True): fc._load_retriever_indexes(Path("data/index.faiss")) mock_retriever.load.assert_called_once_with(Path("data/index") / "bm25_index") + + +# ── iter_input_files ───────────────────────────────────── + +class TestIterInputFiles: + def test_single_file(self, tmp_path): + f = tmp_path / "doc.txt" + f.write_text("hello") + assert list(iter_input_files(f)) == [f] + + def test_directory_filters_by_adapter_type(self, tmp_path): + (tmp_path / "a.txt").write_text("a") + (tmp_path / "b.pdf").write_text("b") + (tmp_path / "c.py").write_text("c") + files = list(iter_input_files(tmp_path, adapter_type="text")) + suffixes = {f.suffix for f in files} + assert ".txt" in suffixes + assert ".py" in suffixes + assert ".pdf" not in suffixes + + def test_custom_extensions_override(self, tmp_path): + (tmp_path / "a.xyz").write_text("a") + (tmp_path / "b.txt").write_text("b") + files = list(iter_input_files(tmp_path, extensions=[".xyz"])) + assert len(files) == 1 + assert files[0].suffix == ".xyz" + + def test_nonexistent_path_raises(self): + with pytest.raises(FileNotFoundError): + list(iter_input_files(Path("/nonexistent"))) + + def test_sorted_deterministic(self, tmp_path): + for name in ["c.txt", "a.txt", "b.txt"]: + (tmp_path / name).write_text(name) + files = list(iter_input_files(tmp_path, adapter_type="text")) + assert files == sorted(files) + + def test_recursive_finds_nested(self, tmp_path): + sub = tmp_path / "sub" + sub.mkdir() + (tmp_path / "top.txt").write_text("top") + (sub / "nested.txt").write_text("nested") + files = list(iter_input_files(tmp_path, adapter_type="text", recursive=True)) + assert len(files) == 2 + + def test_non_recursive_skips_nested(self, tmp_path): + sub = tmp_path / "sub" + sub.mkdir() + (tmp_path / "top.txt").write_text("top") + (sub / "nested.txt").write_text("nested") + files = list(iter_input_files(tmp_path, adapter_type="text", recursive=False)) + assert len(files) == 1 + + +# ── Adapter directory rejection ────────────────────────── + +class TestAdapterRejectsDirectory: + def test_text_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.text import TextAdapter + (tmp_path / "a.txt").write_text("hello") + with pytest.raises(ValueError, match="does not accept directories"): + TextAdapter().extract(tmp_path) + + def test_csv_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.csv import CSVAdapter + (tmp_path / "a.csv").write_text("col\nval") + with pytest.raises(ValueError, match="does not accept directories"): + CSVAdapter().extract(tmp_path) + + def test_markdown_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.markdown import MarkdownAdapter + (tmp_path / "a.md").write_text("# Hello") + with pytest.raises(ValueError, match="does not accept directories"): + MarkdownAdapter().extract(tmp_path) + + def test_json_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.json_adapter import JSONAdapter + (tmp_path / "a.json").write_text('{"key": "val"}') + with pytest.raises(ValueError, match="does not accept directories"): + JSONAdapter().extract(tmp_path) + + +# ── Source metadata consistency ────────────────────────── + +class TestSourceMetadataConsistency: + def test_cli_and_flowcontroller_same_sources(self, tmp_path): + """CLI build_index and FlowController should discover the same files.""" + (tmp_path / "file1.txt").write_text("Content A") + (tmp_path / "file2.txt").write_text("Content B") + (tmp_path / "ignore.pdf").write_text("PDF") + + cli_sources = {str(f) for f in iter_input_files(tmp_path, adapter_type="text")} + fc_sources = {str(f) for f in iter_input_files(tmp_path, adapter_type="text")} + + assert cli_sources == fc_sources + assert len(cli_sources) == 2 + assert all("ignore.pdf" not in s for s in cli_sources) + + +# ── Pre-filter (MetadataIndex + candidate_ids) ─────────── + +@pytest.fixture +def graded_chunks(): + """Chunks with grade metadata for pre-filter testing.""" + return [ + TextChunk("Python编程语言", {"source": "a.txt", "grade": "六年级上", "chunk_index": 0}), + TextChunk("Java编程语言", {"source": "b.txt", "grade": "六年级下", "chunk_index": 0}), + TextChunk("TinySearch检索系统", {"source": "c.txt", "grade": "七年级", "chunk_index": 0}), + TextChunk("向量数据库", {"source": "d.txt", "grade": "六年级上", "chunk_index": 0}), + ] + + +@pytest.fixture +def prefilter_engine(graded_chunks): + """HybridQueryEngine with MetadataIndex for pre-filter testing.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + substr = SubstringRetriever() + substr.build(graded_chunks) + + metadata_index = MetadataIndex() + metadata_index.build(graded_chunks) + + return HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=metadata_index, + filter_mode="auto", + ) + + +class TestPreFilter: + def test_bm25_candidate_ids(self, graded_chunks): + """BM25 with candidate_ids only returns results within the set.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + # Only allow chunks 0 and 3 (grade=六年级上) + results = bm25.retrieve("编程", top_k=5, candidate_ids={0, 3}) + for r in results: + assert r["metadata"]["grade"] == "六年级上" + + def test_substring_candidate_ids(self, graded_chunks): + """SubstringRetriever with candidate_ids restricts search.""" + substr = SubstringRetriever() + substr.build(graded_chunks) + results = substr.retrieve("编程", top_k=5, candidate_ids={0}) + assert len(results) == 1 + assert results[0]["metadata"]["source"] == "a.txt" + + def test_candidate_ids_none_is_noop(self, graded_chunks): + """Passing candidate_ids=None returns same as no kwargs.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + r1 = bm25.retrieve("编程", top_k=5) + r2 = bm25.retrieve("编程", top_k=5, candidate_ids=None) + assert len(r1) == len(r2) + + def test_candidate_ids_empty_returns_empty(self, graded_chunks): + """Passing candidate_ids=set() returns [].""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + results = bm25.retrieve("编程", top_k=5, candidate_ids=set()) + assert results == [] + + def test_filter_mode_pre(self, graded_chunks): + """Pre-filter mode resolves filters via MetadataIndex.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + substr = SubstringRetriever() + substr.build(graded_chunks) + + metadata_index = MetadataIndex() + metadata_index.build(graded_chunks) + + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=metadata_index, + filter_mode="pre", + ) + results = engine.retrieve("编程", top_k=5, filters={"grade": "六年级上"}) + assert all(r["metadata"]["grade"] == "六年级上" for r in results) + + def test_filter_mode_auto_mixed(self, prefilter_engine): + """Auto mode: indexable filters pre-filter, callable filters post-filter.""" + results = prefilter_engine.retrieve( + "编程", top_k=5, + filters={ + "grade": "六年级上", + "source": lambda v: v == "a.txt", # callable → post-filter + }, + ) + assert all(r["metadata"]["source"] == "a.txt" for r in results) + + def test_no_metadata_index_backward_compat(self, graded_chunks): + """When metadata_index=None, filters are purely post-applied.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + substr = SubstringRetriever() + substr.build(graded_chunks) + + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=None, + ) + results = engine.retrieve("编程", top_k=5, filters={"grade": "六年级上"}) + assert all(r["metadata"]["grade"] == "六年级上" for r in results) + + def test_empty_candidate_short_circuits(self, graded_chunks): + """When lookup returns empty set, pipeline returns [] immediately.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + + metadata_index = MetadataIndex() + metadata_index.build(graded_chunks) + + engine = HybridQueryEngine( + [bm25], + WeightedFusion(), + metadata_index=metadata_index, + filter_mode="pre", + ) + results = engine.retrieve("编程", top_k=5, filters={"grade": "不存在"}) + assert results == [] diff --git a/tests/test_metadata_index.py b/tests/test_metadata_index.py new file mode 100644 index 0000000..ac12cdb --- /dev/null +++ b/tests/test_metadata_index.py @@ -0,0 +1,153 @@ +""" +Tests for MetadataIndex: inverted index over TextChunk metadata. +""" +import pytest +from pathlib import Path + +from tinysearch.base import TextChunk +from tinysearch.indexers.metadata_index import MetadataIndex + + +@pytest.fixture +def sample_chunks(): + return [ + TextChunk("Q1", {"grade": "六年级上", "type": "选择题", "difficulty": 1}), + TextChunk("Q2", {"grade": "六年级下", "type": "填空题", "difficulty": 2}), + TextChunk("Q3", {"grade": "六年级上", "type": "填空题", "difficulty": 3}), + TextChunk("Q4", {"grade": "七年级", "tags": ["动词", "时态"], "difficulty": 1}), + TextChunk("Q5", {"grade": "七年级", "tags": ["名词", "复数"]}), + ] + + +@pytest.fixture +def built_index(sample_chunks): + idx = MetadataIndex() + idx.build(sample_chunks) + return idx + + +# ── Build & Lookup ───────────────────────────── + +class TestBuildAndLookup: + def test_build_basic(self, built_index): + assert built_index.total_chunks == 5 + assert "grade" in built_index.fields + assert "type" in built_index.fields + assert "tags" in built_index.fields + + def test_lookup_exact_str(self, built_index): + result = built_index.lookup({"grade": "六年级上"}) + assert result == {0, 2} + + def test_lookup_exact_int(self, built_index): + result = built_index.lookup({"difficulty": 1}) + assert result == {0, 3} + + def test_lookup_list_filter(self, built_index): + """List filter = OR over values.""" + result = built_index.lookup({"grade": ["六年级上", "六年级下"]}) + assert result == {0, 1, 2} + + def test_lookup_multiple_filters_intersect(self, built_index): + """Multiple filters = AND (set intersection).""" + result = built_index.lookup({"grade": "六年级上", "type": "填空题"}) + assert result == {2} + + def test_lookup_callable_returns_none(self, built_index): + result = built_index.lookup({"grade": lambda v: True}) + assert result is None + + def test_lookup_nonexistent_field(self, built_index): + result = built_index.lookup({"nonexistent": "value"}) + assert result == set() + + def test_lookup_nonexistent_value(self, built_index): + result = built_index.lookup({"grade": "不存在的年级"}) + assert result == set() + + def test_list_metadata_indexed_per_element(self, built_index): + """List[str] metadata: each element indexed separately.""" + assert built_index.lookup({"tags": "动词"}) == {3} + assert built_index.lookup({"tags": "名词"}) == {4} + assert built_index.lookup({"tags": ["动词", "名词"]}) == {3, 4} + + def test_empty_filters_returns_none(self, built_index): + assert built_index.lookup({}) is None + + def test_empty_intersection_short_circuits(self, built_index): + """When first filter yields empty, result is empty.""" + result = built_index.lookup({"grade": "不存在", "type": "选择题"}) + assert result == set() + + +# ── Classify Filters ───────────────────────────── + +class TestClassifyFilters: + def test_all_indexable(self, built_index): + indexable, callables = built_index.classify_filters({ + "grade": "六年级上", + "type": ["选择题", "填空题"], + }) + assert len(indexable) == 2 + assert len(callables) == 0 + + def test_all_callable(self, built_index): + indexable, callables = built_index.classify_filters({ + "grade": lambda v: "六" in v, + }) + assert len(indexable) == 0 + assert len(callables) == 1 + + def test_mixed(self, built_index): + indexable, callables = built_index.classify_filters({ + "grade": "六年级上", + "difficulty": lambda v: v > 2, + }) + assert "grade" in indexable + assert "difficulty" in callables + + +# ── Save / Load ──────────────────────────────── + +class TestSaveLoad: + def test_roundtrip(self, built_index, tmp_path): + save_path = tmp_path / "metadata_index.json" + built_index.save(save_path) + + loaded = MetadataIndex() + loaded.load(save_path) + + assert loaded.total_chunks == built_index.total_chunks + assert loaded.lookup({"grade": "六年级上"}) == {0, 2} + assert loaded.lookup({"difficulty": 1}) == {0, 3} + assert loaded.lookup({"tags": "动词"}) == {3} + + def test_save_creates_parent_dirs(self, built_index, tmp_path): + save_path = tmp_path / "deep" / "nested" / "idx.json" + built_index.save(save_path) + assert save_path.exists() + + +# ── Edge Cases ───────────────────────────────── + +class TestEdgeCases: + def test_empty_chunks(self): + idx = MetadataIndex() + idx.build([]) + assert idx.total_chunks == 0 + assert idx.lookup({"key": "val"}) == set() + + def test_chunks_with_no_metadata(self): + idx = MetadataIndex() + idx.build([TextChunk("text", None), TextChunk("text2", {})]) + assert idx.total_chunks == 2 + assert idx.lookup({"any": "thing"}) == set() + + def test_bool_metadata(self): + idx = MetadataIndex() + idx.build([ + TextChunk("a", {"active": True}), + TextChunk("b", {"active": False}), + ]) + assert idx.lookup({"active": True}) == {0} + assert idx.lookup({"active": False}) == {1} diff --git a/tinysearch/adapters/csv.py b/tinysearch/adapters/csv.py index 815bf1c..21757ce 100644 --- a/tinysearch/adapters/csv.py +++ b/tinysearch/adapters/csv.py @@ -55,20 +55,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all CSV files in it - csv_files = list(filepath.glob("**/*.csv")) - - result = [] - for file in csv_files: - try: - result.extend(self._extract_from_csv(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_csv(filepath) + raise ValueError( + f"CSVAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_csv(filepath) def _extract_from_csv(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/json_adapter.py b/tinysearch/adapters/json_adapter.py index 1f06293..12e504c 100644 --- a/tinysearch/adapters/json_adapter.py +++ b/tinysearch/adapters/json_adapter.py @@ -54,20 +54,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all JSON files in it - json_files = list(filepath.glob("**/*.json")) - - result = [] - for file in json_files: - try: - result.extend(self._extract_from_json(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_json(filepath) + raise ValueError( + f"JSONAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_json(filepath) def _extract_from_json(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/markdown.py b/tinysearch/adapters/markdown.py index fd42856..1e97a2b 100644 --- a/tinysearch/adapters/markdown.py +++ b/tinysearch/adapters/markdown.py @@ -55,22 +55,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all Markdown files in it - md_files = [] - for ext in [".md", ".markdown", ".mdown", ".mkdn"]: - md_files.extend(filepath.glob(f"**/*{ext}")) - - result = [] - for file in md_files: - try: - result.extend(self._extract_from_markdown(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_markdown(filepath) + raise ValueError( + f"MarkdownAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_markdown(filepath) def _extract_from_markdown(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/pdf.py b/tinysearch/adapters/pdf.py index be90bc4..1d35403 100644 --- a/tinysearch/adapters/pdf.py +++ b/tinysearch/adapters/pdf.py @@ -56,20 +56,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all PDF files in it - pdf_files = list(filepath.glob("**/*.pdf")) - - result = [] - for file in pdf_files: - try: - result.extend(self._extract_from_pdf(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_pdf(filepath) + raise ValueError( + f"PDFAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_pdf(filepath) def _extract_from_pdf(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/text.py b/tinysearch/adapters/text.py index bb085c6..d484a39 100644 --- a/tinysearch/adapters/text.py +++ b/tinysearch/adapters/text.py @@ -39,23 +39,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all text files in it - text_files = [] - for ext in [".txt", ".text", ".md", ".py", ".js", ".html", ".css", ".json"]: - text_files.extend(filepath.glob(f"**/*{ext}")) - - result = [] - for file in text_files: - try: - with open(file, "r", encoding=self.encoding, errors=self.errors) as f: - result.append(f.read()) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - with open(filepath, "r", encoding=self.encoding, errors=self.errors) as f: - content = f.read() - - return [content] \ No newline at end of file + raise ValueError( + f"TextAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + with open(filepath, "r", encoding=self.encoding, errors=self.errors) as f: + content = f.read() + + return [content] \ No newline at end of file diff --git a/tinysearch/cli.py b/tinysearch/cli.py index d2c607e..d0fc959 100644 --- a/tinysearch/cli.py +++ b/tinysearch/cli.py @@ -15,6 +15,7 @@ Retriever, FusionStrategy, Reranker, ) from .adapters import TextAdapter, PDFAdapter, CSVAdapter, MarkdownAdapter, JSONAdapter +from tinysearch.utils.file_discovery import iter_input_files from .splitters import CharacterTextSplitter from .embedders import HuggingFaceEmbedder # 直接从模块导入 @@ -265,9 +266,12 @@ def load_query_engine(config: Config, embedder: Embedder, indexer: FAISSIndexer) template=config.get("query_engine.template", "请帮我查找:{query}") ) elif query_engine_type == "hybrid": + from tinysearch.indexers.metadata_index import MetadataIndex retrievers = load_retrievers(config, embedder, indexer) fusion = load_fusion(config) reranker = load_reranker(config) + filter_mode = config.get("query_engine.filter_mode", "auto") + metadata_index = MetadataIndex() if filter_mode != "post" else None return HybridQueryEngine( retrievers=retrievers, fusion_strategy=fusion, @@ -275,6 +279,8 @@ def load_query_engine(config: Config, embedder: Embedder, indexer: FAISSIndexer) recall_multiplier=config.get("query_engine.recall_multiplier", 2), min_scores=config.get("query_engine.min_scores", None), filter_multiplier=config.get("query_engine.filter_multiplier", 3), + metadata_index=metadata_index, + filter_mode=filter_mode, ) else: raise ValueError(f"Unsupported query engine type: {query_engine_type}") @@ -293,7 +299,7 @@ def _get_retriever_index_dir(index_path: Path) -> Path: def _build_hybrid_retriever_indexes( query_engine: QueryEngine, chunks, logger=None, ) -> None: - """Build BM25 / Substring indexes when the engine is a HybridQueryEngine.""" + """Build BM25 / Substring / MetadataIndex when the engine is a HybridQueryEngine.""" if not isinstance(query_engine, HybridQueryEngine): return for retriever in query_engine.retrievers: @@ -304,11 +310,17 @@ def _build_hybrid_retriever_indexes( log_step(f"Building {rname} index") retriever.build(chunks) + # Build metadata index for pre-filtering + if query_engine.metadata_index is not None: + if logger: + log_step("Building MetadataIndex") + query_engine.metadata_index.build(chunks) + def _save_hybrid_retriever_indexes( query_engine: QueryEngine, index_path: Path, logger=None, ) -> None: - """Save non-vector retriever indexes alongside the FAISS index.""" + """Save non-vector retriever indexes and metadata index alongside the FAISS index.""" if not isinstance(query_engine, HybridQueryEngine): return index_dir = _get_retriever_index_dir(index_path) @@ -321,11 +333,18 @@ def _save_hybrid_retriever_indexes( log_step(f"Saving {type(retriever).__name__} index to {rpath}") retriever.save(rpath) + # Save metadata index + if query_engine.metadata_index is not None: + mpath = index_dir / "metadata_index.json" + if logger: + log_step(f"Saving MetadataIndex to {mpath}") + query_engine.metadata_index.save(mpath) + def _load_hybrid_retriever_indexes( query_engine: QueryEngine, index_path: Path, logger=None, ) -> None: - """Load non-vector retriever indexes from alongside the FAISS index.""" + """Load non-vector retriever indexes and metadata index from alongside the FAISS index.""" if not isinstance(query_engine, HybridQueryEngine): return index_dir = _get_retriever_index_dir(index_path) @@ -339,6 +358,14 @@ def _load_hybrid_retriever_indexes( log_step(f"Loading {type(retriever).__name__} index from {rpath}") retriever.load(rpath) + # Load metadata index + if query_engine.metadata_index is not None: + mpath = index_dir / "metadata_index.json" + if mpath.exists(): + if logger: + log_step(f"Loading MetadataIndex from {mpath}") + query_engine.metadata_index.load(mpath) + def build_index(args: argparse.Namespace, config: Config) -> None: """ @@ -367,20 +394,15 @@ def build_index(args: argparse.Namespace, config: Config) -> None: data_path = Path(args.data) texts: list = [] metadata: list = [] - if data_path.is_dir(): - for child in sorted(data_path.rglob("*")): - if not child.is_file(): - continue - try: - file_texts = adapter.extract(child) - except Exception: - continue # adapter will skip unsupported extensions - for t in file_texts: - texts.append(t) - metadata.append({"source": str(child)}) - else: - texts = adapter.extract(args.data) - metadata = [{"source": str(data_path)} for _ in range(len(texts))] + adapter_type = config.get("adapter.type", "text") + for file_path in iter_input_files(data_path, adapter_type=adapter_type): + try: + file_texts = adapter.extract(file_path) + except Exception: + continue + for t in file_texts: + texts.append(t) + metadata.append({"source": str(file_path)}) logger.info(f"📄 Extracted {len(texts)} documents") # Split text into chunks diff --git a/tinysearch/flow/controller.py b/tinysearch/flow/controller.py index 998f940..fbace3b 100644 --- a/tinysearch/flow/controller.py +++ b/tinysearch/flow/controller.py @@ -12,6 +12,7 @@ from tinysearch.flow.hot_update import HotUpdateManager from tinysearch.query.hybrid import HybridQueryEngine from tinysearch.retrievers.vector_retriever import VectorRetriever +from tinysearch.utils.file_discovery import iter_input_files class FlowController(FlowControllerBase): @@ -159,24 +160,15 @@ def process_directory(self, dir_path: Union[str, Path], extensions: Optional[Lis force_reprocess: If True, reprocess even if file is in cache """ dir_path = Path(dir_path) - + if not dir_path.is_dir(): raise ValueError(f"{dir_path} is not a directory") - - # Function to check if a file should be processed - def should_process(file_path: Path) -> bool: - if extensions and file_path.suffix.lower() not in extensions: - return False + + adapter_type = self.config.get("adapter", {}).get("type", "text") + for file_path in iter_input_files(dir_path, adapter_type=adapter_type, extensions=extensions, recursive=recursive): if not force_reprocess and str(file_path) in self.processed_files: - return False - return True - - # Process all matching files in directory - for item in dir_path.iterdir(): - if item.is_file() and should_process(item): - self.process_file(item, force_reprocess) - elif item.is_dir() and recursive: - self.process_directory(item, extensions, recursive, force_reprocess) + continue + self.process_file(file_path, force_reprocess) def build_index(self, data_path: Union[str, Path], **kwargs) -> None: """ @@ -212,13 +204,17 @@ def _get_hybrid_retrievers(self) -> List[Retriever]: return [] def _build_retriever_indexes(self, chunks: List[TextChunk]) -> None: - """Build indexes for non-vector retrievers in HybridQueryEngine""" + """Build indexes for non-vector retrievers and metadata index in HybridQueryEngine""" for retriever in self._get_hybrid_retrievers(): # Skip VectorRetriever - it's already handled by self.indexer.build() if isinstance(retriever, VectorRetriever): continue retriever.build(chunks) + # Build metadata index for pre-filtering + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.metadata_index is not None: + self.query_engine.metadata_index.build(chunks) + def save_index(self, path: Optional[Union[str, Path]] = None) -> None: """ Save the built index to disk. @@ -237,7 +233,7 @@ def save_index(self, path: Optional[Union[str, Path]] = None) -> None: self._save_retriever_indexes(Path(str(path))) def _save_retriever_indexes(self, base_path: Path) -> None: - """Save indexes for non-vector retrievers inside the FAISS index directory""" + """Save indexes for non-vector retrievers and metadata index inside the FAISS index directory""" # FAISS saves into base_path.with_suffix('') (e.g. "index.faiss" → "index/") index_dir = base_path.with_suffix('') if base_path.suffix else base_path for retriever in self._get_hybrid_retrievers(): @@ -248,6 +244,10 @@ def _save_retriever_indexes(self, base_path: Path) -> None: retriever_path = index_dir / f"{retriever_name}_index" retriever.save(retriever_path) + # Save metadata index + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.metadata_index is not None: + self.query_engine.metadata_index.save(index_dir / "metadata_index.json") + def load_index(self, path: Optional[Union[str, Path]] = None) -> None: """ Load an index from disk. @@ -266,7 +266,7 @@ def load_index(self, path: Optional[Union[str, Path]] = None) -> None: self._load_retriever_indexes(Path(str(path))) def _load_retriever_indexes(self, base_path: Path) -> None: - """Load indexes for non-vector retrievers from the FAISS index directory""" + """Load indexes for non-vector retrievers and metadata index from the FAISS index directory""" index_dir = base_path.with_suffix('') if base_path.suffix else base_path for retriever in self._get_hybrid_retrievers(): if isinstance(retriever, VectorRetriever): @@ -275,6 +275,12 @@ def _load_retriever_indexes(self, base_path: Path) -> None: retriever_path = index_dir / f"{retriever_name}_index" if retriever_path.exists(): retriever.load(retriever_path) + + # Load metadata index + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.metadata_index is not None: + metadata_path = index_dir / "metadata_index.json" + if metadata_path.exists(): + self.query_engine.metadata_index.load(metadata_path) def query(self, query_text: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ diff --git a/tinysearch/indexers/__init__.py b/tinysearch/indexers/__init__.py index ca8b3e8..86c9aa0 100644 --- a/tinysearch/indexers/__init__.py +++ b/tinysearch/indexers/__init__.py @@ -6,9 +6,11 @@ # Make the FAISS indexer available from the root module from .faiss_indexer import FAISSIndexer +from .metadata_index import MetadataIndex __all__ = [ - "FAISSIndexer" + "FAISSIndexer", + "MetadataIndex", ] def index_exists(path: Union[str, Path]) -> bool: diff --git a/tinysearch/indexers/faiss_indexer.py b/tinysearch/indexers/faiss_indexer.py index 4d6e87e..b55e913 100644 --- a/tinysearch/indexers/faiss_indexer.py +++ b/tinysearch/indexers/faiss_indexer.py @@ -92,12 +92,17 @@ def build(self, vectors: List[List[float]], texts: List[TextChunk]) -> None: # Create ID mapping self.ids_map = {i: i for i in range(len(texts))} - def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]: + def search(self, query_vector: List[float], top_k: int = 5, + candidate_ids: Optional[set] = None) -> List[Dict[str, Any]]: """ - Search the index for vectors similar to the query vector + Search the index for vectors similar to the query vector. + Args: query_vector: Query embedding vector top_k: Number of results to return + candidate_ids: Optional set of chunk indices to restrict search to. + When provided, over-retrieves then filters by membership. + Returns: List of dictionaries containing text chunks, similarity scores, and embedding vectors """ @@ -115,8 +120,16 @@ def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, An # Normalize query vector if using cosine similarity if self.metric == "cosine": faiss.normalize_L2(query_np) + + # When candidate_ids is set, over-recall then filter + if candidate_ids is not None: + effective_k = min(self.index.ntotal, top_k * 3) + else: + effective_k = top_k + # Search index - distances, indices = self.index.search(query_np, top_k) + distances, indices = self.index.search(query_np, effective_k) + # Convert to result format results = [] for i in range(len(indices[0])): @@ -125,6 +138,9 @@ def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, An # Skip invalid indices (can happen if there are fewer results than top_k) if idx < 0 or idx >= len(self.texts): continue + # Pre-filter: skip if not in candidate set + if candidate_ids is not None and idx not in candidate_ids: + continue # Map to original text chunk text_idx = self.ids_map[idx] text_chunk = self.texts[text_idx] @@ -146,6 +162,8 @@ def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, An "score": similarity, "embedding": embedding }) + if candidate_ids is not None and len(results) >= top_k: + break return results def save(self, path: Union[str, Path]) -> None: diff --git a/tinysearch/indexers/metadata_index.py b/tinysearch/indexers/metadata_index.py new file mode 100644 index 0000000..d8bc57a --- /dev/null +++ b/tinysearch/indexers/metadata_index.py @@ -0,0 +1,201 @@ +""" +Inverted index over TextChunk metadata for fast candidate set lookup. + +Enables O(1) pre-filtering by metadata fields (e.g., grade, type, tags) +instead of linear post-filtering over all retrieval results. +""" +import json +import logging +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +from tinysearch.base import TextChunk + +logger = logging.getLogger(__name__) + +# Filter value types matching HybridQueryEngine._match_filters +FilterValue = Union[str, int, float, bool, List, Callable] + +# Scalar types that can be directly indexed +_INDEXABLE_TYPES = (str, int, float, bool) + + +class MetadataIndex: + """ + Inverted index: metadata field -> value -> set of chunk IDs. + + Supports: + - Scalar values (str, int, float, bool): indexed directly + - List[str/int/...] values (e.g., tags): each element indexed separately + """ + + def __init__(self) -> None: + self._index: Dict[str, Dict[Any, Set[int]]] = defaultdict( + lambda: defaultdict(set) + ) + self._total_chunks: int = 0 + + def build(self, chunks: List[TextChunk]) -> None: + """ + Build inverted indices from chunk metadata. + + Args: + chunks: Ordered list of TextChunks. The positional index (0-based) + becomes the chunk ID used in candidate sets. + """ + self._index = defaultdict(lambda: defaultdict(set)) + self._total_chunks = len(chunks) + + for i, chunk in enumerate(chunks): + if not chunk.metadata: + continue + for key, value in chunk.metadata.items(): + if isinstance(value, _INDEXABLE_TYPES): + self._index[key][value].add(i) + elif isinstance(value, list): + for element in value: + if isinstance(element, _INDEXABLE_TYPES): + self._index[key][element].add(i) + + field_stats = {k: len(v) for k, v in self._index.items()} + logger.info( + "MetadataIndex built: %d chunks, fields=%s", + self._total_chunks, + field_stats, + ) + + def lookup(self, filters: Dict[str, FilterValue]) -> Optional[Set[int]]: + """ + Resolve filters to a candidate ID set via inverted index. + + Filter semantics (matches HybridQueryEngine._match_filters): + - scalar (str/int/float/bool): exact match -> direct set lookup + - list: OR over values -> union of sets + - multiple filter keys: AND -> set intersection + - callable: cannot be resolved -> returns None + + Returns: + Set[int] of matching chunk IDs, or None if any filter is callable. + """ + if not filters: + return None + + result: Optional[Set[int]] = None + + for key, condition in filters.items(): + if callable(condition): + return None + + matched = self._lookup_single(key, condition) + + if result is None: + result = matched + else: + result = result & matched + + # Short-circuit on empty intersection + if result is not None and len(result) == 0: + return set() + + return result if result is not None else set() + + def _lookup_single(self, key: str, condition: FilterValue) -> Set[int]: + """Lookup a single filter key-value pair.""" + field_index = self._index.get(key) + if field_index is None: + return set() + + if isinstance(condition, list): + # OR: union of all matching values + matched = set() + for val in condition: + matched |= field_index.get(val, set()) + return matched + else: + # Exact match + return set(field_index.get(condition, set())) + + def classify_filters( + self, filters: Dict[str, FilterValue] + ) -> Tuple[Dict[str, FilterValue], Dict[str, FilterValue]]: + """ + Split filters into indexable and non-indexable (callable) parts. + + Returns: + (indexable_filters, callable_filters) + """ + indexable = {} + callables = {} + for key, condition in filters.items(): + if callable(condition): + callables[key] = condition + else: + indexable[key] = condition + return indexable, callables + + @property + def total_chunks(self) -> int: + return self._total_chunks + + @property + def fields(self) -> List[str]: + return list(self._index.keys()) + + def save(self, path: Union[str, Path]) -> None: + """ + Save to JSON. Sets become sorted lists; value types are preserved + via a type tag for proper reconstruction on load. + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + serializable = { + "version": 1, + "total_chunks": self._total_chunks, + "fields": {}, + } + + for field_name, value_map in self._index.items(): + entries = {} + for value, ids in value_map.items(): + # Use repr-style key that preserves type info + type_tag = type(value).__name__ + str_key = json.dumps({"v": value, "t": type_tag}, ensure_ascii=False) + entries[str_key] = sorted(ids) + serializable["fields"][field_name] = entries + + with open(path, "w", encoding="utf-8") as f: + json.dump(serializable, f, ensure_ascii=False) + + logger.info("MetadataIndex saved to %s", path) + + def load(self, path: Union[str, Path]) -> None: + """Load from JSON, reconstructing sets and value types.""" + path = Path(path) + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + self._total_chunks = data["total_chunks"] + self._index = defaultdict(lambda: defaultdict(set)) + + type_constructors = {"str": str, "int": int, "float": float, "bool": bool} + + for field_name, entries in data["fields"].items(): + for str_key, id_list in entries.items(): + key_data = json.loads(str_key) + value = key_data["v"] + type_tag = key_data["t"] + # Reconstruct typed value + constructor = type_constructors.get(type_tag) + if constructor: + value = constructor(value) + self._index[field_name][value] = set(id_list) + + logger.info( + "MetadataIndex loaded from %s: %d chunks, %d fields", + path, + self._total_chunks, + len(self._index), + ) diff --git a/tinysearch/query/hybrid.py b/tinysearch/query/hybrid.py index dfdea11..82e5066 100644 --- a/tinysearch/query/hybrid.py +++ b/tinysearch/query/hybrid.py @@ -2,7 +2,7 @@ Hybrid query engine - multi-retriever fusion with optional reranking """ import logging -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from tinysearch.base import QueryEngine, Retriever, FusionStrategy, Reranker @@ -18,13 +18,15 @@ class HybridQueryEngine(QueryEngine): with optional reranking. Pipeline: - 1. Each retriever recalls top_k * recall_multiplier candidates - (× filter_multiplier when metadata filters are active) - 2. Per-retriever min_score filtering - 3. Metadata post-filtering (if filters provided) - 4. FusionStrategy merges and deduplicates results - 5. Optional Reranker re-scores the fused candidates - 6. Return top_k final results + 1. (If filter_mode is "pre" or "auto") Use MetadataIndex to resolve + indexable filters into a candidate_ids set; pass to retrievers + 2. Each retriever recalls top_k * recall_multiplier candidates + (× filter_multiplier only when post-filters are active) + 3. Per-retriever min_score filtering + 4. Metadata post-filtering for callable/non-indexable filters + 5. FusionStrategy merges and deduplicates results + 6. Optional Reranker re-scores the fused candidates + 7. Return top_k final results """ def __init__( @@ -35,6 +37,8 @@ def __init__( recall_multiplier: int = 2, min_scores: Optional[List[float]] = None, filter_multiplier: int = 3, + metadata_index=None, + filter_mode: str = "auto", ): """ Args: @@ -43,7 +47,10 @@ def __init__( reranker: Optional reranker for final re-scoring recall_multiplier: Multiply top_k by this for each retriever's recall min_scores: Per-retriever minimum score thresholds (length must match retrievers) - filter_multiplier: Extra recall multiplier when filters are active + filter_multiplier: Extra recall multiplier when post-filters are active + metadata_index: Optional MetadataIndex for inverted-index pre-filtering + filter_mode: "pre" (always pre-filter), "post" (always post-filter), + or "auto" (pre-filter indexable parts, post-filter callables) """ if not retrievers: raise ValueError("At least one retriever is required") @@ -52,12 +59,16 @@ def __init__( f"min_scores length ({len(min_scores)}) must match " f"retrievers length ({len(retrievers)})" ) + if filter_mode not in ("pre", "post", "auto"): + raise ValueError(f"filter_mode must be 'pre', 'post', or 'auto', got '{filter_mode}'") self.retrievers = retrievers self.fusion_strategy = fusion_strategy self.reranker = reranker self.recall_multiplier = recall_multiplier self.min_scores = min_scores self.filter_multiplier = filter_multiplier + self.metadata_index = metadata_index + self.filter_mode = filter_mode def format_query(self, query: str) -> str: """Pass-through: hybrid engine doesn't transform queries""" @@ -100,16 +111,53 @@ def _retrieve_pipeline( filters = kwargs.pop("filters", None) weights = kwargs.pop("weights", None) - # Compute recall amount — over-recall when filters are active + # Resolve pre-filter vs post-filter strategy + candidate_ids = None + post_filters = None + + if filters and self.metadata_index is not None: + if self.filter_mode == "pre": + candidate_ids = self.metadata_index.lookup(filters) + if candidate_ids is None: + # Has callable filters, fall back to post-filter + post_filters = filters + elif len(candidate_ids) == 0: + return { + "results": [], + "per_retriever": [[] for _ in self.retrievers], + "fused_before_rerank": [], + } + elif self.filter_mode == "post": + post_filters = filters + else: # auto + indexable, callables = self.metadata_index.classify_filters(filters) + if indexable: + candidate_ids = self.metadata_index.lookup(indexable) + if candidate_ids is not None and len(candidate_ids) == 0: + return { + "results": [], + "per_retriever": [[] for _ in self.retrievers], + "fused_before_rerank": [], + } + if callables: + post_filters = callables + elif filters: + # No metadata_index available, always post-filter + post_filters = filters + + # Compute recall amount — filter_multiplier only when post-filtering recall_k = top_k * self.recall_multiplier - if filters: + if post_filters: recall_k *= self.filter_multiplier # Step 1: Recall from each retriever per_retriever: List[List[Dict[str, Any]]] = [] for i, retriever in enumerate(self.retrievers): try: - results = retriever.retrieve(query, top_k=recall_k) + retriever_kwargs: Dict[str, Any] = {} + if candidate_ids is not None: + retriever_kwargs["candidate_ids"] = candidate_ids + results = retriever.retrieve(query, top_k=recall_k, **retriever_kwargs) except Exception as e: retriever_name = type(retriever).__name__ logger.warning( @@ -124,10 +172,10 @@ def _retrieve_pipeline( per_retriever.append(results) - # Step 2: Metadata post-filtering - if filters: + # Step 2: Post-filtering (only for callable/non-indexable filters) + if post_filters: per_retriever = [ - self._apply_filters(results, filters) for results in per_retriever + self._apply_filters(results, post_filters) for results in per_retriever ] # Step 3: Fuse results (pass dynamic weights if provided) diff --git a/tinysearch/retrievers/bm25_retriever.py b/tinysearch/retrievers/bm25_retriever.py index 92b65e9..0c47e3b 100644 --- a/tinysearch/retrievers/bm25_retriever.py +++ b/tinysearch/retrievers/bm25_retriever.py @@ -84,18 +84,31 @@ def build(self, chunks: List[TextChunk]) -> None: self._index = bm25s.BM25() self._index.index(self._corpus_tokens) - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: - """Retrieve documents using BM25 keyword matching""" + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: + """ + Retrieve documents using BM25 keyword matching. + + Args: + query: Query string + top_k: Number of results to return + **kwargs: + candidate_ids: Optional Set[int] of chunk indices to restrict search to + """ if self._index is None: return [] + candidate_ids = kwargs.get("candidate_ids") + # Tokenize query query_tokens = self.tokenizer(query) if not query_tokens: return [] - # Clamp top_k to number of indexed documents - effective_k = min(top_k, len(self._chunks)) + # When pre-filtering, over-recall then filter by candidate_ids + if candidate_ids is not None: + effective_k = min(len(self._chunks), top_k * 3) + else: + effective_k = min(top_k, len(self._chunks)) if effective_k == 0: return [] @@ -112,6 +125,8 @@ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: idx = int(idx) if idx < 0 or idx >= len(self._chunks): continue + if candidate_ids is not None and idx not in candidate_ids: + continue chunk = self._chunks[idx] results.append({ "text": chunk.text, @@ -119,6 +134,8 @@ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: "score": float(score), "retrieval_method": "bm25", }) + if len(results) >= top_k: + break return results diff --git a/tinysearch/retrievers/substring_retriever.py b/tinysearch/retrievers/substring_retriever.py index 0038d4d..97dc730 100644 --- a/tinysearch/retrievers/substring_retriever.py +++ b/tinysearch/retrievers/substring_retriever.py @@ -33,11 +33,20 @@ def build(self, chunks: List[TextChunk]) -> None: """Store chunks in memory for substring search""" self._chunks = list(chunks) - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: - """Search chunks using regex/substring matching""" + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: + """ + Search chunks using regex/substring matching. + + Args: + query: Query string or regex pattern + top_k: Number of results to return + **kwargs: + candidate_ids: Optional Set[int] of chunk indices to restrict search to + """ if not self._chunks or not query: return [] + candidate_ids = kwargs.get("candidate_ids") results = [] try: @@ -46,7 +55,9 @@ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: else: pattern = re.compile(re.escape(query), re.IGNORECASE) - for chunk in self._chunks: + for i, chunk in enumerate(self._chunks): + if candidate_ids is not None and i not in candidate_ids: + continue match = pattern.search(chunk.text) if match: score = self._calculate_match_score(match, chunk.text) diff --git a/tinysearch/retrievers/vector_retriever.py b/tinysearch/retrievers/vector_retriever.py index 1583083..2e77c5b 100644 --- a/tinysearch/retrievers/vector_retriever.py +++ b/tinysearch/retrievers/vector_retriever.py @@ -40,8 +40,18 @@ def build(self, chunks: List[TextChunk]) -> None: vectors = self.embedder.embed(texts) self.indexer.build(vectors, chunks) - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: - """Embed query and search the vector index""" + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: + """ + Embed query and search the vector index. + + Args: + query: Query string + top_k: Number of results to return + **kwargs: + candidate_ids: Optional Set[int] of chunk indices to restrict search to + """ + candidate_ids = kwargs.get("candidate_ids") + # Apply query template if configured formatted_query = query if self.query_template: @@ -56,8 +66,8 @@ def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: return [] query_vector = query_vectors[0] - # Search - raw_results = self.indexer.search(query_vector, top_k) + # Search (forward candidate_ids to indexer) + raw_results = self.indexer.search(query_vector, top_k, candidate_ids=candidate_ids) # Normalize scores to [0, 1] and add retrieval_method results = [] diff --git a/tinysearch/utils/__init__.py b/tinysearch/utils/__init__.py new file mode 100644 index 0000000..e8307ab --- /dev/null +++ b/tinysearch/utils/__init__.py @@ -0,0 +1,6 @@ +""" +Utility modules for TinySearch. +""" +from tinysearch.utils.file_discovery import iter_input_files, ADAPTER_EXTENSIONS + +__all__ = ["iter_input_files", "ADAPTER_EXTENSIONS"] diff --git a/tinysearch/utils/file_discovery.py b/tinysearch/utils/file_discovery.py new file mode 100644 index 0000000..b5c0d77 --- /dev/null +++ b/tinysearch/utils/file_discovery.py @@ -0,0 +1,55 @@ +""" +Centralized file discovery for directory-based indexing. +""" +import logging +from pathlib import Path +from typing import Iterator, List, Optional + +logger = logging.getLogger(__name__) + +# Default file extensions per adapter type +ADAPTER_EXTENSIONS = { + "text": [".txt", ".text", ".md", ".py", ".js", ".html", ".css", ".json"], + "pdf": [".pdf"], + "csv": [".csv"], + "markdown": [".md", ".markdown", ".mdown", ".mkdn"], + "json": [".json"], +} + + +def iter_input_files( + data_path: Path, + adapter_type: str = "text", + extensions: Optional[List[str]] = None, + recursive: bool = True, +) -> Iterator[Path]: + """ + Discover files under a path that a given adapter can process. + + Args: + data_path: File or directory path + adapter_type: Adapter type name, used to look up default extensions + extensions: Custom extension list (overrides defaults) + recursive: Whether to recurse into subdirectories + + Yields: + Path objects in sorted order for deterministic results + """ + data_path = Path(data_path) + + if not data_path.exists(): + raise FileNotFoundError(f"Path not found: {data_path}") + + if data_path.is_file(): + yield data_path + return + + # Directory: filter by extensions + allowed = set( + ext.lower() for ext in (extensions or ADAPTER_EXTENSIONS.get(adapter_type, [])) + ) + + pattern = data_path.rglob("*") if recursive else data_path.iterdir() + for child in sorted(pattern): + if child.is_file() and (not allowed or child.suffix.lower() in allowed): + yield child From 376e5942b6ca859e87cad804438b65fc8a08f2e2 Mon Sep 17 00:00:00 2001 From: CodePothunter Date: Sun, 22 Mar 2026 21:28:11 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E2=9C=A8=E3=80=90=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E3=80=91=EF=BC=9A=E6=96=B0=E5=A2=9E=20RecordAdapter=20?= =?UTF-8?q?=E5=92=8C=E5=A2=9E=E9=87=8F=E7=B4=A2=E5=BC=95=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20API=20=E6=95=B0=E6=8D=AE=E6=BA=90=E4=B8=8E=E5=8F=98?= =?UTF-8?q?=E6=9B=B4=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **tinysearch/base.py** - 新增 RecordAdapter ABC,支持 Dict 数据源转 TextChunk **tinysearch/records.py** - 新增 build_chunks_from_records() 批量转换工具,保证 record_id 始终在 metadata 中 **tinysearch/indexers/hash_tracker.py** - 新增 ContentHashTracker 基于 MD5 的内容变更检测(new/modified/deleted/unchanged) - 新增 ChangeSet 数据结构,支持 save/load JSON 持久化 **tinysearch/indexers/metadata_index.py** - 新增 add_chunks() 增量追加倒排条目,无需全量重建 **tinysearch/query/hybrid.py** - 新增 soft_deleted_ids 参数,fusion 后自动过滤已删除记录 - 新增 add_soft_deletes()/clear_soft_deletes() 便捷方法 **tinysearch/flow/controller.py** - 新增 build_from_records() 从内存记录全量构建索引 - 新增 build_incremental() 增量更新管线(FAISS.add + MetadataIndex.add_chunks + BM25全量重建) - data_adapter 参数改为 Optional,支持纯 RecordAdapter 使用场景 - save/load 集成 soft_deletes.json 持久化 **tests/test_record_adapter.py** - 新增 RecordAdapter ABC + build_chunks_from_records 7 个测试 **tests/test_incremental.py** - 新增 ContentHashTracker、MetadataIndex.add_chunks、soft delete、build_incremental 21 个测试 **影响说明** - 功能:NeoSAGE 等 API 数据源项目可直接使用 TinySearch 标准管线 - 性能:增量更新仅对变更记录做 embedding,避免全量重算 - 兼容性:完全向后兼容,现有 DataAdapter 用法不受影响 Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_hybrid.py | 16 +- tests/test_incremental.py | 323 ++++++++++++++++++++++++++ tests/test_record_adapter.py | 87 +++++++ tinysearch/base.py | 26 +++ tinysearch/flow/controller.py | 162 ++++++++++++- tinysearch/indexers/__init__.py | 3 + tinysearch/indexers/hash_tracker.py | 153 ++++++++++++ tinysearch/indexers/metadata_index.py | 27 +++ tinysearch/query/hybrid.py | 22 ++ tinysearch/records.py | 51 ++++ 10 files changed, 861 insertions(+), 9 deletions(-) create mode 100644 tests/test_incremental.py create mode 100644 tests/test_record_adapter.py create mode 100644 tinysearch/indexers/hash_tracker.py create mode 100644 tinysearch/records.py diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 2a22dbd..6d13656 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -363,6 +363,7 @@ def test_flow_controller_uses_faiss_dir(self): fc = FlowController.__new__(FlowController) fc.query_engine = MagicMock(spec=HybridQueryEngine) fc.query_engine.metadata_index = None + fc.query_engine.soft_deleted_ids = set() mock_retriever = MagicMock() mock_retriever.__class__.__name__ = "BM25Retriever" @@ -371,20 +372,25 @@ def test_flow_controller_uses_faiss_dir(self): fc._save_retriever_indexes(Path("data/index.faiss")) mock_retriever.save.assert_called_once_with(Path("data/index") / "bm25_index") - def test_flow_controller_load_uses_faiss_dir(self): + def test_flow_controller_load_uses_faiss_dir(self, tmp_path): from tinysearch.flow.controller import FlowController fc = FlowController.__new__(FlowController) fc.query_engine = MagicMock(spec=HybridQueryEngine) fc.query_engine.metadata_index = None + fc.query_engine.soft_deleted_ids = set() mock_retriever = MagicMock() mock_retriever.__class__.__name__ = "BM25Retriever" fc.query_engine.retrievers = [mock_retriever] - # Simulate that the path exists - with patch.object(Path, "exists", return_value=True): - fc._load_retriever_indexes(Path("data/index.faiss")) - mock_retriever.load.assert_called_once_with(Path("data/index") / "bm25_index") + # Create a real directory structure for the test + index_dir = tmp_path / "index" + index_dir.mkdir() + bm25_dir = index_dir / "bm25_index" + bm25_dir.mkdir() + + fc._load_retriever_indexes(tmp_path / "index.faiss") + mock_retriever.load.assert_called_once_with(index_dir / "bm25_index") # ── iter_input_files ───────────────────────────────────── diff --git a/tests/test_incremental.py b/tests/test_incremental.py new file mode 100644 index 0000000..16c6939 --- /dev/null +++ b/tests/test_incremental.py @@ -0,0 +1,323 @@ +""" +Tests for incremental indexing: ContentHashTracker, MetadataIndex.add_chunks, +soft delete, and FlowController.build_incremental. +""" +import pytest +from unittest.mock import MagicMock, patch +from typing import Any, Dict + +from tinysearch.base import RecordAdapter, TextChunk +from tinysearch.indexers.hash_tracker import ContentHashTracker, ChangeSet +from tinysearch.indexers.metadata_index import MetadataIndex +from tinysearch.query.hybrid import HybridQueryEngine +from tinysearch.retrievers.bm25_retriever import BM25Retriever +from tinysearch.retrievers.substring_retriever import SubstringRetriever +from tinysearch.fusion.weighted import WeightedFusion + + +class SimpleAdapter(RecordAdapter): + def to_chunk(self, record_id: str, record: Dict[str, Any]) -> TextChunk: + return TextChunk( + text=record.get("text", ""), + metadata={"record_id": record_id, "grade": record.get("grade", "")}, + ) + + +# ── ContentHashTracker ─────────────────────────── + +class TestContentHashTracker: + def test_compute_hash_deterministic(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"k": "v"}) + h2 = tracker.compute_hash("hello", {"k": "v"}) + assert h1 == h2 + + def test_compute_hash_changes_on_text(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"k": "v"}) + h2 = tracker.compute_hash("world", {"k": "v"}) + assert h1 != h2 + + def test_compute_hash_changes_on_metadata(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"grade": "六年级"}) + h2 = tracker.compute_hash("hello", {"grade": "七年级"}) + assert h1 != h2 + + def test_compute_hash_ignores_internal_keys(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"grade": "六年级", "chunk_index": 0}) + h2 = tracker.compute_hash("hello", {"grade": "六年级", "chunk_index": 5}) + assert h1 == h2 + + def test_compute_hash_custom_metadata_keys(self): + tracker = ContentHashTracker(hash_metadata_keys=["grade"]) + h1 = tracker.compute_hash("hello", {"grade": "六年级", "type": "A"}) + h2 = tracker.compute_hash("hello", {"grade": "六年级", "type": "B"}) + assert h1 == h2 # "type" is not in hash_metadata_keys + + def test_detect_all_new(self): + tracker = ContentHashTracker() + current = { + "q1": TextChunk("hello", {"record_id": "q1"}), + "q2": TextChunk("world", {"record_id": "q2"}), + } + changes = tracker.detect_changes(current) + assert len(changes.new) == 2 + assert len(changes.modified) == 0 + assert len(changes.deleted) == 0 + + def test_detect_no_changes(self): + tracker = ContentHashTracker() + current = {"q1": TextChunk("hello", {"record_id": "q1"})} + tracker.update(current) + changes = tracker.detect_changes(current) + assert not changes.has_changes + assert len(changes.unchanged) == 1 + + def test_detect_modified(self): + tracker = ContentHashTracker() + original = {"q1": TextChunk("hello", {"record_id": "q1"})} + tracker.update(original) + modified = {"q1": TextChunk("changed", {"record_id": "q1"})} + changes = tracker.detect_changes(modified) + assert len(changes.modified) == 1 + assert changes.modified[0] == "q1" + + def test_detect_deleted(self): + tracker = ContentHashTracker() + original = {"q1": TextChunk("a", {}), "q2": TextChunk("b", {})} + tracker.update(original) + current = {"q1": TextChunk("a", {})} + changes = tracker.detect_changes(current) + assert changes.deleted == {"q2"} + + def test_detect_mixed(self): + tracker = ContentHashTracker() + original = { + "q1": TextChunk("unchanged", {}), + "q2": TextChunk("will_modify", {}), + "q3": TextChunk("will_delete", {}), + } + tracker.update(original) + current = { + "q1": TextChunk("unchanged", {}), + "q2": TextChunk("modified_text", {}), + "q4": TextChunk("brand_new", {}), + } + changes = tracker.detect_changes(current) + assert set(changes.new) == {"q4"} + assert set(changes.modified) == {"q2"} + assert changes.deleted == {"q3"} + assert changes.unchanged == {"q1"} + + def test_update_and_remove(self): + tracker = ContentHashTracker() + records = {"q1": TextChunk("hello", {})} + tracker.update(records) + assert tracker.tracked_count == 1 + tracker.remove({"q1"}) + assert tracker.tracked_count == 0 + + def test_save_load_roundtrip(self, tmp_path): + tracker = ContentHashTracker(hash_metadata_keys=["grade"]) + records = { + "q1": TextChunk("hello", {"grade": "六年级"}), + "q2": TextChunk("world", {"grade": "七年级"}), + } + tracker.update(records) + + path = tmp_path / "hashes.json" + tracker.save(path) + + loaded = ContentHashTracker() + loaded.load(path) + assert loaded.tracked_count == 2 + # Should detect no changes + changes = loaded.detect_changes(records) + assert not changes.has_changes + + +class TestChangeSet: + def test_has_changes_true(self): + cs = ChangeSet(new=["a"], modified=[], deleted=set(), unchanged=set()) + assert cs.has_changes + + def test_has_changes_false(self): + cs = ChangeSet(new=[], modified=[], deleted=set(), unchanged={"a"}) + assert not cs.has_changes + + def test_repr(self): + cs = ChangeSet(new=["a"], modified=["b"], deleted={"c"}, unchanged={"d"}) + assert "new=1" in repr(cs) + assert "modified=1" in repr(cs) + + +# ── MetadataIndex.add_chunks ───────────────────── + +class TestMetadataIndexAddChunks: + def test_add_increases_total(self): + idx = MetadataIndex() + idx.build([TextChunk("a", {"grade": "六年级"})]) + assert idx.total_chunks == 1 + idx.add_chunks([TextChunk("b", {"grade": "七年级"})], start_id=1) + assert idx.total_chunks == 2 + + def test_add_chunks_findable(self): + idx = MetadataIndex() + idx.build([TextChunk("a", {"grade": "六年级"})]) + idx.add_chunks([TextChunk("b", {"grade": "七年级"})], start_id=1) + assert idx.lookup({"grade": "七年级"}) == {1} + assert idx.lookup({"grade": "六年级"}) == {0} + + def test_add_preserves_existing(self): + idx = MetadataIndex() + idx.build([TextChunk("a", {"grade": "六年级"})]) + original = idx.lookup({"grade": "六年级"}) + idx.add_chunks([TextChunk("b", {"grade": "七年级"})], start_id=1) + assert idx.lookup({"grade": "六年级"}) == original + + def test_add_chunks_list_metadata(self): + idx = MetadataIndex() + idx.build([]) + idx.add_chunks([TextChunk("a", {"tags": ["动词", "时态"]})], start_id=0) + assert idx.lookup({"tags": "动词"}) == {0} + assert idx.lookup({"tags": "时态"}) == {0} + + +# ── Soft Delete ────────────────────────────────── + +class TestSoftDelete: + def _make_engine(self, chunks, soft_deleted_ids=None): + bm25 = BM25Retriever() + bm25.build(chunks) + return HybridQueryEngine( + [bm25], + WeightedFusion(), + soft_deleted_ids=soft_deleted_ids, + ) + + def test_soft_delete_filters_results(self): + chunks = [ + TextChunk("Python编程", {"record_id": "q1"}), + TextChunk("Java编程", {"record_id": "q2"}), + ] + engine = self._make_engine(chunks, soft_deleted_ids={"q1"}) + results = engine.retrieve("编程", top_k=5) + record_ids = {r["metadata"]["record_id"] for r in results} + assert "q1" not in record_ids + assert "q2" in record_ids + + def test_add_and_clear_soft_deletes(self): + chunks = [TextChunk("test", {"record_id": "q1"})] + engine = self._make_engine(chunks) + assert engine.soft_delete_count == 0 + engine.add_soft_deletes({"q1"}) + assert engine.soft_delete_count == 1 + engine.clear_soft_deletes() + assert engine.soft_delete_count == 0 + + def test_backward_compat_no_soft_deletes(self): + """Default None doesn't break existing behavior.""" + chunks = [TextChunk("Python编程", {"record_id": "q1"})] + engine = self._make_engine(chunks) + results = engine.retrieve("编程", top_k=5) + assert len(results) > 0 + + +# ── build_incremental ──────────────────────────── + +class TestBuildIncremental: + def _make_fc(self): + """Create a FlowController with mocked embedder/indexer for testing.""" + from tinysearch.flow.controller import FlowController + + mock_embedder = MagicMock() + mock_embedder.embed.return_value = [[0.1] * 10] + + mock_indexer = MagicMock() + + bm25 = BM25Retriever() + substr = SubstringRetriever() + metadata_index = MetadataIndex() + + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=metadata_index, + ) + + fc = FlowController( + data_adapter=None, + text_splitter=MagicMock(), + embedder=mock_embedder, + indexer=mock_indexer, + query_engine=engine, + config={}, + ) + return fc + + def test_no_changes_returns_early(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + records = {"q1": {"text": "hello", "grade": "六年级"}} + # First build + tracker.update({rid: adapter.to_chunk(rid, r) for rid, r in records.items()}) + + # Second call — no changes + stats = fc.build_incremental(records, adapter, tracker) + assert not stats["full_rebuild"] + assert stats["new"] == 0 + assert stats["modified"] == 0 + + def test_new_records_detected(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + records = {"q1": {"text": "hello", "grade": "六年级"}} + stats = fc.build_incremental(records, adapter, tracker) + assert stats["new"] == 1 + assert stats["deleted"] == 0 + + def test_threshold_triggers_full_rebuild(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + # Build initial 150 records + initial = {f"q{i}": {"text": f"text{i}"} for i in range(150)} + tracker.update({rid: adapter.to_chunk(rid, r) for rid, r in initial.items()}) + + # Delete all of them (150 > threshold 100) + stats = fc.build_incremental({}, adapter, tracker, delete_rebuild_threshold=100) + assert stats["full_rebuild"] is True + assert stats["deleted"] == 150 + + def test_stats_correct(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + # Initial build + initial = { + "q1": {"text": "unchanged"}, + "q2": {"text": "will_modify"}, + "q3": {"text": "will_delete"}, + } + tracker.update({rid: adapter.to_chunk(rid, r) for rid, r in initial.items()}) + + # Incremental update + current = { + "q1": {"text": "unchanged"}, + "q2": {"text": "modified_text"}, + "q4": {"text": "brand_new"}, + } + stats = fc.build_incremental(current, adapter, tracker) + assert stats["new"] == 1 + assert stats["modified"] == 1 + assert stats["deleted"] == 1 + assert stats["unchanged"] == 1 + assert not stats["full_rebuild"] diff --git a/tests/test_record_adapter.py b/tests/test_record_adapter.py new file mode 100644 index 0000000..1bf74ae --- /dev/null +++ b/tests/test_record_adapter.py @@ -0,0 +1,87 @@ +""" +Tests for RecordAdapter and build_chunks_from_records. +""" +import pytest +from typing import Any, Dict + +from tinysearch.base import RecordAdapter, TextChunk, TextSplitter +from tinysearch.records import build_chunks_from_records + + +class SimpleAdapter(RecordAdapter): + """Test adapter: extracts 'text' field, puts rest in metadata.""" + + def to_chunk(self, record_id: str, record: Dict[str, Any]) -> TextChunk: + return TextChunk( + text=record.get("text", ""), + metadata={"record_id": record_id, "grade": record.get("grade", "")}, + ) + + +class TestRecordAdapterABC: + def test_abc_cannot_instantiate(self): + with pytest.raises(TypeError): + RecordAdapter() + + def test_concrete_implementation(self): + adapter = SimpleAdapter() + chunk = adapter.to_chunk("q1", {"text": "hello", "grade": "六年级"}) + assert isinstance(chunk, TextChunk) + assert chunk.text == "hello" + assert chunk.metadata["record_id"] == "q1" + assert chunk.metadata["grade"] == "六年级" + + +class TestBuildChunksFromRecords: + def test_basic_no_splitter(self): + adapter = SimpleAdapter() + records = { + "q1": {"text": "Python编程", "grade": "六年级"}, + "q2": {"text": "Java编程", "grade": "七年级"}, + "q3": {"text": "TinySearch", "grade": "八年级"}, + } + chunks = build_chunks_from_records(records, adapter) + assert len(chunks) == 3 + assert all(isinstance(c, TextChunk) for c in chunks) + assert [c.metadata["record_id"] for c in chunks] == ["q1", "q2", "q3"] + + def test_record_id_guaranteed_in_metadata(self): + """Even if adapter doesn't set record_id, it's auto-added.""" + + class NoIdAdapter(RecordAdapter): + def to_chunk(self, rid, record): + return TextChunk(text=record["text"], metadata={"grade": "test"}) + + chunks = build_chunks_from_records( + {"q1": {"text": "hello"}}, NoIdAdapter() + ) + assert chunks[0].metadata["record_id"] == "q1" + + def test_with_splitter(self): + """Long text gets split; metadata inherited.""" + from tinysearch.splitters import CharacterTextSplitter + + adapter = SimpleAdapter() + # Create a record with text long enough to be split + long_text = "A" * 500 + " " + "B" * 500 + records = {"q1": {"text": long_text, "grade": "六年级"}} + + splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=0) + chunks = build_chunks_from_records(records, adapter, splitter=splitter) + + assert len(chunks) > 1 + # All sub-chunks inherit the original metadata + for c in chunks: + assert c.metadata["record_id"] == "q1" + assert c.metadata["grade"] == "六年级" + assert "chunk_index" in c.metadata + + def test_empty_records(self): + chunks = build_chunks_from_records({}, SimpleAdapter()) + assert chunks == [] + + def test_order_preserved(self): + adapter = SimpleAdapter() + records = {"c": {"text": "C"}, "a": {"text": "A"}, "b": {"text": "B"}} + chunks = build_chunks_from_records(records, adapter) + assert [c.text for c in chunks] == ["C", "A", "B"] diff --git a/tinysearch/base.py b/tinysearch/base.py index 4463a0e..676103c 100644 --- a/tinysearch/base.py +++ b/tinysearch/base.py @@ -26,6 +26,32 @@ def extract(self, filepath: Union[str, pathlib.Path]) -> List[str]: pass +class RecordAdapter(ABC): + """ + Interface for adapters that convert structured records (dicts) to TextChunks. + Unlike DataAdapter (file-oriented), RecordAdapter works with in-memory data + from APIs, databases, or other programmatic sources. + """ + + @abstractmethod + def to_chunk(self, record_id: str, record: Dict[str, Any]) -> "TextChunk": + """ + Convert one record to a TextChunk with text and metadata. + + The implementation should: + - Extract/compose the searchable text from the record fields + - Build metadata dict including at minimum {"record_id": record_id} + + Args: + record_id: Unique identifier for the record + record: Record data as a dictionary + + Returns: + TextChunk with text and metadata derived from the record + """ + pass + + class TextChunk: """ Represents a chunk of text with optional metadata diff --git a/tinysearch/flow/controller.py b/tinysearch/flow/controller.py index fbace3b..8f6156c 100644 --- a/tinysearch/flow/controller.py +++ b/tinysearch/flow/controller.py @@ -21,9 +21,12 @@ class FlowController(FlowControllerBase): Manages the entire data processing from ingestion to query handling. """ + # Soft-delete threshold: trigger full rebuild when exceeded + DELETE_REBUILD_THRESHOLD = 100 + def __init__( self, - data_adapter: DataAdapter, + data_adapter: Optional[DataAdapter], text_splitter: TextSplitter, embedder: Embedder, indexer: VectorIndexer, @@ -32,9 +35,9 @@ def __init__( ): """ Initialize the flow controller with all required components - + Args: - data_adapter: Component for data extraction + data_adapter: Component for data extraction (None when using record-based API) text_splitter: Component for text chunking embedder: Component for generating embeddings indexer: Component for vector indexing and search @@ -197,6 +200,145 @@ def build_index(self, data_path: Union[str, Path], **kwargs) -> None: else: self.process_file(data_path, force_reprocess=force_reprocess) + def build_from_records( + self, + records: Dict[str, Dict[str, Any]], + adapter: Any, + splitter: Optional[TextSplitter] = None, + ) -> List[TextChunk]: + """ + Build the search index from in-memory records. + + Args: + records: Mapping of record_id -> record_data + adapter: RecordAdapter to convert records to TextChunks + splitter: Optional TextSplitter for further chunking + + Returns: + List of TextChunks that were indexed + """ + from tinysearch.records import build_chunks_from_records + + chunks = build_chunks_from_records(records, adapter, splitter) + if not chunks: + return chunks + + vectors = self.embedder.embed([c.text for c in chunks]) + self.indexer.build(vectors, chunks) + self._build_retriever_indexes(chunks) + + return chunks + + def build_incremental( + self, + records: Dict[str, Dict[str, Any]], + adapter: Any, + hash_tracker: Any, + splitter: Optional[TextSplitter] = None, + delete_rebuild_threshold: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Incrementally update the search index based on record changes. + + Pipeline: + 1. adapter.to_chunk() for each record + 2. hash_tracker.detect_changes() → new/modified/deleted + 3. If soft deletes exceed threshold → full rebuild + 4. Else → FAISS.add() + MetadataIndex.add_chunks() for new/modified + 5. BM25/Substring always full rebuild (fast) + + Args: + records: Current complete set of records {record_id: record_data} + adapter: RecordAdapter to convert records to TextChunks + hash_tracker: ContentHashTracker for change detection + splitter: Optional TextSplitter + delete_rebuild_threshold: Max soft deletes before full rebuild + + Returns: + Dict with stats: {new, modified, deleted, unchanged, full_rebuild} + """ + from tinysearch.records import build_chunks_from_records + + threshold = delete_rebuild_threshold or self.DELETE_REBUILD_THRESHOLD + + # Step 1: Convert all current records to chunks for change detection + current_record_chunks: Dict[str, TextChunk] = {} + for rid, rdata in records.items(): + chunk = adapter.to_chunk(rid, rdata) + if "record_id" not in chunk.metadata: + chunk.metadata["record_id"] = rid + current_record_chunks[rid] = chunk + + # Step 2: Detect changes + changes = hash_tracker.detect_changes(current_record_chunks) + + stats = { + "new": len(changes.new), + "modified": len(changes.modified), + "deleted": len(changes.deleted), + "unchanged": len(changes.unchanged), + "full_rebuild": False, + } + + if not changes.has_changes: + return stats + + # Step 3: Check if full rebuild needed + total_soft_deletes = len(changes.deleted) + len(changes.modified) + if isinstance(self.query_engine, HybridQueryEngine): + total_soft_deletes += self.query_engine.soft_delete_count + + if total_soft_deletes >= threshold: + # Full rebuild path + stats["full_rebuild"] = True + self.build_from_records(records, adapter, splitter) + + if isinstance(self.query_engine, HybridQueryEngine): + self.query_engine.clear_soft_deletes() + + hash_tracker.remove(changes.deleted) + hash_tracker.update(current_record_chunks) + return stats + + # Step 4: Incremental path + + # 4a: Soft-delete modified + deleted + if isinstance(self.query_engine, HybridQueryEngine): + ids_to_soft_delete = changes.deleted | set(changes.modified) + if ids_to_soft_delete: + self.query_engine.add_soft_deletes(ids_to_soft_delete) + + # 4b: Embed and add new + modified to FAISS + records_to_add = {rid: records[rid] for rid in changes.new + changes.modified} + if records_to_add: + new_chunks = build_chunks_from_records(records_to_add, adapter, splitter) + if new_chunks: + vectors = self.embedder.embed([c.text for c in new_chunks]) + self.indexer.add(vectors, new_chunks) + + # Incrementally add to MetadataIndex + if (isinstance(self.query_engine, HybridQueryEngine) + and self.query_engine.metadata_index is not None): + start_id = self.query_engine.metadata_index.total_chunks + self.query_engine.metadata_index.add_chunks(new_chunks, start_id) + + # 4c: BM25/Substring always full rebuild (fast) + all_current_chunks = build_chunks_from_records(records, adapter, splitter) + if isinstance(self.query_engine, HybridQueryEngine): + for retriever in self.query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue + retriever.build(all_current_chunks) + + # Step 5: Update hash tracker + hash_tracker.remove(changes.deleted) + hash_tracker.update({ + rid: current_record_chunks[rid] + for rid in changes.new + changes.modified + }) + + return stats + def _get_hybrid_retrievers(self) -> List[Retriever]: """Get retrievers from HybridQueryEngine, if applicable""" if isinstance(self.query_engine, HybridQueryEngine): @@ -248,6 +390,11 @@ def _save_retriever_indexes(self, base_path: Path) -> None: if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.metadata_index is not None: self.query_engine.metadata_index.save(index_dir / "metadata_index.json") + # Save soft-delete set + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.soft_deleted_ids: + with open(index_dir / "soft_deletes.json", "w") as f: + json.dump(sorted(self.query_engine.soft_deleted_ids), f) + def load_index(self, path: Optional[Union[str, Path]] = None) -> None: """ Load an index from disk. @@ -281,7 +428,14 @@ def _load_retriever_indexes(self, base_path: Path) -> None: metadata_path = index_dir / "metadata_index.json" if metadata_path.exists(): self.query_engine.metadata_index.load(metadata_path) - + + # Load soft-delete set + if isinstance(self.query_engine, HybridQueryEngine): + soft_delete_path = index_dir / "soft_deletes.json" + if soft_delete_path.exists(): + with open(soft_delete_path, "r") as f: + self.query_engine.soft_deleted_ids = set(json.load(f)) + def query(self, query_text: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ Process a query and return relevant chunks diff --git a/tinysearch/indexers/__init__.py b/tinysearch/indexers/__init__.py index 86c9aa0..d88d23f 100644 --- a/tinysearch/indexers/__init__.py +++ b/tinysearch/indexers/__init__.py @@ -7,10 +7,13 @@ # Make the FAISS indexer available from the root module from .faiss_indexer import FAISSIndexer from .metadata_index import MetadataIndex +from .hash_tracker import ContentHashTracker, ChangeSet __all__ = [ "FAISSIndexer", "MetadataIndex", + "ContentHashTracker", + "ChangeSet", ] def index_exists(path: Union[str, Path]) -> bool: diff --git a/tinysearch/indexers/hash_tracker.py b/tinysearch/indexers/hash_tracker.py new file mode 100644 index 0000000..061fe5b --- /dev/null +++ b/tinysearch/indexers/hash_tracker.py @@ -0,0 +1,153 @@ +""" +Content hash tracking for incremental indexing change detection. +""" +import hashlib +import json +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from tinysearch.base import TextChunk + +logger = logging.getLogger(__name__) + + +class ChangeSet: + """Result of change detection between current and tracked records.""" + + __slots__ = ("new", "modified", "deleted", "unchanged") + + def __init__( + self, + new: List[str], + modified: List[str], + deleted: Set[str], + unchanged: Set[str], + ): + self.new = new + self.modified = modified + self.deleted = deleted + self.unchanged = unchanged + + @property + def has_changes(self) -> bool: + return bool(self.new or self.modified or self.deleted) + + def __repr__(self) -> str: + return ( + f"ChangeSet(new={len(self.new)}, modified={len(self.modified)}, " + f"deleted={len(self.deleted)}, unchanged={len(self.unchanged)})" + ) + + +class ContentHashTracker: + """ + Track content hashes for record-level change detection. + + Each record is identified by its record_id (string). The hash is + computed from the text and a configurable subset of metadata keys. + """ + + _INTERNAL_KEYS = frozenset({"chunk_index", "total_chunks"}) + + def __init__(self, hash_metadata_keys: Optional[List[str]] = None): + """ + Args: + hash_metadata_keys: Metadata keys to include in hash. + If None, all keys except internal ones are included. + """ + self._hashes: Dict[str, str] = {} + self._hash_metadata_keys = hash_metadata_keys + + def compute_hash(self, text: str, metadata: Dict[str, Any]) -> str: + """Compute MD5 hash of text + sorted metadata.""" + hasher = hashlib.md5() + hasher.update(text.encode("utf-8")) + + if self._hash_metadata_keys is not None: + keys = self._hash_metadata_keys + else: + keys = sorted(k for k in metadata if k not in self._INTERNAL_KEYS) + + for key in sorted(keys): + if key in metadata: + val = metadata[key] + # Normalize lists for deterministic hashing + if isinstance(val, list): + val = json.dumps(sorted(str(v) for v in val), ensure_ascii=False) + hasher.update(f"{key}={val}".encode("utf-8")) + + return hasher.hexdigest() + + def detect_changes(self, current_records: Dict[str, TextChunk]) -> ChangeSet: + """ + Compare current records against tracked hashes. + + Args: + current_records: Mapping of record_id -> TextChunk + + Returns: + ChangeSet with new, modified, deleted, unchanged + """ + tracked_ids = set(self._hashes.keys()) + + new_ids: List[str] = [] + modified_ids: List[str] = [] + unchanged_ids: Set[str] = set() + + for rid, chunk in current_records.items(): + h = self.compute_hash(chunk.text, chunk.metadata) + if rid not in tracked_ids: + new_ids.append(rid) + elif self._hashes[rid] != h: + modified_ids.append(rid) + else: + unchanged_ids.add(rid) + + deleted_ids = tracked_ids - set(current_records.keys()) + + return ChangeSet( + new=new_ids, + modified=modified_ids, + deleted=deleted_ids, + unchanged=unchanged_ids, + ) + + def update(self, records: Dict[str, TextChunk]) -> None: + """Update tracked hashes after a successful build.""" + for rid, chunk in records.items(): + self._hashes[rid] = self.compute_hash(chunk.text, chunk.metadata) + + def remove(self, record_ids: Set[str]) -> None: + """Remove record_ids from tracking.""" + for rid in record_ids: + self._hashes.pop(rid, None) + + @property + def tracked_count(self) -> int: + return len(self._hashes) + + def save(self, path: Union[str, Path]) -> None: + """Save hash state to JSON.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump( + { + "version": 1, + "hash_metadata_keys": self._hash_metadata_keys, + "hashes": self._hashes, + }, + f, + ensure_ascii=False, + ) + logger.info("ContentHashTracker saved to %s (%d records)", path, len(self._hashes)) + + def load(self, path: Union[str, Path]) -> None: + """Load hash state from JSON.""" + path = Path(path) + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + self._hashes = data["hashes"] + self._hash_metadata_keys = data.get("hash_metadata_keys") + logger.info("ContentHashTracker loaded from %s (%d records)", path, len(self._hashes)) diff --git a/tinysearch/indexers/metadata_index.py b/tinysearch/indexers/metadata_index.py index d8bc57a..2dabf5a 100644 --- a/tinysearch/indexers/metadata_index.py +++ b/tinysearch/indexers/metadata_index.py @@ -65,6 +65,33 @@ def build(self, chunks: List[TextChunk]) -> None: field_stats, ) + def add_chunks(self, chunks: List[TextChunk], start_id: int) -> None: + """ + Incrementally add new chunks to the inverted index. + + Args: + chunks: New TextChunks to index + start_id: First chunk ID to assign (typically current total_chunks) + """ + for offset, chunk in enumerate(chunks): + chunk_id = start_id + offset + if not chunk.metadata: + continue + for key, value in chunk.metadata.items(): + if isinstance(value, _INDEXABLE_TYPES): + self._index[key][value].add(chunk_id) + elif isinstance(value, list): + for element in value: + if isinstance(element, _INDEXABLE_TYPES): + self._index[key][element].add(chunk_id) + + self._total_chunks += len(chunks) + logger.info( + "MetadataIndex: added %d chunks (total now %d)", + len(chunks), + self._total_chunks, + ) + def lookup(self, filters: Dict[str, FilterValue]) -> Optional[Set[int]]: """ Resolve filters to a candidate ID set via inverted index. diff --git a/tinysearch/query/hybrid.py b/tinysearch/query/hybrid.py index 82e5066..766491e 100644 --- a/tinysearch/query/hybrid.py +++ b/tinysearch/query/hybrid.py @@ -39,6 +39,7 @@ def __init__( filter_multiplier: int = 3, metadata_index=None, filter_mode: str = "auto", + soft_deleted_ids: Optional[set] = None, ): """ Args: @@ -51,6 +52,7 @@ def __init__( metadata_index: Optional MetadataIndex for inverted-index pre-filtering filter_mode: "pre" (always pre-filter), "post" (always post-filter), or "auto" (pre-filter indexable parts, post-filter callables) + soft_deleted_ids: Optional set of record_ids to exclude from results """ if not retrievers: raise ValueError("At least one retriever is required") @@ -69,6 +71,19 @@ def __init__( self.filter_multiplier = filter_multiplier self.metadata_index = metadata_index self.filter_mode = filter_mode + self.soft_deleted_ids: set = soft_deleted_ids or set() + + def add_soft_deletes(self, record_ids: set) -> None: + """Mark record_ids as soft-deleted (excluded from results).""" + self.soft_deleted_ids |= record_ids + + def clear_soft_deletes(self) -> None: + """Clear all soft deletes (typically after a full rebuild).""" + self.soft_deleted_ids.clear() + + @property + def soft_delete_count(self) -> int: + return len(self.soft_deleted_ids) def format_query(self, query: str) -> str: """Pass-through: hybrid engine doesn't transform queries""" @@ -184,6 +199,13 @@ def _retrieve_pipeline( fuse_kwargs["weights"] = weights fused = self.fusion_strategy.fuse(per_retriever, **fuse_kwargs) + # Step 3.5: Remove soft-deleted results + if self.soft_deleted_ids: + fused = [ + r for r in fused + if r.get("metadata", {}).get("record_id") not in self.soft_deleted_ids + ] + fused_before_rerank = list(fused) # Step 4: Optional reranking diff --git a/tinysearch/records.py b/tinysearch/records.py new file mode 100644 index 0000000..d01db5b --- /dev/null +++ b/tinysearch/records.py @@ -0,0 +1,51 @@ +""" +Utilities for building TextChunks from structured records via RecordAdapter. +""" +import logging +from typing import Any, Dict, List, Optional + +from tinysearch.base import RecordAdapter, TextChunk, TextSplitter + +logger = logging.getLogger(__name__) + + +def build_chunks_from_records( + records: Dict[str, Dict[str, Any]], + adapter: RecordAdapter, + splitter: Optional[TextSplitter] = None, +) -> List[TextChunk]: + """ + Convert a dict of records to TextChunks ready for indexing. + + Args: + records: Mapping of record_id -> record_data + adapter: RecordAdapter that converts each record to a TextChunk + splitter: Optional TextSplitter. If None, each record becomes exactly + one TextChunk. If provided, each record's text is further + split, with the original metadata inherited by all sub-chunks. + + Returns: + Ordered list of TextChunks (order follows dict iteration order) + """ + chunks: List[TextChunk] = [] + + for record_id, record_data in records.items(): + chunk = adapter.to_chunk(record_id, record_data) + + # Ensure record_id is always in metadata + if "record_id" not in chunk.metadata: + chunk.metadata["record_id"] = record_id + + if splitter is None: + chunks.append(chunk) + else: + sub_chunks = splitter.split([chunk.text], [chunk.metadata]) + chunks.extend(sub_chunks) + + logger.info( + "Built %d chunks from %d records (splitter=%s)", + len(chunks), + len(records), + type(splitter).__name__ if splitter else "None", + ) + return chunks