diff --git a/README.md b/README.md index 9641b9a..84c284f 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,17 @@ # TinySearch -**A lightweight vector search and retrieval system for document understanding, semantic search, and text information retrieval.** +**A lightweight hybrid search and retrieval system for document understanding, semantic search, and text information retrieval.** -TinySearch provides an end-to-end solution for converting documents into searchable vector embeddings, with a focus on simplicity, flexibility, and efficiency. Perfect for building RAG (Retrieval-Augmented Generation) systems, semantic document search, knowledge bases, and more. +TinySearch provides an end-to-end solution for converting documents into searchable vector embeddings, with support for hybrid multi-retriever search (Vector + BM25 + Substring). Focused on simplicity, flexibility, and efficiency. Perfect for building RAG (Retrieval-Augmented Generation) systems, semantic document search, knowledge bases, and more. ## Key Features - ๐Ÿงฉ **Modular Design**: Plug-and-play components for data processing, text splitting, embedding generation, and vector retrieval - ๐Ÿ” **Semantic Search**: Find contextually relevant information beyond keyword matching +- ๐Ÿ”€ **Hybrid Retrieval**: Combine Vector, BM25, and Substring retrievers with configurable fusion strategies (RRF, Weighted) +- ๐Ÿ“ **BM25 Keyword Search**: Fast keyword-based retrieval with jieba Chinese tokenization support +- ๐Ÿ”ค **Substring/Regex Search**: Ctrl+F style exact match and regex pattern search +- ๐Ÿ† **Cross-Encoder Reranking**: Optional reranking with BGE Reranker for improved relevance - โš™๏ธ **Highly Configurable**: Simple YAML configuration to control all aspects of the system - ๐Ÿ”Œ **Multiple Input Formats**: Support for TXT, PDF, CSV, Markdown, JSON, and custom adapters - ๐Ÿค– **Embedding Models**: Integration with HuggingFace models like Qwen-Embedding and more @@ -29,55 +33,63 @@ flowchart TB subgraph Input Documents["Documents
(PDF, Text, CSV, JSON, etc.)"] end - + subgraph DataProcessing["Data Processing"] DataAdapter["DataAdapter
Extract text from data source"] TextSplitter["TextSplitter
Chunk text into segments"] - Embedder["Embedder
Generate vector embeddings"] end - - subgraph IndexLayer["Index Layer"] - VectorIndexer["VectorIndexer
Build and maintain FAISS index"] - IndexStorage["Index Storage
(FAISS Index + Original Text)"] + + subgraph RetrieverLayer["Retriever Layer"] + VectorRetriever["VectorRetriever
Embedder + FAISS"] + BM25Retriever["BM25Retriever
bm25s + jieba"] + SubstringRetriever["SubstringRetriever
Regex / exact match"] end - + + subgraph FusionLayer["Fusion & Reranking"] + FusionStrategy["FusionStrategy
(RRF / Weighted)"] + Reranker["Reranker (optional)
Cross-Encoder BGE"] + end + subgraph QueryLayer["Query Layer"] UserQuery["User Query"] - QueryEngine["QueryEngine
Process and reformat query"] + QueryEngine["QueryEngine
Template / Hybrid"] SearchResults["Search Results
Ranked by relevance"] end - + subgraph FlowControl["Flow Control"] Config["Configuration"] FlowController["FlowController
Orchestrate data flow"] end - + subgraph API["API Layer"] CLI["Command Line Interface"] FastAPI["FastAPI Web Service"] end - + %% Data Flow - Indexing Documents --> DataAdapter DataAdapter --> TextSplitter - TextSplitter --> Embedder - Embedder --> VectorIndexer - VectorIndexer --> IndexStorage - + TextSplitter --> VectorRetriever + TextSplitter --> BM25Retriever + TextSplitter --> SubstringRetriever + %% Data Flow - Querying UserQuery --> QueryEngine - QueryEngine --> Embedder - Embedder --> VectorIndexer - VectorIndexer --> SearchResults - + QueryEngine --> VectorRetriever + QueryEngine --> BM25Retriever + QueryEngine --> SubstringRetriever + VectorRetriever --> FusionStrategy + BM25Retriever --> FusionStrategy + SubstringRetriever --> FusionStrategy + FusionStrategy --> Reranker + Reranker --> SearchResults + %% Control Flow Config --> FlowController FlowController --> DataAdapter FlowController --> TextSplitter - FlowController --> Embedder - FlowController --> VectorIndexer FlowController --> QueryEngine - + %% API Flow CLI --> FlowController FastAPI --> FlowController @@ -152,9 +164,15 @@ For API documentation, see [API Guide](docs/api.md) and [API Authentication Guid ## Installation ```bash -# Basic installation +# Basic installation (vector search only) pip install tinysearch +# With hybrid search (BM25 + Chinese tokenization) +pip install tinysearch bm25s jieba + +# With cross-encoder reranking +pip install tinysearch FlagEmbedding + # With API support pip install tinysearch[api] @@ -195,7 +213,7 @@ indexer: type: faiss index_path: .cache/index.faiss metric: cosine - + query_engine: method: template template: "Please find information about: {query}" @@ -226,6 +244,118 @@ Then visit http://localhost:8000 in your browser to use the web interface, or se curl -X POST http://localhost:8000/query -H "Content-Type: application/json" -d '{"query": "Your search query", "top_k": 5}' ``` +## Hybrid Search + +TinySearch supports hybrid retrieval that combines multiple search strategies for better recall and precision. You can mix Vector (semantic), BM25 (keyword), and Substring (exact match) retrievers, then fuse their results using RRF or Weighted fusion. + +### Hybrid Search Configuration + +To enable hybrid search, set `query_engine.method` to `"hybrid"` and configure the `retrievers` list: + +```yaml +# Embedding + Vector index (required for vector retriever) +embedder: + model: Qwen/Qwen3-Embedding-0.6B + device: cuda +indexer: + type: faiss + index_path: .cache/index.faiss + metric: cosine + +# Multi-retriever configuration +retrievers: + - type: vector # Semantic search (uses embedder + indexer above) + - type: bm25 # Keyword search + tokenizer: jieba # Chinese tokenization (optional, fallback: whitespace) + - type: substring # Exact match / regex + is_regex: false + +# Fusion strategy +fusion: + strategy: weighted # "weighted" or "rrf" + weights: [0.5, 0.4, 0.1] # Weights for each retriever (vector, bm25, substring) + min_score: 0.1 # Drop results below this fused score + +# Optional: Cross-encoder reranking +reranker: + enabled: false + model: BAAI/bge-reranker-v2-m3 + +# Query engine +query_engine: + method: hybrid # "template" (vector only) or "hybrid" + top_k: 20 + recall_multiplier: 2 # Each retriever recalls top_k * 2 candidates before fusion +``` + +### Fusion Strategies + +| Strategy | Description | Best For | +|----------|-------------|----------| +| **weighted** | Min-max normalize scores, then weighted sum | When you know relative retriever importance | +| **rrf** | Reciprocal Rank Fusion: `score = ฮฃ 1/(rank + k)` | When score distributions differ (more robust) | + +### Programmatic Usage + +```python +from tinysearch.base import TextChunk +from tinysearch.retrievers import VectorRetriever, BM25Retriever, SubstringRetriever +from tinysearch.fusion import WeightedFusion, ReciprocalRankFusion +from tinysearch.query import HybridQueryEngine + +# Build retrievers +bm25 = BM25Retriever() +bm25.build(chunks) # chunks: List[TextChunk] + +substr = SubstringRetriever(is_regex=False) +substr.build(chunks) + +# Create hybrid engine +engine = HybridQueryEngine( + retrievers=[bm25, substr], + fusion_strategy=WeightedFusion(weights=[0.7, 0.3]), + recall_multiplier=2, +) + +results = engine.retrieve("ๆœ็ดขๅ…ณ้”ฎ่ฏ", top_k=10) +# Each result: {"text", "metadata", "score", "retrieval_method": "hybrid", "scores": {...}} +``` + +### Result Format + +Hybrid search results include per-retriever scores for transparency: + +```python +{ + "text": "...", # Chunk text + "metadata": {...}, # Source metadata + "score": 0.85, # Final fused score [0, 1] + "retrieval_method": "hybrid", + "scores": { # Per-retriever original scores + "vector": 0.92, + "bm25": 0.75, + } +} +``` + +### Optional Dependencies for Hybrid Search + +| Package | Purpose | Required? | +|---------|---------|-----------| +| `bm25s` | BM25 retrieval engine | Only for BM25Retriever | +| `jieba` | Chinese tokenization | Optional (fallback: whitespace split) | +| `FlagEmbedding` | Cross-encoder reranking | Optional (only for reranker) | + +Install with: +```bash +pip install bm25s jieba # For BM25 + Chinese support +pip install FlagEmbedding # For cross-encoder reranking +``` + +### Backward Compatibility + +If you don't configure `retrievers` or keep `query_engine.method: "template"`, TinySearch behaves exactly as before โ€” pure vector search with no additional dependencies. + ## Examples TinySearch includes various example scripts in the `examples/` directory to demonstrate different features: @@ -303,14 +433,16 @@ The TinySearch Web UI provides an intuitive interface for interacting with the s ## Custom Data Adapters -You can create custom data adapters by implementing the `DataAdapter` interface: +You can create custom data adapters by implementing the `DataAdapter` interface. +Each adapter handles **single file** extraction โ€” directory traversal is handled +automatically by the framework: ```python from tinysearch.base import DataAdapter class MyAdapter(DataAdapter): def extract(self, filepath): - # Your code to extract text from the file + # Extract text from a single file return [text1, text2, ...] ``` @@ -326,6 +458,38 @@ adapter: param1: value1 ``` +## Directory Indexing + +When building an index from a directory, TinySearch uses a centralized file +discovery mechanism (`iter_input_files()`). Each adapter type has default +file extensions: + +| Adapter | Extensions | +|---------|-----------| +| text | `.txt`, `.text`, `.md`, `.py`, `.js`, `.html`, `.css`, `.json` | +| pdf | `.pdf` | +| csv | `.csv` | +| markdown | `.md`, `.markdown`, `.mdown`, `.mkdn` | +| json | `.json` | + +You can override extensions in `config.yaml`: + +```yaml +adapter: + type: text + params: + extensions: [".txt", ".log"] +``` + +Or use `iter_input_files()` programmatically: + +```python +from tinysearch.utils.file_discovery import iter_input_files + +for file_path in iter_input_files("./data", adapter_type="text"): + texts = adapter.extract(file_path) +``` + ## Modern Logging System TinySearch features a modern logging system powered by [loguru](https://github.com/Delgan/loguru) with beautiful, colorful output and flexible configuration options. @@ -426,4 +590,4 @@ This project is licensed under the MIT License - see the LICENSE file for detail ## Keywords -vector search, semantic search, document retrieval, embeddings, FAISS, RAG, information retrieval, text search, vector database, document understanding, NLP, natural language processing, AI search \ No newline at end of file +vector search, semantic search, hybrid search, BM25, document retrieval, embeddings, FAISS, RAG, information retrieval, text search, vector database, document understanding, NLP, natural language processing, AI search, reranking, fusion \ No newline at end of file diff --git a/setup.py b/setup.py index e0cbadc..eda9364 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,13 @@ "indexers": [ "faiss-cpu>=1.7.0", ], + "hybrid": [ + "bm25s>=0.1.0", + "jieba>=0.42.0", + ], + "reranker": [ + "FlagEmbedding>=1.0.0", + ], "dev": [ "pytest>=6.0.0", "black>=21.5b2", @@ -64,7 +71,7 @@ version=version.get("__version__", "0.1.0"), author="TinySearch Team", author_email="tinysearch@example.com", - description="A lightweight vector retrieval system", + description="A lightweight hybrid search and retrieval system (Vector + BM25 + Substring)", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/yourusername/tinysearch", diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py new file mode 100644 index 0000000..6d13656 --- /dev/null +++ b/tests/test_hybrid.py @@ -0,0 +1,618 @@ +""" +Tests for hybrid search: HybridQueryEngine, FusionStrategies, and CLI/FlowController integration. +""" +import pytest +from unittest.mock import MagicMock, patch +from pathlib import Path + +from tinysearch.base import TextChunk, QueryEngine, Retriever +from tinysearch.utils.file_discovery import iter_input_files +from tinysearch.retrievers.bm25_retriever import BM25Retriever +from tinysearch.retrievers.substring_retriever import SubstringRetriever +from tinysearch.fusion.weighted import WeightedFusion +from tinysearch.fusion.rrf import ReciprocalRankFusion +from tinysearch.fusion._utils import make_doc_key +from tinysearch.query.hybrid import HybridQueryEngine +from tinysearch.indexers.metadata_index import MetadataIndex + + +# โ”€โ”€ Fixtures โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +@pytest.fixture +def sample_chunks(): + return [ + TextChunk("Python็ผ–็จ‹่ฏญ่จ€", {"source": "a.txt", "grade": "ๅ…ญๅนด็บงไธŠ", "chunk_index": 0}), + TextChunk("Java็ผ–็จ‹่ฏญ่จ€", {"source": "b.txt", "grade": "ๅ…ญๅนด็บงไธ‹", "chunk_index": 0}), + TextChunk("TinySearchๆฃ€็ดข็ณป็ปŸ", {"source": "c.txt", "grade": "ไธƒๅนด็บง", "chunk_index": 0}), + ] + + +@pytest.fixture +def hybrid_engine(sample_chunks): + bm25 = BM25Retriever() + bm25.build(sample_chunks) + substr = SubstringRetriever() + substr.build(sample_chunks) + return HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + min_scores=[0.0, 0.0], + ) + + +# โ”€โ”€ HybridQueryEngine basic โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestHybridQueryEngine: + def test_retrieve_returns_results(self, hybrid_engine): + results = hybrid_engine.retrieve("็ผ–็จ‹", top_k=5) + assert len(results) > 0 + assert all("text" in r and "score" in r for r in results) + + def test_backward_compat_no_kwargs(self, hybrid_engine): + """retrieve() works with no extra kwargs (backward compatible).""" + results = hybrid_engine.retrieve("็ผ–็จ‹", top_k=5) + assert isinstance(results, list) + + def test_empty_retrievers_raises(self): + with pytest.raises(ValueError, match="At least one retriever"): + HybridQueryEngine([], WeightedFusion()) + + def test_min_scores_length_mismatch(self, sample_chunks): + bm25 = BM25Retriever() + bm25.build(sample_chunks) + with pytest.raises(ValueError, match="min_scores length"): + HybridQueryEngine([bm25], WeightedFusion(), min_scores=[0.0, 0.0]) + + +# โ”€โ”€ Metadata filtering โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestMetadataFiltering: + def test_exact_match_filter(self, hybrid_engine): + results = hybrid_engine.retrieve("็ผ–็จ‹", top_k=5, filters={"source": "a.txt"}) + assert all(r["metadata"]["source"] == "a.txt" for r in results) + + def test_list_or_filter(self, hybrid_engine): + results = hybrid_engine.retrieve( + "็ผ–็จ‹", top_k=5, + filters={"grade": ["ๅ…ญๅนด็บงไธŠ", "ๅ…ญๅนด็บงไธ‹"]}, + ) + assert len(results) > 0 + assert all(r["metadata"]["grade"].startswith("ๅ…ญๅนด็บง") for r in results) + + def test_callable_filter(self, hybrid_engine): + results = hybrid_engine.retrieve( + "็ผ–็จ‹", top_k=5, + filters={"source": lambda v: v in ["a.txt", "b.txt"]}, + ) + assert all(r["metadata"]["source"] in ["a.txt", "b.txt"] for r in results) + + def test_missing_key_excluded(self, hybrid_engine): + """Chunks missing a filter key should be excluded.""" + results = hybrid_engine.retrieve( + "็ผ–็จ‹", top_k=5, + filters={"nonexistent_key": "value"}, + ) + assert len(results) == 0 + + def test_none_metadata_excluded(self): + assert HybridQueryEngine._match_filters(None, {"key": "val"}) is False + + def test_filter_over_recall(self, hybrid_engine): + """With filters, recall_k should be multiplied by filter_multiplier.""" + # Default filter_multiplier=3, recall_multiplier=2 + # So recall_k = 5 * 2 * 3 = 30 + # This is tested implicitly: we still get results despite filters + results = hybrid_engine.retrieve( + "็ผ–็จ‹", top_k=5, + filters={"grade": ["ๅ…ญๅนด็บงไธŠ"]}, + ) + assert len(results) > 0 + + +# โ”€โ”€ Dynamic weights โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestDynamicWeights: + def test_dynamic_weights_override(self, hybrid_engine): + """Passing weights= at query time should override constructor weights.""" + r1 = hybrid_engine.retrieve("็ผ–็จ‹", top_k=5, weights=[1.0, 0.0]) + r2 = hybrid_engine.retrieve("็ผ–็จ‹", top_k=5, weights=[0.0, 1.0]) + # Results should differ because weights favor different retrievers + # At minimum, both should return results + assert len(r1) > 0 + assert len(r2) > 0 + + +# โ”€โ”€ min_scores โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestMinScores: + def test_high_min_scores_filters_everything(self, sample_chunks): + bm25 = BM25Retriever() + bm25.build(sample_chunks) + substr = SubstringRetriever() + substr.build(sample_chunks) + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.5, 0.5]), + min_scores=[999.0, 999.0], + ) + results = engine.retrieve("็ผ–็จ‹", top_k=5) + assert len(results) == 0 + + def test_zero_min_scores_passes_all(self, hybrid_engine): + results = hybrid_engine.retrieve("็ผ–็จ‹", top_k=5) + assert len(results) > 0 + + +# โ”€โ”€ retrieve_with_details โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestRetrieveWithDetails: + def test_returns_all_keys(self, hybrid_engine): + d = hybrid_engine.retrieve_with_details("็ผ–็จ‹", top_k=5) + assert "results" in d + assert "per_retriever" in d + assert "fused_before_rerank" in d + + def test_per_retriever_count_matches(self, hybrid_engine): + d = hybrid_engine.retrieve_with_details("็ผ–็จ‹", top_k=5) + assert len(d["per_retriever"]) == len(hybrid_engine.retrievers) + + def test_results_match_retrieve(self, hybrid_engine): + """retrieve_with_details().results should equal retrieve().""" + r1 = hybrid_engine.retrieve("็ผ–็จ‹", top_k=5) + r2 = hybrid_engine.retrieve_with_details("็ผ–็จ‹", top_k=5)["results"] + assert len(r1) == len(r2) + for a, b in zip(r1, r2): + assert a["text"] == b["text"] + + +# โ”€โ”€ Fusion dedup key โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestFusionDedupKey: + def test_same_text_different_source_not_merged(self): + """Two chunks with same text but different source should stay separate.""" + chunks = [ + TextChunk("Same text", {"source": "fileA.txt", "chunk_index": 0}), + TextChunk("Same text", {"source": "fileB.txt", "chunk_index": 0}), + ] + bm25 = BM25Retriever() + bm25.build(chunks) + substr = SubstringRetriever() + substr.build(chunks) + engine = HybridQueryEngine( + [bm25, substr], WeightedFusion([0.5, 0.5]) + ) + results = engine.retrieve("Same text", top_k=5) + sources = {r["metadata"]["source"] for r in results} + assert sources == {"fileA.txt", "fileB.txt"} + + def test_same_text_different_chunk_index_not_merged(self): + """Same text at different positions in same file should stay separate.""" + chunks = [ + TextChunk("Repeated", {"source": "f.txt", "chunk_index": 0}), + TextChunk("Repeated", {"source": "f.txt", "chunk_index": 5}), + ] + bm25 = BM25Retriever() + bm25.build(chunks) + substr = SubstringRetriever() + substr.build(chunks) + engine = HybridQueryEngine( + [bm25, substr], WeightedFusion([0.5, 0.5]) + ) + results = engine.retrieve("Repeated", top_k=5) + indices = {r["metadata"]["chunk_index"] for r in results} + assert indices == {0, 5} + + def test_make_doc_key_different_source(self): + r1 = {"text": "hello", "metadata": {"source": "a"}} + r2 = {"text": "hello", "metadata": {"source": "b"}} + assert make_doc_key(r1) != make_doc_key(r2) + + def test_make_doc_key_same_chunk(self): + """Same chunk from different retrievers should produce the same key.""" + r1 = {"text": "hello", "metadata": {"source": "a", "chunk_index": 0}} + r2 = {"text": "hello", "metadata": {"source": "a", "chunk_index": 0}} + assert make_doc_key(r1) == make_doc_key(r2) + + def test_rrf_dedup_also_fixed(self): + """RRF fusion should also keep different-source same-text chunks separate.""" + chunks = [ + TextChunk("Same text", {"source": "x.txt", "chunk_index": 0}), + TextChunk("Same text", {"source": "y.txt", "chunk_index": 0}), + ] + bm25 = BM25Retriever() + bm25.build(chunks) + substr = SubstringRetriever() + substr.build(chunks) + engine = HybridQueryEngine( + [bm25, substr], ReciprocalRankFusion() + ) + results = engine.retrieve("Same text", top_k=5) + assert len(results) == 2 + + +# โ”€โ”€ Silent exception logging โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestRetrieverFailureLogging: + def test_failing_retriever_logged_not_crash(self, caplog): + """A failing retriever should log a warning and return empty, not crash.""" + class FailRetriever(Retriever): + def build(self, chunks): pass + def retrieve(self, query, top_k=5): + raise RuntimeError("boom") + def save(self, path): pass + def load(self, path): pass + + engine = HybridQueryEngine([FailRetriever()], WeightedFusion()) + import logging + with caplog.at_level(logging.WARNING, logger="tinysearch.query.hybrid"): + results = engine.retrieve("test", top_k=5) + assert results == [] + assert "FailRetriever failed" in caplog.text + + +# โ”€โ”€ ABC signature โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestABCSignature: + def test_query_engine_retrieve_accepts_kwargs(self): + import inspect + sig = inspect.signature(QueryEngine.retrieve) + params = list(sig.parameters.keys()) + assert "kwargs" in params + + def test_hybrid_engine_is_query_engine(self, hybrid_engine): + assert isinstance(hybrid_engine, QueryEngine) + + +# โ”€โ”€ FlowController kwargs forwarding โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestFlowControllerKwargs: + def test_query_forwards_kwargs(self): + """FlowController.query() should forward **kwargs to query_engine.retrieve().""" + mock_engine = MagicMock(spec=QueryEngine) + mock_engine.retrieve.return_value = [] + + from tinysearch.flow.controller import FlowController + # Minimal construction โ€” we only test query(), not build + fc = FlowController.__new__(FlowController) + fc.query_engine = mock_engine + fc.config = {"query_engine": {"top_k": 5}} + + fc.query("test", top_k=5, filters={"source": "a"}, weights=[1.0]) + mock_engine.retrieve.assert_called_once_with( + "test", 5, filters={"source": "a"}, weights=[1.0] + ) + + +# โ”€โ”€ CLI helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestCLIHelpers: + def test_get_retriever_index_dir(self): + from tinysearch.cli import _get_retriever_index_dir + assert _get_retriever_index_dir(Path("index.faiss")) == Path("index") + assert _get_retriever_index_dir(Path("data/my.faiss")) == Path("data/my") + assert _get_retriever_index_dir(Path("mydir")) == Path("mydir") + + def test_build_hybrid_noop_for_template_engine(self): + """_build_hybrid_retriever_indexes should be a no-op for non-hybrid engines.""" + from tinysearch.cli import _build_hybrid_retriever_indexes + mock_engine = MagicMock() # not a HybridQueryEngine + # Should not raise + _build_hybrid_retriever_indexes(mock_engine, []) + + def test_save_load_hybrid_noop_for_template_engine(self): + from tinysearch.cli import _save_hybrid_retriever_indexes, _load_hybrid_retriever_indexes + mock_engine = MagicMock() + _save_hybrid_retriever_indexes(mock_engine, Path("test.faiss")) + _load_hybrid_retriever_indexes(mock_engine, Path("test.faiss")) + + def test_build_index_directory_per_file_source(self, tmp_path): + """CLI build_index on a directory should assign per-file source metadata.""" + import argparse + from tinysearch.cli import build_index + from tinysearch.config import Config + + # Create two files with the same content + (tmp_path / "file1.txt").write_text("Same content here") + (tmp_path / "file2.txt").write_text("Same content here") + + config = Config() + config.set("query_engine.method", "hybrid") + config.set("retrievers", [{"type": "bm25"}, {"type": "substring"}]) + config.set("fusion.strategy", "weighted") + config.set("fusion.weights", [0.5, 0.5]) + + # We just need to verify the metadata, not actually embed/index. + # Patch embedder and indexer to avoid needing a real model. + captured_chunks = [] + original_split = None + + from tinysearch.splitters import CharacterTextSplitter + orig_split = CharacterTextSplitter.split + + def spy_split(self, texts, metadata=None): + result = orig_split(self, texts, metadata) + captured_chunks.extend(result) + return result + + args = argparse.Namespace(data=str(tmp_path)) + + with patch.object(CharacterTextSplitter, "split", spy_split), \ + patch("tinysearch.cli.load_embedder") as mock_emb, \ + patch("tinysearch.cli.load_indexer") as mock_idx: + mock_emb_inst = MagicMock() + mock_emb_inst.embed.return_value = [[0.0] * 10] * 100 + mock_emb.return_value = mock_emb_inst + mock_idx_inst = MagicMock() + mock_idx.return_value = mock_idx_inst + + build_index(args, config) + + # Each chunk should have a per-file source, not the directory + sources = {c.metadata["source"] for c in captured_chunks} + assert len(sources) == 2, f"Expected 2 distinct sources, got {sources}" + assert all(str(tmp_path) != s for s in sources), \ + f"source should be per-file, not the directory: {sources}" + + +# โ”€โ”€ Retriever index save path โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestRetrieverIndexPath: + def test_flow_controller_uses_faiss_dir(self): + """FlowController should save retriever indexes inside the FAISS index directory.""" + from tinysearch.flow.controller import FlowController + fc = FlowController.__new__(FlowController) + fc.query_engine = MagicMock(spec=HybridQueryEngine) + fc.query_engine.metadata_index = None + fc.query_engine.soft_deleted_ids = set() + + mock_retriever = MagicMock() + mock_retriever.__class__.__name__ = "BM25Retriever" + fc.query_engine.retrievers = [mock_retriever] + + fc._save_retriever_indexes(Path("data/index.faiss")) + mock_retriever.save.assert_called_once_with(Path("data/index") / "bm25_index") + + def test_flow_controller_load_uses_faiss_dir(self, tmp_path): + from tinysearch.flow.controller import FlowController + fc = FlowController.__new__(FlowController) + fc.query_engine = MagicMock(spec=HybridQueryEngine) + fc.query_engine.metadata_index = None + fc.query_engine.soft_deleted_ids = set() + + mock_retriever = MagicMock() + mock_retriever.__class__.__name__ = "BM25Retriever" + fc.query_engine.retrievers = [mock_retriever] + + # Create a real directory structure for the test + index_dir = tmp_path / "index" + index_dir.mkdir() + bm25_dir = index_dir / "bm25_index" + bm25_dir.mkdir() + + fc._load_retriever_indexes(tmp_path / "index.faiss") + mock_retriever.load.assert_called_once_with(index_dir / "bm25_index") + + +# โ”€โ”€ iter_input_files โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestIterInputFiles: + def test_single_file(self, tmp_path): + f = tmp_path / "doc.txt" + f.write_text("hello") + assert list(iter_input_files(f)) == [f] + + def test_directory_filters_by_adapter_type(self, tmp_path): + (tmp_path / "a.txt").write_text("a") + (tmp_path / "b.pdf").write_text("b") + (tmp_path / "c.py").write_text("c") + files = list(iter_input_files(tmp_path, adapter_type="text")) + suffixes = {f.suffix for f in files} + assert ".txt" in suffixes + assert ".py" in suffixes + assert ".pdf" not in suffixes + + def test_custom_extensions_override(self, tmp_path): + (tmp_path / "a.xyz").write_text("a") + (tmp_path / "b.txt").write_text("b") + files = list(iter_input_files(tmp_path, extensions=[".xyz"])) + assert len(files) == 1 + assert files[0].suffix == ".xyz" + + def test_nonexistent_path_raises(self): + with pytest.raises(FileNotFoundError): + list(iter_input_files(Path("/nonexistent"))) + + def test_sorted_deterministic(self, tmp_path): + for name in ["c.txt", "a.txt", "b.txt"]: + (tmp_path / name).write_text(name) + files = list(iter_input_files(tmp_path, adapter_type="text")) + assert files == sorted(files) + + def test_recursive_finds_nested(self, tmp_path): + sub = tmp_path / "sub" + sub.mkdir() + (tmp_path / "top.txt").write_text("top") + (sub / "nested.txt").write_text("nested") + files = list(iter_input_files(tmp_path, adapter_type="text", recursive=True)) + assert len(files) == 2 + + def test_non_recursive_skips_nested(self, tmp_path): + sub = tmp_path / "sub" + sub.mkdir() + (tmp_path / "top.txt").write_text("top") + (sub / "nested.txt").write_text("nested") + files = list(iter_input_files(tmp_path, adapter_type="text", recursive=False)) + assert len(files) == 1 + + +# โ”€โ”€ Adapter directory rejection โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestAdapterRejectsDirectory: + def test_text_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.text import TextAdapter + (tmp_path / "a.txt").write_text("hello") + with pytest.raises(ValueError, match="does not accept directories"): + TextAdapter().extract(tmp_path) + + def test_csv_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.csv import CSVAdapter + (tmp_path / "a.csv").write_text("col\nval") + with pytest.raises(ValueError, match="does not accept directories"): + CSVAdapter().extract(tmp_path) + + def test_markdown_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.markdown import MarkdownAdapter + (tmp_path / "a.md").write_text("# Hello") + with pytest.raises(ValueError, match="does not accept directories"): + MarkdownAdapter().extract(tmp_path) + + def test_json_adapter_rejects_directory(self, tmp_path): + from tinysearch.adapters.json_adapter import JSONAdapter + (tmp_path / "a.json").write_text('{"key": "val"}') + with pytest.raises(ValueError, match="does not accept directories"): + JSONAdapter().extract(tmp_path) + + +# โ”€โ”€ Source metadata consistency โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestSourceMetadataConsistency: + def test_cli_and_flowcontroller_same_sources(self, tmp_path): + """CLI build_index and FlowController should discover the same files.""" + (tmp_path / "file1.txt").write_text("Content A") + (tmp_path / "file2.txt").write_text("Content B") + (tmp_path / "ignore.pdf").write_text("PDF") + + cli_sources = {str(f) for f in iter_input_files(tmp_path, adapter_type="text")} + fc_sources = {str(f) for f in iter_input_files(tmp_path, adapter_type="text")} + + assert cli_sources == fc_sources + assert len(cli_sources) == 2 + assert all("ignore.pdf" not in s for s in cli_sources) + + +# โ”€โ”€ Pre-filter (MetadataIndex + candidate_ids) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +@pytest.fixture +def graded_chunks(): + """Chunks with grade metadata for pre-filter testing.""" + return [ + TextChunk("Python็ผ–็จ‹่ฏญ่จ€", {"source": "a.txt", "grade": "ๅ…ญๅนด็บงไธŠ", "chunk_index": 0}), + TextChunk("Java็ผ–็จ‹่ฏญ่จ€", {"source": "b.txt", "grade": "ๅ…ญๅนด็บงไธ‹", "chunk_index": 0}), + TextChunk("TinySearchๆฃ€็ดข็ณป็ปŸ", {"source": "c.txt", "grade": "ไธƒๅนด็บง", "chunk_index": 0}), + TextChunk("ๅ‘้‡ๆ•ฐๆฎๅบ“", {"source": "d.txt", "grade": "ๅ…ญๅนด็บงไธŠ", "chunk_index": 0}), + ] + + +@pytest.fixture +def prefilter_engine(graded_chunks): + """HybridQueryEngine with MetadataIndex for pre-filter testing.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + substr = SubstringRetriever() + substr.build(graded_chunks) + + metadata_index = MetadataIndex() + metadata_index.build(graded_chunks) + + return HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=metadata_index, + filter_mode="auto", + ) + + +class TestPreFilter: + def test_bm25_candidate_ids(self, graded_chunks): + """BM25 with candidate_ids only returns results within the set.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + # Only allow chunks 0 and 3 (grade=ๅ…ญๅนด็บงไธŠ) + results = bm25.retrieve("็ผ–็จ‹", top_k=5, candidate_ids={0, 3}) + for r in results: + assert r["metadata"]["grade"] == "ๅ…ญๅนด็บงไธŠ" + + def test_substring_candidate_ids(self, graded_chunks): + """SubstringRetriever with candidate_ids restricts search.""" + substr = SubstringRetriever() + substr.build(graded_chunks) + results = substr.retrieve("็ผ–็จ‹", top_k=5, candidate_ids={0}) + assert len(results) == 1 + assert results[0]["metadata"]["source"] == "a.txt" + + def test_candidate_ids_none_is_noop(self, graded_chunks): + """Passing candidate_ids=None returns same as no kwargs.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + r1 = bm25.retrieve("็ผ–็จ‹", top_k=5) + r2 = bm25.retrieve("็ผ–็จ‹", top_k=5, candidate_ids=None) + assert len(r1) == len(r2) + + def test_candidate_ids_empty_returns_empty(self, graded_chunks): + """Passing candidate_ids=set() returns [].""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + results = bm25.retrieve("็ผ–็จ‹", top_k=5, candidate_ids=set()) + assert results == [] + + def test_filter_mode_pre(self, graded_chunks): + """Pre-filter mode resolves filters via MetadataIndex.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + substr = SubstringRetriever() + substr.build(graded_chunks) + + metadata_index = MetadataIndex() + metadata_index.build(graded_chunks) + + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=metadata_index, + filter_mode="pre", + ) + results = engine.retrieve("็ผ–็จ‹", top_k=5, filters={"grade": "ๅ…ญๅนด็บงไธŠ"}) + assert all(r["metadata"]["grade"] == "ๅ…ญๅนด็บงไธŠ" for r in results) + + def test_filter_mode_auto_mixed(self, prefilter_engine): + """Auto mode: indexable filters pre-filter, callable filters post-filter.""" + results = prefilter_engine.retrieve( + "็ผ–็จ‹", top_k=5, + filters={ + "grade": "ๅ…ญๅนด็บงไธŠ", + "source": lambda v: v == "a.txt", # callable โ†’ post-filter + }, + ) + assert all(r["metadata"]["source"] == "a.txt" for r in results) + + def test_no_metadata_index_backward_compat(self, graded_chunks): + """When metadata_index=None, filters are purely post-applied.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + substr = SubstringRetriever() + substr.build(graded_chunks) + + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=None, + ) + results = engine.retrieve("็ผ–็จ‹", top_k=5, filters={"grade": "ๅ…ญๅนด็บงไธŠ"}) + assert all(r["metadata"]["grade"] == "ๅ…ญๅนด็บงไธŠ" for r in results) + + def test_empty_candidate_short_circuits(self, graded_chunks): + """When lookup returns empty set, pipeline returns [] immediately.""" + bm25 = BM25Retriever() + bm25.build(graded_chunks) + + metadata_index = MetadataIndex() + metadata_index.build(graded_chunks) + + engine = HybridQueryEngine( + [bm25], + WeightedFusion(), + metadata_index=metadata_index, + filter_mode="pre", + ) + results = engine.retrieve("็ผ–็จ‹", top_k=5, filters={"grade": "ไธๅญ˜ๅœจ"}) + assert results == [] diff --git a/tests/test_incremental.py b/tests/test_incremental.py new file mode 100644 index 0000000..16c6939 --- /dev/null +++ b/tests/test_incremental.py @@ -0,0 +1,323 @@ +""" +Tests for incremental indexing: ContentHashTracker, MetadataIndex.add_chunks, +soft delete, and FlowController.build_incremental. +""" +import pytest +from unittest.mock import MagicMock, patch +from typing import Any, Dict + +from tinysearch.base import RecordAdapter, TextChunk +from tinysearch.indexers.hash_tracker import ContentHashTracker, ChangeSet +from tinysearch.indexers.metadata_index import MetadataIndex +from tinysearch.query.hybrid import HybridQueryEngine +from tinysearch.retrievers.bm25_retriever import BM25Retriever +from tinysearch.retrievers.substring_retriever import SubstringRetriever +from tinysearch.fusion.weighted import WeightedFusion + + +class SimpleAdapter(RecordAdapter): + def to_chunk(self, record_id: str, record: Dict[str, Any]) -> TextChunk: + return TextChunk( + text=record.get("text", ""), + metadata={"record_id": record_id, "grade": record.get("grade", "")}, + ) + + +# โ”€โ”€ ContentHashTracker โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestContentHashTracker: + def test_compute_hash_deterministic(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"k": "v"}) + h2 = tracker.compute_hash("hello", {"k": "v"}) + assert h1 == h2 + + def test_compute_hash_changes_on_text(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"k": "v"}) + h2 = tracker.compute_hash("world", {"k": "v"}) + assert h1 != h2 + + def test_compute_hash_changes_on_metadata(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"grade": "ๅ…ญๅนด็บง"}) + h2 = tracker.compute_hash("hello", {"grade": "ไธƒๅนด็บง"}) + assert h1 != h2 + + def test_compute_hash_ignores_internal_keys(self): + tracker = ContentHashTracker() + h1 = tracker.compute_hash("hello", {"grade": "ๅ…ญๅนด็บง", "chunk_index": 0}) + h2 = tracker.compute_hash("hello", {"grade": "ๅ…ญๅนด็บง", "chunk_index": 5}) + assert h1 == h2 + + def test_compute_hash_custom_metadata_keys(self): + tracker = ContentHashTracker(hash_metadata_keys=["grade"]) + h1 = tracker.compute_hash("hello", {"grade": "ๅ…ญๅนด็บง", "type": "A"}) + h2 = tracker.compute_hash("hello", {"grade": "ๅ…ญๅนด็บง", "type": "B"}) + assert h1 == h2 # "type" is not in hash_metadata_keys + + def test_detect_all_new(self): + tracker = ContentHashTracker() + current = { + "q1": TextChunk("hello", {"record_id": "q1"}), + "q2": TextChunk("world", {"record_id": "q2"}), + } + changes = tracker.detect_changes(current) + assert len(changes.new) == 2 + assert len(changes.modified) == 0 + assert len(changes.deleted) == 0 + + def test_detect_no_changes(self): + tracker = ContentHashTracker() + current = {"q1": TextChunk("hello", {"record_id": "q1"})} + tracker.update(current) + changes = tracker.detect_changes(current) + assert not changes.has_changes + assert len(changes.unchanged) == 1 + + def test_detect_modified(self): + tracker = ContentHashTracker() + original = {"q1": TextChunk("hello", {"record_id": "q1"})} + tracker.update(original) + modified = {"q1": TextChunk("changed", {"record_id": "q1"})} + changes = tracker.detect_changes(modified) + assert len(changes.modified) == 1 + assert changes.modified[0] == "q1" + + def test_detect_deleted(self): + tracker = ContentHashTracker() + original = {"q1": TextChunk("a", {}), "q2": TextChunk("b", {})} + tracker.update(original) + current = {"q1": TextChunk("a", {})} + changes = tracker.detect_changes(current) + assert changes.deleted == {"q2"} + + def test_detect_mixed(self): + tracker = ContentHashTracker() + original = { + "q1": TextChunk("unchanged", {}), + "q2": TextChunk("will_modify", {}), + "q3": TextChunk("will_delete", {}), + } + tracker.update(original) + current = { + "q1": TextChunk("unchanged", {}), + "q2": TextChunk("modified_text", {}), + "q4": TextChunk("brand_new", {}), + } + changes = tracker.detect_changes(current) + assert set(changes.new) == {"q4"} + assert set(changes.modified) == {"q2"} + assert changes.deleted == {"q3"} + assert changes.unchanged == {"q1"} + + def test_update_and_remove(self): + tracker = ContentHashTracker() + records = {"q1": TextChunk("hello", {})} + tracker.update(records) + assert tracker.tracked_count == 1 + tracker.remove({"q1"}) + assert tracker.tracked_count == 0 + + def test_save_load_roundtrip(self, tmp_path): + tracker = ContentHashTracker(hash_metadata_keys=["grade"]) + records = { + "q1": TextChunk("hello", {"grade": "ๅ…ญๅนด็บง"}), + "q2": TextChunk("world", {"grade": "ไธƒๅนด็บง"}), + } + tracker.update(records) + + path = tmp_path / "hashes.json" + tracker.save(path) + + loaded = ContentHashTracker() + loaded.load(path) + assert loaded.tracked_count == 2 + # Should detect no changes + changes = loaded.detect_changes(records) + assert not changes.has_changes + + +class TestChangeSet: + def test_has_changes_true(self): + cs = ChangeSet(new=["a"], modified=[], deleted=set(), unchanged=set()) + assert cs.has_changes + + def test_has_changes_false(self): + cs = ChangeSet(new=[], modified=[], deleted=set(), unchanged={"a"}) + assert not cs.has_changes + + def test_repr(self): + cs = ChangeSet(new=["a"], modified=["b"], deleted={"c"}, unchanged={"d"}) + assert "new=1" in repr(cs) + assert "modified=1" in repr(cs) + + +# โ”€โ”€ MetadataIndex.add_chunks โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestMetadataIndexAddChunks: + def test_add_increases_total(self): + idx = MetadataIndex() + idx.build([TextChunk("a", {"grade": "ๅ…ญๅนด็บง"})]) + assert idx.total_chunks == 1 + idx.add_chunks([TextChunk("b", {"grade": "ไธƒๅนด็บง"})], start_id=1) + assert idx.total_chunks == 2 + + def test_add_chunks_findable(self): + idx = MetadataIndex() + idx.build([TextChunk("a", {"grade": "ๅ…ญๅนด็บง"})]) + idx.add_chunks([TextChunk("b", {"grade": "ไธƒๅนด็บง"})], start_id=1) + assert idx.lookup({"grade": "ไธƒๅนด็บง"}) == {1} + assert idx.lookup({"grade": "ๅ…ญๅนด็บง"}) == {0} + + def test_add_preserves_existing(self): + idx = MetadataIndex() + idx.build([TextChunk("a", {"grade": "ๅ…ญๅนด็บง"})]) + original = idx.lookup({"grade": "ๅ…ญๅนด็บง"}) + idx.add_chunks([TextChunk("b", {"grade": "ไธƒๅนด็บง"})], start_id=1) + assert idx.lookup({"grade": "ๅ…ญๅนด็บง"}) == original + + def test_add_chunks_list_metadata(self): + idx = MetadataIndex() + idx.build([]) + idx.add_chunks([TextChunk("a", {"tags": ["ๅŠจ่ฏ", "ๆ—ถๆ€"]})], start_id=0) + assert idx.lookup({"tags": "ๅŠจ่ฏ"}) == {0} + assert idx.lookup({"tags": "ๆ—ถๆ€"}) == {0} + + +# โ”€โ”€ Soft Delete โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestSoftDelete: + def _make_engine(self, chunks, soft_deleted_ids=None): + bm25 = BM25Retriever() + bm25.build(chunks) + return HybridQueryEngine( + [bm25], + WeightedFusion(), + soft_deleted_ids=soft_deleted_ids, + ) + + def test_soft_delete_filters_results(self): + chunks = [ + TextChunk("Python็ผ–็จ‹", {"record_id": "q1"}), + TextChunk("Java็ผ–็จ‹", {"record_id": "q2"}), + ] + engine = self._make_engine(chunks, soft_deleted_ids={"q1"}) + results = engine.retrieve("็ผ–็จ‹", top_k=5) + record_ids = {r["metadata"]["record_id"] for r in results} + assert "q1" not in record_ids + assert "q2" in record_ids + + def test_add_and_clear_soft_deletes(self): + chunks = [TextChunk("test", {"record_id": "q1"})] + engine = self._make_engine(chunks) + assert engine.soft_delete_count == 0 + engine.add_soft_deletes({"q1"}) + assert engine.soft_delete_count == 1 + engine.clear_soft_deletes() + assert engine.soft_delete_count == 0 + + def test_backward_compat_no_soft_deletes(self): + """Default None doesn't break existing behavior.""" + chunks = [TextChunk("Python็ผ–็จ‹", {"record_id": "q1"})] + engine = self._make_engine(chunks) + results = engine.retrieve("็ผ–็จ‹", top_k=5) + assert len(results) > 0 + + +# โ”€โ”€ build_incremental โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestBuildIncremental: + def _make_fc(self): + """Create a FlowController with mocked embedder/indexer for testing.""" + from tinysearch.flow.controller import FlowController + + mock_embedder = MagicMock() + mock_embedder.embed.return_value = [[0.1] * 10] + + mock_indexer = MagicMock() + + bm25 = BM25Retriever() + substr = SubstringRetriever() + metadata_index = MetadataIndex() + + engine = HybridQueryEngine( + [bm25, substr], + WeightedFusion([0.6, 0.4]), + metadata_index=metadata_index, + ) + + fc = FlowController( + data_adapter=None, + text_splitter=MagicMock(), + embedder=mock_embedder, + indexer=mock_indexer, + query_engine=engine, + config={}, + ) + return fc + + def test_no_changes_returns_early(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + records = {"q1": {"text": "hello", "grade": "ๅ…ญๅนด็บง"}} + # First build + tracker.update({rid: adapter.to_chunk(rid, r) for rid, r in records.items()}) + + # Second call โ€” no changes + stats = fc.build_incremental(records, adapter, tracker) + assert not stats["full_rebuild"] + assert stats["new"] == 0 + assert stats["modified"] == 0 + + def test_new_records_detected(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + records = {"q1": {"text": "hello", "grade": "ๅ…ญๅนด็บง"}} + stats = fc.build_incremental(records, adapter, tracker) + assert stats["new"] == 1 + assert stats["deleted"] == 0 + + def test_threshold_triggers_full_rebuild(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + # Build initial 150 records + initial = {f"q{i}": {"text": f"text{i}"} for i in range(150)} + tracker.update({rid: adapter.to_chunk(rid, r) for rid, r in initial.items()}) + + # Delete all of them (150 > threshold 100) + stats = fc.build_incremental({}, adapter, tracker, delete_rebuild_threshold=100) + assert stats["full_rebuild"] is True + assert stats["deleted"] == 150 + + def test_stats_correct(self): + fc = self._make_fc() + adapter = SimpleAdapter() + tracker = ContentHashTracker() + + # Initial build + initial = { + "q1": {"text": "unchanged"}, + "q2": {"text": "will_modify"}, + "q3": {"text": "will_delete"}, + } + tracker.update({rid: adapter.to_chunk(rid, r) for rid, r in initial.items()}) + + # Incremental update + current = { + "q1": {"text": "unchanged"}, + "q2": {"text": "modified_text"}, + "q4": {"text": "brand_new"}, + } + stats = fc.build_incremental(current, adapter, tracker) + assert stats["new"] == 1 + assert stats["modified"] == 1 + assert stats["deleted"] == 1 + assert stats["unchanged"] == 1 + assert not stats["full_rebuild"] diff --git a/tests/test_metadata_index.py b/tests/test_metadata_index.py new file mode 100644 index 0000000..ac12cdb --- /dev/null +++ b/tests/test_metadata_index.py @@ -0,0 +1,153 @@ +""" +Tests for MetadataIndex: inverted index over TextChunk metadata. +""" +import pytest +from pathlib import Path + +from tinysearch.base import TextChunk +from tinysearch.indexers.metadata_index import MetadataIndex + + +@pytest.fixture +def sample_chunks(): + return [ + TextChunk("Q1", {"grade": "ๅ…ญๅนด็บงไธŠ", "type": "้€‰ๆ‹ฉ้ข˜", "difficulty": 1}), + TextChunk("Q2", {"grade": "ๅ…ญๅนด็บงไธ‹", "type": "ๅกซ็ฉบ้ข˜", "difficulty": 2}), + TextChunk("Q3", {"grade": "ๅ…ญๅนด็บงไธŠ", "type": "ๅกซ็ฉบ้ข˜", "difficulty": 3}), + TextChunk("Q4", {"grade": "ไธƒๅนด็บง", "tags": ["ๅŠจ่ฏ", "ๆ—ถๆ€"], "difficulty": 1}), + TextChunk("Q5", {"grade": "ไธƒๅนด็บง", "tags": ["ๅ่ฏ", "ๅคๆ•ฐ"]}), + ] + + +@pytest.fixture +def built_index(sample_chunks): + idx = MetadataIndex() + idx.build(sample_chunks) + return idx + + +# โ”€โ”€ Build & Lookup โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestBuildAndLookup: + def test_build_basic(self, built_index): + assert built_index.total_chunks == 5 + assert "grade" in built_index.fields + assert "type" in built_index.fields + assert "tags" in built_index.fields + + def test_lookup_exact_str(self, built_index): + result = built_index.lookup({"grade": "ๅ…ญๅนด็บงไธŠ"}) + assert result == {0, 2} + + def test_lookup_exact_int(self, built_index): + result = built_index.lookup({"difficulty": 1}) + assert result == {0, 3} + + def test_lookup_list_filter(self, built_index): + """List filter = OR over values.""" + result = built_index.lookup({"grade": ["ๅ…ญๅนด็บงไธŠ", "ๅ…ญๅนด็บงไธ‹"]}) + assert result == {0, 1, 2} + + def test_lookup_multiple_filters_intersect(self, built_index): + """Multiple filters = AND (set intersection).""" + result = built_index.lookup({"grade": "ๅ…ญๅนด็บงไธŠ", "type": "ๅกซ็ฉบ้ข˜"}) + assert result == {2} + + def test_lookup_callable_returns_none(self, built_index): + result = built_index.lookup({"grade": lambda v: True}) + assert result is None + + def test_lookup_nonexistent_field(self, built_index): + result = built_index.lookup({"nonexistent": "value"}) + assert result == set() + + def test_lookup_nonexistent_value(self, built_index): + result = built_index.lookup({"grade": "ไธๅญ˜ๅœจ็š„ๅนด็บง"}) + assert result == set() + + def test_list_metadata_indexed_per_element(self, built_index): + """List[str] metadata: each element indexed separately.""" + assert built_index.lookup({"tags": "ๅŠจ่ฏ"}) == {3} + assert built_index.lookup({"tags": "ๅ่ฏ"}) == {4} + assert built_index.lookup({"tags": ["ๅŠจ่ฏ", "ๅ่ฏ"]}) == {3, 4} + + def test_empty_filters_returns_none(self, built_index): + assert built_index.lookup({}) is None + + def test_empty_intersection_short_circuits(self, built_index): + """When first filter yields empty, result is empty.""" + result = built_index.lookup({"grade": "ไธๅญ˜ๅœจ", "type": "้€‰ๆ‹ฉ้ข˜"}) + assert result == set() + + +# โ”€โ”€ Classify Filters โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestClassifyFilters: + def test_all_indexable(self, built_index): + indexable, callables = built_index.classify_filters({ + "grade": "ๅ…ญๅนด็บงไธŠ", + "type": ["้€‰ๆ‹ฉ้ข˜", "ๅกซ็ฉบ้ข˜"], + }) + assert len(indexable) == 2 + assert len(callables) == 0 + + def test_all_callable(self, built_index): + indexable, callables = built_index.classify_filters({ + "grade": lambda v: "ๅ…ญ" in v, + }) + assert len(indexable) == 0 + assert len(callables) == 1 + + def test_mixed(self, built_index): + indexable, callables = built_index.classify_filters({ + "grade": "ๅ…ญๅนด็บงไธŠ", + "difficulty": lambda v: v > 2, + }) + assert "grade" in indexable + assert "difficulty" in callables + + +# โ”€โ”€ Save / Load โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestSaveLoad: + def test_roundtrip(self, built_index, tmp_path): + save_path = tmp_path / "metadata_index.json" + built_index.save(save_path) + + loaded = MetadataIndex() + loaded.load(save_path) + + assert loaded.total_chunks == built_index.total_chunks + assert loaded.lookup({"grade": "ๅ…ญๅนด็บงไธŠ"}) == {0, 2} + assert loaded.lookup({"difficulty": 1}) == {0, 3} + assert loaded.lookup({"tags": "ๅŠจ่ฏ"}) == {3} + + def test_save_creates_parent_dirs(self, built_index, tmp_path): + save_path = tmp_path / "deep" / "nested" / "idx.json" + built_index.save(save_path) + assert save_path.exists() + + +# โ”€โ”€ Edge Cases โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestEdgeCases: + def test_empty_chunks(self): + idx = MetadataIndex() + idx.build([]) + assert idx.total_chunks == 0 + assert idx.lookup({"key": "val"}) == set() + + def test_chunks_with_no_metadata(self): + idx = MetadataIndex() + idx.build([TextChunk("text", None), TextChunk("text2", {})]) + assert idx.total_chunks == 2 + assert idx.lookup({"any": "thing"}) == set() + + def test_bool_metadata(self): + idx = MetadataIndex() + idx.build([ + TextChunk("a", {"active": True}), + TextChunk("b", {"active": False}), + ]) + assert idx.lookup({"active": True}) == {0} + assert idx.lookup({"active": False}) == {1} diff --git a/tests/test_record_adapter.py b/tests/test_record_adapter.py new file mode 100644 index 0000000..1bf74ae --- /dev/null +++ b/tests/test_record_adapter.py @@ -0,0 +1,87 @@ +""" +Tests for RecordAdapter and build_chunks_from_records. +""" +import pytest +from typing import Any, Dict + +from tinysearch.base import RecordAdapter, TextChunk, TextSplitter +from tinysearch.records import build_chunks_from_records + + +class SimpleAdapter(RecordAdapter): + """Test adapter: extracts 'text' field, puts rest in metadata.""" + + def to_chunk(self, record_id: str, record: Dict[str, Any]) -> TextChunk: + return TextChunk( + text=record.get("text", ""), + metadata={"record_id": record_id, "grade": record.get("grade", "")}, + ) + + +class TestRecordAdapterABC: + def test_abc_cannot_instantiate(self): + with pytest.raises(TypeError): + RecordAdapter() + + def test_concrete_implementation(self): + adapter = SimpleAdapter() + chunk = adapter.to_chunk("q1", {"text": "hello", "grade": "ๅ…ญๅนด็บง"}) + assert isinstance(chunk, TextChunk) + assert chunk.text == "hello" + assert chunk.metadata["record_id"] == "q1" + assert chunk.metadata["grade"] == "ๅ…ญๅนด็บง" + + +class TestBuildChunksFromRecords: + def test_basic_no_splitter(self): + adapter = SimpleAdapter() + records = { + "q1": {"text": "Python็ผ–็จ‹", "grade": "ๅ…ญๅนด็บง"}, + "q2": {"text": "Java็ผ–็จ‹", "grade": "ไธƒๅนด็บง"}, + "q3": {"text": "TinySearch", "grade": "ๅ…ซๅนด็บง"}, + } + chunks = build_chunks_from_records(records, adapter) + assert len(chunks) == 3 + assert all(isinstance(c, TextChunk) for c in chunks) + assert [c.metadata["record_id"] for c in chunks] == ["q1", "q2", "q3"] + + def test_record_id_guaranteed_in_metadata(self): + """Even if adapter doesn't set record_id, it's auto-added.""" + + class NoIdAdapter(RecordAdapter): + def to_chunk(self, rid, record): + return TextChunk(text=record["text"], metadata={"grade": "test"}) + + chunks = build_chunks_from_records( + {"q1": {"text": "hello"}}, NoIdAdapter() + ) + assert chunks[0].metadata["record_id"] == "q1" + + def test_with_splitter(self): + """Long text gets split; metadata inherited.""" + from tinysearch.splitters import CharacterTextSplitter + + adapter = SimpleAdapter() + # Create a record with text long enough to be split + long_text = "A" * 500 + " " + "B" * 500 + records = {"q1": {"text": long_text, "grade": "ๅ…ญๅนด็บง"}} + + splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=0) + chunks = build_chunks_from_records(records, adapter, splitter=splitter) + + assert len(chunks) > 1 + # All sub-chunks inherit the original metadata + for c in chunks: + assert c.metadata["record_id"] == "q1" + assert c.metadata["grade"] == "ๅ…ญๅนด็บง" + assert "chunk_index" in c.metadata + + def test_empty_records(self): + chunks = build_chunks_from_records({}, SimpleAdapter()) + assert chunks == [] + + def test_order_preserved(self): + adapter = SimpleAdapter() + records = {"c": {"text": "C"}, "a": {"text": "A"}, "b": {"text": "B"}} + chunks = build_chunks_from_records(records, adapter) + assert [c.text for c in chunks] == ["C", "A", "B"] diff --git a/tinysearch/__init__.py b/tinysearch/__init__.py index d2a4c08..9a61b24 100644 --- a/tinysearch/__init__.py +++ b/tinysearch/__init__.py @@ -1,13 +1,16 @@ """ -TinySearch: A lightweight vector retrieval system +TinySearch: A lightweight hybrid retrieval system """ -__version__ = "0.1.0" +__version__ = "0.2.0" # Make submodules available for import from . import adapters from . import splitters -from . import embedders +from . import embedders from . import indexers from . import query -from . import flow \ No newline at end of file +from . import flow +from . import retrievers +from . import fusion +from . import rerankers \ No newline at end of file diff --git a/tinysearch/adapters/csv.py b/tinysearch/adapters/csv.py index 815bf1c..21757ce 100644 --- a/tinysearch/adapters/csv.py +++ b/tinysearch/adapters/csv.py @@ -55,20 +55,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all CSV files in it - csv_files = list(filepath.glob("**/*.csv")) - - result = [] - for file in csv_files: - try: - result.extend(self._extract_from_csv(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_csv(filepath) + raise ValueError( + f"CSVAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_csv(filepath) def _extract_from_csv(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/json_adapter.py b/tinysearch/adapters/json_adapter.py index 1f06293..12e504c 100644 --- a/tinysearch/adapters/json_adapter.py +++ b/tinysearch/adapters/json_adapter.py @@ -54,20 +54,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all JSON files in it - json_files = list(filepath.glob("**/*.json")) - - result = [] - for file in json_files: - try: - result.extend(self._extract_from_json(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_json(filepath) + raise ValueError( + f"JSONAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_json(filepath) def _extract_from_json(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/markdown.py b/tinysearch/adapters/markdown.py index fd42856..1e97a2b 100644 --- a/tinysearch/adapters/markdown.py +++ b/tinysearch/adapters/markdown.py @@ -55,22 +55,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all Markdown files in it - md_files = [] - for ext in [".md", ".markdown", ".mdown", ".mkdn"]: - md_files.extend(filepath.glob(f"**/*{ext}")) - - result = [] - for file in md_files: - try: - result.extend(self._extract_from_markdown(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_markdown(filepath) + raise ValueError( + f"MarkdownAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_markdown(filepath) def _extract_from_markdown(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/pdf.py b/tinysearch/adapters/pdf.py index be90bc4..1d35403 100644 --- a/tinysearch/adapters/pdf.py +++ b/tinysearch/adapters/pdf.py @@ -56,20 +56,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all PDF files in it - pdf_files = list(filepath.glob("**/*.pdf")) - - result = [] - for file in pdf_files: - try: - result.extend(self._extract_from_pdf(file)) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - return self._extract_from_pdf(filepath) + raise ValueError( + f"PDFAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + return self._extract_from_pdf(filepath) def _extract_from_pdf(self, filepath: Path) -> List[str]: """ diff --git a/tinysearch/adapters/text.py b/tinysearch/adapters/text.py index bb085c6..d484a39 100644 --- a/tinysearch/adapters/text.py +++ b/tinysearch/adapters/text.py @@ -39,23 +39,13 @@ def extract(self, filepath: Union[str, Path]) -> List[str]: raise FileNotFoundError(f"File not found: {filepath}") if filepath.is_dir(): - # If a directory is provided, process all text files in it - text_files = [] - for ext in [".txt", ".text", ".md", ".py", ".js", ".html", ".css", ".json"]: - text_files.extend(filepath.glob(f"**/*{ext}")) - - result = [] - for file in text_files: - try: - with open(file, "r", encoding=self.encoding, errors=self.errors) as f: - result.append(f.read()) - except Exception as e: - print(f"Error reading {file}: {e}") - - return result - else: - # Process a single file - with open(filepath, "r", encoding=self.encoding, errors=self.errors) as f: - content = f.read() - - return [content] \ No newline at end of file + raise ValueError( + f"TextAdapter.extract() does not accept directories. " + "Use iter_input_files() to iterate files, then call extract() on each." + ) + + # Process a single file + with open(filepath, "r", encoding=self.encoding, errors=self.errors) as f: + content = f.read() + + return [content] \ No newline at end of file diff --git a/tinysearch/api.py b/tinysearch/api.py index f1cd75b..276583a 100644 --- a/tinysearch/api.py +++ b/tinysearch/api.py @@ -19,7 +19,10 @@ import tempfile from .config import Config -from .cli import load_embedder, load_indexer, load_query_engine, load_adapter, load_splitter +from .cli import ( + load_embedder, load_indexer, load_query_engine, load_adapter, load_splitter, + _load_hybrid_retriever_indexes, +) from .flow.controller import FlowController from .logger import get_logger, configure_logger, log_success, log_warning, log_error @@ -177,6 +180,8 @@ def initialize_components(): logger = get_logger("api_init") if Path(index_path).exists(): indexer.load(index_path) + # Load non-vector retriever indexes (BM25, Substring) for hybrid mode + _load_hybrid_retriever_indexes(query_engine, Path(index_path), logger) log_success(f"TinySearch initialized with index: {index_path}") else: log_warning(f"Index not found: {index_path}") @@ -409,8 +414,11 @@ async def query( raise HTTPException(status_code=500, detail="TinySearch not initialized") try: - # Execute query - results = query_engine.retrieve(query_request.query, top_k=query_request.top_k) + # Execute query โ€” forward params (filters, weights, etc.) to the engine + extra_params = query_request.params or {} + results = query_engine.retrieve( + query_request.query, top_k=query_request.top_k, **extra_params + ) # Format results matches = [] diff --git a/tinysearch/base.py b/tinysearch/base.py index cd91871..676103c 100644 --- a/tinysearch/base.py +++ b/tinysearch/base.py @@ -26,6 +26,32 @@ def extract(self, filepath: Union[str, pathlib.Path]) -> List[str]: pass +class RecordAdapter(ABC): + """ + Interface for adapters that convert structured records (dicts) to TextChunks. + Unlike DataAdapter (file-oriented), RecordAdapter works with in-memory data + from APIs, databases, or other programmatic sources. + """ + + @abstractmethod + def to_chunk(self, record_id: str, record: Dict[str, Any]) -> "TextChunk": + """ + Convert one record to a TextChunk with text and metadata. + + The implementation should: + - Extract/compose the searchable text from the record fields + - Build metadata dict including at minimum {"record_id": record_id} + + Args: + record_id: Unique identifier for the record + record: Record data as a dictionary + + Returns: + TextChunk with text and metadata derived from the record + """ + pass + + class TextChunk: """ Represents a chunk of text with optional metadata @@ -152,20 +178,105 @@ def format_query(self, query: str) -> str: pass @abstractmethod - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ Retrieve relevant chunks for a query - + Args: query: Query string top_k: Number of results to return - + **kwargs: Engine-specific parameters (e.g. filters, weights) + Returns: List of dictionaries containing text chunks and similarity scores """ pass +class Retriever(ABC): + """ + Interface for text-level retrievers that search by query string. + Unlike VectorIndexer (which takes pre-embedded vectors), Retriever + operates on raw text queries, making it suitable for BM25, substring, + and wrapped vector search. + """ + + @abstractmethod + def build(self, chunks: List[TextChunk]) -> None: + """ + Build the retriever index from text chunks + + Args: + chunks: List of TextChunk objects to index + """ + pass + + @abstractmethod + def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Retrieve relevant chunks for a text query + + Args: + query: Raw query string + top_k: Number of results to return + + Returns: + List of dicts with keys: text, metadata, score, retrieval_method + """ + pass + + @abstractmethod + def save(self, path: Union[str, pathlib.Path]) -> None: + """Save the retriever index to disk""" + pass + + @abstractmethod + def load(self, path: Union[str, pathlib.Path]) -> None: + """Load the retriever index from disk""" + pass + + +class FusionStrategy(ABC): + """ + Interface for fusing results from multiple retrievers into a single ranked list. + """ + + @abstractmethod + def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[str, Any]]: + """ + Fuse multiple result lists into a single ranked list + + Args: + results_list: List of result lists from different retrievers + **kwargs: Strategy-specific parameters + + Returns: + Fused and ranked list of results + """ + pass + + +class Reranker(ABC): + """ + Interface for rerankers that re-score candidates using a cross-encoder or similar model. + """ + + @abstractmethod + def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: + """ + Re-rank candidates for a query + + Args: + query: Query string + candidates: List of candidate results to re-rank + top_k: Number of results to return + + Returns: + Re-ranked list of results + """ + pass + + class FlowController(ABC): """ Interface for flow controllers that orchestrate the data pipeline diff --git a/tinysearch/cli.py b/tinysearch/cli.py index 2c126bb..d0fc959 100644 --- a/tinysearch/cli.py +++ b/tinysearch/cli.py @@ -10,13 +10,24 @@ from typing import Dict, Any, List, Optional, Union, Type, cast from .config import Config -from .base import DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine +from .base import ( + DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine, + Retriever, FusionStrategy, Reranker, +) from .adapters import TextAdapter, PDFAdapter, CSVAdapter, MarkdownAdapter, JSONAdapter +from tinysearch.utils.file_discovery import iter_input_files from .splitters import CharacterTextSplitter from .embedders import HuggingFaceEmbedder # ็›ดๆŽฅไปŽๆจกๅ—ๅฏผๅ…ฅ from .indexers.faiss_indexer import FAISSIndexer from .query.template import TemplateQueryEngine +from .query.hybrid import HybridQueryEngine +from .retrievers.vector_retriever import VectorRetriever +from .retrievers.bm25_retriever import BM25Retriever +from .retrievers.substring_retriever import SubstringRetriever +from .fusion.rrf import ReciprocalRankFusion +from .fusion.weighted import WeightedFusion +from .rerankers.cross_encoder import CrossEncoderReranker from .logger import get_logger, configure_logger, log_step, log_progress, log_success, log_error @@ -137,30 +148,225 @@ def load_indexer(config: Config) -> FAISSIndexer: raise ValueError(f"Unsupported indexer type: {indexer_type}") +def load_retriever(config: Config, retriever_config: Dict[str, Any], embedder: Embedder, indexer: FAISSIndexer) -> Retriever: + """ + Load a single retriever based on its config dict. + + Args: + config: Global configuration object + retriever_config: Dict like {"type": "vector"} or {"type": "bm25", "tokenizer": "jieba"} + embedder: Embedder instance (used by vector retriever) + indexer: FAISSIndexer instance (used by vector retriever) + + Returns: + Retriever instance + """ + rtype = retriever_config.get("type", "vector") + + if rtype == "vector": + return VectorRetriever( + embedder=embedder, + indexer=indexer, + query_template=config.get("query_engine.template"), + ) + elif rtype == "bm25": + tokenizer_name = retriever_config.get("tokenizer", "default") + tokenizer = None # use default + if tokenizer_name == "jieba": + try: + import jieba + tokenizer = lambda text: list(jieba.cut(text.lower())) + except ImportError: + pass # fallback to default + return BM25Retriever(tokenizer=tokenizer) + elif rtype == "substring": + return SubstringRetriever( + is_regex=retriever_config.get("is_regex", False), + ) + else: + raise ValueError(f"Unsupported retriever type: {rtype}") + + +def load_retrievers(config: Config, embedder: Embedder, indexer: FAISSIndexer) -> List[Retriever]: + """ + Load all configured retrievers. + + Returns: + List of Retriever instances + """ + retriever_configs = config.get("retrievers", [{"type": "vector"}]) + return [ + load_retriever(config, rc, embedder, indexer) + for rc in retriever_configs + ] + + +def load_fusion(config: Config) -> FusionStrategy: + """ + Load a fusion strategy based on configuration. + + Returns: + FusionStrategy instance + """ + strategy = config.get("fusion.strategy", "weighted") + if strategy == "rrf": + return ReciprocalRankFusion( + k=config.get("fusion.k", 60), + ) + elif strategy == "weighted": + return WeightedFusion( + weights=config.get("fusion.weights"), + min_score=config.get("fusion.min_score", 0.0), + ) + else: + raise ValueError(f"Unsupported fusion strategy: {strategy}") + + +def load_reranker(config: Config) -> Optional[Reranker]: + """ + Load a reranker if enabled in configuration. + + Returns: + Reranker instance or None + """ + if not config.get("reranker.enabled", False): + return None + + return CrossEncoderReranker( + model_name=config.get("reranker.model", "BAAI/bge-reranker-v2-m3"), + device=config.get("reranker.device"), + batch_size=config.get("reranker.batch_size", 64), + max_length=config.get("reranker.max_length", 512), + use_fp16=config.get("reranker.use_fp16", True), + ) + + def load_query_engine(config: Config, embedder: Embedder, indexer: FAISSIndexer) -> QueryEngine: """ - Load a query engine based on configuration - + Load a query engine based on configuration. + + Supports: + - "template": Original TemplateQueryEngine (backward compatible) + - "hybrid": HybridQueryEngine with multi-retriever fusion + Args: config: Configuration object embedder: Embedder instance indexer: FAISSIndexer instance - + Returns: QueryEngine instance """ query_engine_type = config.get("query_engine.method", "template") - + if query_engine_type == "template": return TemplateQueryEngine( embedder=embedder, indexer=indexer, template=config.get("query_engine.template", "่ฏทๅธฎๆˆ‘ๆŸฅๆ‰พ๏ผš{query}") ) + elif query_engine_type == "hybrid": + from tinysearch.indexers.metadata_index import MetadataIndex + retrievers = load_retrievers(config, embedder, indexer) + fusion = load_fusion(config) + reranker = load_reranker(config) + filter_mode = config.get("query_engine.filter_mode", "auto") + metadata_index = MetadataIndex() if filter_mode != "post" else None + return HybridQueryEngine( + retrievers=retrievers, + fusion_strategy=fusion, + reranker=reranker, + recall_multiplier=config.get("query_engine.recall_multiplier", 2), + min_scores=config.get("query_engine.min_scores", None), + filter_multiplier=config.get("query_engine.filter_multiplier", 3), + metadata_index=metadata_index, + filter_mode=filter_mode, + ) else: raise ValueError(f"Unsupported query engine type: {query_engine_type}") +def _get_retriever_index_dir(index_path: Path) -> Path: + """Derive the retriever index directory from the FAISS index path. + + FAISS saves into ``index_path.with_suffix('')`` (e.g. ``index.faiss`` โ†’ + ``index/``). Non-vector retriever indexes are stored as subdirectories + inside the same directory so they stay together with the FAISS index. + """ + return index_path.with_suffix('') if index_path.suffix else index_path + + +def _build_hybrid_retriever_indexes( + query_engine: QueryEngine, chunks, logger=None, +) -> None: + """Build BM25 / Substring / MetadataIndex when the engine is a HybridQueryEngine.""" + if not isinstance(query_engine, HybridQueryEngine): + return + for retriever in query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue # already covered by FAISS indexer.build() + rname = type(retriever).__name__ + if logger: + log_step(f"Building {rname} index") + retriever.build(chunks) + + # Build metadata index for pre-filtering + if query_engine.metadata_index is not None: + if logger: + log_step("Building MetadataIndex") + query_engine.metadata_index.build(chunks) + + +def _save_hybrid_retriever_indexes( + query_engine: QueryEngine, index_path: Path, logger=None, +) -> None: + """Save non-vector retriever indexes and metadata index alongside the FAISS index.""" + if not isinstance(query_engine, HybridQueryEngine): + return + index_dir = _get_retriever_index_dir(index_path) + for retriever in query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue + rname = type(retriever).__name__.lower().replace("retriever", "") + rpath = index_dir / f"{rname}_index" + if logger: + log_step(f"Saving {type(retriever).__name__} index to {rpath}") + retriever.save(rpath) + + # Save metadata index + if query_engine.metadata_index is not None: + mpath = index_dir / "metadata_index.json" + if logger: + log_step(f"Saving MetadataIndex to {mpath}") + query_engine.metadata_index.save(mpath) + + +def _load_hybrid_retriever_indexes( + query_engine: QueryEngine, index_path: Path, logger=None, +) -> None: + """Load non-vector retriever indexes and metadata index from alongside the FAISS index.""" + if not isinstance(query_engine, HybridQueryEngine): + return + index_dir = _get_retriever_index_dir(index_path) + for retriever in query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue + rname = type(retriever).__name__.lower().replace("retriever", "") + rpath = index_dir / f"{rname}_index" + if rpath.exists(): + if logger: + log_step(f"Loading {type(retriever).__name__} index from {rpath}") + retriever.load(rpath) + + # Load metadata index + if query_engine.metadata_index is not None: + mpath = index_dir / "metadata_index.json" + if mpath.exists(): + if logger: + log_step(f"Loading MetadataIndex from {mpath}") + query_engine.metadata_index.load(mpath) + + def build_index(args: argparse.Namespace, config: Config) -> None: """ Build a search index @@ -177,15 +383,31 @@ def build_index(args: argparse.Namespace, config: Config) -> None: splitter = load_splitter(config) embedder = load_embedder(config) indexer = load_indexer(config) + query_engine = load_query_engine(config, embedder, indexer) - # Extract text from data source + # Extract text from data source with per-file source metadata. + # When args.data is a directory, process each file individually so that + # every text gets its actual file path as "source" โ€” not just the + # directory name. This prevents fusion dedup key collisions between + # chunks from different files that happen to share the same text. log_step("Extracting text") - texts = adapter.extract(args.data) + data_path = Path(args.data) + texts: list = [] + metadata: list = [] + adapter_type = config.get("adapter.type", "text") + for file_path in iter_input_files(data_path, adapter_type=adapter_type): + try: + file_texts = adapter.extract(file_path) + except Exception: + continue + for t in file_texts: + texts.append(t) + metadata.append({"source": str(file_path)}) logger.info(f"๐Ÿ“„ Extracted {len(texts)} documents") # Split text into chunks log_step("Splitting text into chunks") - chunks = splitter.split(texts) + chunks = splitter.split(texts, metadata) logger.info(f"โœ‚๏ธ Created {len(chunks)} text chunks") # Generate embeddings @@ -193,15 +415,21 @@ def build_index(args: argparse.Namespace, config: Config) -> None: vectors = embedder.embed([chunk.text for chunk in chunks]) logger.info(f"๐Ÿง  Generated {len(vectors)} embedding vectors") - # Build index + # Build FAISS index log_step("Building index") indexer.build(vectors, chunks) - # Save index + # Build non-vector retriever indexes (BM25, Substring) for hybrid mode + _build_hybrid_retriever_indexes(query_engine, chunks, logger) + + # Save FAISS index index_path = Path(config.get("indexer.index_path", "index.faiss")) log_step(f"Saving index to {index_path}") indexer.save(index_path) + # Save non-vector retriever indexes + _save_hybrid_retriever_indexes(query_engine, index_path, logger) + log_success("Index built successfully") @@ -220,11 +448,14 @@ def query_index(args: argparse.Namespace, config: Config) -> None: indexer = load_indexer(config) query_engine = load_query_engine(config, embedder, indexer) - # Load index + # Load FAISS index index_path = Path(config.get("indexer.index_path", "index.faiss")) log_step(f"Loading index from {index_path}") indexer.load(index_path) + # Load non-vector retriever indexes (BM25, Substring) for hybrid mode + _load_hybrid_retriever_indexes(query_engine, index_path, logger) + # Process query query = args.q top_k = args.top_k or config.get("query_engine.top_k", 5) diff --git a/tinysearch/config.py b/tinysearch/config.py index b50fe68..5a2019a 100644 --- a/tinysearch/config.py +++ b/tinysearch/config.py @@ -42,10 +42,23 @@ def __init__(self, config_path: Optional[Union[str, Path]] = None): "index_path": ".cache/index.faiss", "metric": "cosine" }, + "retrievers": [ + {"type": "vector"} + ], + "fusion": { + "strategy": "weighted", + "weights": [1.0], + }, + "reranker": { + "enabled": False, + }, "query_engine": { "method": "template", "template": "่ฏทๅธฎๆˆ‘ๆŸฅๆ‰พ๏ผš{query}", - "top_k": 5 + "top_k": 5, + "recall_multiplier": 2, + "min_scores": None, + "filter_multiplier": 3, }, "flow": { "use_cache": True, diff --git a/tinysearch/flow/controller.py b/tinysearch/flow/controller.py index bd67071..8f6156c 100644 --- a/tinysearch/flow/controller.py +++ b/tinysearch/flow/controller.py @@ -7,9 +7,12 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union, Set, Tuple, cast, Callable -from tinysearch.base import DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine +from tinysearch.base import DataAdapter, TextSplitter, Embedder, VectorIndexer, QueryEngine, Retriever from tinysearch.base import TextChunk, FlowController as FlowControllerBase from tinysearch.flow.hot_update import HotUpdateManager +from tinysearch.query.hybrid import HybridQueryEngine +from tinysearch.retrievers.vector_retriever import VectorRetriever +from tinysearch.utils.file_discovery import iter_input_files class FlowController(FlowControllerBase): @@ -18,9 +21,12 @@ class FlowController(FlowControllerBase): Manages the entire data processing from ingestion to query handling. """ + # Soft-delete threshold: trigger full rebuild when exceeded + DELETE_REBUILD_THRESHOLD = 100 + def __init__( self, - data_adapter: DataAdapter, + data_adapter: Optional[DataAdapter], text_splitter: TextSplitter, embedder: Embedder, indexer: VectorIndexer, @@ -29,9 +35,9 @@ def __init__( ): """ Initialize the flow controller with all required components - + Args: - data_adapter: Component for data extraction + data_adapter: Component for data extraction (None when using record-based API) text_splitter: Component for text chunking embedder: Component for generating embeddings indexer: Component for vector indexing and search @@ -141,6 +147,9 @@ def process_file(self, file_path: Union[str, Path], force_reprocess: bool = Fals # Add to index self.indexer.build(vectors, chunks) + + # Build retriever indexes for HybridQueryEngine + self._build_retriever_indexes(chunks) def process_directory(self, dir_path: Union[str, Path], extensions: Optional[List[str]] = None, recursive: bool = True, force_reprocess: bool = False) -> None: @@ -154,24 +163,15 @@ def process_directory(self, dir_path: Union[str, Path], extensions: Optional[Lis force_reprocess: If True, reprocess even if file is in cache """ dir_path = Path(dir_path) - + if not dir_path.is_dir(): raise ValueError(f"{dir_path} is not a directory") - - # Function to check if a file should be processed - def should_process(file_path: Path) -> bool: - if extensions and file_path.suffix.lower() not in extensions: - return False + + adapter_type = self.config.get("adapter", {}).get("type", "text") + for file_path in iter_input_files(dir_path, adapter_type=adapter_type, extensions=extensions, recursive=recursive): if not force_reprocess and str(file_path) in self.processed_files: - return False - return True - - # Process all matching files in directory - for item in dir_path.iterdir(): - if item.is_file() and should_process(item): - self.process_file(item, force_reprocess) - elif item.is_dir() and recursive: - self.process_directory(item, extensions, recursive, force_reprocess) + continue + self.process_file(file_path, force_reprocess) def build_index(self, data_path: Union[str, Path], **kwargs) -> None: """ @@ -200,32 +200,242 @@ def build_index(self, data_path: Union[str, Path], **kwargs) -> None: else: self.process_file(data_path, force_reprocess=force_reprocess) + def build_from_records( + self, + records: Dict[str, Dict[str, Any]], + adapter: Any, + splitter: Optional[TextSplitter] = None, + ) -> List[TextChunk]: + """ + Build the search index from in-memory records. + + Args: + records: Mapping of record_id -> record_data + adapter: RecordAdapter to convert records to TextChunks + splitter: Optional TextSplitter for further chunking + + Returns: + List of TextChunks that were indexed + """ + from tinysearch.records import build_chunks_from_records + + chunks = build_chunks_from_records(records, adapter, splitter) + if not chunks: + return chunks + + vectors = self.embedder.embed([c.text for c in chunks]) + self.indexer.build(vectors, chunks) + self._build_retriever_indexes(chunks) + + return chunks + + def build_incremental( + self, + records: Dict[str, Dict[str, Any]], + adapter: Any, + hash_tracker: Any, + splitter: Optional[TextSplitter] = None, + delete_rebuild_threshold: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Incrementally update the search index based on record changes. + + Pipeline: + 1. adapter.to_chunk() for each record + 2. hash_tracker.detect_changes() โ†’ new/modified/deleted + 3. If soft deletes exceed threshold โ†’ full rebuild + 4. Else โ†’ FAISS.add() + MetadataIndex.add_chunks() for new/modified + 5. BM25/Substring always full rebuild (fast) + + Args: + records: Current complete set of records {record_id: record_data} + adapter: RecordAdapter to convert records to TextChunks + hash_tracker: ContentHashTracker for change detection + splitter: Optional TextSplitter + delete_rebuild_threshold: Max soft deletes before full rebuild + + Returns: + Dict with stats: {new, modified, deleted, unchanged, full_rebuild} + """ + from tinysearch.records import build_chunks_from_records + + threshold = delete_rebuild_threshold or self.DELETE_REBUILD_THRESHOLD + + # Step 1: Convert all current records to chunks for change detection + current_record_chunks: Dict[str, TextChunk] = {} + for rid, rdata in records.items(): + chunk = adapter.to_chunk(rid, rdata) + if "record_id" not in chunk.metadata: + chunk.metadata["record_id"] = rid + current_record_chunks[rid] = chunk + + # Step 2: Detect changes + changes = hash_tracker.detect_changes(current_record_chunks) + + stats = { + "new": len(changes.new), + "modified": len(changes.modified), + "deleted": len(changes.deleted), + "unchanged": len(changes.unchanged), + "full_rebuild": False, + } + + if not changes.has_changes: + return stats + + # Step 3: Check if full rebuild needed + total_soft_deletes = len(changes.deleted) + len(changes.modified) + if isinstance(self.query_engine, HybridQueryEngine): + total_soft_deletes += self.query_engine.soft_delete_count + + if total_soft_deletes >= threshold: + # Full rebuild path + stats["full_rebuild"] = True + self.build_from_records(records, adapter, splitter) + + if isinstance(self.query_engine, HybridQueryEngine): + self.query_engine.clear_soft_deletes() + + hash_tracker.remove(changes.deleted) + hash_tracker.update(current_record_chunks) + return stats + + # Step 4: Incremental path + + # 4a: Soft-delete modified + deleted + if isinstance(self.query_engine, HybridQueryEngine): + ids_to_soft_delete = changes.deleted | set(changes.modified) + if ids_to_soft_delete: + self.query_engine.add_soft_deletes(ids_to_soft_delete) + + # 4b: Embed and add new + modified to FAISS + records_to_add = {rid: records[rid] for rid in changes.new + changes.modified} + if records_to_add: + new_chunks = build_chunks_from_records(records_to_add, adapter, splitter) + if new_chunks: + vectors = self.embedder.embed([c.text for c in new_chunks]) + self.indexer.add(vectors, new_chunks) + + # Incrementally add to MetadataIndex + if (isinstance(self.query_engine, HybridQueryEngine) + and self.query_engine.metadata_index is not None): + start_id = self.query_engine.metadata_index.total_chunks + self.query_engine.metadata_index.add_chunks(new_chunks, start_id) + + # 4c: BM25/Substring always full rebuild (fast) + all_current_chunks = build_chunks_from_records(records, adapter, splitter) + if isinstance(self.query_engine, HybridQueryEngine): + for retriever in self.query_engine.retrievers: + if isinstance(retriever, VectorRetriever): + continue + retriever.build(all_current_chunks) + + # Step 5: Update hash tracker + hash_tracker.remove(changes.deleted) + hash_tracker.update({ + rid: current_record_chunks[rid] + for rid in changes.new + changes.modified + }) + + return stats + + def _get_hybrid_retrievers(self) -> List[Retriever]: + """Get retrievers from HybridQueryEngine, if applicable""" + if isinstance(self.query_engine, HybridQueryEngine): + return self.query_engine.retrievers + return [] + + def _build_retriever_indexes(self, chunks: List[TextChunk]) -> None: + """Build indexes for non-vector retrievers and metadata index in HybridQueryEngine""" + for retriever in self._get_hybrid_retrievers(): + # Skip VectorRetriever - it's already handled by self.indexer.build() + if isinstance(retriever, VectorRetriever): + continue + retriever.build(chunks) + + # Build metadata index for pre-filtering + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.metadata_index is not None: + self.query_engine.metadata_index.build(chunks) + def save_index(self, path: Optional[Union[str, Path]] = None) -> None: """ - Save the built index to disk - + Save the built index to disk. + Also saves retriever indexes for HybridQueryEngine. + Args: path: Path to save the index to, if None use the config path """ if path is None: path = self.config.get("indexer", {}).get("index_path", "index.faiss") - - # Use cast to ensure type safety for optional path + + # Save the main vector index self.indexer.save(cast(Union[str, Path], path)) - + + # Save non-vector retriever indexes + self._save_retriever_indexes(Path(str(path))) + + def _save_retriever_indexes(self, base_path: Path) -> None: + """Save indexes for non-vector retrievers and metadata index inside the FAISS index directory""" + # FAISS saves into base_path.with_suffix('') (e.g. "index.faiss" โ†’ "index/") + index_dir = base_path.with_suffix('') if base_path.suffix else base_path + for retriever in self._get_hybrid_retrievers(): + if isinstance(retriever, VectorRetriever): + continue + # Derive subdirectory name from retriever class + retriever_name = type(retriever).__name__.lower().replace("retriever", "") + retriever_path = index_dir / f"{retriever_name}_index" + retriever.save(retriever_path) + + # Save metadata index + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.metadata_index is not None: + self.query_engine.metadata_index.save(index_dir / "metadata_index.json") + + # Save soft-delete set + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.soft_deleted_ids: + with open(index_dir / "soft_deletes.json", "w") as f: + json.dump(sorted(self.query_engine.soft_deleted_ids), f) + def load_index(self, path: Optional[Union[str, Path]] = None) -> None: """ - Load an index from disk - + Load an index from disk. + Also loads retriever indexes for HybridQueryEngine. + Args: path: Path to load the index from, if None use the config path """ if path is None: path = self.config.get("indexer", {}).get("index_path", "index.faiss") - - # Use cast to ensure type safety for optional path + + # Load the main vector index self.indexer.load(cast(Union[str, Path], path)) - + + # Load non-vector retriever indexes + self._load_retriever_indexes(Path(str(path))) + + def _load_retriever_indexes(self, base_path: Path) -> None: + """Load indexes for non-vector retrievers and metadata index from the FAISS index directory""" + index_dir = base_path.with_suffix('') if base_path.suffix else base_path + for retriever in self._get_hybrid_retrievers(): + if isinstance(retriever, VectorRetriever): + continue + retriever_name = type(retriever).__name__.lower().replace("retriever", "") + retriever_path = index_dir / f"{retriever_name}_index" + if retriever_path.exists(): + retriever.load(retriever_path) + + # Load metadata index + if isinstance(self.query_engine, HybridQueryEngine) and self.query_engine.metadata_index is not None: + metadata_path = index_dir / "metadata_index.json" + if metadata_path.exists(): + self.query_engine.metadata_index.load(metadata_path) + + # Load soft-delete set + if isinstance(self.query_engine, HybridQueryEngine): + soft_delete_path = index_dir / "soft_deletes.json" + if soft_delete_path.exists(): + with open(soft_delete_path, "r") as f: + self.query_engine.soft_deleted_ids = set(json.load(f)) + def query(self, query_text: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ Process a query and return relevant chunks @@ -242,7 +452,7 @@ def query(self, query_text: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any top_k = self.config.get("query_engine", {}).get("top_k", 5) # Use cast to ensure type safety for optional top_k - return self.query_engine.retrieve(query_text, cast(int, top_k)) + return self.query_engine.retrieve(query_text, cast(int, top_k), **kwargs) def clear_cache(self) -> None: """Clear all cached data""" diff --git a/tinysearch/fusion/__init__.py b/tinysearch/fusion/__init__.py new file mode 100644 index 0000000..cfa509e --- /dev/null +++ b/tinysearch/fusion/__init__.py @@ -0,0 +1,11 @@ +""" +Fusion strategies for combining multi-retriever results +""" + +from .rrf import ReciprocalRankFusion +from .weighted import WeightedFusion + +__all__ = [ + "ReciprocalRankFusion", + "WeightedFusion", +] diff --git a/tinysearch/fusion/_utils.py b/tinysearch/fusion/_utils.py new file mode 100644 index 0000000..aa4edd5 --- /dev/null +++ b/tinysearch/fusion/_utils.py @@ -0,0 +1,22 @@ +""" +Shared utilities for fusion strategies. +""" +from typing import Any, Dict + + +def make_doc_key(result: Dict[str, Any]) -> str: + """ + Create a dedup key from a retrieval result. + + Uses text + metadata (source, chunk_index) to distinguish genuinely + different chunks that happen to share the same text content, while still + merging the same chunk found by different retrievers. + + Disambiguation levels: + - source: separates same text from different files + - chunk_index: separates same text in different positions within one file + """ + meta = result.get("metadata", {}) + source = meta.get("source", "") + chunk_index = meta.get("chunk_index", "") + return f"{result['text']}\x00{source}\x00{chunk_index}" diff --git a/tinysearch/fusion/rrf.py b/tinysearch/fusion/rrf.py new file mode 100644 index 0000000..e8cf012 --- /dev/null +++ b/tinysearch/fusion/rrf.py @@ -0,0 +1,71 @@ +""" +Reciprocal Rank Fusion (RRF) strategy +""" +from collections import defaultdict +from typing import Any, Dict, List + +from tinysearch.base import FusionStrategy +from tinysearch.fusion._utils import make_doc_key + + +class ReciprocalRankFusion(FusionStrategy): + """ + Reciprocal Rank Fusion combines multiple ranked lists using: + score(doc) = sum( 1 / (rank_i + k) ) + + where k is a constant (default 60) that reduces the impact of high-ranked items. + This is a robust, parameter-free fusion method widely used in information retrieval. + """ + + def __init__(self, k: int = 60): + """ + Args: + k: RRF constant. Higher values reduce the gap between ranks. + Default 60 is the standard value from the original RRF paper. + """ + self.k = k + + def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[str, Any]]: + """ + Fuse multiple result lists using RRF. + + Args: + results_list: List of result lists from different retrievers. + Each result must have 'text' key for deduplication. + + Returns: + Fused list sorted by RRF score descending. + """ + # Track RRF scores and best result per document (keyed by text) + rrf_scores: Dict[str, float] = defaultdict(float) + best_result: Dict[str, Dict[str, Any]] = {} + per_method_scores: Dict[str, Dict[str, float]] = defaultdict(dict) + + for result_list in results_list: + for rank, result in enumerate(result_list): + doc_key = make_doc_key(result) + rrf_score = 1.0 / (rank + self.k) + rrf_scores[doc_key] += rrf_score + + method = result.get("retrieval_method", "unknown") + per_method_scores[doc_key][method] = result.get("score", 0.0) + + # Keep the result with the highest original score + if doc_key not in best_result or result.get("score", 0) > best_result[doc_key].get("score", 0): + best_result[doc_key] = result + + # Build fused results + fused = [] + for doc_key, rrf_score in rrf_scores.items(): + base = best_result[doc_key] + fused.append({ + "text": base["text"], + "metadata": base.get("metadata", {}), + "score": rrf_score, + "retrieval_method": "hybrid", + "scores": per_method_scores[doc_key], + }) + + # Sort by fused score descending + fused.sort(key=lambda x: x["score"], reverse=True) + return fused diff --git a/tinysearch/fusion/weighted.py b/tinysearch/fusion/weighted.py new file mode 100644 index 0000000..2bad2d4 --- /dev/null +++ b/tinysearch/fusion/weighted.py @@ -0,0 +1,105 @@ +""" +Weighted score fusion strategy +""" +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from tinysearch.base import FusionStrategy +from tinysearch.fusion._utils import make_doc_key + + +class WeightedFusion(FusionStrategy): + """ + Weighted fusion normalizes each retriever's scores to [0, 1] via min-max, + then computes a weighted sum: + fusion_score = sum( normalized_score_i * weight_i ) + """ + + def __init__( + self, + weights: Optional[List[float]] = None, + min_score: float = 0.0, + ): + """ + Args: + weights: Weight for each retriever. If None, equal weights are used. + min_score: Minimum fusion score threshold. Results below this are dropped. + """ + self.weights = weights + self.min_score = min_score + + def fuse(self, results_list: List[List[Dict[str, Any]]], **kwargs) -> List[Dict[str, Any]]: + """ + Fuse multiple result lists using weighted score fusion. + + Args: + results_list: List of result lists from different retrievers. + **kwargs: + weights: Override weights (takes precedence over self.weights) + + Returns: + Fused list sorted by fusion score descending. + """ + if not results_list: + return [] + + n_retrievers = len(results_list) + weights = kwargs.get("weights", self.weights) + if weights is None: + weights = [1.0 / n_retrievers] * n_retrievers + + if len(weights) != n_retrievers: + raise ValueError( + f"Number of weights ({len(weights)}) must match " + f"number of result lists ({n_retrievers})" + ) + + # Normalize scores per retriever (min-max to [0, 1]) + normalized_lists = [] + for result_list in results_list: + if not result_list: + normalized_lists.append([]) + continue + scores = [r.get("score", 0.0) for r in result_list] + min_s = min(scores) + max_s = max(scores) + range_s = max_s - min_s if max_s != min_s else 1.0 + + normalized = [] + for r in result_list: + norm_score = (r.get("score", 0.0) - min_s) / range_s + normalized.append({**r, "_norm_score": norm_score}) + normalized_lists.append(normalized) + + # Aggregate by document text + doc_scores: Dict[str, float] = defaultdict(float) + best_result: Dict[str, Dict[str, Any]] = {} + per_method_scores: Dict[str, Dict[str, float]] = defaultdict(dict) + + for i, (result_list, weight) in enumerate(zip(normalized_lists, weights)): + for result in result_list: + doc_key = make_doc_key(result) + doc_scores[doc_key] += result["_norm_score"] * weight + + method = result.get("retrieval_method", "unknown") + per_method_scores[doc_key][method] = result.get("score", 0.0) + + if doc_key not in best_result: + best_result[doc_key] = result + + # Build fused results + fused = [] + for doc_key, fusion_score in doc_scores.items(): + if fusion_score < self.min_score: + continue + base = best_result[doc_key] + fused.append({ + "text": base["text"], + "metadata": base.get("metadata", {}), + "score": fusion_score, + "retrieval_method": "hybrid", + "scores": per_method_scores[doc_key], + }) + + fused.sort(key=lambda x: x["score"], reverse=True) + return fused diff --git a/tinysearch/indexers/__init__.py b/tinysearch/indexers/__init__.py index ca8b3e8..d88d23f 100644 --- a/tinysearch/indexers/__init__.py +++ b/tinysearch/indexers/__init__.py @@ -6,9 +6,14 @@ # Make the FAISS indexer available from the root module from .faiss_indexer import FAISSIndexer +from .metadata_index import MetadataIndex +from .hash_tracker import ContentHashTracker, ChangeSet __all__ = [ - "FAISSIndexer" + "FAISSIndexer", + "MetadataIndex", + "ContentHashTracker", + "ChangeSet", ] def index_exists(path: Union[str, Path]) -> bool: diff --git a/tinysearch/indexers/faiss_indexer.py b/tinysearch/indexers/faiss_indexer.py index 4d6e87e..b55e913 100644 --- a/tinysearch/indexers/faiss_indexer.py +++ b/tinysearch/indexers/faiss_indexer.py @@ -92,12 +92,17 @@ def build(self, vectors: List[List[float]], texts: List[TextChunk]) -> None: # Create ID mapping self.ids_map = {i: i for i in range(len(texts))} - def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, Any]]: + def search(self, query_vector: List[float], top_k: int = 5, + candidate_ids: Optional[set] = None) -> List[Dict[str, Any]]: """ - Search the index for vectors similar to the query vector + Search the index for vectors similar to the query vector. + Args: query_vector: Query embedding vector top_k: Number of results to return + candidate_ids: Optional set of chunk indices to restrict search to. + When provided, over-retrieves then filters by membership. + Returns: List of dictionaries containing text chunks, similarity scores, and embedding vectors """ @@ -115,8 +120,16 @@ def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, An # Normalize query vector if using cosine similarity if self.metric == "cosine": faiss.normalize_L2(query_np) + + # When candidate_ids is set, over-recall then filter + if candidate_ids is not None: + effective_k = min(self.index.ntotal, top_k * 3) + else: + effective_k = top_k + # Search index - distances, indices = self.index.search(query_np, top_k) + distances, indices = self.index.search(query_np, effective_k) + # Convert to result format results = [] for i in range(len(indices[0])): @@ -125,6 +138,9 @@ def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, An # Skip invalid indices (can happen if there are fewer results than top_k) if idx < 0 or idx >= len(self.texts): continue + # Pre-filter: skip if not in candidate set + if candidate_ids is not None and idx not in candidate_ids: + continue # Map to original text chunk text_idx = self.ids_map[idx] text_chunk = self.texts[text_idx] @@ -146,6 +162,8 @@ def search(self, query_vector: List[float], top_k: int = 5) -> List[Dict[str, An "score": similarity, "embedding": embedding }) + if candidate_ids is not None and len(results) >= top_k: + break return results def save(self, path: Union[str, Path]) -> None: diff --git a/tinysearch/indexers/hash_tracker.py b/tinysearch/indexers/hash_tracker.py new file mode 100644 index 0000000..061fe5b --- /dev/null +++ b/tinysearch/indexers/hash_tracker.py @@ -0,0 +1,153 @@ +""" +Content hash tracking for incremental indexing change detection. +""" +import hashlib +import json +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from tinysearch.base import TextChunk + +logger = logging.getLogger(__name__) + + +class ChangeSet: + """Result of change detection between current and tracked records.""" + + __slots__ = ("new", "modified", "deleted", "unchanged") + + def __init__( + self, + new: List[str], + modified: List[str], + deleted: Set[str], + unchanged: Set[str], + ): + self.new = new + self.modified = modified + self.deleted = deleted + self.unchanged = unchanged + + @property + def has_changes(self) -> bool: + return bool(self.new or self.modified or self.deleted) + + def __repr__(self) -> str: + return ( + f"ChangeSet(new={len(self.new)}, modified={len(self.modified)}, " + f"deleted={len(self.deleted)}, unchanged={len(self.unchanged)})" + ) + + +class ContentHashTracker: + """ + Track content hashes for record-level change detection. + + Each record is identified by its record_id (string). The hash is + computed from the text and a configurable subset of metadata keys. + """ + + _INTERNAL_KEYS = frozenset({"chunk_index", "total_chunks"}) + + def __init__(self, hash_metadata_keys: Optional[List[str]] = None): + """ + Args: + hash_metadata_keys: Metadata keys to include in hash. + If None, all keys except internal ones are included. + """ + self._hashes: Dict[str, str] = {} + self._hash_metadata_keys = hash_metadata_keys + + def compute_hash(self, text: str, metadata: Dict[str, Any]) -> str: + """Compute MD5 hash of text + sorted metadata.""" + hasher = hashlib.md5() + hasher.update(text.encode("utf-8")) + + if self._hash_metadata_keys is not None: + keys = self._hash_metadata_keys + else: + keys = sorted(k for k in metadata if k not in self._INTERNAL_KEYS) + + for key in sorted(keys): + if key in metadata: + val = metadata[key] + # Normalize lists for deterministic hashing + if isinstance(val, list): + val = json.dumps(sorted(str(v) for v in val), ensure_ascii=False) + hasher.update(f"{key}={val}".encode("utf-8")) + + return hasher.hexdigest() + + def detect_changes(self, current_records: Dict[str, TextChunk]) -> ChangeSet: + """ + Compare current records against tracked hashes. + + Args: + current_records: Mapping of record_id -> TextChunk + + Returns: + ChangeSet with new, modified, deleted, unchanged + """ + tracked_ids = set(self._hashes.keys()) + + new_ids: List[str] = [] + modified_ids: List[str] = [] + unchanged_ids: Set[str] = set() + + for rid, chunk in current_records.items(): + h = self.compute_hash(chunk.text, chunk.metadata) + if rid not in tracked_ids: + new_ids.append(rid) + elif self._hashes[rid] != h: + modified_ids.append(rid) + else: + unchanged_ids.add(rid) + + deleted_ids = tracked_ids - set(current_records.keys()) + + return ChangeSet( + new=new_ids, + modified=modified_ids, + deleted=deleted_ids, + unchanged=unchanged_ids, + ) + + def update(self, records: Dict[str, TextChunk]) -> None: + """Update tracked hashes after a successful build.""" + for rid, chunk in records.items(): + self._hashes[rid] = self.compute_hash(chunk.text, chunk.metadata) + + def remove(self, record_ids: Set[str]) -> None: + """Remove record_ids from tracking.""" + for rid in record_ids: + self._hashes.pop(rid, None) + + @property + def tracked_count(self) -> int: + return len(self._hashes) + + def save(self, path: Union[str, Path]) -> None: + """Save hash state to JSON.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump( + { + "version": 1, + "hash_metadata_keys": self._hash_metadata_keys, + "hashes": self._hashes, + }, + f, + ensure_ascii=False, + ) + logger.info("ContentHashTracker saved to %s (%d records)", path, len(self._hashes)) + + def load(self, path: Union[str, Path]) -> None: + """Load hash state from JSON.""" + path = Path(path) + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + self._hashes = data["hashes"] + self._hash_metadata_keys = data.get("hash_metadata_keys") + logger.info("ContentHashTracker loaded from %s (%d records)", path, len(self._hashes)) diff --git a/tinysearch/indexers/metadata_index.py b/tinysearch/indexers/metadata_index.py new file mode 100644 index 0000000..2dabf5a --- /dev/null +++ b/tinysearch/indexers/metadata_index.py @@ -0,0 +1,228 @@ +""" +Inverted index over TextChunk metadata for fast candidate set lookup. + +Enables O(1) pre-filtering by metadata fields (e.g., grade, type, tags) +instead of linear post-filtering over all retrieval results. +""" +import json +import logging +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +from tinysearch.base import TextChunk + +logger = logging.getLogger(__name__) + +# Filter value types matching HybridQueryEngine._match_filters +FilterValue = Union[str, int, float, bool, List, Callable] + +# Scalar types that can be directly indexed +_INDEXABLE_TYPES = (str, int, float, bool) + + +class MetadataIndex: + """ + Inverted index: metadata field -> value -> set of chunk IDs. + + Supports: + - Scalar values (str, int, float, bool): indexed directly + - List[str/int/...] values (e.g., tags): each element indexed separately + """ + + def __init__(self) -> None: + self._index: Dict[str, Dict[Any, Set[int]]] = defaultdict( + lambda: defaultdict(set) + ) + self._total_chunks: int = 0 + + def build(self, chunks: List[TextChunk]) -> None: + """ + Build inverted indices from chunk metadata. + + Args: + chunks: Ordered list of TextChunks. The positional index (0-based) + becomes the chunk ID used in candidate sets. + """ + self._index = defaultdict(lambda: defaultdict(set)) + self._total_chunks = len(chunks) + + for i, chunk in enumerate(chunks): + if not chunk.metadata: + continue + for key, value in chunk.metadata.items(): + if isinstance(value, _INDEXABLE_TYPES): + self._index[key][value].add(i) + elif isinstance(value, list): + for element in value: + if isinstance(element, _INDEXABLE_TYPES): + self._index[key][element].add(i) + + field_stats = {k: len(v) for k, v in self._index.items()} + logger.info( + "MetadataIndex built: %d chunks, fields=%s", + self._total_chunks, + field_stats, + ) + + def add_chunks(self, chunks: List[TextChunk], start_id: int) -> None: + """ + Incrementally add new chunks to the inverted index. + + Args: + chunks: New TextChunks to index + start_id: First chunk ID to assign (typically current total_chunks) + """ + for offset, chunk in enumerate(chunks): + chunk_id = start_id + offset + if not chunk.metadata: + continue + for key, value in chunk.metadata.items(): + if isinstance(value, _INDEXABLE_TYPES): + self._index[key][value].add(chunk_id) + elif isinstance(value, list): + for element in value: + if isinstance(element, _INDEXABLE_TYPES): + self._index[key][element].add(chunk_id) + + self._total_chunks += len(chunks) + logger.info( + "MetadataIndex: added %d chunks (total now %d)", + len(chunks), + self._total_chunks, + ) + + def lookup(self, filters: Dict[str, FilterValue]) -> Optional[Set[int]]: + """ + Resolve filters to a candidate ID set via inverted index. + + Filter semantics (matches HybridQueryEngine._match_filters): + - scalar (str/int/float/bool): exact match -> direct set lookup + - list: OR over values -> union of sets + - multiple filter keys: AND -> set intersection + - callable: cannot be resolved -> returns None + + Returns: + Set[int] of matching chunk IDs, or None if any filter is callable. + """ + if not filters: + return None + + result: Optional[Set[int]] = None + + for key, condition in filters.items(): + if callable(condition): + return None + + matched = self._lookup_single(key, condition) + + if result is None: + result = matched + else: + result = result & matched + + # Short-circuit on empty intersection + if result is not None and len(result) == 0: + return set() + + return result if result is not None else set() + + def _lookup_single(self, key: str, condition: FilterValue) -> Set[int]: + """Lookup a single filter key-value pair.""" + field_index = self._index.get(key) + if field_index is None: + return set() + + if isinstance(condition, list): + # OR: union of all matching values + matched = set() + for val in condition: + matched |= field_index.get(val, set()) + return matched + else: + # Exact match + return set(field_index.get(condition, set())) + + def classify_filters( + self, filters: Dict[str, FilterValue] + ) -> Tuple[Dict[str, FilterValue], Dict[str, FilterValue]]: + """ + Split filters into indexable and non-indexable (callable) parts. + + Returns: + (indexable_filters, callable_filters) + """ + indexable = {} + callables = {} + for key, condition in filters.items(): + if callable(condition): + callables[key] = condition + else: + indexable[key] = condition + return indexable, callables + + @property + def total_chunks(self) -> int: + return self._total_chunks + + @property + def fields(self) -> List[str]: + return list(self._index.keys()) + + def save(self, path: Union[str, Path]) -> None: + """ + Save to JSON. Sets become sorted lists; value types are preserved + via a type tag for proper reconstruction on load. + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + serializable = { + "version": 1, + "total_chunks": self._total_chunks, + "fields": {}, + } + + for field_name, value_map in self._index.items(): + entries = {} + for value, ids in value_map.items(): + # Use repr-style key that preserves type info + type_tag = type(value).__name__ + str_key = json.dumps({"v": value, "t": type_tag}, ensure_ascii=False) + entries[str_key] = sorted(ids) + serializable["fields"][field_name] = entries + + with open(path, "w", encoding="utf-8") as f: + json.dump(serializable, f, ensure_ascii=False) + + logger.info("MetadataIndex saved to %s", path) + + def load(self, path: Union[str, Path]) -> None: + """Load from JSON, reconstructing sets and value types.""" + path = Path(path) + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + self._total_chunks = data["total_chunks"] + self._index = defaultdict(lambda: defaultdict(set)) + + type_constructors = {"str": str, "int": int, "float": float, "bool": bool} + + for field_name, entries in data["fields"].items(): + for str_key, id_list in entries.items(): + key_data = json.loads(str_key) + value = key_data["v"] + type_tag = key_data["t"] + # Reconstruct typed value + constructor = type_constructors.get(type_tag) + if constructor: + value = constructor(value) + self._index[field_name][value] = set(id_list) + + logger.info( + "MetadataIndex loaded from %s: %d chunks, %d fields", + path, + self._total_chunks, + len(self._index), + ) diff --git a/tinysearch/query/__init__.py b/tinysearch/query/__init__.py index 6338169..4f1d3a2 100644 --- a/tinysearch/query/__init__.py +++ b/tinysearch/query/__init__.py @@ -3,7 +3,9 @@ """ from .template import TemplateQueryEngine - +from .hybrid import HybridQueryEngine + __all__ = [ - "TemplateQueryEngine" + "TemplateQueryEngine", + "HybridQueryEngine", ] \ No newline at end of file diff --git a/tinysearch/query/hybrid.py b/tinysearch/query/hybrid.py new file mode 100644 index 0000000..766491e --- /dev/null +++ b/tinysearch/query/hybrid.py @@ -0,0 +1,259 @@ +""" +Hybrid query engine - multi-retriever fusion with optional reranking +""" +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from tinysearch.base import QueryEngine, Retriever, FusionStrategy, Reranker + +logger = logging.getLogger(__name__) + +# Type alias for filter values +FilterValue = Union[str, int, float, bool, List, Callable] + + +class HybridQueryEngine(QueryEngine): + """ + Query engine that combines multiple retrievers via a fusion strategy, + with optional reranking. + + Pipeline: + 1. (If filter_mode is "pre" or "auto") Use MetadataIndex to resolve + indexable filters into a candidate_ids set; pass to retrievers + 2. Each retriever recalls top_k * recall_multiplier candidates + (ร— filter_multiplier only when post-filters are active) + 3. Per-retriever min_score filtering + 4. Metadata post-filtering for callable/non-indexable filters + 5. FusionStrategy merges and deduplicates results + 6. Optional Reranker re-scores the fused candidates + 7. Return top_k final results + """ + + def __init__( + self, + retrievers: List[Retriever], + fusion_strategy: FusionStrategy, + reranker: Optional[Reranker] = None, + recall_multiplier: int = 2, + min_scores: Optional[List[float]] = None, + filter_multiplier: int = 3, + metadata_index=None, + filter_mode: str = "auto", + soft_deleted_ids: Optional[set] = None, + ): + """ + Args: + retrievers: List of Retriever instances for multi-path retrieval + fusion_strategy: Strategy to fuse results from multiple retrievers + reranker: Optional reranker for final re-scoring + recall_multiplier: Multiply top_k by this for each retriever's recall + min_scores: Per-retriever minimum score thresholds (length must match retrievers) + filter_multiplier: Extra recall multiplier when post-filters are active + metadata_index: Optional MetadataIndex for inverted-index pre-filtering + filter_mode: "pre" (always pre-filter), "post" (always post-filter), + or "auto" (pre-filter indexable parts, post-filter callables) + soft_deleted_ids: Optional set of record_ids to exclude from results + """ + if not retrievers: + raise ValueError("At least one retriever is required") + if min_scores is not None and len(min_scores) != len(retrievers): + raise ValueError( + f"min_scores length ({len(min_scores)}) must match " + f"retrievers length ({len(retrievers)})" + ) + if filter_mode not in ("pre", "post", "auto"): + raise ValueError(f"filter_mode must be 'pre', 'post', or 'auto', got '{filter_mode}'") + self.retrievers = retrievers + self.fusion_strategy = fusion_strategy + self.reranker = reranker + self.recall_multiplier = recall_multiplier + self.min_scores = min_scores + self.filter_multiplier = filter_multiplier + self.metadata_index = metadata_index + self.filter_mode = filter_mode + self.soft_deleted_ids: set = soft_deleted_ids or set() + + def add_soft_deletes(self, record_ids: set) -> None: + """Mark record_ids as soft-deleted (excluded from results).""" + self.soft_deleted_ids |= record_ids + + def clear_soft_deletes(self) -> None: + """Clear all soft deletes (typically after a full rebuild).""" + self.soft_deleted_ids.clear() + + @property + def soft_delete_count(self) -> int: + return len(self.soft_deleted_ids) + + def format_query(self, query: str) -> str: + """Pass-through: hybrid engine doesn't transform queries""" + return query + + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: + """ + Multi-path retrieval with fusion and optional reranking. + + Args: + query: Query string + top_k: Number of final results to return + **kwargs: + filters: Dict of metadata filters (see _match_filters) + weights: List of floats to override fusion weights dynamically + + Returns: + Fused (and optionally reranked) list of results + """ + return self._retrieve_pipeline(query, top_k, **kwargs)["results"] + + def retrieve_with_details( + self, query: str, top_k: int = 5, **kwargs + ) -> Dict[str, Any]: + """ + Like retrieve(), but returns structured details of each pipeline stage. + + Returns: + Dict with keys: + results: Final top_k results + per_retriever: List of per-retriever raw results (after min_score) + fused_before_rerank: Fused results before reranking + """ + return self._retrieve_pipeline(query, top_k, **kwargs) + + def _retrieve_pipeline( + self, query: str, top_k: int, **kwargs + ) -> Dict[str, Any]: + """Core pipeline shared by retrieve() and retrieve_with_details().""" + filters = kwargs.pop("filters", None) + weights = kwargs.pop("weights", None) + + # Resolve pre-filter vs post-filter strategy + candidate_ids = None + post_filters = None + + if filters and self.metadata_index is not None: + if self.filter_mode == "pre": + candidate_ids = self.metadata_index.lookup(filters) + if candidate_ids is None: + # Has callable filters, fall back to post-filter + post_filters = filters + elif len(candidate_ids) == 0: + return { + "results": [], + "per_retriever": [[] for _ in self.retrievers], + "fused_before_rerank": [], + } + elif self.filter_mode == "post": + post_filters = filters + else: # auto + indexable, callables = self.metadata_index.classify_filters(filters) + if indexable: + candidate_ids = self.metadata_index.lookup(indexable) + if candidate_ids is not None and len(candidate_ids) == 0: + return { + "results": [], + "per_retriever": [[] for _ in self.retrievers], + "fused_before_rerank": [], + } + if callables: + post_filters = callables + elif filters: + # No metadata_index available, always post-filter + post_filters = filters + + # Compute recall amount โ€” filter_multiplier only when post-filtering + recall_k = top_k * self.recall_multiplier + if post_filters: + recall_k *= self.filter_multiplier + + # Step 1: Recall from each retriever + per_retriever: List[List[Dict[str, Any]]] = [] + for i, retriever in enumerate(self.retrievers): + try: + retriever_kwargs: Dict[str, Any] = {} + if candidate_ids is not None: + retriever_kwargs["candidate_ids"] = candidate_ids + results = retriever.retrieve(query, top_k=recall_k, **retriever_kwargs) + except Exception as e: + retriever_name = type(retriever).__name__ + logger.warning( + "Retriever %s failed, skipping: %s", retriever_name, e + ) + results = [] + + # Per-retriever min_score filtering + if self.min_scores is not None: + threshold = self.min_scores[i] + results = [r for r in results if r.get("score", 0) >= threshold] + + per_retriever.append(results) + + # Step 2: Post-filtering (only for callable/non-indexable filters) + if post_filters: + per_retriever = [ + self._apply_filters(results, post_filters) for results in per_retriever + ] + + # Step 3: Fuse results (pass dynamic weights if provided) + fuse_kwargs: Dict[str, Any] = {} + if weights is not None: + fuse_kwargs["weights"] = weights + fused = self.fusion_strategy.fuse(per_retriever, **fuse_kwargs) + + # Step 3.5: Remove soft-deleted results + if self.soft_deleted_ids: + fused = [ + r for r in fused + if r.get("metadata", {}).get("record_id") not in self.soft_deleted_ids + ] + + fused_before_rerank = list(fused) + + # Step 4: Optional reranking + if self.reranker is not None and fused: + fused = self.reranker.rerank(query, fused, top_k=top_k) + + final = fused[:top_k] + + return { + "results": final, + "per_retriever": per_retriever, + "fused_before_rerank": fused_before_rerank, + } + + @staticmethod + def _match_filters( + metadata: Optional[Dict[str, Any]], filters: Dict[str, FilterValue] + ) -> bool: + """ + Check whether a metadata dict matches all filter criteria. + + Filter syntax (all keys are AND-ed): + - str/int/float/bool โ†’ exact match + - list โ†’ match any value in the list (OR) + - callable โ†’ predicate function returning bool + + A result with missing metadata or missing a required key does NOT pass. + """ + if metadata is None: + return False + for key, condition in filters.items(): + if key not in metadata: + return False + value = metadata[key] + if callable(condition): + if not condition(value): + return False + elif isinstance(condition, list): + if value not in condition: + return False + else: + if value != condition: + return False + return True + + @classmethod + def _apply_filters( + cls, results: List[Dict[str, Any]], filters: Dict[str, FilterValue] + ) -> List[Dict[str, Any]]: + """Filter a list of results by metadata criteria.""" + return [r for r in results if cls._match_filters(r.get("metadata"), filters)] diff --git a/tinysearch/query/template.py b/tinysearch/query/template.py index 24ec1b7..a4faf63 100644 --- a/tinysearch/query/template.py +++ b/tinysearch/query/template.py @@ -54,7 +54,7 @@ def format_query(self, query: str) -> str: # If all else fails, just concatenate return f"{self.template} {query}" - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: """ Retrieve relevant chunks for a query diff --git a/tinysearch/records.py b/tinysearch/records.py new file mode 100644 index 0000000..d01db5b --- /dev/null +++ b/tinysearch/records.py @@ -0,0 +1,51 @@ +""" +Utilities for building TextChunks from structured records via RecordAdapter. +""" +import logging +from typing import Any, Dict, List, Optional + +from tinysearch.base import RecordAdapter, TextChunk, TextSplitter + +logger = logging.getLogger(__name__) + + +def build_chunks_from_records( + records: Dict[str, Dict[str, Any]], + adapter: RecordAdapter, + splitter: Optional[TextSplitter] = None, +) -> List[TextChunk]: + """ + Convert a dict of records to TextChunks ready for indexing. + + Args: + records: Mapping of record_id -> record_data + adapter: RecordAdapter that converts each record to a TextChunk + splitter: Optional TextSplitter. If None, each record becomes exactly + one TextChunk. If provided, each record's text is further + split, with the original metadata inherited by all sub-chunks. + + Returns: + Ordered list of TextChunks (order follows dict iteration order) + """ + chunks: List[TextChunk] = [] + + for record_id, record_data in records.items(): + chunk = adapter.to_chunk(record_id, record_data) + + # Ensure record_id is always in metadata + if "record_id" not in chunk.metadata: + chunk.metadata["record_id"] = record_id + + if splitter is None: + chunks.append(chunk) + else: + sub_chunks = splitter.split([chunk.text], [chunk.metadata]) + chunks.extend(sub_chunks) + + logger.info( + "Built %d chunks from %d records (splitter=%s)", + len(chunks), + len(records), + type(splitter).__name__ if splitter else "None", + ) + return chunks diff --git a/tinysearch/rerankers/__init__.py b/tinysearch/rerankers/__init__.py new file mode 100644 index 0000000..296350e --- /dev/null +++ b/tinysearch/rerankers/__init__.py @@ -0,0 +1,9 @@ +""" +Rerankers for re-scoring retrieval results +""" + +from .cross_encoder import CrossEncoderReranker + +__all__ = [ + "CrossEncoderReranker", +] diff --git a/tinysearch/rerankers/cross_encoder.py b/tinysearch/rerankers/cross_encoder.py new file mode 100644 index 0000000..592bc8b --- /dev/null +++ b/tinysearch/rerankers/cross_encoder.py @@ -0,0 +1,126 @@ +""" +Cross-encoder reranker using FlagEmbedding BGE Reranker +""" +from typing import Any, Dict, List, Optional + +from tinysearch.base import Reranker +from tinysearch.logger import get_logger + +logger = get_logger("CrossEncoderReranker") + +try: + from FlagEmbedding import FlagReranker + FLAGEMBEDDING_AVAILABLE = True +except ImportError: + FLAGEMBEDDING_AVAILABLE = False + + +class CrossEncoderReranker(Reranker): + """ + Cross-encoder reranker using BAAI/bge-reranker-v2-m3 (or compatible model). + + Uses FlagEmbedding for GPU-accelerated cross-encoder inference. + The model is lazily loaded on first use. + """ + + def __init__( + self, + model_name: str = "BAAI/bge-reranker-v2-m3", + device: Optional[str] = None, + batch_size: int = 64, + max_length: int = 512, + use_fp16: bool = True, + ): + """ + Args: + model_name: HuggingFace model name or local path + device: Device to use ("cuda", "cpu", or None for auto) + batch_size: Batch size for inference + max_length: Maximum sequence length + use_fp16: Whether to use FP16 for faster inference + """ + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.max_length = max_length + self.use_fp16 = use_fp16 + self._model = None + + def _ensure_model(self) -> None: + """Lazily load the reranker model""" + if self._model is not None: + return + + if not FLAGEMBEDDING_AVAILABLE: + raise ImportError( + "FlagEmbedding is required for CrossEncoderReranker. " + "Install with: pip install FlagEmbedding" + ) + + # Normalize device for FlagReranker + device = self.device + if device is None: + try: + import torch + device = 0 if torch.cuda.is_available() else "cpu" + except ImportError: + device = "cpu" + elif device == "cuda": + device = 0 + elif device.startswith("cuda:") and device[5:].isdigit(): + device = int(device[5:]) + + logger.info(f"Loading reranker model: {self.model_name} on device={device}") + + self._model = FlagReranker( + self.model_name, + devices=device, + batch_size=self.batch_size, + max_length=self.max_length, + use_fp16=self.use_fp16, + ) + + # Trigger full weight loading with dummy inference + try: + self._model.compute_score([["warmup", "warmup"]]) + except Exception: + pass + + logger.info("Reranker model loaded successfully") + + def rerank(self, query: str, candidates: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: + """Re-rank candidates using cross-encoder scoring""" + if not candidates: + return [] + + self._ensure_model() + + # Build query-document pairs + pairs = [] + for item in candidates: + doc_text = item.get("text", "") + pairs.append([query, doc_text]) + + # Compute scores + scores = self._model.compute_score(pairs) + + # Handle single result (compute_score returns a float instead of list) + if isinstance(scores, (int, float)): + scores = [scores] + + # Attach rerank scores + reranked = [] + for item, score in zip(candidates, scores): + reranked.append({ + **item, + "rerank_score": float(score), + }) + + # Sort by rerank score descending + reranked.sort(key=lambda x: x["rerank_score"], reverse=True) + return reranked[:top_k] + + @classmethod + def is_available(cls) -> bool: + """Check if FlagEmbedding is installed""" + return FLAGEMBEDDING_AVAILABLE diff --git a/tinysearch/retrievers/__init__.py b/tinysearch/retrievers/__init__.py new file mode 100644 index 0000000..bf71115 --- /dev/null +++ b/tinysearch/retrievers/__init__.py @@ -0,0 +1,13 @@ +""" +Retrievers for text-level search +""" + +from .vector_retriever import VectorRetriever +from .bm25_retriever import BM25Retriever +from .substring_retriever import SubstringRetriever + +__all__ = [ + "VectorRetriever", + "BM25Retriever", + "SubstringRetriever", +] diff --git a/tinysearch/retrievers/bm25_retriever.py b/tinysearch/retrievers/bm25_retriever.py new file mode 100644 index 0000000..0c47e3b --- /dev/null +++ b/tinysearch/retrievers/bm25_retriever.py @@ -0,0 +1,179 @@ +""" +BM25 Retriever - Keyword-based search using bm25s + +Fast BM25 implementation with optional jieba Chinese tokenization. +""" +import json +import pickle +import re +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +from tinysearch.base import Retriever, TextChunk +from tinysearch.logger import get_logger + +logger = get_logger("BM25Retriever") + +try: + import bm25s + BM25S_AVAILABLE = True +except ImportError: + BM25S_AVAILABLE = False + +try: + import jieba + JIEBA_AVAILABLE = True +except ImportError: + JIEBA_AVAILABLE = False + + +def _default_tokenizer(text: str) -> List[str]: + """Default tokenizer: jieba if available, else whitespace split""" + text_lower = text.lower() + if JIEBA_AVAILABLE: + return list(jieba.cut(text_lower)) + return re.findall(r'[\w]+', text_lower) + + +class BM25Retriever(Retriever): + """ + BM25 keyword-based retriever using the bm25s library. + + Features: + - Fast indexing and retrieval via bm25s + - Configurable tokenizer (default: jieba for Chinese, fallback to whitespace) + - Persistent index storage + """ + + def __init__( + self, + tokenizer: Optional[Callable[[str], List[str]]] = None, + ): + """ + Args: + tokenizer: Custom tokenizer function. Takes a string, returns list of tokens. + Defaults to jieba (if available) or whitespace splitting. + """ + if not BM25S_AVAILABLE: + logger.warning( + "bm25s not installed. BM25 retrieval will not work. " + "Install with: pip install bm25s" + ) + self.tokenizer = tokenizer or _default_tokenizer + + # Internal state + self._index: Optional[Any] = None # bm25s.BM25 + self._chunks: List[TextChunk] = [] + self._corpus_tokens: List[List[str]] = [] + + def build(self, chunks: List[TextChunk]) -> None: + """Build BM25 index from text chunks""" + if not BM25S_AVAILABLE: + raise ImportError( + "bm25s is required for BM25Retriever. Install with: pip install bm25s" + ) + if not chunks: + return + + self._chunks = list(chunks) + + # Tokenize all documents + self._corpus_tokens = [self.tokenizer(chunk.text) for chunk in chunks] + + # Build bm25s index + self._index = bm25s.BM25() + self._index.index(self._corpus_tokens) + + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: + """ + Retrieve documents using BM25 keyword matching. + + Args: + query: Query string + top_k: Number of results to return + **kwargs: + candidate_ids: Optional Set[int] of chunk indices to restrict search to + """ + if self._index is None: + return [] + + candidate_ids = kwargs.get("candidate_ids") + + # Tokenize query + query_tokens = self.tokenizer(query) + if not query_tokens: + return [] + + # When pre-filtering, over-recall then filter by candidate_ids + if candidate_ids is not None: + effective_k = min(len(self._chunks), top_k * 3) + else: + effective_k = min(top_k, len(self._chunks)) + if effective_k == 0: + return [] + + # Retrieve (no corpus โ†’ returns indices instead of documents) + result = self._index.retrieve( + [query_tokens], k=effective_k, return_as="tuple" + ) + indices = result.documents[0] + scores = result.scores[0] + + # Build result list + results = [] + for idx, score in zip(indices, scores): + idx = int(idx) + if idx < 0 or idx >= len(self._chunks): + continue + if candidate_ids is not None and idx not in candidate_ids: + continue + chunk = self._chunks[idx] + results.append({ + "text": chunk.text, + "metadata": chunk.metadata, + "score": float(score), + "retrieval_method": "bm25", + }) + if len(results) >= top_k: + break + + return results + + def save(self, path: Union[str, Path]) -> None: + """Save BM25 index to disk""" + if self._index is None: + raise ValueError("No index to save. Call build() first.") + + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + # Save bm25s index + self._index.save(str(path), corpus=self._corpus_tokens) + + # Save chunks metadata + chunks_data = [(chunk.text, chunk.metadata) for chunk in self._chunks] + with open(path / "chunks.pkl", "wb") as f: + pickle.dump(chunks_data, f) + + def load(self, path: Union[str, Path]) -> None: + """Load BM25 index from disk""" + if not BM25S_AVAILABLE: + raise ImportError( + "bm25s is required for BM25Retriever. Install with: pip install bm25s" + ) + + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"BM25 index directory not found: {path}") + + # Load bm25s index (don't load corpus - we manage chunks separately) + self._index = bm25s.BM25.load(str(path), load_corpus=False) + + # Load chunks metadata + chunks_file = path / "chunks.pkl" + if chunks_file.exists(): + with open(chunks_file, "rb") as f: + chunks_data = pickle.load(f) + self._chunks = [TextChunk(text, metadata) for text, metadata in chunks_data] + else: + self._chunks = [] diff --git a/tinysearch/retrievers/substring_retriever.py b/tinysearch/retrievers/substring_retriever.py new file mode 100644 index 0000000..97dc730 --- /dev/null +++ b/tinysearch/retrievers/substring_retriever.py @@ -0,0 +1,159 @@ +""" +Substring Retriever - Ctrl+F style exact match search + +Fast regex-based substring search. No external dependencies. +""" +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +import json +import pickle + +from tinysearch.base import Retriever, TextChunk + + +class SubstringRetriever(Retriever): + """ + Regex/substring retriever for exact match queries. + + Operates on raw text with regex or plain substring matching. + Useful for finding exact phrases, codes, or patterns. + """ + + def __init__(self, is_regex: bool = False): + """ + Args: + is_regex: If True, treat query as regex pattern. + If False, escape query for literal substring matching. + """ + self.is_regex = is_regex + self._chunks: List[TextChunk] = [] + + def build(self, chunks: List[TextChunk]) -> None: + """Store chunks in memory for substring search""" + self._chunks = list(chunks) + + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: + """ + Search chunks using regex/substring matching. + + Args: + query: Query string or regex pattern + top_k: Number of results to return + **kwargs: + candidate_ids: Optional Set[int] of chunk indices to restrict search to + """ + if not self._chunks or not query: + return [] + + candidate_ids = kwargs.get("candidate_ids") + results = [] + + try: + if self.is_regex: + pattern = re.compile(query, re.IGNORECASE) + else: + pattern = re.compile(re.escape(query), re.IGNORECASE) + + for i, chunk in enumerate(self._chunks): + if candidate_ids is not None and i not in candidate_ids: + continue + match = pattern.search(chunk.text) + if match: + score = self._calculate_match_score(match, chunk.text) + results.append({ + "text": chunk.text, + "metadata": chunk.metadata, + "score": score, + "retrieval_method": "substring", + "match_text": match.group(0), + }) + + # Collect extra then sort + if len(results) >= top_k * 2: + break + + except re.error: + # Fallback to plain substring if regex fails + results = self._plain_substring_search(query, top_k * 2) + + # Sort by score descending and take top_k + results.sort(key=lambda x: x["score"], reverse=True) + return results[:top_k] + + def _calculate_match_score(self, match: re.Match, text: str) -> float: + """Calculate match quality score""" + score = 1.0 + + # Bonus for match at start + if match.start() == 0: + score += 2.0 + + # Bonus for longer matches (max +3) + match_len = len(match.group(0)) + score += min(match_len / 10.0, 3.0) + + # Small penalty for matching deep in long text + if len(text) > 500 and match.start() > 100: + score -= 0.5 + + return score + + def _plain_substring_search(self, query: str, limit: int) -> List[Dict[str, Any]]: + """Fallback plain substring search when regex fails""" + results = [] + query_lower = query.lower() + + for chunk in self._chunks: + text_lower = chunk.text.lower() + if query_lower in text_lower: + pos = text_lower.find(query_lower) + score = 1.0 + if pos == 0: + score += 2.0 + score += min(len(query) / 10.0, 3.0) + + results.append({ + "text": chunk.text, + "metadata": chunk.metadata, + "score": score, + "retrieval_method": "substring", + "match_text": query[:50], + }) + + if len(results) >= limit: + break + + return results + + def save(self, path: Union[str, Path]) -> None: + """Save chunks to disk""" + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + chunks_data = [(chunk.text, chunk.metadata) for chunk in self._chunks] + with open(path / "chunks.pkl", "wb") as f: + pickle.dump(chunks_data, f) + + # Save config + config = {"is_regex": self.is_regex} + with open(path / "config.json", "w") as f: + json.dump(config, f) + + def load(self, path: Union[str, Path]) -> None: + """Load chunks from disk""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Substring index directory not found: {path}") + + chunks_file = path / "chunks.pkl" + if chunks_file.exists(): + with open(chunks_file, "rb") as f: + chunks_data = pickle.load(f) + self._chunks = [TextChunk(text, metadata) for text, metadata in chunks_data] + + config_file = path / "config.json" + if config_file.exists(): + with open(config_file, "r") as f: + config = json.load(f) + self.is_regex = config.get("is_regex", False) diff --git a/tinysearch/retrievers/vector_retriever.py b/tinysearch/retrievers/vector_retriever.py new file mode 100644 index 0000000..2e77c5b --- /dev/null +++ b/tinysearch/retrievers/vector_retriever.py @@ -0,0 +1,94 @@ +""" +Vector retriever - wraps Embedder + VectorIndexer as a Retriever +""" +from typing import Any, Dict, List, Optional, Union +import pathlib +import numpy as np + +from tinysearch.base import Retriever, Embedder, VectorIndexer, TextChunk + + +class VectorRetriever(Retriever): + """ + Wraps an Embedder and VectorIndexer into the Retriever interface. + + This is the bridge between the existing vector search pipeline and + the new Retriever abstraction, ensuring backward compatibility. + """ + + def __init__( + self, + embedder: Embedder, + indexer: VectorIndexer, + query_template: Optional[str] = None, + ): + """ + Args: + embedder: Embedder to convert text to vectors + indexer: VectorIndexer for similarity search + query_template: Optional template for formatting queries (e.g. "่ฏทๅธฎๆˆ‘ๆŸฅๆ‰พ๏ผš{query}") + """ + self.embedder = embedder + self.indexer = indexer + self.query_template = query_template + + def build(self, chunks: List[TextChunk]) -> None: + """Embed chunks and build the vector index""" + if not chunks: + return + texts = [chunk.text for chunk in chunks] + vectors = self.embedder.embed(texts) + self.indexer.build(vectors, chunks) + + def retrieve(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: + """ + Embed query and search the vector index. + + Args: + query: Query string + top_k: Number of results to return + **kwargs: + candidate_ids: Optional Set[int] of chunk indices to restrict search to + """ + candidate_ids = kwargs.get("candidate_ids") + + # Apply query template if configured + formatted_query = query + if self.query_template: + try: + formatted_query = self.query_template.format(query=query) + except (KeyError, ValueError): + formatted_query = query + + # Embed the query + query_vectors = self.embedder.embed([formatted_query]) + if not query_vectors: + return [] + query_vector = query_vectors[0] + + # Search (forward candidate_ids to indexer) + raw_results = self.indexer.search(query_vector, top_k, candidate_ids=candidate_ids) + + # Normalize scores to [0, 1] and add retrieval_method + results = [] + for r in raw_results: + result = { + "text": r["text"], + "metadata": r.get("metadata", {}), + "score": float(r.get("score", 0.0)), + "retrieval_method": "vector", + } + # Preserve embedding if present + if "embedding" in r: + result["embedding"] = r["embedding"] + results.append(result) + + return results + + def save(self, path: Union[str, pathlib.Path]) -> None: + """Save the vector index to disk""" + self.indexer.save(path) + + def load(self, path: Union[str, pathlib.Path]) -> None: + """Load the vector index from disk""" + self.indexer.load(path) diff --git a/tinysearch/utils/__init__.py b/tinysearch/utils/__init__.py new file mode 100644 index 0000000..e8307ab --- /dev/null +++ b/tinysearch/utils/__init__.py @@ -0,0 +1,6 @@ +""" +Utility modules for TinySearch. +""" +from tinysearch.utils.file_discovery import iter_input_files, ADAPTER_EXTENSIONS + +__all__ = ["iter_input_files", "ADAPTER_EXTENSIONS"] diff --git a/tinysearch/utils/file_discovery.py b/tinysearch/utils/file_discovery.py new file mode 100644 index 0000000..b5c0d77 --- /dev/null +++ b/tinysearch/utils/file_discovery.py @@ -0,0 +1,55 @@ +""" +Centralized file discovery for directory-based indexing. +""" +import logging +from pathlib import Path +from typing import Iterator, List, Optional + +logger = logging.getLogger(__name__) + +# Default file extensions per adapter type +ADAPTER_EXTENSIONS = { + "text": [".txt", ".text", ".md", ".py", ".js", ".html", ".css", ".json"], + "pdf": [".pdf"], + "csv": [".csv"], + "markdown": [".md", ".markdown", ".mdown", ".mkdn"], + "json": [".json"], +} + + +def iter_input_files( + data_path: Path, + adapter_type: str = "text", + extensions: Optional[List[str]] = None, + recursive: bool = True, +) -> Iterator[Path]: + """ + Discover files under a path that a given adapter can process. + + Args: + data_path: File or directory path + adapter_type: Adapter type name, used to look up default extensions + extensions: Custom extension list (overrides defaults) + recursive: Whether to recurse into subdirectories + + Yields: + Path objects in sorted order for deterministic results + """ + data_path = Path(data_path) + + if not data_path.exists(): + raise FileNotFoundError(f"Path not found: {data_path}") + + if data_path.is_file(): + yield data_path + return + + # Directory: filter by extensions + allowed = set( + ext.lower() for ext in (extensions or ADAPTER_EXTENSIONS.get(adapter_type, [])) + ) + + pattern = data_path.rglob("*") if recursive else data_path.iterdir() + for child in sorted(pattern): + if child.is_file() and (not allowed or child.suffix.lower() in allowed): + yield child