diff --git a/.gitignore b/.gitignore index cd8005815..274ab5e13 100644 --- a/.gitignore +++ b/.gitignore @@ -234,3 +234,4 @@ tests/data # Local working directory (personal scripts, docs, tools) local/ +local_docs/ diff --git a/CLAUDE.md b/CLAUDE.md index 09ab66439..9f22e9b93 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -48,6 +48,9 @@ index = SearchIndex(schema, redis_url="redis://localhost:6379") token.strip().strip(",").replace(""", "").replace(""", "").lower() ``` +### Protected Directories +**CRITICAL**: NEVER delete the `local_docs/` directory or any files within it. + ### Git Operations **CRITICAL**: NEVER use `git push` or attempt to push to remote repositories. The user will handle all git push operations. diff --git a/docs/api/cli.rst b/docs/api/cli.rst new file mode 100644 index 000000000..d13ff5c9b --- /dev/null +++ b/docs/api/cli.rst @@ -0,0 +1,614 @@ +********************** +Command Line Interface +********************** + +RedisVL provides a command line interface (CLI) called ``rvl`` for managing vector search indices. The CLI enables you to create, inspect, and delete indices directly from your terminal without writing Python code. + +Installation +============ + +The ``rvl`` command is included when you install RedisVL. + +.. code-block:: bash + + pip install redisvl + +Verify the installation by running: + +.. code-block:: bash + + rvl version + +Connection Configuration +======================== + +The CLI connects to Redis using the following resolution order: + +1. The ``REDIS_URL`` environment variable, if set +2. Explicit connection flags (``--host``, ``--port``, ``--url``) +3. Default values (``localhost:6379``) + +**Connection Flags** + +All commands that interact with Redis accept these optional flags: + +.. list-table:: + :widths: 20 15 50 15 + :header-rows: 1 + + * - Flag + - Type + - Description + - Default + * - ``-u``, ``--url`` + - string + - Full Redis URL (e.g., ``redis://localhost:6379``) + - None + * - ``--host`` + - string + - Redis server hostname + - ``localhost`` + * - ``-p``, ``--port`` + - integer + - Redis server port + - ``6379`` + * - ``--user`` + - string + - Redis username for authentication + - ``default`` + * - ``-a``, ``--password`` + - string + - Redis password for authentication + - Empty + * - ``--ssl`` + - flag + - Enable SSL/TLS encryption + - Disabled + +**Examples** + +Connect using environment variable: + +.. code-block:: bash + + export REDIS_URL="redis://localhost:6379" + rvl index listall + +Connect with explicit host and port: + +.. code-block:: bash + + rvl index listall --host myredis.example.com --port 6380 + +Connect with authentication and SSL: + +.. code-block:: bash + + rvl index listall --user admin --password secret --ssl + +Getting Help +============ + +All commands support the ``-h`` and ``--help`` flags to display usage information. + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Flag + - Description + * - ``-h``, ``--help`` + - Display usage information for the command + +**Examples** + +.. code-block:: bash + + # Display top-level help + rvl --help + + # Display help for a command group + rvl index --help + + # Display help for a specific subcommand + rvl index create --help + +Running ``rvl`` without any arguments also displays the top-level help message. + +.. tip:: + + For a hands-on tutorial with practical examples, see the :doc:`/user_guide/cli`. + +Commands +======== + +rvl version +----------- + +Display the installed RedisVL version. + +**Syntax** + +.. code-block:: bash + + rvl version [OPTIONS] + +**Options** + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Option + - Description + * - ``-s``, ``--short`` + - Print only the version number without additional formatting + +**Examples** + +.. code-block:: bash + + # Full version output + rvl version + + # Version number only + rvl version --short + +rvl index +--------- + +Manage vector search indices. This command group provides subcommands for creating, inspecting, listing, and removing indices. + +**Syntax** + +.. code-block:: bash + + rvl index [OPTIONS] + +**Subcommands** + +.. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - Subcommand + - Description + * - ``create`` + - Create a new index from a YAML schema file + * - ``info`` + - Display detailed information about an index + * - ``listall`` + - List all existing indices in the Redis instance + * - ``delete`` + - Remove an index while preserving the underlying data + * - ``destroy`` + - Remove an index and delete all associated data + +rvl index create +^^^^^^^^^^^^^^^^ + +Create a new vector search index from a YAML schema definition. + +**Syntax** + +.. code-block:: bash + + rvl index create -s [CONNECTION_OPTIONS] + +**Required Options** + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Option + - Description + * - ``-s``, ``--schema`` + - Path to the YAML schema file defining the index structure + +**Example** + +.. code-block:: bash + + rvl index create -s schema.yaml + +**Schema File Format** + +The schema file must be valid YAML with the following structure: + +.. code-block:: yaml + + version: '0.1.0' + + index: + name: my_index + prefix: doc + storage_type: hash + + fields: + - name: content + type: text + - name: embedding + type: vector + attrs: + dims: 768 + algorithm: hnsw + distance_metric: cosine + +rvl index info +^^^^^^^^^^^^^^ + +Display detailed information about an existing index, including field definitions and index options. + +**Syntax** + +.. code-block:: bash + + rvl index info (-i | -s ) [OPTIONS] + +**Options** + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Option + - Description + * - ``-i``, ``--index`` + - Name of the index to inspect + * - ``-s``, ``--schema`` + - Path to the schema file (alternative to specifying index name) + +**Example** + +.. code-block:: bash + + rvl index info -i my_index + +**Output** + +The command displays two tables: + +1. **Index Information** containing the index name, storage type, key prefixes, index options, and indexing status +2. **Index Fields** listing each field with its name, attribute, type, and any additional field options + +rvl index listall +^^^^^^^^^^^^^^^^^ + +List all vector search indices in the connected Redis instance. + +**Syntax** + +.. code-block:: bash + + rvl index listall [CONNECTION_OPTIONS] + +**Example** + +.. code-block:: bash + + rvl index listall + +**Output** + +Returns a numbered list of all index names: + +.. code-block:: text + + Indices: + 1. products_index + 2. documents_index + 3. embeddings_index + +rvl index delete +^^^^^^^^^^^^^^^^ + +Remove an index from Redis while preserving the underlying data. Use this when you want to rebuild an index with a different schema without losing your data. + +**Syntax** + +.. code-block:: bash + + rvl index delete (-i | -s ) [CONNECTION_OPTIONS] + +**Options** + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Option + - Description + * - ``-i``, ``--index`` + - Name of the index to delete + * - ``-s``, ``--schema`` + - Path to the schema file (alternative to specifying index name) + +**Example** + +.. code-block:: bash + + rvl index delete -i my_index + +rvl index destroy +^^^^^^^^^^^^^^^^^ + +Remove an index and permanently delete all associated data from Redis. This operation cannot be undone. + +**Syntax** + +.. code-block:: bash + + rvl index destroy (-i | -s ) [CONNECTION_OPTIONS] + +**Options** + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Option + - Description + * - ``-i``, ``--index`` + - Name of the index to destroy + * - ``-s``, ``--schema`` + - Path to the schema file (alternative to specifying index name) + +**Example** + +.. code-block:: bash + + rvl index destroy -i my_index + +.. warning:: + + This command permanently deletes both the index and all documents stored with the index prefix. Ensure you have backups before running this command. + +rvl stats +--------- + +Display statistics about an existing index, including document counts, memory usage, and indexing performance metrics. + +**Syntax** + +.. code-block:: bash + + rvl stats (-i | -s ) [OPTIONS] + +**Options** + +.. list-table:: + :widths: 25 75 + :header-rows: 1 + + * - Option + - Description + * - ``-i``, ``--index`` + - Name of the index to query + * - ``-s``, ``--schema`` + - Path to the schema file (alternative to specifying index name) + +**Example** + +.. code-block:: bash + + rvl stats -i my_index + +**Statistics Reference** + +The command returns the following metrics: + +.. list-table:: + :widths: 35 65 + :header-rows: 1 + + * - Metric + - Description + * - ``num_docs`` + - Total number of indexed documents + * - ``num_terms`` + - Number of distinct terms in text fields + * - ``max_doc_id`` + - Highest internal document ID + * - ``num_records`` + - Total number of index records + * - ``percent_indexed`` + - Percentage of documents fully indexed + * - ``hash_indexing_failures`` + - Number of documents that failed to index + * - ``number_of_uses`` + - Number of times the index has been queried + * - ``bytes_per_record_avg`` + - Average bytes per index record + * - ``doc_table_size_mb`` + - Document table size in megabytes + * - ``inverted_sz_mb`` + - Inverted index size in megabytes + * - ``key_table_size_mb`` + - Key table size in megabytes + * - ``offset_bits_per_record_avg`` + - Average offset bits per record + * - ``offset_vectors_sz_mb`` + - Offset vectors size in megabytes + * - ``offsets_per_term_avg`` + - Average offsets per term + * - ``records_per_doc_avg`` + - Average records per document + * - ``sortable_values_size_mb`` + - Sortable values size in megabytes + * - ``total_indexing_time`` + - Total time spent indexing in milliseconds + * - ``total_inverted_index_blocks`` + - Number of inverted index blocks + * - ``vector_index_sz_mb`` + - Vector index size in megabytes + +rvl migrate +----------- + +Manage document-preserving index migrations. This command group provides subcommands for planning, executing, and validating schema migrations that preserve existing data. + +**Syntax** + +.. code-block:: bash + + rvl migrate [OPTIONS] + +**Subcommands** + +.. list-table:: + :widths: 20 80 + :header-rows: 1 + + * - Subcommand + - Description + * - ``helper`` + - Show migration guidance and supported capabilities + * - ``list`` + - List all available indexes + * - ``plan`` + - Generate a migration plan from a schema patch or target schema + * - ``wizard`` + - Interactively build a migration plan and schema patch + * - ``apply`` + - Execute a reviewed drop/recreate migration plan + * - ``estimate`` + - Estimate disk space required for a migration (dry-run) + * - ``validate`` + - Validate a completed migration against the live index + * - ``batch-plan`` + - Generate a batch migration plan for multiple indexes + * - ``batch-apply`` + - Execute a batch migration plan with checkpointing + * - ``batch-resume`` + - Resume an interrupted batch migration + * - ``batch-status`` + - Show status of an in-progress or completed batch migration + +rvl migrate plan +^^^^^^^^^^^^^^^^ + +Generate a migration plan for a document-preserving drop/recreate migration. + +**Syntax** + +.. code-block:: bash + + rvl migrate plan --index (--schema-patch | --target-schema ) [OPTIONS] + +**Required Options** + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Option + - Description + * - ``--index``, ``-i`` + - Name of the source index to migrate + * - ``--schema-patch`` + - Path to a YAML schema patch file (mutually exclusive with ``--target-schema``) + * - ``--target-schema`` + - Path to a full target schema YAML file (mutually exclusive with ``--schema-patch``) + +**Optional Options** + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Option + - Description + * - ``--plan-out`` + - Output path for the migration plan YAML (default: ``migration_plan.yaml``) + +**Example** + +.. code-block:: bash + + rvl migrate plan -i my_index --schema-patch changes.yaml --plan-out plan.yaml + +rvl migrate apply +^^^^^^^^^^^^^^^^^ + +Execute a reviewed drop/recreate migration plan. Use ``--async`` for large migrations involving vector quantization. + +**Syntax** + +.. code-block:: bash + + rvl migrate apply --plan [OPTIONS] + +**Required Options** + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Option + - Description + * - ``--plan`` + - Path to the migration plan YAML file + +**Optional Options** + +.. list-table:: + :widths: 30 70 + :header-rows: 1 + + * - Option + - Description + * - ``--async`` + - Run migration asynchronously (recommended for large quantization jobs) + * - ``--query-check-file`` + - Path to a YAML file with post-migration query checks + * - ``--resume`` + - Path to a checkpoint file for crash-safe recovery + +**Example** + +.. code-block:: bash + + rvl migrate apply --plan plan.yaml + rvl migrate apply --plan plan.yaml --async --resume checkpoint.yaml + +rvl migrate wizard +^^^^^^^^^^^^^^^^^^ + +Interactively build a schema patch and migration plan through a guided wizard. + +**Syntax** + +.. code-block:: bash + + rvl migrate wizard [--index ] [OPTIONS] + +**Example** + +.. code-block:: bash + + rvl migrate wizard -i my_index --plan-out plan.yaml + +Exit Codes +========== + +The CLI returns the following exit codes: + +.. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - Code + - Description + * - ``0`` + - Command completed successfully + * - ``1`` + - Command failed due to missing required arguments or invalid input + +Related Resources +================= + +- :doc:`/user_guide/cli` for a tutorial-style walkthrough +- :doc:`schema` for YAML schema format details +- :doc:`searchindex` for the Python ``SearchIndex`` API + diff --git a/docs/concepts/field-attributes.md b/docs/concepts/field-attributes.md index c7764a4a7..73b0d4cf1 100644 --- a/docs/concepts/field-attributes.md +++ b/docs/concepts/field-attributes.md @@ -267,7 +267,7 @@ Key vector attributes: - `dims`: Vector dimensionality (required) - `algorithm`: `flat`, `hnsw`, or `svs-vamana` - `distance_metric`: `COSINE`, `L2`, or `IP` -- `datatype`: `float16`, `float32`, `float64`, or `bfloat16` +- `datatype`: Vector precision (see table below) - `index_missing`: Allow searching for documents without vectors ```yaml @@ -281,6 +281,48 @@ Key vector attributes: index_missing: true # Handle documents without embeddings ``` +### Vector Datatypes + +The `datatype` attribute controls how vector components are stored. Smaller datatypes reduce memory usage but may affect precision. + +| Datatype | Bits | Memory (768 dims) | Use Case | +|----------|------|-------------------|----------| +| `float32` | 32 | 3 KB | Default. Best precision for most applications. | +| `float16` | 16 | 1.5 KB | Good balance of memory and precision. Recommended for large-scale deployments. | +| `bfloat16` | 16 | 1.5 KB | Better dynamic range than float16. Useful when embeddings have large value ranges. | +| `float64` | 64 | 6 KB | Maximum precision. Rarely needed. | +| `int8` | 8 | 768 B | Integer quantization. Significant memory savings with some precision loss. | +| `uint8` | 8 | 768 B | Unsigned integer quantization. For embeddings with non-negative values. | + +**Algorithm Compatibility:** + +| Datatype | FLAT | HNSW | SVS-VAMANA | +|----------|------|------|------------| +| `float32` | Yes | Yes | Yes | +| `float16` | Yes | Yes | Yes | +| `bfloat16` | Yes | Yes | No | +| `float64` | Yes | Yes | No | +| `int8` | Yes | Yes | No | +| `uint8` | Yes | Yes | No | + +**Choosing a Datatype:** + +- **Start with `float32`** unless you have memory constraints +- **Use `float16`** for production systems with millions of vectors (50% memory savings, minimal precision loss) +- **Use `int8`/`uint8`** only after benchmarking recall on your specific dataset +- **SVS-VAMANA users**: Must use `float16` or `float32` + +**Quantization with the Migrator:** + +You can change vector datatypes on existing indexes using the migration wizard: + +```bash +rvl migrate wizard --index my_index --url redis://localhost:6379 +# Select "Update field" > choose vector field > change datatype +``` + +The migrator automatically re-encodes stored vectors to the new precision. See {doc}`/user_guide/how_to_guides/migrate-indexes` for details. + ## Redis-Specific Subtleties ### Modifier Ordering @@ -304,6 +346,54 @@ Not all attributes work with all field types: | `unf` | ✓ | ✗ | ✓ | ✗ | ✗ | | `withsuffixtrie` | ✓ | ✓ | ✗ | ✗ | ✗ | +### Migration Support + +The migration wizard (`rvl migrate wizard`) supports updating field attributes on existing indexes. The table below shows which attributes can be updated via the wizard vs requiring manual schema patch editing. + +**Wizard Prompts:** + +| Attribute | Text | Tag | Numeric | Geo | Vector | +|-----------|------|-----|---------|-----|--------| +| `sortable` | Wizard | Wizard | Wizard | Wizard | N/A | +| `index_missing` | Wizard | Wizard | Wizard | Wizard | N/A | +| `index_empty` | Wizard | Wizard | N/A | N/A | N/A | +| `no_index` | Wizard | Wizard | Wizard | Wizard | N/A | +| `unf` | Wizard* | N/A | Wizard* | N/A | N/A | +| `separator` | N/A | Wizard | N/A | N/A | N/A | +| `case_sensitive` | N/A | Wizard | N/A | N/A | N/A | +| `no_stem` | Wizard | N/A | N/A | N/A | N/A | +| `weight` | Wizard | N/A | N/A | N/A | N/A | +| `algorithm` | N/A | N/A | N/A | N/A | Wizard | +| `datatype` | N/A | N/A | N/A | N/A | Wizard | +| `distance_metric` | N/A | N/A | N/A | N/A | Wizard | +| `m`, `ef_construction` | N/A | N/A | N/A | N/A | Wizard | + +*\* `unf` is only prompted when `sortable` is enabled.* + +**Manual Schema Patch Required:** + +| Attribute | Notes | +|-----------|-------| +| `withsuffixtrie` | Suffix/contains search optimization | + +*Note: `phonetic_matcher` is supported by the wizard for text fields.* + +**Example manual patch** for adding `index_missing` to a field: + +```yaml +# schema_patch.yaml +version: 1 +changes: + update_fields: + - name: category + attrs: + index_missing: true +``` + +```bash +rvl migrate plan --index my_index --schema-patch schema_patch.yaml +``` + ### JSON Path for Nested Fields When using JSON storage, use the `path` attribute to index nested fields: diff --git a/docs/concepts/index-migrations.md b/docs/concepts/index-migrations.md new file mode 100644 index 000000000..fb12d75ca --- /dev/null +++ b/docs/concepts/index-migrations.md @@ -0,0 +1,255 @@ +--- +myst: + html_meta: + "description lang=en": | + Learn how RedisVL index migrations work and which schema changes are supported. +--- + +# Index Migrations + +Redis Search indexes are immutable. To change an index schema, you must drop the existing index and create a new one. RedisVL provides a migration workflow that automates this process while preserving your data. + +This page explains how migrations work and which changes are supported. For step by step instructions, see the [migration guide](../user_guide/how_to_guides/migrate-indexes.md). + +## Supported and blocked changes + +The migrator classifies schema changes into two categories: + +| Change | Status | +|--------|--------| +| Add or remove a field | Supported | +| Rename a field | Supported | +| Change field options (sortable, separator) | Supported | +| Change key prefix | Supported | +| Rename the index | Supported | +| Change vector algorithm (FLAT, HNSW, SVS-VAMANA) | Supported | +| Change distance metric (COSINE, L2, IP) | Supported | +| Tune algorithm parameters (M, EF_CONSTRUCTION) | Supported | +| Quantize vectors (float32 to float16/bfloat16/int8/uint8) | Supported | +| Change vector dimensions | Blocked | +| Change storage type (hash to JSON) | Blocked | +| Add a new vector field | Blocked | + +**Note:** INT8 and UINT8 vector datatypes require Redis 8.0+. SVS-VAMANA algorithm requires Redis 8.2+ and Intel AVX-512 hardware. + +**Supported** changes can be applied automatically using `rvl migrate`. The migrator handles the index rebuild and any necessary data transformations. + +**Blocked** changes require manual intervention because they involve incompatible data formats or missing data. The migrator will reject these changes and explain why. + +## How the migrator works + +The migrator uses a plan first workflow: + +1. **Plan**: Capture the current schema, classify your changes, and generate a migration plan +2. **Review**: Inspect the plan before making any changes +3. **Apply**: Drop the index, transform data if needed, and recreate with the new schema +4. **Validate**: Verify the result matches expectations + +This separation ensures you always know what will happen before any changes are made. + +## Migration mode: drop_recreate + +The `drop_recreate` mode rebuilds the index in place while preserving your documents. + +The process: + +1. Drop only the index structure (documents remain in Redis) +2. For datatype changes, re-encode vectors to the target precision +3. Recreate the index with the new schema +4. Wait for Redis to re-index the existing documents +5. Validate the result + +**Tradeoff**: The index is unavailable during the rebuild. Review the migration plan carefully before applying. + +## Index only vs document dependent changes + +Schema changes fall into two categories based on whether they require modifying stored data. + +**Index only changes** affect how Redis Search indexes data, not the data itself: + +- Algorithm changes: The stored vector bytes are identical. Only the index structure differs. +- Distance metric changes: Same vectors, different similarity calculation. +- Adding or removing fields: The documents already contain the data. The index just starts or stops indexing it. + +These changes complete quickly because they only require rebuilding the index. + +**Document dependent changes** require modifying the stored data: + +- Datatype changes (float32 to float16): Stored vector bytes must be re-encoded. +- Field renames: Stored field names must be updated in every document. +- Dimension changes: Vectors must be re-embedded with a different model. + +The migrator handles datatype changes and field renames automatically. Dimension changes are blocked because they require re-embedding with a different model (application level logic). + +## Vector quantization + +Changing vector precision from float32 to float16 reduces memory usage at the cost of slight precision loss. The migrator handles this automatically by: + +1. Reading all vectors from Redis +2. Converting to the target precision +3. Writing updated vectors back +4. Recreating the index with the new schema + +Typical reductions: + +| Metric | Value | +|--------|-------| +| Index size reduction | ~50% | +| Memory reduction | ~35% | + +Quantization time is proportional to document count. Plan for downtime accordingly. + +## Why some changes are blocked + +### Vector dimension changes + +Vector dimensions are determined by your embedding model. A 384 dimensional vector from one model is mathematically incompatible with a 768 dimensional index expecting vectors from a different model. There is no way to resize an embedding. + +**Resolution**: Re-embed your documents using the new model and load them into a new index. + +### Storage type changes + +Hash and JSON have different data layouts. Hash stores flat key value pairs. JSON stores nested structures. Converting between them requires understanding your schema and restructuring each document. + +**Resolution**: Export your data, transform it to the new format, and reload into a new index. + +### Adding a vector field + +Adding a vector field means all existing documents need vectors for that field. The migrator cannot generate these vectors because it does not know which embedding model to use or what content to embed. + +**Resolution**: Add vectors to your documents using your application, then run the migration. + +## Downtime considerations + +With `drop_recreate`, your index is unavailable between the drop and when re-indexing completes. + +**CRITICAL**: Downtime requires both reads AND writes to be paused: + +| Requirement | Reason | +|-------------|--------| +| **Pause reads** | Index is unavailable during migration | +| **Pause writes** | Redis updates indexes synchronously. Writes during migration may conflict with vector re-encoding or be missed | + +Plan for: + +- Search unavailability during the migration window +- Partial results while indexing is in progress +- Resource usage from the re-indexing process +- Quantization time if changing vector datatypes + +The duration depends on document count, field count, and vector dimensions. For large indexes, consider running migrations during low traffic periods. + +## Sync vs async execution + +The migrator provides both synchronous and asynchronous execution modes. + +### What becomes async and what stays sync + +The migration workflow has distinct phases. Here is what each mode affects: + +| Phase | Sync mode | Async mode | Notes | +|-------|-----------|------------|-------| +| **Plan generation** | `MigrationPlanner.create_plan()` | `AsyncMigrationPlanner.create_plan()` | Reads index metadata from Redis | +| **Schema snapshot** | Sync Redis calls | Async Redis calls | Single `FT.INFO` command | +| **Enumeration** | FT.AGGREGATE (or SCAN fallback) | FT.AGGREGATE (or SCAN fallback) | Before drop, only if quantization needed | +| **Drop index** | `index.delete()` | `await index.delete()` | Single `FT.DROPINDEX` command | +| **Quantization** | Sequential HGET + HSET | Sequential HGET + batched HSET | Uses pre-enumerated keys | +| **Create index** | `index.create()` | `await index.create()` | Single `FT.CREATE` command | +| **Readiness polling** | `time.sleep()` loop | `asyncio.sleep()` loop | Polls `FT.INFO` until indexed | +| **Validation** | Sync Redis calls | Async Redis calls | Schema and doc count checks | +| **CLI interaction** | Always sync | Always sync | User prompts, file I/O | +| **YAML read/write** | Always sync | Always sync | Local filesystem only | + +### When to use sync (default) + +Sync execution is simpler and sufficient for most migrations: + +- Small to medium indexes (under 100K documents) +- Index-only changes (algorithm, distance metric, field options) +- Interactive CLI usage where blocking is acceptable + +For migrations without quantization, the Redis operations are fast single commands. Sync mode adds no meaningful overhead. + +### When to use async + +Async execution (`--async` flag) provides benefits in specific scenarios: + +**Large quantization jobs (1M+ vectors)** + +Converting float32 to float16 requires reading every vector, converting it, and writing it back. The async executor: + +- Enumerates documents using `FT.AGGREGATE WITHCURSOR` for index-specific enumeration (falls back to `SCAN` only if indexing failures exist) +- Pipelines `HSET` operations in batches (100-1000 operations per pipeline is optimal for Redis) +- Yields to the event loop between batches so other tasks can proceed + +**Large keyspaces (40M+ keys)** + +When your Redis instance has many keys and the index has indexing failures (requiring SCAN fallback), async mode yields between batches. + +**Async application integration** + +If your application uses asyncio, you can integrate migration directly: + +```python +import asyncio +from redisvl.migration import AsyncMigrationPlanner, AsyncMigrationExecutor + +async def migrate(): + planner = AsyncMigrationPlanner() + plan = await planner.create_plan("myindex", redis_url="redis://localhost:6379") + + executor = AsyncMigrationExecutor() + report = await executor.apply(plan, redis_url="redis://localhost:6379") + +asyncio.run(migrate()) +``` + +### Why async helps with quantization + +The migrator uses an optimized enumeration strategy: + +1. **Index-based enumeration**: Uses `FT.AGGREGATE WITHCURSOR` to enumerate only indexed documents (not the entire keyspace) +2. **Fallback for safety**: If the index has indexing failures (`hash_indexing_failures > 0`), falls back to `SCAN` to ensure completeness +3. **Enumerate before drop**: Captures the document list while the index still exists, then drops and quantizes + +This optimization provides 10-1000x speedup for sparse indexes (where only a small fraction of prefix-matching keys are indexed). + +**Sync quantization:** +``` +enumerate keys (FT.AGGREGATE or SCAN) -> store list +for each batch of 500 keys: + for each key: + HGET field (blocks) + convert array + pipeline.HSET(field, new_bytes) + pipeline.execute() (blocks) +``` + +**Async quantization:** +``` +enumerate keys (FT.AGGREGATE or SCAN) -> store list +for each batch of 500 keys: + for each key: + await HGET field (yields) + convert array + pipeline.HSET(field, new_bytes) + await pipeline.execute() (yields) +``` + +Each `await` is a yield point where other coroutines can run. For millions of vectors, this prevents your application from freezing. + +### What async does NOT improve + +Async execution does not reduce: + +- **Total migration time**: Same work, different scheduling +- **Redis server load**: Same commands execute on the server +- **Downtime window**: Index remains unavailable during rebuild +- **Network round trips**: Same number of Redis calls + +The benefit is application responsiveness, not faster migration. + +## Learn more + +- [Migration guide](../user_guide/how_to_guides/migrate-indexes.md): Step by step instructions +- [Search and indexing](search-and-indexing.md): How Redis Search indexes work diff --git a/docs/concepts/index.md b/docs/concepts/index.md index 0e522b1a2..02f4d8b01 100644 --- a/docs/concepts/index.md +++ b/docs/concepts/index.md @@ -26,6 +26,13 @@ How RedisVL components connect: schemas, indexes, queries, and extensions. Schemas, fields, documents, storage types, and query patterns. ::: +:::{grid-item-card} 🔄 Index Migrations +:link: index-migrations +:link-type: doc + +How RedisVL handles migration planning, rebuilds, and future shadow migration. +::: + :::{grid-item-card} 🏷️ Field Attributes :link: field-attributes :link-type: doc @@ -62,6 +69,7 @@ Pre-built patterns: caching, message history, and semantic routing. architecture search-and-indexing +index-migrations field-attributes queries utilities diff --git a/docs/concepts/search-and-indexing.md b/docs/concepts/search-and-indexing.md index b4fe69569..5312d7dfb 100644 --- a/docs/concepts/search-and-indexing.md +++ b/docs/concepts/search-and-indexing.md @@ -106,9 +106,14 @@ To change a schema, you create a new index with the updated configuration, reind Planning your schema carefully upfront reduces the need for migrations, but the capability exists when requirements evolve. ---- +RedisVL now includes a dedicated migration workflow for this lifecycle: + +- `drop_recreate` for document-preserving rebuilds, including vector quantization (`float32` → `float16`) -**Related concepts:** {doc}`field-attributes` explains how to configure field options like `sortable` and `index_missing`. {doc}`queries` covers the different query types available. +That means schema evolution is no longer only a manual operational pattern. It is also a product surface in RedisVL with a planner, CLI, and validation artifacts. + +--- -**Learn more:** {doc}`/user_guide/01_getting_started` walks through building your first index. {doc}`/user_guide/05_hash_vs_json` compares storage options in depth. {doc}`/user_guide/02_complex_filtering` covers query composition. +**Related concepts:** {doc}`field-attributes` explains how to configure field options like `sortable` and `index_missing`. {doc}`queries` covers the different query types available. {doc}`index-migrations` explains migration modes, supported changes, and architecture. +**Learn more:** {doc}`/user_guide/01_getting_started` walks through building your first index. {doc}`/user_guide/05_hash_vs_json` compares storage options in depth. {doc}`/user_guide/02_complex_filtering` covers query composition. {doc}`/user_guide/how_to_guides/migrate-indexes` shows how to use the migration CLI in practice. diff --git a/docs/user_guide/cli.ipynb b/docs/user_guide/cli.ipynb index ba9d645a3..060b3bce7 100644 --- a/docs/user_guide/cli.ipynb +++ b/docs/user_guide/cli.ipynb @@ -6,7 +6,7 @@ "source": [ "# The RedisVL CLI\n", "\n", - "RedisVL is a Python library with a dedicated CLI to help load and create vector search indices within Redis.\n", + "RedisVL is a Python library with a dedicated CLI to help load, inspect, migrate, and create vector search indices within Redis.\n", "\n", "This notebook will walk through how to use the Redis Vector Library CLI (``rvl``).\n", "\n", @@ -50,7 +50,16 @@ "| `rvl index` | `delete --index` or `-i ` | remove the specified index, leaving the data still in Redis|\n", "| `rvl index` | `destroy --index` or `-i `| remove the specified index, as well as the associated data|\n", "| `rvl stats` | `--index` or `-i ` | display the index statistics, including number of docs, average bytes per record, indexing time, etc|\n", - "| `rvl stats` | `--schema` or `-s ` | display the index statistics of a schema defined in . The index must have already been created within Redis|" + "| `rvl stats` | `--schema` or `-s ` | display the index statistics of a schema defined in . The index must have already been created within Redis|\n", + "| `rvl migrate` | `helper` or `list` | show migration guidance and list indexes available for migration|\n", + "| `rvl migrate` | `wizard` | interactively build a migration plan and schema patch|\n", + "| `rvl migrate` | `plan` | generate `migration_plan.yaml` from a patch or target schema|\n", + "| `rvl migrate` | `apply` | execute a reviewed `drop_recreate` migration|\n", + "| `rvl migrate` | `validate` | validate a completed migration and emit report artifacts|\n", + "| `rvl migrate` | `batch-plan` | create a batch migration plan for multiple indexes|\n", + "| `rvl migrate` | `batch-apply` | execute a batch migration|\n", + "| `rvl migrate` | `batch-resume` | resume an interrupted batch migration|\n", + "| `rvl migrate` | `batch-status` | check batch migration progress|" ] }, { @@ -355,6 +364,35 @@ "!rvl stats -i vectorizers" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Migrate\n", + "\n", + "The ``rvl migrate`` command provides a full workflow for changing index schemas without losing data. Common use cases include vector quantization (float32 → float16), algorithm changes (HNSW → FLAT), and adding/removing fields.\n", + "\n", + "```bash\n", + "# List available indexes\n", + "rvl migrate list --url redis://localhost:6379\n", + "\n", + "# Build a migration plan interactively\n", + "rvl migrate wizard --index myindex --url redis://localhost:6379\n", + "\n", + "# Or generate from a schema patch file\n", + "rvl migrate plan --index myindex --schema-patch patch.yaml --url redis://localhost:6379\n", + "\n", + "# Apply with backup and multi-worker quantization\n", + "rvl migrate apply --plan migration_plan.yaml --url redis://localhost:6379 \\\n", + " --backup-dir /tmp/backups --workers 4 --batch-size 500\n", + "\n", + "# Validate the result\n", + "rvl migrate validate --plan migration_plan.yaml --url redis://localhost:6379\n", + "```\n", + "\n", + "See the [Migration Guide](how_to_guides/migrate-indexes.md) for detailed usage, performance tuning, and examples." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -374,15 +412,6 @@ }, { "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Choosing your Redis instance\n", - "By default rvl first checks if you have `REDIS_URL` environment variable defined and tries to connect to that. If not, it then falls back to `localhost:6379`, unless you pass the `--host` or `--port` arguments" - ] - }, - { - "cell_type": "code", - "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2026-02-16T15:58:08.651332Z", @@ -391,33 +420,23 @@ "shell.execute_reply": "2026-02-16T15:58:10.874011Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Indices:\n", - "1. vectorizers\n" - ] - } - ], "source": [ - "# specify your Redis instance to connect to\n", - "!rvl index listall --host localhost --port 6379" + "### Choosing your Redis instance\n", + "By default rvl first checks if you have `REDIS_URL` environment variable defined and tries to connect to that. If not, it then falls back to `localhost:6379`, unless you pass the `--host` or `--port` arguments" ] }, { - "cell_type": "markdown", + "cell_type": "code", "metadata": {}, "source": [ - "### Using SSL encryption\n", - "If your Redis instance is configured to use SSL encryption then set the `--ssl` flag.\n", - "You can similarly specify the username and password to construct the full Redis URL" - ] + "# specify your Redis instance to connect to\n", + "!rvl index listall --host localhost --port 6379" + ], + "outputs": [], + "execution_count": null }, { - "cell_type": "code", - "execution_count": 12, + "cell_type": "markdown", "metadata": { "execution": { "iopub.execute_input": "2026-02-16T15:58:10.876537Z", @@ -426,10 +445,10 @@ "shell.execute_reply": "2026-02-16T15:58:13.099303Z" } }, - "outputs": [], "source": [ - "# connect to rediss://jane_doe:password123@localhost:6379\n", - "!rvl index listall --user jane_doe -a password123 --ssl" + "### Using SSL encryption\n", + "If your Redis instance is configured to use SSL encryption then set the `--ssl` flag.\n", + "You can similarly specify the username and password to construct the full Redis URL" ] }, { @@ -453,8 +472,16 @@ } ], "source": [ - "!rvl index destroy -i vectorizers" + "# connect to rediss://jane_doe:password123@localhost:6379\n", + "!rvl index listall --user jane_doe -a password123 --ssl" ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "!rvl index destroy -i vectorizers" } ], "metadata": { diff --git a/docs/user_guide/how_to_guides/index.md b/docs/user_guide/how_to_guides/index.md index c03d705da..f6511d54c 100644 --- a/docs/user_guide/how_to_guides/index.md +++ b/docs/user_guide/how_to_guides/index.md @@ -34,6 +34,7 @@ How-to guides are **task-oriented** recipes that help you accomplish specific go :::{grid-item-card} 💾 Storage - [Choose a Storage Type](../05_hash_vs_json.ipynb) -- Hash vs JSON formats and nested data +- [Migrate an Index](migrate-indexes.md) -- use the migrator helper, wizard, plan, apply, and validate workflow ::: :::{grid-item-card} 💻 CLI Operations @@ -59,6 +60,7 @@ How-to guides are **task-oriented** recipes that help you accomplish specific go | Optimize index performance | [Optimize Indexes with SVS-VAMANA](../09_svs_vamana.ipynb) | | Decide on storage format | [Choose a Storage Type](../05_hash_vs_json.ipynb) | | Manage indices from terminal | [Manage Indices with the CLI](../cli.ipynb) | +| Plan and run a supported index migration | [Migrate an Index](migrate-indexes.md) | ```{toctree} :hidden: @@ -74,4 +76,5 @@ Optimize Indexes with SVS-VAMANA <../09_svs_vamana> Cache Embeddings <../10_embeddings_cache> Use Advanced Query Types <../11_advanced_queries> Write SQL Queries for Redis <../12_sql_to_redis_queries> +Migrate an Index ``` diff --git a/docs/user_guide/how_to_guides/migrate-indexes.md b/docs/user_guide/how_to_guides/migrate-indexes.md new file mode 100644 index 000000000..65f04d7d0 --- /dev/null +++ b/docs/user_guide/how_to_guides/migrate-indexes.md @@ -0,0 +1,1119 @@ +--- +myst: + html_meta: + "description lang=en": | + How to migrate a RedisVL index schema without losing data. +--- + +# Migrate an Index + +This guide shows how to safely change your index schema using the RedisVL migrator. + +## Quick Start + +Add a field to your index in 4 commands: + +```bash +# 1. See what indexes exist +rvl migrate list --url redis://localhost:6379 + +# 2. Use the wizard to build a migration plan +rvl migrate wizard --index myindex --url redis://localhost:6379 + +# 3. Apply the migration +rvl migrate apply --plan migration_plan.yaml --url redis://localhost:6379 + +# 4. Verify the result +rvl migrate validate --plan migration_plan.yaml --url redis://localhost:6379 +``` + +## Prerequisites + +- Redis with the Search module (Redis Stack, Redis Cloud, or Redis Enterprise) +- An existing index to migrate +- `redisvl` installed (`pip install redisvl`) + +```bash +# Local development with Redis 8.0+ (recommended for full feature support) +docker run -d --name redis -p 6379:6379 redis:8.0 +``` + +**Note:** Redis 8.0+ is required for INT8/UINT8 vector datatypes. SVS-VAMANA algorithm requires Redis 8.2+ and Intel AVX-512 hardware. + +## Step 1: Discover Available Indexes + +```bash +rvl migrate list --url redis://localhost:6379 +``` + +**Example output:** +``` +Available indexes: + 1. products_idx + 2. users_idx + 3. orders_idx +``` + +## Step 2: Build Your Schema Change + +Choose one of these approaches: + +### Option A: Use the Wizard (Recommended) + +The wizard guides you through building a migration interactively. Run: + +```bash +rvl migrate wizard --index myindex --url redis://localhost:6379 +``` + +**Example wizard session (adding a field):** + +```text +Building a migration plan for index 'myindex' +Current schema: +- Index name: myindex +- Storage type: hash + - title (text) + - embedding (vector) + +Choose an action: +1. Add field (text, tag, numeric, geo) +2. Update field (sortable, weight, separator) +3. Remove field +4. Preview patch (show pending changes as YAML) +5. Finish +Enter a number: 1 + +Field name: category +Field type options: text, tag, numeric, geo +Field type: tag + Sortable: enables sorting and aggregation on this field +Sortable [y/n]: n + Separator: character that splits multiple values (default: comma) +Separator [leave blank to keep existing/default]: | + +Choose an action: +1. Add field (text, tag, numeric, geo) +2. Update field (sortable, weight, separator) +3. Remove field +4. Preview patch (show pending changes as YAML) +5. Finish +Enter a number: 5 + +Migration plan written to /path/to/migration_plan.yaml +Mode: drop_recreate +Supported: True +Warnings: +- Index downtime is required +``` + +**Example wizard session (quantizing vectors):** + +```text +Choose an action: +1. Add field (text, tag, numeric, geo) +2. Update field (sortable, weight, separator) +3. Remove field +4. Preview patch (show pending changes as YAML) +5. Finish +Enter a number: 2 + +Updatable fields: +1. title (text) +2. embedding (vector) +Select a field to update by number or name: 2 + +Current vector config for 'embedding': + algorithm: HNSW + datatype: float32 + distance_metric: cosine + dims: 384 (cannot be changed) + m: 16 + ef_construction: 200 + +Leave blank to keep current value. + Algorithm: vector search method (FLAT=brute force, HNSW=graph, SVS-VAMANA=compressed graph) +Algorithm [current: HNSW]: + Datatype: float16, float32, bfloat16, float64, int8, uint8 + (float16 reduces memory ~50%, int8/uint8 reduce ~75%) +Datatype [current: float32]: float16 + Distance metric: how similarity is measured (cosine, l2, ip) +Distance metric [current: cosine]: + M: number of connections per node (higher=better recall, more memory) +M [current: 16]: + EF_CONSTRUCTION: build-time search depth (higher=better recall, slower build) +EF_CONSTRUCTION [current: 200]: + +Choose an action: +... +5. Finish +Enter a number: 5 + +Migration plan written to /path/to/migration_plan.yaml +Mode: drop_recreate +Supported: True +``` + +### Option B: Write a Schema Patch (YAML) + +Create `schema_patch.yaml` manually: + +```yaml +version: 1 +changes: + add_fields: + - name: category + type: tag + path: $.category + attrs: + separator: "|" + remove_fields: + - legacy_field + update_fields: + - name: title + attrs: + sortable: true + - name: embedding + attrs: + datatype: float16 # quantize vectors + algorithm: HNSW + distance_metric: cosine +``` + +Then generate the plan: + +```bash +rvl migrate plan \ + --index myindex \ + --schema-patch schema_patch.yaml \ + --url redis://localhost:6379 \ + --plan-out migration_plan.yaml +``` + +### Option C: Provide a Target Schema + +If you have the complete target schema, use it directly: + +```bash +rvl migrate plan \ + --index myindex \ + --target-schema target_schema.yaml \ + --url redis://localhost:6379 \ + --plan-out migration_plan.yaml +``` + +## Step 3: Review the Migration Plan + +Before applying, review `migration_plan.yaml`: + +```yaml +# migration_plan.yaml (example) +version: 1 +mode: drop_recreate + +source: + schema_snapshot: + index: + name: myindex + prefix: "doc:" + storage_type: json + fields: + - name: title + type: text + - name: embedding + type: vector + attrs: + dims: 384 + algorithm: hnsw + datatype: float32 + stats_snapshot: + num_docs: 10000 + keyspace: + prefixes: ["doc:"] + key_sample: ["doc:1", "doc:2", "doc:3"] + +requested_changes: + add_fields: + - name: category + type: tag + +diff_classification: + supported: true + blocked_reasons: [] + +rename_operations: + rename_index: null + change_prefix: null + rename_fields: [] + +merged_target_schema: + index: + name: myindex + prefix: "doc:" + storage_type: json + fields: + - name: title + type: text + - name: category + type: tag + - name: embedding + type: vector + attrs: + dims: 384 + algorithm: hnsw + datatype: float32 + +warnings: + - "Index downtime is required" +``` + +**Key fields to check:** +- `diff_classification.supported` - Must be `true` to proceed +- `diff_classification.blocked_reasons` - Must be empty +- `warnings` - Top-level warnings about the migration +- `merged_target_schema` - The final schema after migration + +## Understanding Downtime Requirements + +**CRITICAL**: During a `drop_recreate` migration, your application must: + +| Requirement | Description | +|-------------|-------------| +| **Pause reads** | Index is unavailable during migration | +| **Pause writes** | Writes during migration may be missed or cause conflicts | + +### Why Both Reads AND Writes Must Be Paused + +- **Reads**: The index definition is dropped and recreated. Any queries during this window will fail. +- **Writes**: Redis updates indexes synchronously on every write. If your app writes documents while the index is dropped, those writes are not indexed. Additionally, if you're quantizing vectors (float32 → float16), concurrent writes may conflict with the migration's re-encoding process. + +### What "Downtime" Means + +| Downtime Type | Reads | Writes | Safe? | +|---------------|-------|--------|-------| +| Full quiesce (recommended) | Stopped | Stopped | **YES** | +| Read-only pause | Stopped | Continuing | **NO** | +| Active | Active | Active | **NO** | + +### Recovery from Interrupted Migration + +| Interruption Point | Documents | Index | Recovery | +|--------------------|-----------|-------|----------| +| After drop, before quantize | Unchanged | **None** | Re-run apply (or `--resume` if checkpoint exists) | +| During quantization | Partially quantized | **None** | Re-run with `--resume` to continue from checkpoint | +| After quantization, before create | Quantized | **None** | Re-run apply (will recreate index) | +| After create | Correct | Rebuilding | Wait for index ready | + +The underlying documents are **never deleted** by `drop_recreate` mode. For large quantization jobs, use `--resume` to enable checkpoint-based recovery. See [Crash-safe resume for quantization](#crash-safe-resume-for-quantization) below. + +## Step 4: Apply the Migration + +The `apply` command executes the migration. The index will be temporarily unavailable during the drop-recreate process. + +```bash +rvl migrate apply \ + --plan migration_plan.yaml \ + --url redis://localhost:6379 \ + --report-out migration_report.yaml \ + --benchmark-out benchmark_report.yaml +``` + +### What `apply` does + +The migration executor follows this sequence: + +**STEP 1: Enumerate keys** (before any modifications) +- Discovers all document keys belonging to the source index +- Uses `FT.AGGREGATE WITHCURSOR` for efficient enumeration +- Falls back to `SCAN` if the index has indexing failures +- Keys are stored in memory for quantization or rename operations + +**STEP 2: Drop source index** +- Issues `FT.DROPINDEX` to remove the index structure +- **The underlying documents remain in Redis** - only the index metadata is deleted +- After this point, the index is unavailable until step 6 completes + +**STEP 3: Quantize vectors** (if changing vector datatype) +- For each document in the enumerated key list: + - Reads the document (including the old vector) + - Converts the vector to the new datatype (e.g., float32 → float16) + - Writes back the converted vector to the same document +- Processes documents in batches of 500 using Redis pipelines +- Skipped for JSON storage (vectors are re-indexed automatically on recreate) +- **Checkpoint support**: For large datasets, use `--resume` to enable crash-safe recovery + +**STEP 4: Key renames** (if changing key prefix) +- If the migration changes the key prefix, renames each key from old prefix to new prefix +- Skipped if no prefix change + +**STEP 5: Create target index** +- Issues `FT.CREATE` with the merged target schema +- Redis begins background indexing of existing documents + +**STEP 6: Wait for re-indexing** +- Polls `FT.INFO` until indexing completes +- The index becomes available for queries when this completes + +**Summary**: The migration preserves all documents, drops only the index structure, performs any document-level transformations (quantization, renames), then recreates the index with the new schema. + +### Async execution for large migrations + +For large migrations (especially those involving vector quantization), use the `--async` flag: + +```bash +rvl migrate apply \ + --plan migration_plan.yaml \ + --async \ + --url redis://localhost:6379 +``` + +**What becomes async:** + +- Document enumeration during quantization (uses `FT.AGGREGATE WITHCURSOR` for index-specific enumeration, falling back to SCAN only if indexing failures exist) +- Vector read/write operations (sequential async HGET, batched HSET via pipeline) +- Index readiness polling (uses `asyncio.sleep()` instead of blocking) +- Validation checks + +**What stays sync:** + +- CLI prompts and user interaction +- YAML file reading/writing +- Progress display + +**When to use async:** + +- Quantizing millions of vectors (float32 to float16) +- Integrating into an async application + +For most migrations (index-only changes, small datasets), sync mode is sufficient and simpler. + +See {doc}`/concepts/index-migrations` for detailed async vs sync guidance. + +### Crash-safe resume for quantization + +When migrating large datasets with vector quantization (e.g. float32 to float16), the re-encoding step can take minutes or hours. If the process is interrupted (crash, network drop, OOM kill), you don't want to start over. The `--resume` flag enables checkpoint-based recovery. + +#### How it works + +1. **Pre-flight estimate**: before any mutations, `apply` prints a disk space estimate showing RDB snapshot cost, AOF growth (if enabled), and post-migration memory savings. +2. **BGSAVE safety snapshot**: the migrator triggers a Redis `BGSAVE` and waits for it to complete before modifying any data. This gives you a point-in-time snapshot to fall back on. +3. **Checkpoint file**: when `--resume` is provided, the migrator writes a YAML checkpoint after every batch of 500 documents. The checkpoint records how many keys have been processed and the last batch of keys written. +4. **Batch undo buffer**: if a single batch fails mid-write, original vector values are rolled back via pipeline before the error propagates. Only the current batch is held in memory. +5. **Idempotent skip**: on resume, vectors that were already converted are detected by byte-width inspection and skipped automatically. + +#### Step-by-step: using crash-safe resume + +**1. Estimate disk space (dry-run, no mutations):** + +```bash +rvl migrate estimate --plan migration_plan.yaml +``` + +Example output: + +```text +Pre-migration disk space estimate: + Index: products_idx (1,000,000 documents) + Vector field 'embedding': 768 dims, float32 -> float16 + + RDB snapshot (BGSAVE): ~2.87 GB + AOF growth: not estimated (pass aof_enabled=True if AOF is on) + Total new disk required: ~2.87 GB + + Post-migration memory savings: ~1.43 GB (50% reduction) +``` + +If AOF is enabled: + +```bash +rvl migrate estimate --plan migration_plan.yaml --aof-enabled +``` + +**2. Apply with checkpoint enabled:** + +```bash +rvl migrate apply \ + --plan migration_plan.yaml \ + --resume quantize_checkpoint.yaml \ + --url redis://localhost:6379 \ + --report-out migration_report.yaml +``` + +The `--resume` flag takes a path to a checkpoint file. If the file does not exist, a new checkpoint is created. If it already exists (from a previous interrupted run), the migrator resumes from where it left off. + +**3. If the process crashes or is interrupted:** + +The checkpoint file (`quantize_checkpoint.yaml`) will contain the progress: + +```yaml +index_name: products_idx +total_keys: 1000000 +completed_keys: 450000 +completed_batches: 900 +last_batch_keys: + - 'products:449501' + - 'products:449502' + # ... +status: in_progress +checkpoint_path: quantize_checkpoint.yaml +``` + +**4. Resume the migration:** + +Re-run the exact same command: + +```bash +rvl migrate apply \ + --plan migration_plan.yaml \ + --resume quantize_checkpoint.yaml \ + --url redis://localhost:6379 \ + --report-out migration_report.yaml +``` + +The migrator will: +- Detect the existing checkpoint and skip already-processed keys +- Re-enumerate documents via SCAN (the index was already dropped before the crash) +- Continue quantizing from where it left off +- Print progress like `[4/6] Quantize vectors: 450,000/1,000,000 docs` + +**5. On successful completion:** + +The checkpoint status is set to `completed`. You can safely delete the checkpoint file. + +#### What gets rolled back on batch failure + +If a batch of 500 documents fails mid-write (e.g. Redis returns an error), the migrator: +1. Restores original vector bytes for all documents in that batch using the undo buffer +2. Saves the checkpoint (so progress up to the last successful batch is preserved) +3. Raises the error + +This means you never end up with partially-written vectors in a single batch. + +#### Limitations + +- **Same-width conversions** (float16 to bfloat16, or int8 to uint8) are **not supported** with `--resume`. These conversions cannot be detected by byte-width inspection, so idempotent skip is impossible. The migrator will refuse to proceed and suggest running without `--resume`. +- **JSON storage** does not need vector re-encoding (Redis re-indexes JSON vectors on `FT.CREATE`). The checkpoint is still created for consistency but no batched writes occur. +- The checkpoint file must match the migration plan. If you change the plan, delete the old checkpoint and start fresh. + +#### Python API with checkpoints + +```python +from redisvl.migration import MigrationExecutor + +executor = MigrationExecutor() +report = executor.apply( + plan, + redis_url="redis://localhost:6379", + checkpoint_path="quantize_checkpoint.yaml", +) +``` + +For async: + +```python +from redisvl.migration import AsyncMigrationExecutor + +executor = AsyncMigrationExecutor() +report = await executor.apply( + plan, + redis_url="redis://localhost:6379", + checkpoint_path="quantize_checkpoint.yaml", +) +``` + +## Step 5: Validate the Result + +Validation happens automatically during `apply`, but you can run it separately: + +```bash +rvl migrate validate \ + --plan migration_plan.yaml \ + --url redis://localhost:6379 \ + --report-out migration_report.yaml +``` + +**Validation checks:** +- Live schema matches `merged_target_schema` +- Document count matches the source snapshot +- Sampled keys still exist +- No increase in indexing failures + +## What's Supported + +| Change | Supported | Notes | +|--------|-----------|-------| +| Add text/tag/numeric/geo field | ✅ | | +| Remove a field | ✅ | | +| Rename a field | ✅ | Renames field in all documents | +| Change key prefix | ✅ | Renames keys via RENAME command | +| Rename the index | ✅ | Index-only | +| Make a field sortable | ✅ | | +| Change field options (separator, stemming) | ✅ | | +| Change vector algorithm (FLAT ↔ HNSW ↔ SVS-VAMANA) | ✅ | Index-only | +| Change distance metric (COSINE ↔ L2 ↔ IP) | ✅ | Index-only | +| Tune HNSW parameters (M, EF_CONSTRUCTION) | ✅ | Index-only | +| Quantize vectors (float32 → float16/bfloat16/int8/uint8) | ✅ | Auto re-encode | + +## What's Blocked + +| Change | Why | Workaround | +|--------|-----|------------| +| Change vector dimensions | Requires re-embedding | Re-embed with new model, reload data | +| Change storage type (hash ↔ JSON) | Different data format | Export, transform, reload | +| Add a new vector field | Requires vectors for all docs | Add vectors first, then migrate | + +## CLI Reference + +### Single-Index Commands + +| Command | Description | +|---------|-------------| +| `rvl migrate list` | List all indexes | +| `rvl migrate wizard` | Build a migration interactively | +| `rvl migrate plan` | Generate a migration plan | +| `rvl migrate apply` | Execute a migration | +| `rvl migrate estimate` | Estimate disk space for a migration (dry-run) | +| `rvl migrate validate` | Verify a migration result | + +### Batch Commands + +| Command | Description | +|---------|-------------| +| `rvl migrate batch-plan` | Create a batch migration plan | +| `rvl migrate batch-apply` | Execute a batch migration | +| `rvl migrate batch-resume` | Resume an interrupted batch | +| `rvl migrate batch-status` | Check batch progress | + +**Common flags:** +- `--url` : Redis connection URL +- `--index` : Index name to migrate +- `--plan` / `--plan-out` : Path to migration plan +- `--async` : Use async executor for large migrations (apply only) +- `--report-out` : Path for validation report +- `--benchmark-out` : Path for performance metrics + +**Apply flags (quantization & reliability):** +- `--backup-dir ` : Directory for vector backup files. Enables crash-safe resume and manual rollback. Required when using `--workers` > 1. +- `--batch-size ` : Keys per pipeline batch (default 500). Values 200–1000 are typical. +- `--workers ` : Parallel quantization workers (default 1). Each worker opens its own Redis connection. See [Performance](#performance-tuning) for guidance. +- `--keep-backup` : Retain backup files after a successful migration (default: auto-cleanup). + +**Batch-specific flags:** +- `--pattern` : Glob pattern to match index names (e.g., `*_idx`) +- `--indexes` : Explicit list of index names +- `--indexes-file` : File containing index names (one per line) +- `--schema-patch` : Path to shared schema patch YAML +- `--state` : Path to checkpoint state file +- `--failure-policy` : `fail_fast` or `continue_on_error` +- `--accept-data-loss` : Required for quantization (lossy changes) +- `--retry-failed` : Retry previously failed indexes on resume + +## Troubleshooting + +### Migration blocked: "unsupported change" + +The planner detected a change that requires data transformation. Check `diff_classification.blocked_reasons` in the plan for details. + +### Apply failed: "source schema mismatch" + +The live index schema changed since the plan was generated. Re-run `rvl migrate plan` to create a fresh plan. + +### Apply failed: "timeout waiting for index ready" + +The index is taking longer to rebuild than expected. This can happen with large datasets. Check Redis logs and consider increasing the timeout or running during lower traffic periods. + +### Validation failed: "document count mismatch" + +Documents were added or removed between plan and apply. This is expected if your application is actively writing. Re-run `plan` and `apply` during a quieter period when the document count is stable, or verify the mismatch is due only to normal application traffic. + +### How to recover from a failed migration + +If `apply` fails mid-migration: + +1. **Check if the index exists:** `rvl index info --index myindex` +2. **If the index exists but is wrong:** Re-run `apply` with the same plan +3. **If the index was dropped:** Recreate it from the plan's `merged_target_schema` + +The underlying documents are never deleted by `drop_recreate`. + +## Backup, Resume & Rollback + +### How Backups Work + +When you pass `--backup-dir` (or `backup_dir` in the Python API), the +migration executor saves **original vector bytes** to disk before mutating +them. This enables two key capabilities: + +1. **Crash-safe resume** — if the process dies mid-migration, re-running the + same command with the same `--backup-dir` automatically detects partial + progress and resumes from the last completed batch. +2. **Manual rollback** — the backup files contain the original (pre-quantization) + vector values, which can be restored to undo a migration. + +Backup files are written to the specified directory with this layout: + +``` +/ + migration_backup_.header # JSON: phase, progress counters, field metadata + migration_backup_.data # Binary: length-prefixed batches of original vectors +``` + +**Disk usage:** approximately `num_docs × dims × bytes_per_element`. +For example, 1M docs with 768-dim float32 vectors ≈ 2.9 GB. + +By default, backup files are **automatically deleted** after a successful +migration. Pass `--keep-backup` to retain them for post-migration auditing +or potential rollback. + +### Crash-Safe Resume + +If a migration is interrupted (crash, network error, Ctrl+C), simply re-run +the exact same command: + +```bash +# Original command that was interrupted +rvl migrate apply --plan plan.yaml --url redis://localhost:6379 \ + --backup-dir /tmp/backups --workers 4 + +# Just re-run it — progress is resumed automatically +rvl migrate apply --plan plan.yaml --url redis://localhost:6379 \ + --backup-dir /tmp/backups --workers 4 +``` + +The executor detects the existing backup header, reads how many batches were +completed, and resumes from the next unfinished batch. No data is duplicated +or lost. + +```{note} +**Single-worker vs multi-worker resume:** In single-worker mode, the full +backup is written *before* the index is dropped, so a crash at any point +leaves a complete backup on disk. In multi-worker mode, dump and quantize +are fused (each worker reads, backs up, and converts its shard in one pass +*after* the index drop). A crash during this fused phase may leave partial +backup shards. Re-running detects and resumes from partial state. +``` + +### Rollback + +If you need to undo a quantization migration and restore original vectors, +use the `rollback` command: + +```bash +rvl migrate rollback --backup-dir /tmp/backups --url redis://localhost:6379 +``` + +This reads every batch from the backup files and pipeline-HSETs the original +(pre-quantization) vector bytes back into Redis. After rollback completes: + +- Your vector data is restored to its original datatype +- You will need to **manually recreate the original index schema** if the + index was changed during migration (the rollback command restores data + only, not the index definition) + +```bash +# After rollback, recreate the original index if needed: +rvl index create --schema original_schema.yaml --url redis://localhost:6379 +``` + +```{important} +Rollback requires that backup files were preserved. Either pass +`--keep-backup` during migration, or ensure the backup directory was not +cleaned up. Without backup files, rollback is not possible. +``` + +### Python API for Rollback + +```python +from redisvl.migration.backup import VectorBackup +import redis + +r = redis.from_url("redis://localhost:6379") +backup = VectorBackup.load("/tmp/backups/migration_backup_myindex") + +for keys, originals in backup.iter_batches(): + pipe = r.pipeline(transaction=False) + for key in keys: + if key in originals: + for field_name, original_bytes in originals[key].items(): + pipe.hset(key, field_name, original_bytes) + pipe.execute() + +print("Rollback complete") +``` + +## Python API + +For programmatic migrations, use the migration classes directly: + +### Sync API + +```python +from redisvl.migration import MigrationPlanner, MigrationExecutor + +planner = MigrationPlanner() +plan = planner.create_plan( + "myindex", + redis_url="redis://localhost:6379", + schema_patch_path="schema_patch.yaml", +) + +executor = MigrationExecutor() +report = executor.apply(plan, redis_url="redis://localhost:6379") +print(f"Migration result: {report.result}") +``` + +With backup and multi-worker quantization: + +```python +report = executor.apply( + plan, + redis_url="redis://localhost:6379", + backup_dir="/tmp/migration_backups", # enables crash-safe resume + batch_size=500, # keys per pipeline batch + num_workers=4, # parallel quantization workers + keep_backup=True, # retain backups for rollback +) +print(f"Quantized in {report.timings.quantize_duration_seconds}s") +``` + +### Async API + +```python +import asyncio +from redisvl.migration import AsyncMigrationPlanner, AsyncMigrationExecutor + +async def migrate(): + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + "myindex", + redis_url="redis://localhost:6379", + schema_patch_path="schema_patch.yaml", + ) + + executor = AsyncMigrationExecutor() + report = await executor.apply( + plan, + redis_url="redis://localhost:6379", + backup_dir="/tmp/migration_backups", + num_workers=4, + ) + print(f"Migration result: {report.result}") + +asyncio.run(migrate()) +``` + +## Batch Migration + +When you need to apply the same schema change to multiple indexes, use batch migration. This is common for: + +- Quantizing all indexes from float32 → float16 +- Standardizing vector algorithms across indexes +- Coordinated migrations during maintenance windows + +### Quick Start: Batch Migration + +```bash +# 1. Create a shared patch (applies to any index with an 'embedding' field) +cat > quantize_patch.yaml << 'EOF' +version: 1 +changes: + update_fields: + - name: embedding + attrs: + datatype: float16 +EOF + +# 2. Create a batch plan for all indexes matching a pattern +rvl migrate batch-plan \ + --pattern "*_idx" \ + --schema-patch quantize_patch.yaml \ + --plan-out batch_plan.yaml \ + --url redis://localhost:6379 + +# 3. Apply the batch plan +rvl migrate batch-apply \ + --plan batch_plan.yaml \ + --accept-data-loss \ + --url redis://localhost:6379 + +# 4. Check status +rvl migrate batch-status --state batch_state.yaml +``` + +### Batch Plan Options + +**Select indexes by pattern:** +```bash +rvl migrate batch-plan \ + --pattern "*_idx" \ + --schema-patch quantize_patch.yaml \ + --plan-out batch_plan.yaml \ + --url redis://localhost:6379 +``` + +**Select indexes by explicit list:** +```bash +rvl migrate batch-plan \ + --indexes "products_idx,users_idx,orders_idx" \ + --schema-patch quantize_patch.yaml \ + --plan-out batch_plan.yaml \ + --url redis://localhost:6379 +``` + +**Select indexes from a file (for 100+ indexes):** +```bash +# Create index list file +echo -e "products_idx\nusers_idx\norders_idx" > indexes.txt + +rvl migrate batch-plan \ + --indexes-file indexes.txt \ + --schema-patch quantize_patch.yaml \ + --plan-out batch_plan.yaml \ + --url redis://localhost:6379 +``` + +### Batch Plan Review + +The generated `batch_plan.yaml` shows which indexes will be migrated: + +```yaml +version: 1 +batch_id: "batch_20260320_100000" +mode: drop_recreate +failure_policy: fail_fast +requires_quantization: true + +shared_patch: + version: 1 + changes: + update_fields: + - name: embedding + attrs: + datatype: float16 + +indexes: + - name: products_idx + applicable: true + skip_reason: null + - name: users_idx + applicable: true + skip_reason: null + - name: legacy_idx + applicable: false + skip_reason: "Field 'embedding' not found" + +created_at: "2026-03-20T10:00:00Z" +``` + +**Key fields:** +- `applicable: true` means the patch applies to this index +- `skip_reason` explains why an index will be skipped + +### Applying a Batch Plan + +```bash +# Apply with fail-fast (default: stop on first error) +rvl migrate batch-apply \ + --plan batch_plan.yaml \ + --accept-data-loss \ + --url redis://localhost:6379 + +# Apply with continue-on-error (set at batch-plan time) +# Note: failure_policy is set during batch-plan, not batch-apply +rvl migrate batch-plan \ + --pattern "*_idx" \ + --schema-patch quantize_patch.yaml \ + --failure-policy continue_on_error \ + --plan-out batch_plan.yaml \ + --url redis://localhost:6379 + +rvl migrate batch-apply \ + --plan batch_plan.yaml \ + --accept-data-loss \ + --url redis://localhost:6379 +``` + +**Flags for batch-apply:** +- `--accept-data-loss` : Required when quantizing vectors (float32 → float16 is lossy) +- `--state` : Path to checkpoint file (default: `batch_state.yaml`) +- `--report-dir` : Directory for per-index reports (default: `./reports/`) + +**Note:** `--failure-policy` is set during `batch-plan`, not `batch-apply`. The policy is stored in the batch plan file. + +### Resume After Failure + +Batch migration automatically checkpoints progress. If interrupted: + +```bash +# Resume from where it left off +rvl migrate batch-resume \ + --state batch_state.yaml \ + --accept-data-loss \ + --url redis://localhost:6379 + +# Retry previously failed indexes +rvl migrate batch-resume \ + --state batch_state.yaml \ + --retry-failed \ + --accept-data-loss \ + --url redis://localhost:6379 +``` + +**Note:** If the batch plan involves quantization (e.g., `float32` → `float16`), you must pass `--accept-data-loss` to `batch-resume`, just as with `batch-apply`. Omit `--accept-data-loss` if the batch plan does not involve quantization. + +### Checking Batch Status + +```bash +rvl migrate batch-status --state batch_state.yaml +``` + +**Example output:** +``` +Batch Migration Status +====================== +Batch ID: batch_20260320_100000 +Started: 2026-03-20T10:00:00Z +Updated: 2026-03-20T10:25:00Z + +Completed: 2 + - products_idx: success (10:02:30) + - users_idx: failed - Redis connection timeout (10:05:45) + +In Progress: inventory_idx +Remaining: 1 (analytics_idx) +``` + +### Batch Report + +After completion, a `batch_report.yaml` is generated: + +```yaml +version: 1 +batch_id: "batch_20260320_100000" +status: completed # or partial_failure, failed +summary: + total_indexes: 3 + successful: 3 + failed: 0 + skipped: 0 + total_duration_seconds: 127.5 +indexes: + - name: products_idx + status: success + report_path: ./reports/products_idx_report.yaml + - name: users_idx + status: success + report_path: ./reports/users_idx_report.yaml + - name: orders_idx + status: success + report_path: ./reports/orders_idx_report.yaml +completed_at: "2026-03-20T10:02:07Z" +``` + +### Python API for Batch Migration + +```python +from redisvl.migration import BatchMigrationPlanner, BatchMigrationExecutor + +# Create batch plan +planner = BatchMigrationPlanner() +batch_plan = planner.create_batch_plan( + redis_url="redis://localhost:6379", + pattern="*_idx", + schema_patch_path="quantize_patch.yaml", +) + +# Review applicability +for idx in batch_plan.indexes: + if idx.applicable: + print(f"Will migrate: {idx.name}") + else: + print(f"Skipping {idx.name}: {idx.skip_reason}") + +# Execute batch +executor = BatchMigrationExecutor() +report = executor.apply( + batch_plan, + redis_url="redis://localhost:6379", + state_path="batch_state.yaml", + report_dir="./reports/", + progress_callback=lambda name, pos, total, status: print(f"[{pos}/{total}] {name}: {status}"), +) + +print(f"Batch status: {report.status}") +print(f"Successful: {report.summary.successful}/{report.summary.total_indexes}") +``` + +### Batch Migration Tips + +1. **Test on a single index first**: Run a single-index migration to verify the patch works before applying to a batch. + +2. **Use `continue_on_error` for large batches**: This ensures one failure doesn't block all remaining indexes. + +3. **Schedule during low-traffic periods**: Each index has downtime during migration. + +4. **Review skipped indexes**: The `skip_reason` often indicates schema differences that need attention. + +5. **Keep checkpoint files**: The `batch_state.yaml` is essential for resume. Don't delete it until the batch completes successfully. + +## Performance Tuning + +### Quantization Throughput + +Vector quantization (e.g. float32 → float16) is the most time-consuming +phase of a datatype migration. Observed throughput on a local Redis instance: + +| Workers | Dims | Throughput | Notes | +|---------|------|------------|-------| +| 1 | 256 | ~70K docs/sec | Single worker is fastest for low dims | +| 4 | 256 | ~62K docs/sec | Worker overhead exceeds parallelism benefit | +| 1 | 1536 | ~15K docs/sec | Higher dims = more conversion work | +| 4 | 1536 | ~15K docs/sec | I/O-bound; Redis is the bottleneck | + +**Guidance:** +- For **low-dimensional vectors** (≤ 256 dims), use `--workers 1` (the default). Per-vector conversion is so cheap that process-spawning and extra-connection overhead outweigh the parallelism benefit. +- For **high-dimensional vectors** (≥ 768 dims), `--workers 2-4` may help if the Redis server has available CPU headroom. Diminishing returns above 4–8 workers on a single Redis instance because Redis command processing is single-threaded. +- The main bottleneck for large migrations is typically **index rebuild time** (the `FT.CREATE` background indexing after vectors are written), not quantization itself. + +### Batch Size + +The `--batch-size` flag controls how many keys are read/written per Redis +pipeline round-trip. The default of 500 is a good balance. Larger batches +(1000+) reduce round-trips but increase per-batch memory and latency. + +### Backup Disk Space + +When `--backup-dir` is provided, original vectors are saved to disk before +mutation. Approximate size: `num_docs × dims × bytes_per_element`. + +| Docs | Dims | Source dtype | Backup size | +|--------|------|-------------|-------------| +| 100K | 768 | float32 | ~292 MB | +| 1M | 768 | float32 | ~2.9 GB | +| 1M | 1536 | float32 | ~5.7 GB | + +### HNSW vs FLAT Index Capacity + +```{note} +When migrating from **HNSW** to **FLAT**, the target index may report a +*higher* document count than the source. This is not a bug — it reflects +a fundamental difference in how the two algorithms store vectors. + +HNSW maintains a navigable small-world graph with per-node neighbor lists. +This graph overhead limits how many vectors can fit in available memory. +FLAT stores vectors as a simple array with no graph overhead. + +If the source HNSW index was operating near its memory capacity, some +documents may have been registered in Redis Search's document table but +not fully indexed into the HNSW graph. After migration to FLAT, those +same documents become fully searchable because FLAT requires less memory +per vector. + +The migration validator compares the total key count +(`num_docs + hash_indexing_failures`) between source and target, so this +scenario is handled correctly in the general case. +``` + +## Learn more + +- {doc}`/concepts/index-migrations`: How migrations work and which changes are supported diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 5d2cf6dfd..d85177e73 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -39,7 +39,7 @@ Schema → Index → Load → Query **Solve specific problems.** Task-oriented recipes for LLM extensions, querying, embeddings, optimization, and storage. +++ -LLM Caching • Filtering • Vectorizers • Reranking +LLM Caching • Filtering • Vectorizers • Reranking • Migrations ::: :::{grid-item-card} 💻 CLI Reference @@ -49,7 +49,7 @@ LLM Caching • Filtering • Vectorizers • Reranking **Command-line tools.** Manage indices, inspect stats, and work with schemas using the `rvl` CLI. +++ -rvl index • rvl stats • Schema YAML +rvl index • rvl stats • rvl migrate • Schema YAML ::: :::{grid-item-card} 💡 Use Cases diff --git a/local_docs/index_migrator/30_quantize_performance_spec.md b/local_docs/index_migrator/30_quantize_performance_spec.md new file mode 100644 index 000000000..b47469354 --- /dev/null +++ b/local_docs/index_migrator/30_quantize_performance_spec.md @@ -0,0 +1,603 @@ +# Quantization Performance Overhaul + +## Problem Statement + +Migrating 1M keys currently takes ~50 minutes. Quantization (re-encoding vectors from one dtype to another) accounts for 78% of that time. The bottleneck is the quantize loop performing individual `HGET` calls per key per field — one network round trip each — while only pipelining the `HSET` writes. + +### Current per-document cost breakdown (from benchmark at 100K) + +| Component | Per-doc cost | % of migration | +|-----------|-------------|----------------| +| Quantize | 241 µs | 78% | +| Reindex | 66 µs | ~20% | +| Other | — | ~2% | + +### Why it's slow + +The inner loop of `_quantize_vectors` does: + +```python +for key in batch: # 500 keys per batch + for field_name, change in datatype_changes.items(): + field_data = client.hget(key, field_name) # ← 1 round trip per key per field + # ... convert ... + pipe.hset(key, field_name, new_bytes) # pipelined, but writes only + pipe.execute() # 1 round trip for all writes +``` + +For 1M keys × 1 field: **1,000,000 read round trips + 2,000 write round trips** = ~1,002,000 total round trips. + +### Additional problems + +1. **BGSAVE is heavy and imprecise.** Before mutation, the executor triggers `BGSAVE` as a safety snapshot. This snapshots the *entire* database, not just the vectors being modified. For large DBs it takes minutes and provides no targeted rollback capability. + +2. **QuantizationCheckpoint tracks progress but not data.** The current checkpoint records which keys have been processed. On crash you can resume, but the original vector bytes are gone — they've been overwritten in Redis. There is no rollback to original values after a crash (the `BatchUndoBuffer` only survives within a single batch, not across process crashes). + +3. **Single-worker execution.** One process, one Redis connection, sequential batches. No parallelism. + +## Design + +Three changes, applied in order: + +### Change 1: Pipeline reads + +Batch `HGET` reads into a pipeline, same as writes. This is the single biggest win with zero risk. + +**Before:** 1,002,000 round trips for 1M keys. +**After:** 4,000 round trips for 1M keys (2,000 read + 2,000 write). + +```python +for i in range(0, total_keys, batch_size): + batch = keys[i : i + batch_size] + + # Phase A: pipelined reads + read_pipe = client.pipeline(transaction=False) + read_meta = [] + for key in batch: + for field_name, change in datatype_changes.items(): + read_pipe.hget(key, field_name) + read_meta.append((key, field_name, change)) + read_results = read_pipe.execute() + + # Phase B: convert + pipelined writes + write_pipe = client.pipeline(transaction=False) + for (key, field_name, change), field_data in zip(read_meta, read_results): + if not field_data: + continue + # ... idempotent check, convert, store original for backup ... + write_pipe.hset(key, field_name, new_bytes) + if writes_pending: + write_pipe.execute() +``` + +**Estimated improvement:** 50 min → 3-5 min (10-15x speedup from eliminating read round trips). + +Applies to both sync (`MigrationExecutor`) and async (`AsyncMigrationExecutor`). + +### Change 2: Replace BGSAVE + QuantizationCheckpoint with a vector backup file + +Remove BGSAVE entirely. Replace the current checkpoint (which only tracks progress) with a backup file that stores the **original vector bytes** for every key being quantized. + +**Two-phase approach (Alternative A):** + +``` +Phase 1 — DUMP (before index drop, index still alive): + Enumerate keys via FT.AGGREGATE + For each batch of 500 keys: + Pipeline-read all vector fields via HGET + Append {key: {field: original_bytes}} to backup file + Flush to disk + +Phase 2 — QUANTIZE (after index drop): + For each batch in the backup file: + Read original vectors FROM the backup file (no Redis reads) + Convert dtype + Pipeline-write new vectors to Redis + Update progress counter in backup file header +``` + +The dump runs while the index is alive, so FT.AGGREGATE is always available +for enumeration. After the dump completes, the backup file contains the full +key list and all original vectors. The quantize phase never reads from Redis +— it reads originals from the local file and writes converted vectors back. +**SCAN is never used at any point.** + +#### Backup file format + +Binary file using msgpack for compactness and speed. Structure: + +``` +[header] + index_name: str + total_keys: int + fields: {field_name: {source_dtype, target_dtype, dims}} + phase: "dump" | "ready" | "active" | "completed" + dump_completed_batches: int + quantize_completed_batches: int + batch_size: int + +[batch 0] + keys: [key1, key2, ..., key500] + vectors: {key1: {field1: bytes, field2: bytes}, key2: {...}, ...} + +[batch 1] + ... +``` + +Each batch is written as a length-prefixed msgpack blob. The header is +stored in a separate small file (`.header`) and updated atomically +(write to temp file + `os.rename`) after each batch operation. + +#### Progress tracking + +The header has two counters that track exactly where the migration is: + +- `dump_completed_batches` — how many batches have been read from Redis + and written to the backup file +- `quantize_completed_batches` — how many batches have been converted and + written back to Redis + +**Dump progress (Phase 1):** + +``` +for batch_idx, batch in enumerate(chunked(keys, batch_size)): + originals = pipeline_read(client, batch, fields) # pipelined HGET + backup.write_batch(batch_idx, batch, originals) # append to file + backup.update_header(dump_completed_batches=batch_idx + 1) # atomic +``` + +No Redis state modified. On crash at `dump_completed_batches=N`, just +re-enumerate and resume from batch N. + +**Quantize progress (Phase 2):** + +``` +for batch_idx, (batch_keys, originals) in enumerate(backup.iter_batches()): + if batch_idx < backup.header.quantize_completed_batches: + continue # already written to Redis in a previous run + converted = convert_all(originals, datatype_changes) + pipeline_write(client, converted) # pipelined HSET + backup.update_header(quantize_completed_batches=batch_idx + 1) # atomic +``` + +The header update happens **only after** `pipeline_write` succeeds. If the +process crashes mid-batch (after some HSETs execute but before the header +update), the header still says `quantize_completed_batches=N`. On resume, +batch N is re-processed. This is safe because: + +- **HSET is idempotent** — re-writing the same converted value is a no-op +- **Originals are in the backup file** — if you need to rollback, read + batch N's originals from disk and HSET them back +- **No partial state** — either the header says "batch N done" (all HSETs + committed + header flushed) or it doesn't (re-process the whole batch) + +#### Phase transitions and resume + +Phase values: `dump` → `ready` → `active` → `completed` + +| Phase | Meaning | Index state | What to do on resume | +|-------|---------|-------------|---------------------| +| `dump` | Dump in progress | **Alive** | Re-enumerate via FT.AGGREGATE, resume dump from `dump_completed_batches` | +| `ready` | Dump complete, pre-drop | **Alive** | All originals on disk. Proceed to drop + quantize | +| `active` | Quantize in progress | **Dropped** | Key list + originals in file. Resume quantize from `quantize_completed_batches` | +| `completed` | All writes done | **Dropped** | Skip to FT.CREATE | + +On resume: +- **No backup file:** No mutations happened. Restart from scratch. +- **Backup file exists:** Read header, check `phase`, resume from the + appropriate point. No SCAN, no re-discovery. The file has everything. +- **Rollback:** At any point after dump completes, read originals from + backup file, pipeline-HSET them back to Redis. + +#### Worked example: crash and resume + +Setup: 2,000 keys, batch_size=500, quantizing `embedding` from float32 → float16. +4 batches total: + +``` +batch 0 = doc:0 – doc:499 (always these keys, fixed at dump time) +batch 1 = doc:500 – doc:999 +batch 2 = doc:1000 – doc:1499 +batch 3 = doc:1500 – doc:1999 +``` + +The batch layout is frozen when the dump completes. Batch index 2 always +means keys doc:1000–doc:1499 whether it's the first run or a resume after +crash. One integer counter is enough to track progress. + +**Phase 1 — Dump (index alive, no mutations):** + +``` +batch 0: pipeline HGET doc:0..499 → write float32 bytes to backup file + header: phase=dump, dump_completed_batches=1 +batch 1: pipeline HGET doc:500..999 → write to file + header: dump_completed_batches=2 +batch 2: pipeline HGET doc:1000..1499 → write to file + header: dump_completed_batches=3 +batch 3: pipeline HGET doc:1500..1999 → write to file + header: dump_completed_batches=4 + +header: phase=ready (dump complete, all originals on disk) +``` + +No Redis data modified. Index still alive. + +**Drop index + key renames happen here.** + +``` +header: phase=active (mutations starting) +``` + +**Phase 2 — Quantize (index dropped):** + +``` +batch 0: read originals from backup file (local disk, no Redis read) + convert float32 → float16 in memory + pipeline HSET doc:0..499 with float16 bytes → pipe.execute() ✓ + header: quantize_completed_batches=1 ← atomic write AFTER execute + +batch 1: read from file → convert → pipeline HSET doc:500..999 → execute ✓ + header: quantize_completed_batches=2 + +batch 2: read from file → convert → pipeline HSET doc:1000..1499 + pipe.execute() → ☠️ CRASH (process killed mid-pipeline) + header update NEVER HAPPENED +``` + +**State after crash:** + +``` +Redis: + doc:0 – doc:499 → float16 ✅ (batch 0, fully committed) + doc:500 – doc:999 → float16 ✅ (batch 1, fully committed) + doc:1000 – doc:1200 → float16 ⚠️ (batch 2, partial — ~200 keys written before crash) + doc:1201 – doc:1499 → float32 ❌ (batch 2, rest not yet written) + doc:1500 – doc:1999 → float32 ❌ (batch 3, not started) + +Backup file header: + phase: active + quantize_completed_batches: 2 ← NOT updated (crash happened before atomic write) + +Backup file data: + batch 0: {doc:0..499: original float32 bytes} ← all originals preserved + batch 1: {doc:500..999: original float32 bytes} + batch 2: {doc:1000..1499: original float32 bytes} + batch 3: {doc:1500..1999: original float32 bytes} +``` + +**Resume — user runs migrate again:** + +``` +Read header → phase=active, quantize_completed_batches=2 + +batch 0: 0 < 2 → skip +batch 1: 1 < 2 → skip + +batch 2: 2 >= 2 → re-process entire batch + read doc:1000..1499 originals from backup file + convert float32 → float16 + pipeline HSET all 500 keys → pipe.execute() ✓ + header: quantize_completed_batches=3 + + doc:1000–doc:1200 already had float16 from the crashed run. + HSET overwrites with the same float16 value. Idempotent. No harm. + +batch 3: 3 >= 2 → process normally + read doc:1500..1999 from file → convert → HSET → execute ✓ + header: quantize_completed_batches=4 + +header: phase=completed +``` + +All 2,000 keys now float16. Proceed to FT.CREATE. + +**Rollback — user wants original float32 back instead:** + +``` +Read backup file +batch 0: read original float32 bytes → pipeline HSET doc:0..499 +batch 1: read originals → pipeline HSET doc:500..999 +batch 2: read originals → pipeline HSET doc:1000..1499 +batch 3: read originals → pipeline HSET doc:1500..1999 +Done — all 2,000 keys restored to float32 +``` + +The backup file makes rollback possible at any point because it stores +the actual vector bytes, not just a list of which keys were processed. + +#### What this replaces + +| Old component | New replacement | +|---------------|----------------| +| `trigger_bgsave_and_wait` | Removed entirely | +| `async_trigger_bgsave_and_wait` | Removed entirely | +| `QuantizationCheckpoint` model | `VectorBackup` (new) | +| `BatchUndoBuffer` | Backup file (originals always on disk) | +| `is_already_quantized` check | Kept as safety net, but less critical since backup has originals | +| BGSAVE CLI step | Removed from progress labels | + + +#### Disk space + +Backup file size = N_keys × N_fields × bytes_per_vector. + +| Scale | Dims | Source dtype | Backup size | +|-------|------|-------------|-------------| +| 100K | 768 | float32 | ~292 MB | +| 1M | 768 | float32 | ~2.9 GB | +| 1M | 1536 | float32 | ~5.7 GB | +| 10M | 1536 | float32 | ~57 GB | + +Formula: `N × dims × bytes_per_element` (plus ~100 bytes/key overhead for key names and msgpack framing). + +The CLI should estimate and display the required disk space before starting, and abort if insufficient. + +### Change 3: Multi-worker parallelism (opt-in) + +Split the key list into N slices. Each worker gets its own Redis connection, its own backup file shard, and processes its slice independently. + +``` + ┌─────────────────────┐ + │ Coordinator │ + │ (main thread) │ + │ │ + │ 1. Enumerate keys │ + │ 2. Split into N │ + │ 3. Launch workers │ + │ 4. Wait for all │ + │ 5. Merge progress │ + └────────┬────────────┘ + │ + ┌──────────────┼──────────────┐ + │ │ │ + ┌────────▼───────┐ ┌───▼──────────┐ ┌─▼──────────────┐ + │ Worker 0 │ │ Worker 1 │ │ Worker N-1 │ + │ keys[0:250K] │ │ keys[250K:] │ │ keys[750K:] │ + │ connection 0 │ │ connection 1│ │ connection N-1│ + │ backup_0.bin │ │ backup_1.bin│ │ backup_N.bin │ + │ │ │ │ │ │ + │ dump → quant │ │ dump → quant│ │ dump → quant │ + └────────────────┘ └──────────────┘ └────────────────┘ +``` + +#### Implementation + +**Sync executor:** `concurrent.futures.ThreadPoolExecutor`. GIL is released during socket I/O (Redis calls) and numpy operations, so threads provide real parallelism for the I/O-bound workload. + +**Async executor:** `asyncio.gather` with N concurrent coroutines, each using its own Redis connection from a pool. + +#### Worker function (pseudocode) + +```python +def _quantize_worker( + worker_id: int, + redis_url: str, + keys: List[str], + datatype_changes: Dict, + backup_path: str, + batch_size: int, + progress_queue: Queue, +) -> WorkerResult: + """Independent worker: dump + quantize a slice of keys.""" + client = Redis.from_url(redis_url) # own connection + backup = VectorBackup.create(backup_path, ...) + + # Phase 1: dump originals + for batch in chunked(keys, batch_size): + originals = pipeline_read(client, batch, datatype_changes) + backup.write_batch(batch, originals) + progress_queue.put(("dump", worker_id, len(batch))) + + backup.mark_dump_complete() + + # Phase 2: quantize from backup + for batch_idx, (batch_keys, originals) in enumerate(backup.iter_batches()): + converted = {k: convert(v, datatype_changes) for k, v in originals.items()} + pipeline_write(client, converted) + backup.mark_batch_quantized(batch_idx) + progress_queue.put(("quantize", worker_id, len(batch_keys))) + + backup.mark_complete() + client.close() + return WorkerResult(worker_id, docs_quantized=len(keys)) +``` + +#### Configuration + +| Parameter | Default | CLI flag | Notes | +|-----------|---------|----------|-------| +| `workers` | 1 | `--workers N` | Number of parallel workers | +| `batch_size` | 500 | `--batch-size N` | Keys per pipeline batch | +| `backup_dir` | `.` | `--backup-dir PATH` | Directory for backup files | + +#### Safety constraints (from benchmark report) + +Per the benchmark notes on N-worker risks: + +1. **Default N=1.** Opt-in only. The user must explicitly pass `--workers N` to enable parallelism. +2. **Replication backlog.** N concurrent HSET writers increase replication lag. For replicated deployments, recommend N ≤ 4 and monitor replication offset. +3. **AOF pressure.** N writers accelerate AOF buffer growth. If AOF is enabled, warn the user and suggest lower N or disabling AOF during migration. +4. **Redis is single-threaded.** N connections do not give Nx server throughput. The speedup comes from overlapping client-side I/O (network round trips). Diminishing returns above N=4-8 for a single Redis instance. +5. **Cluster mode.** Each worker's key slice must map to keys the worker's connection can reach. For non-clustered Redis this is trivial. For cluster, keys are already partitioned by slot — the coordinator should group keys by slot and assign slot-contiguous ranges to workers. + +## Migration flow (new) + +``` +STEP 1: Enumerate keys (FT.AGGREGATE — index is alive) + │ +STEP 2: Field renames (if any, pipelined) + │ +STEP 3: Dump originals to backup file ← NEW, index still alive + N workers, pipelined HGET reads + backup file per worker + │ +STEP 4: Drop index (FT.DROPINDEX) + │ +STEP 5: Key renames (if any, DUMP/RESTORE/DEL) + │ +STEP 6: Quantize (read from backup file, write to Redis) + N workers, pipelined HSET writes + no Redis reads — originals come from local file + │ +STEP 7: Create index (FT.CREATE with new schema) + │ +STEP 8: Wait for indexing (poll FT.INFO percent_indexed) + │ +STEP 9: Validate (schema match, doc count, key sample, query checks) +``` + +BGSAVE removed entirely. SCAN removed entirely — never needed. + +### Why SCAN is never needed + +The dump phase (Step 3) runs while the index is still alive. Enumeration +always uses FT.AGGREGATE against the live index. Once the dump completes, +every key and its original vector bytes are stored in the backup file. All +subsequent steps (drop, key renames, quantize) use the key list from the +backup file — they never need to re-discover keys from Redis. + +On resume after crash: +- **Crash during dump (Steps 1-3):** Index is still alive (hasn't been + dropped). Re-enumerate via FT.AGGREGATE. Resume dump from + `dump_completed_batches`. +- **Crash after drop (Steps 4-6):** Backup file has the complete key list + and all original vectors. Resume quantize from `quantize_completed_batches`. + No enumeration needed — just read the file. +- **Crash during create/validate (Steps 7-9):** Data is fully written. + Just re-run FT.CREATE + validate. + +### Crash recovery matrix + +| Crash point | Index state | Backup file state | Recovery | +|---|---|---|---| +| During enumerate (1) | Alive | Doesn't exist | Restart from scratch | +| During field renames (2) | Alive | Doesn't exist | Restart — renames are idempotent (HSET) | +| During dump (3) | **Alive** | `phase=dump`, partial | Re-enumerate via FT.AGGREGATE, resume dump | +| After dump, before drop | **Alive** | `phase=ready`, complete | Proceed to drop | +| During/after drop (4) | **Gone** | `phase=active` | Key list + originals in file, proceed | +| During key renames (5) | Gone | `phase=active` | Proceed — renames are idempotent | +| During quantize (6) | Gone | `phase=active`, `quantize_batch=M` | Resume from batch M | +| During create (7) | Gone/rebuilding | `phase=completed` | Re-run FT.CREATE | +| During wait/validate (8-9) | Building | `phase=completed` | Re-poll, re-validate | + +**No crash scenario requires SCAN.** Every recovery path uses either +FT.AGGREGATE (index alive) or the backup file (index dropped but file has keys). + +## Files changed + +### New files + +| File | Purpose | +|------|---------| +| `redisvl/migration/backup.py` | `VectorBackup` class — read/write backup files | + +### Modified files + +| File | Changes | +|------|---------| +| `redisvl/migration/executor.py` | Pipeline reads, replace checkpoint with backup, multi-worker | +| `redisvl/migration/async_executor.py` | Same changes for async path | +| `redisvl/migration/reliability.py` | Remove `QuantizationCheckpoint`, `trigger_bgsave_and_wait`, `async_trigger_bgsave_and_wait`, `BatchUndoBuffer`. Keep `is_already_quantized`, `detect_vector_dtype`, `is_same_width_dtype_conversion`. | +| `redisvl/cli/migrate.py` | Add `--workers`, `--batch-size`, `--backup-dir` flags. Remove BGSAVE progress step. | +| `docs/concepts/index-migrations.md` | Update flow description, remove BGSAVE reference | +| `docs/user_guide/how_to_guides/migrate-indexes.md` | Update CLI flags, flow description | + +### Removed functionality + +| Component | Reason | +|-----------|--------| +| `trigger_bgsave_and_wait` | Replaced by backup file | +| `async_trigger_bgsave_and_wait` | Replaced by backup file | +| `QuantizationCheckpoint` | Replaced by `VectorBackup` | +| `BatchUndoBuffer` | Replaced by backup file (originals always on disk) | +| `--resume` / `checkpoint_path` parameter | Replaced by `--backup-dir` resume semantics | +| BGSAVE progress step in CLI | No longer needed | + +## Expected performance + +### Pipeline reads only (Change 1, N=1) + +| Scale | Current | After pipelining | Speedup | +|-------|---------|-----------------|---------| +| 100K | 31s | ~3s | ~10x | +| 1M | ~50 min | ~5 min | ~10x | +| 10M | ~8 hrs | ~50 min | ~10x | + +### Pipeline reads + 4 workers (Changes 1+3, N=4) + +| Scale | Current | After both | Speedup | +|-------|---------|-----------|---------| +| 100K | 31s | ~1s | ~30x | +| 1M | ~50 min | ~1.5 min | ~33x | +| 10M | ~8 hrs | ~15 min | ~32x | + +Note: Worker speedup is sub-linear (not 4x) because Redis is single-threaded. The gains come from overlapping client-side network I/O. Actual speedup depends on network latency, Redis load, and whether the deployment is standalone or clustered. + +### Two-phase overhead (Change 2) + +The dump phase adds one extra read pass over all keys. But since reads are now pipelined, this adds ~50% to read time (one extra pipeline execution per batch). The quantize phase reads from the local backup file instead of Redis, which is faster than Redis reads. Net effect: roughly neutral — the dump overhead is offset by faster quantize reads. + +## Implementation order + +1. **Pipeline reads** (sync + async). Run benchmarks to validate 10x improvement. +2. **VectorBackup file** (new module). Unit test read/write/resume. +3. **Replace BGSAVE + checkpoint** in executor with backup file. Integration test. +4. **Multi-worker** (sync first, then async). Benchmark at N=1,2,4. +5. **CLI flags** (`--workers`, `--batch-size`, `--backup-dir`). +6. **Update docs** and existing tests. + +## Open questions + +1. **msgpack vs pickle vs custom binary?** msgpack is compact and fast but adds a dependency. pickle is stdlib but not portable. Custom binary is zero-dep but more code. Recommendation: msgpack (already a common dep in data workflows, ~10x faster than YAML). +2. **Should backup files be cleaned up automatically on success?** Recommend yes with a `--keep-backup` flag to retain. +3. **Should the dump phase estimate and display ETA?** Yes — can calculate from batch throughput after the first few batches. +4. **Cluster slot-aware worker assignment?** Defer to a follow-up. For now, workers are assigned contiguous key ranges which works for standalone Redis. Cluster support needs slot grouping. + +## References + +- Benchmark report: `local_docs/index_migrator/05_migration_benchmark_report.md` +- Scaling notes: `local_docs/index_migrator/notes_scaling_and_reliability.md` +- PR review (pipelined reads): nkode run `ce95e0e4`, finding #2 +- PR review (async pipelining): Copilot comment on `async_executor.py:639-654` +- PR review (sync pipelining): Copilot comment on `executor.py:674-695` + +## Implementation Status — COMPLETE + +Branch: `feat/migrate-perf-overhaul` + +All three spec changes are implemented. Legacy components removed. + +### Implemented + +| Component | File | Status | +|-----------|------|--------| +| `VectorBackup` class | `redisvl/migration/backup.py` | ✅ 16 tests | +| Pipeline read/write/convert helpers | `redisvl/migration/quantize.py` | ✅ 6 tests | +| `split_keys`, `MultiWorkerResult` | `redisvl/migration/quantize.py` | ✅ 5 tests | +| `multi_worker_quantize` (sync) | `redisvl/migration/quantize.py` | ✅ 3 tests | +| `async_multi_worker_quantize` | `redisvl/migration/quantize.py` | ✅ Done | +| `_dump_vectors` / `_quantize_from_backup` | Both executors | ✅ 4 tests | +| `apply()` refactored: dump→drop→quantize | Both executors | ✅ Done | +| Multi-worker wired into `apply()` | Both executors | ✅ Done | +| `--backup-dir` CLI flag | `redisvl/cli/migrate.py` | ✅ Done | +| `--batch-size` CLI flag | `redisvl/cli/migrate.py` | ✅ Done | +| `--workers` CLI flag | `redisvl/cli/migrate.py` | ✅ Done | +| `--keep-backup` CLI flag | `redisvl/cli/migrate.py` | ✅ Done | +| Auto-cleanup backup files on success | Both executors | ✅ Done | +| Disk space estimation | Already existed | ✅ Pre-existing | + +### Removed (legacy components) + +| Component | Status | +|-----------|--------| +| `QuantizationCheckpoint` usage | ❌ Removed from executors | +| `BatchUndoBuffer` usage | ❌ Removed from executors | +| `trigger_bgsave_and_wait` / async | ❌ Removed from executors | +| `_quantize_vectors` / `_async_quantize_vectors` | ❌ Removed | +| `--resume` / `checkpoint_path` CLI flag | ❌ Removed | +| 32 legacy tests | ❌ Removed | + +Note: `reliability.py` module still exists (provides `is_same_width_dtype_conversion` +and other utilities). Its legacy classes are no longer imported by executors. + +**Test results:** 788 unit tests pass, mypy clean, pre-commit clean. \ No newline at end of file diff --git a/redisvl/cli/main.py b/redisvl/cli/main.py index 1353192fc..0147adc1b 100644 --- a/redisvl/cli/main.py +++ b/redisvl/cli/main.py @@ -2,6 +2,7 @@ import sys from redisvl.cli.index import Index +from redisvl.cli.migrate import Migrate from redisvl.cli.stats import Stats from redisvl.cli.version import Version from redisvl.utils.log import get_logger @@ -14,6 +15,7 @@ def _usage(): "rvl []\n", "Commands:", "\tindex Index manipulation (create, delete, etc.)", + "\tmigrate Index migration (plan, apply, validate)", "\tversion Obtain the version of RedisVL", "\tstats Obtain statistics about an index", ] @@ -46,6 +48,10 @@ def version(self): Version() exit(0) + def migrate(self): + Migrate() + exit(0) + def stats(self): Stats() exit(0) diff --git a/redisvl/cli/migrate.py b/redisvl/cli/migrate.py new file mode 100644 index 000000000..8ad07fe32 --- /dev/null +++ b/redisvl/cli/migrate.py @@ -0,0 +1,1023 @@ +import argparse +import asyncio +import os +import sys +from pathlib import Path +from typing import Optional + +from redisvl.cli.utils import add_redis_connection_options, create_redis_url +from redisvl.migration import ( + AsyncMigrationExecutor, + BatchMigrationExecutor, + BatchMigrationPlanner, + MigrationExecutor, + MigrationPlanner, + MigrationValidator, + MigrationWizard, +) +from redisvl.migration.utils import ( + detect_aof_enabled, + estimate_disk_space, + list_indexes, + load_migration_plan, + load_yaml, + write_benchmark_report, + write_migration_report, +) +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.utils.log import get_logger + +logger = get_logger("[RedisVL]") + + +class Migrate: + usage = "\n".join( + [ + "rvl migrate []\n", + "Commands:", + "\thelper Show migration guidance and supported capabilities", + "\tlist List all available indexes", + "\twizard Interactively build a migration plan and schema patch", + "\tplan Generate a migration plan for a document-preserving drop/recreate migration", + "\tapply Execute a reviewed drop/recreate migration plan (use --async for large migrations)", + "\testimate Estimate disk space required for a migration plan (dry-run, no mutations)", + "\trollback Restore original vectors from a backup directory (undo quantization)", + "\tvalidate Validate a completed migration plan against the live index", + "\tbatch-plan Generate a batch migration plan for multiple indexes", + "\tbatch-apply Execute a batch migration plan with checkpointing", + "\tbatch-resume Resume an interrupted batch migration", + "\tbatch-status Show status of an in-progress or completed batch migration", + "\n", + ] + ) + + def __init__(self): + parser = argparse.ArgumentParser(usage=self.usage) + parser.add_argument("command", help="Subcommand to run") + + args = parser.parse_args(sys.argv[2:3]) + command = args.command.replace("-", "_") + if not hasattr(self, command): + print(f"Unknown subcommand: {args.command}") + parser.print_help() + sys.exit(1) + + try: + getattr(self, command)() + except Exception as e: + logger.error(e) + sys.exit(1) + + def helper(self): + parser = argparse.ArgumentParser( + usage="rvl migrate helper [--host --port | --url ]" + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + redis_url = create_redis_url(args) + indexes = list_indexes(redis_url=redis_url) + + print("RedisVL Index Migrator\n\nAvailable indexes:") + if indexes: + for position, index_name in enumerate(indexes, start=1): + print(f" {position}. {index_name}") + else: + print(" (none found)") + + print( + """\nSupported changes: + - Adding or removing non-vector fields (text, tag, numeric, geo) + - Changing field options (sortable, separator, weight) + - Changing vector algorithm (FLAT, HNSW, SVS-VAMANA) + - Changing distance metric (COSINE, L2, IP) + - Tuning algorithm parameters (M, EF_CONSTRUCTION, EF_RUNTIME, EPSILON) + - Quantizing vectors (float32 to float16/bfloat16/int8/uint8) + - Changing key prefix (renames all keys) + - Renaming fields (updates all documents) + - Renaming the index + +Not yet supported: + - Changing vector dimensions + - Changing storage type (hash to JSON) + +Commands: + rvl migrate list List all indexes + rvl migrate wizard --index Guided migration builder + rvl migrate plan --index --schema-patch + rvl migrate apply --plan + rvl migrate validate --plan """ + ) + + def list(self): + parser = argparse.ArgumentParser( + usage="rvl migrate list [--host --port | --url ]" + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + redis_url = create_redis_url(args) + indexes = list_indexes(redis_url=redis_url) + print("Available indexes:") + for position, index_name in enumerate(indexes, start=1): + print(f"{position}. {index_name}") + + def wizard(self): + parser = argparse.ArgumentParser( + usage=( + "rvl migrate wizard [--index ] " + "[--patch ] " + "[--plan-out ] [--patch-out ]" + ) + ) + parser.add_argument("-i", "--index", help="Source index name", required=False) + parser.add_argument( + "--patch", + help="Load an existing schema patch to continue editing", + default=None, + ) + parser.add_argument( + "--plan-out", + help="Path to write migration_plan.yaml", + default="migration_plan.yaml", + ) + parser.add_argument( + "--patch-out", + help="Path to write schema_patch.yaml (for later editing)", + default="schema_patch.yaml", + ) + parser.add_argument( + "--target-schema-out", + help="Optional path to write the merged target schema", + default=None, + ) + parser.add_argument( + "--key-sample-limit", + help="Maximum number of keys to sample from the index keyspace", + type=int, + default=10, + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + redis_url = create_redis_url(args) + wizard = MigrationWizard( + planner=MigrationPlanner(key_sample_limit=args.key_sample_limit) + ) + plan = wizard.run( + index_name=args.index, + redis_url=redis_url, + existing_patch_path=args.patch, + plan_out=args.plan_out, + patch_out=args.patch_out, + target_schema_out=args.target_schema_out, + ) + self._print_plan_summary(args.plan_out, plan) + + def plan(self): + parser = argparse.ArgumentParser( + usage=( + "rvl migrate plan --index " + "(--schema-patch | --target-schema )" + ) + ) + parser.add_argument("-i", "--index", help="Source index name", required=True) + parser.add_argument("--schema-patch", help="Path to a schema patch file") + parser.add_argument("--target-schema", help="Path to a target schema file") + parser.add_argument( + "--plan-out", + help="Path to write migration_plan.yaml", + default="migration_plan.yaml", + ) + parser.add_argument( + "--key-sample-limit", + help="Maximum number of keys to sample from the index keyspace", + type=int, + default=10, + ) + parser = add_redis_connection_options(parser) + + args = parser.parse_args(sys.argv[3:]) + redis_url = create_redis_url(args) + planner = MigrationPlanner(key_sample_limit=args.key_sample_limit) + plan = planner.create_plan( + args.index, + redis_url=redis_url, + schema_patch_path=args.schema_patch, + target_schema_path=args.target_schema, + ) + planner.write_plan(plan, args.plan_out) + self._print_plan_summary(args.plan_out, plan) + + def apply(self): + parser = argparse.ArgumentParser( + usage=( + "rvl migrate apply --plan " + "[--async] [--backup-dir ] [--workers N] " + "[--report-out ]" + ) + ) + parser.add_argument("--plan", help="Path to migration_plan.yaml", required=True) + parser.add_argument( + "--async", + dest="use_async", + help="Use async executor (recommended for large migrations with quantization)", + action="store_true", + ) + parser.add_argument( + "--backup-dir", + dest="backup_dir", + help="Directory for vector backup files. Enables crash-safe resume and rollback.", + default=None, + ) + parser.add_argument( + "--batch-size", + dest="batch_size", + type=int, + help="Keys per pipeline batch (default 500)", + default=500, + ) + parser.add_argument( + "--workers", + dest="num_workers", + type=int, + help="Number of parallel workers for quantization (default 1). " + "Each worker gets its own Redis connection. Requires --backup-dir.", + default=1, + ) + parser.add_argument( + "--keep-backup", + dest="keep_backup", + action="store_true", + help="Keep backup files after successful migration (default: auto-delete).", + default=False, + ) + # Deprecated alias for --backup-dir (was --resume in previous versions) + parser.add_argument( + "--resume", + dest="legacy_resume", + help=argparse.SUPPRESS, # hidden from help + default=None, + ) + parser.add_argument( + "--report-out", + help="Path to write migration_report.yaml", + default="migration_report.yaml", + ) + parser.add_argument( + "--benchmark-out", + help="Optional path to write benchmark_report.yaml", + default=None, + ) + parser.add_argument( + "--query-check-file", + help="Optional YAML file containing fetch_ids and keys_exist checks", + default=None, + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + # Validate --workers + if args.num_workers < 1: + parser.error("--workers must be >= 1") + + # Handle deprecated --resume flag + if args.legacy_resume is not None: + import warnings + + # Fail fast if the value looks like a checkpoint file (old semantics) + if os.path.isfile(args.legacy_resume) or args.legacy_resume.endswith( + (".yaml", ".yml") + ): + parser.error( + "--resume semantics have changed: it now expects a backup " + "directory, not a checkpoint file. Use --backup-dir instead." + ) + + warnings.warn( + "--resume is deprecated and will be removed in a future version. " + "Use --backup-dir instead: the backup directory replaces " + "checkpoint files for crash-safe resume and rollback.", + DeprecationWarning, + stacklevel=1, + ) + if args.backup_dir is None: + args.backup_dir = args.legacy_resume + + # Validate --workers > 1 requires --backup-dir + if args.num_workers > 1 and args.backup_dir is None: + parser.error("--workers > 1 requires --backup-dir") + + redis_url = create_redis_url(args) + plan = load_migration_plan(args.plan) + + # Print disk space estimate for quantization migrations + aof_enabled = False + try: + client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url) + try: + aof_enabled = detect_aof_enabled(client) + finally: + client.close() + except Exception as exc: + logger.debug("Could not detect AOF for CLI preflight estimate: %s", exc) + + disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + print(f"\n{disk_estimate.summary()}\n") + + if args.use_async: + report = asyncio.run( + self._apply_async( + plan, + redis_url, + args.query_check_file, + backup_dir=args.backup_dir, + batch_size=args.batch_size, + num_workers=args.num_workers, + keep_backup=args.keep_backup, + ) + ) + else: + report = self._apply_sync( + plan, + redis_url, + args.query_check_file, + backup_dir=args.backup_dir, + batch_size=args.batch_size, + num_workers=args.num_workers, + keep_backup=args.keep_backup, + ) + + write_migration_report(report, args.report_out) + if args.benchmark_out: + write_benchmark_report(report, args.benchmark_out) + self._print_report_summary(args.report_out, report, args.benchmark_out) + + def estimate(self): + """Estimate disk space required for a migration plan (dry-run).""" + parser = argparse.ArgumentParser( + usage="rvl migrate estimate --plan " + ) + parser.add_argument("--plan", help="Path to migration_plan.yaml", required=True) + parser.add_argument( + "--aof-enabled", + action="store_true", + help="Include AOF growth in the disk space estimate", + ) + args = parser.parse_args(sys.argv[3:]) + + plan = load_migration_plan(args.plan) + disk_estimate = estimate_disk_space(plan, aof_enabled=args.aof_enabled) + print(disk_estimate.summary()) + + # Phases that indicate a safe/complete backup for rollback + _SAFE_ROLLBACK_PHASES = frozenset({"ready", "active", "completed"}) + + def rollback(self): + """Restore original vectors from a backup directory (undo quantization).""" + parser = argparse.ArgumentParser( + usage=( + "rvl migrate rollback --backup-dir " + "[--index ] [--yes] [--force] [--url ]" + ) + ) + parser.add_argument( + "--backup-dir", + dest="backup_dir", + help="Directory containing vector backup files from a prior migration", + required=True, + ) + parser.add_argument( + "--index", + dest="index_name", + help="Only restore backups for this index name (filters by backup header)", + default=None, + ) + parser.add_argument( + "--yes", + "-y", + dest="yes", + action="store_true", + help="Skip confirmation prompt for multi-index rollback", + default=False, + ) + parser.add_argument( + "--force", + dest="force", + action="store_true", + help="Proceed even if backup phase indicates incomplete dump", + default=False, + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + redis_url = create_redis_url(args) + + from redisvl.migration.backup import VectorBackup + from redisvl.redis.connection import RedisConnectionFactory + + # Find backup files in the directory + backup_dir = args.backup_dir + if not os.path.isdir(backup_dir): + print(f"Error: backup directory not found: {backup_dir}") + sys.exit(1) + + # Look for .header files to find backups + header_files = sorted(Path(backup_dir).glob("*.header")) + if not header_files: + print(f"Error: no backup files found in {backup_dir}") + sys.exit(1) + + # Derive backup base paths (strip .header suffix) + backup_paths = [str(h.with_suffix("")) for h in header_files] + + # Load, filter, and validate backups + backups_to_restore = [] + for bp in backup_paths: + backup = VectorBackup.load(bp) + if backup is None: + print(f" Skipping {bp}: could not load backup") + continue + if args.index_name and backup.header.index_name != args.index_name: + print( + f" Skipping {os.path.basename(bp)}: " + f"index '{backup.header.index_name}' != '{args.index_name}'" + ) + continue + # Gate on backup phase — refuse incomplete backups unless --force + if backup.header.phase not in self._SAFE_ROLLBACK_PHASES: + if args.force: + print( + f" Warning: {os.path.basename(bp)} has phase " + f"'{backup.header.phase}' (incomplete dump) — " + f"proceeding due to --force" + ) + else: + print( + f" Skipping {os.path.basename(bp)}: backup phase " + f"'{backup.header.phase}' indicates incomplete dump. " + f"Use --force to restore from partial backups." + ) + continue + backups_to_restore.append((bp, backup)) + + if not backups_to_restore: + print("Error: no matching backup files found") + sys.exit(1) + + # Require --index or --yes when multiple distinct indexes detected + distinct_indexes = {b.header.index_name for _, b in backups_to_restore} + if len(distinct_indexes) > 1 and not args.index_name and not args.yes: + print( + f"Error: found backups for {len(distinct_indexes)} distinct indexes: " + f"{', '.join(sorted(distinct_indexes))}. " + f"Use --index to filter or --yes to restore all." + ) + sys.exit(1) + + client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url) + total_restored = 0 + try: + for bp, backup in backups_to_restore: + print( + f"Restoring from: {os.path.basename(bp)} " + f"(index={backup.header.index_name}, " + f"phase={backup.header.phase}, " + f"batches={backup.header.dump_completed_batches})" + ) + + batch_count = 0 + for keys, originals in backup.iter_batches(): + pipe = client.pipeline(transaction=False) + batch_restored = 0 + for key in keys: + if key in originals: + for field_name, original_bytes in originals[key].items(): + pipe.hset(key, field_name, original_bytes) + batch_restored += 1 + pipe.execute() + batch_count += 1 + total_restored += batch_restored + if batch_count % 10 == 0: + print( + f" Restored {total_restored:,} vectors " + f"({batch_count}/{backup.header.dump_completed_batches} batches)" + ) + + print( + f" Done: {batch_count} batches restored from {os.path.basename(bp)}" + ) + finally: + client.close() + + print( + f"\nRollback complete: {total_restored:,} vectors restored to original values" + ) + print( + "Note: You may need to recreate the original index schema " + "(FT.CREATE) if the index was changed during migration." + ) + + @staticmethod + def _make_progress_callback(): + """Create a progress callback for migration apply.""" + step_labels = { + "enumerate": "[1/8] Enumerate keys", + "bgsave": "[2/8] BGSAVE snapshot", + "field_rename": "[3/8] Rename fields", + "drop": "[4/8] Drop index", + "key_rename": "[5/8] Rename keys", + "quantize": "[6/8] Quantize vectors", + "create": "[7/8] Create index", + "index": "[8/8] Re-indexing", + "validate": "Validate", + } + + def progress_callback(step: str, detail: Optional[str]) -> None: + label = step_labels.get(step, step) + if detail and not detail.startswith("done"): + print(f" {label}: {detail} ", end="\r", flush=True) + else: + print(f" {label}: {detail} ") + + return progress_callback + + def _apply_sync( + self, + plan, + redis_url: str, + query_check_file: Optional[str], + backup_dir: Optional[str] = None, + batch_size: int = 500, + num_workers: int = 1, + keep_backup: bool = False, + ): + """Execute migration synchronously.""" + executor = MigrationExecutor() + + print(f"\nApplying migration to '{plan.source.index_name}'...") + + report = executor.apply( + plan, + redis_url=redis_url, + query_check_file=query_check_file, + progress_callback=self._make_progress_callback(), + backup_dir=backup_dir, + batch_size=batch_size, + num_workers=num_workers, + keep_backup=keep_backup, + ) + + self._print_apply_result(report) + return report + + async def _apply_async( + self, + plan, + redis_url: str, + query_check_file: Optional[str], + backup_dir: Optional[str] = None, + batch_size: int = 500, + num_workers: int = 1, + keep_backup: bool = False, + ): + """Execute migration asynchronously (non-blocking for large quantization jobs).""" + executor = AsyncMigrationExecutor() + + print(f"\nApplying migration to '{plan.source.index_name}' (async mode)...") + + report = await executor.apply( + plan, + redis_url=redis_url, + query_check_file=query_check_file, + progress_callback=self._make_progress_callback(), + backup_dir=backup_dir, + batch_size=batch_size, + num_workers=num_workers, + keep_backup=keep_backup, + ) + + self._print_apply_result(report) + return report + + def _print_apply_result(self, report) -> None: + """Print the result summary after migration apply.""" + if report.result == "succeeded": + total_time = report.timings.total_migration_duration_seconds or 0 + downtime = report.timings.downtime_duration_seconds or 0 + print(f"\nMigration completed in {total_time}s (downtime: {downtime}s)") + else: + print(f"\nMigration {report.result}") + if report.validation.errors: + for error in report.validation.errors: + print(f" ERROR: {error}") + + def validate(self): + parser = argparse.ArgumentParser( + usage=( + "rvl migrate validate --plan " + "[--report-out ]" + ) + ) + parser.add_argument("--plan", help="Path to migration_plan.yaml", required=True) + parser.add_argument( + "--report-out", + help="Path to write migration_report.yaml", + default="migration_report.yaml", + ) + parser.add_argument( + "--benchmark-out", + help="Optional path to write benchmark_report.yaml", + default=None, + ) + parser.add_argument( + "--query-check-file", + help="Optional YAML file containing fetch_ids and keys_exist checks", + default=None, + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + redis_url = create_redis_url(args) + plan = load_migration_plan(args.plan) + validator = MigrationValidator() + + from redisvl.migration.utils import timestamp_utc + + started_at = timestamp_utc() + validation, target_info, validation_duration = validator.validate( + plan, + redis_url=redis_url, + query_check_file=args.query_check_file, + ) + finished_at = timestamp_utc() + + from redisvl.migration.models import ( + MigrationBenchmarkSummary, + MigrationReport, + MigrationTimings, + ) + + source_size = float( + plan.source.stats_snapshot.get("vector_index_sz_mb", 0) or 0 + ) + target_size = float(target_info.get("vector_index_sz_mb", 0) or 0) + + report = MigrationReport( + source_index=plan.source.index_name, + target_index=plan.merged_target_schema["index"]["name"], + result="succeeded" if not validation.errors else "failed", + started_at=started_at, + finished_at=finished_at, + timings=MigrationTimings(validation_duration_seconds=validation_duration), + validation=validation, + benchmark_summary=MigrationBenchmarkSummary( + source_index_size_mb=round(source_size, 3), + target_index_size_mb=round(target_size, 3), + index_size_delta_mb=round(target_size - source_size, 3), + ), + warnings=list(plan.warnings), + manual_actions=( + ["Review validation errors before proceeding."] + if validation.errors + else [] + ), + ) + write_migration_report(report, args.report_out) + if args.benchmark_out: + write_benchmark_report(report, args.benchmark_out) + self._print_report_summary(args.report_out, report, args.benchmark_out) + + def _print_plan_summary(self, plan_out: str, plan) -> None: + import os + + abs_path = os.path.abspath(plan_out) + print( + f"""Migration plan written to {abs_path} +Mode: {plan.mode} +Supported: {plan.diff_classification.supported}""" + ) + if plan.warnings: + print("Warnings:") + for warning in plan.warnings: + print(f"- {warning}") + if plan.diff_classification.blocked_reasons: + print("Blocked reasons:") + for reason in plan.diff_classification.blocked_reasons: + print(f"- {reason}") + + print( + f"""\nNext steps: + Review the plan: cat {plan_out} + Apply the migration: rvl migrate apply --plan {plan_out} + Validate the result: rvl migrate validate --plan {plan_out} + To cancel: rm {plan_out}""" + ) + + def _print_report_summary( + self, + report_out: str, + report, + benchmark_out: Optional[str], + ) -> None: + print( + f"""Migration report written to {report_out} +Result: {report.result} +Schema match: {report.validation.schema_match} +Doc count match: {report.validation.doc_count_match} +Key sample exists: {report.validation.key_sample_exists} +Indexing failures delta: {report.validation.indexing_failures_delta}""" + ) + if report.validation.errors: + print("Errors:") + for error in report.validation.errors: + print(f"- {error}") + if report.manual_actions: + print("Manual actions:") + for action in report.manual_actions: + print(f"- {action}") + if benchmark_out: + print(f"Benchmark report written to {benchmark_out}") + + def batch_plan(self): + """Generate a batch migration plan for multiple indexes.""" + parser = argparse.ArgumentParser( + usage=( + "rvl migrate batch-plan --schema-patch " + "(--pattern | --indexes | --indexes-file )" + ) + ) + parser.add_argument( + "--schema-patch", help="Path to shared schema patch file", required=True + ) + parser.add_argument( + "--pattern", help="Glob pattern to match index names (e.g., '*_idx')" + ) + parser.add_argument("--indexes", help="Comma-separated list of index names") + parser.add_argument( + "--indexes-file", help="File with index names (one per line)" + ) + parser.add_argument( + "--failure-policy", + help="How to handle failures: fail_fast or continue_on_error", + choices=["fail_fast", "continue_on_error"], + default="fail_fast", + ) + parser.add_argument( + "--plan-out", + help="Path to write batch_plan.yaml", + default="batch_plan.yaml", + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + redis_url = create_redis_url(args) + indexes = ( + [idx.strip() for idx in args.indexes.split(",") if idx.strip()] + if args.indexes + else None + ) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=indexes, + pattern=args.pattern, + indexes_file=args.indexes_file, + schema_patch_path=args.schema_patch, + redis_url=redis_url, + failure_policy=args.failure_policy, + ) + + planner.write_batch_plan(batch_plan, args.plan_out) + self._print_batch_plan_summary(args.plan_out, batch_plan) + + def batch_apply(self): + """Execute a batch migration plan with checkpointing.""" + parser = argparse.ArgumentParser( + usage=( + "rvl migrate batch-apply --plan " + "[--state ] [--report-dir <./reports>]" + ) + ) + parser.add_argument("--plan", help="Path to batch_plan.yaml", required=True) + parser.add_argument( + "--accept-data-loss", + help="Acknowledge that quantization is lossy and cannot be reverted", + action="store_true", + ) + parser.add_argument( + "--state", + help="Path to checkpoint state file", + default="batch_state.yaml", + ) + parser.add_argument( + "--report-dir", + help="Directory for per-index migration reports", + default="./reports", + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + from redisvl.migration.models import BatchPlan + + plan_data = load_yaml(args.plan) + batch_plan = BatchPlan.model_validate(plan_data) + + if batch_plan.requires_quantization and not args.accept_data_loss: + print( + """WARNING: This batch migration includes quantization (e.g., float32 -> float16). + Vector data will be modified. Original precision cannot be recovered. + To proceed, add --accept-data-loss flag. + + If you need to preserve original vectors, backup your data first: + redis-cli BGSAVE""" + ) + sys.exit(1) + + redis_url = create_redis_url(args) + executor = BatchMigrationExecutor() + + def progress_callback( + index_name: str, position: int, total: int, status: str + ) -> None: + print(f"[{position}/{total}] {index_name}: {status}") + + report = executor.apply( + batch_plan, + batch_plan_path=args.plan, + state_path=args.state, + report_dir=args.report_dir, + redis_url=redis_url, + progress_callback=progress_callback, + ) + + self._print_batch_report_summary(report) + + def batch_resume(self): + """Resume an interrupted batch migration.""" + parser = argparse.ArgumentParser( + usage=( + "rvl migrate batch-resume --state " + "[--plan ] [--retry-failed]" + ) + ) + parser.add_argument( + "--state", help="Path to checkpoint state file", required=True + ) + parser.add_argument( + "--plan", help="Path to batch_plan.yaml (optional, uses state.plan_path)" + ) + parser.add_argument( + "--retry-failed", + help="Retry previously failed indexes", + action="store_true", + ) + parser.add_argument( + "--accept-data-loss", + help="Acknowledge vector quantization data loss", + action="store_true", + ) + parser.add_argument( + "--report-dir", + help="Directory for per-index migration reports", + default="./reports", + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + # Load the batch plan to check for quantization safety gate + executor = BatchMigrationExecutor() + state = executor._load_state(args.state) + plan_path = args.plan or state.plan_path or None + if plan_path: + batch_plan = executor._load_batch_plan(plan_path) + if batch_plan.requires_quantization and not args.accept_data_loss: + print( + """WARNING: This batch migration includes quantization (e.g., float32 -> float16). + Vector data will be modified. Original precision cannot be recovered. + To proceed, add --accept-data-loss flag. + + If you need to preserve original vectors, backup your data first: + redis-cli BGSAVE""" + ) + sys.exit(1) + + redis_url = create_redis_url(args) + + def progress_callback( + index_name: str, position: int, total: int, status: str + ) -> None: + print(f"[{position}/{total}] {index_name}: {status}") + + report = executor.resume( + args.state, + batch_plan_path=args.plan, + retry_failed=args.retry_failed, + report_dir=args.report_dir, + redis_url=redis_url, + progress_callback=progress_callback, + ) + + self._print_batch_report_summary(report) + + def batch_status(self): + """Show status of an in-progress or completed batch migration.""" + parser = argparse.ArgumentParser( + usage="rvl migrate batch-status --state " + ) + parser.add_argument( + "--state", help="Path to checkpoint state file", required=True + ) + args = parser.parse_args(sys.argv[3:]) + + state_path = Path(args.state).resolve() + if not state_path.exists(): + print(f"State file not found: {args.state}") + sys.exit(1) + + from redisvl.migration.models import BatchState + + state_data = load_yaml(args.state) + state = BatchState.model_validate(state_data) + + print( + f"""Batch ID: {state.batch_id} +Started at: {state.started_at} +Updated at: {state.updated_at} +Current index: {state.current_index or '(none)'} +Remaining: {len(state.remaining)} +Completed: {len(state.completed)} + - Succeeded: {state.success_count} + - Failed: {state.failed_count} + - Skipped: {state.skipped_count}""" + ) + + if state.completed: + print("\nCompleted indexes:") + for idx in state.completed: + if idx.status == "success": + status_icon = "[OK]" + elif idx.status == "skipped": + status_icon = "[SKIP]" + else: + status_icon = "[FAIL]" + print(f" {status_icon} {idx.name}") + if idx.error: + print(f" Error: {idx.error}") + + if state.remaining: + print(f"\nRemaining indexes ({len(state.remaining)}):") + for name in state.remaining[:10]: + print(f" - {name}") + if len(state.remaining) > 10: + print(f" ... and {len(state.remaining) - 10} more") + + def _print_batch_plan_summary(self, plan_out: str, batch_plan) -> None: + """Print summary after generating batch plan.""" + import os + + abs_path = os.path.abspath(plan_out) + print( + f"""Batch plan written to {abs_path} +Batch ID: {batch_plan.batch_id} +Mode: {batch_plan.mode} +Failure policy: {batch_plan.failure_policy} +Requires quantization: {batch_plan.requires_quantization} +Total indexes: {len(batch_plan.indexes)} + - Applicable: {batch_plan.applicable_count} + - Skipped: {batch_plan.skipped_count}""" + ) + + if batch_plan.skipped_count > 0: + print("\nSkipped indexes:") + for idx in batch_plan.indexes: + if not idx.applicable: + print(f" - {idx.name}: {idx.skip_reason}") + + print( + f""" +Next steps: + Review the plan: cat {plan_out} + Apply the migration: rvl migrate batch-apply --plan {plan_out}""" + ) + + if batch_plan.requires_quantization: + print(" (add --accept-data-loss for quantization)") + + def _print_batch_report_summary(self, report) -> None: + """Print summary after batch migration completes.""" + print( + f""" +Batch migration {report.status} +Batch ID: {report.batch_id} +Duration: {report.summary.total_duration_seconds}s +Total: {report.summary.total_indexes} + - Succeeded: {report.summary.successful} + - Failed: {report.summary.failed} + - Skipped: {report.summary.skipped}""" + ) + + if report.summary.failed > 0: + print("\nFailed indexes:") + for idx in report.indexes: + if idx.status == "failed": + print(f" - {idx.name}: {idx.error}") diff --git a/redisvl/cli/utils.py b/redisvl/cli/utils.py index 5d76a1842..9b19a126c 100644 --- a/redisvl/cli/utils.py +++ b/redisvl/cli/utils.py @@ -14,9 +14,10 @@ def create_redis_url(args: Namespace) -> str: elif args.url: return args.url else: - url = "redis://" if args.ssl: - url += "rediss://" + url = "rediss://" + else: + url = "redis://" if args.user: url += args.user if args.password: @@ -26,11 +27,7 @@ def create_redis_url(args: Namespace) -> str: return url -def add_index_parsing_options(parser: ArgumentParser) -> ArgumentParser: - parser.add_argument("-i", "--index", help="Index name", type=str, required=False) - parser.add_argument( - "-s", "--schema", help="Path to schema file", type=str, required=False - ) +def add_redis_connection_options(parser: ArgumentParser) -> ArgumentParser: parser.add_argument("-u", "--url", help="Redis URL", type=str, required=False) parser.add_argument("--host", help="Redis host", type=str, default="localhost") parser.add_argument("-p", "--port", help="Redis port", type=int, default=6379) @@ -38,3 +35,11 @@ def add_index_parsing_options(parser: ArgumentParser) -> ArgumentParser: parser.add_argument("--ssl", help="Use SSL", action="store_true") parser.add_argument("-a", "--password", help="Redis password", type=str, default="") return parser + + +def add_index_parsing_options(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument("-i", "--index", help="Index name", type=str, required=False) + parser.add_argument( + "-s", "--schema", help="Path to schema file", type=str, required=False + ) + return add_redis_connection_options(parser) diff --git a/redisvl/migration/__init__.py b/redisvl/migration/__init__.py new file mode 100644 index 000000000..adb0f118a --- /dev/null +++ b/redisvl/migration/__init__.py @@ -0,0 +1,25 @@ +from redisvl.migration.async_executor import AsyncMigrationExecutor +from redisvl.migration.async_planner import AsyncMigrationPlanner +from redisvl.migration.async_validation import AsyncMigrationValidator +from redisvl.migration.batch_executor import BatchMigrationExecutor +from redisvl.migration.batch_planner import BatchMigrationPlanner +from redisvl.migration.executor import MigrationExecutor +from redisvl.migration.models import BatchPlan, BatchState, SchemaPatch +from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.validation import MigrationValidator +from redisvl.migration.wizard import MigrationWizard + +__all__ = [ + "AsyncMigrationExecutor", + "AsyncMigrationPlanner", + "AsyncMigrationValidator", + "BatchMigrationExecutor", + "BatchMigrationPlanner", + "BatchPlan", + "BatchState", + "MigrationExecutor", + "MigrationPlanner", + "MigrationValidator", + "MigrationWizard", + "SchemaPatch", +] diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py new file mode 100644 index 000000000..07344d4aa --- /dev/null +++ b/redisvl/migration/async_executor.py @@ -0,0 +1,1364 @@ +from __future__ import annotations + +import asyncio +import hashlib +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional + +if TYPE_CHECKING: + from redisvl.migration.backup import VectorBackup + +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +from redis.exceptions import ResponseError + +from redisvl.index import AsyncSearchIndex +from redisvl.migration.async_planner import AsyncMigrationPlanner +from redisvl.migration.async_validation import AsyncMigrationValidator +from redisvl.migration.models import ( + MigrationBenchmarkSummary, + MigrationPlan, + MigrationReport, + MigrationTimings, + MigrationValidation, +) +from redisvl.migration.reliability import is_same_width_dtype_conversion +from redisvl.migration.utils import ( + build_scan_match_patterns, + estimate_disk_space, + get_schema_field_path, + normalize_keys, + timestamp_utc, +) +from redisvl.types import AsyncRedisClient +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + + +class AsyncMigrationExecutor: + """Async migration executor for document-preserving drop/recreate flows. + + This is the async version of MigrationExecutor. It uses AsyncSearchIndex + and async Redis operations for better performance on large indexes, + especially during vector quantization. + """ + + def __init__(self, validator: Optional[AsyncMigrationValidator] = None): + self.validator = validator or AsyncMigrationValidator() + + async def _detect_aof_enabled(self, client: Any) -> bool: + """Best-effort detection of whether AOF is enabled on the live Redis.""" + try: + info = await client.info("persistence") + if isinstance(info, dict) and "aof_enabled" in info: + return bool(int(info["aof_enabled"])) + except Exception: + logger.debug("Could not read Redis INFO persistence for AOF detection.") + + try: + config = await client.config_get("appendonly") + if isinstance(config, dict): + value = config.get("appendonly") + if value is not None: + return str(value).lower() in {"yes", "1", "true", "on"} + except Exception: + logger.debug("Could not read Redis CONFIG GET appendonly.") + + return False + + async def _enumerate_indexed_keys( + self, + client: AsyncRedisClient, + index_name: str, + batch_size: int = 1000, + key_separator: str = ":", + ) -> AsyncGenerator[str, None]: + """Async version: Enumerate document keys using FT.AGGREGATE with SCAN fallback. + + Uses FT.AGGREGATE WITHCURSOR for efficient enumeration when the index + has no indexing failures. Falls back to SCAN if: + - Index has hash_indexing_failures > 0 (would miss failed docs) + - FT.AGGREGATE command fails for any reason + """ + # Check for indexing failures - if any, fall back to SCAN + try: + info = await client.ft(index_name).info() + failures = int(info.get("hash_indexing_failures", 0) or 0) + if failures > 0: + logger.warning( + f"Index '{index_name}' has {failures} indexing failures. " + "Using SCAN for complete enumeration." + ) + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): + yield key + return + except Exception as e: + logger.warning(f"Failed to check index info: {e}. Using SCAN fallback.") + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): + yield key + return + + # Try FT.AGGREGATE enumeration + try: + async for key in self._enumerate_with_aggregate( + client, index_name, batch_size + ): + yield key + except ResponseError as e: + logger.warning( + f"FT.AGGREGATE failed: {e}. Falling back to SCAN enumeration." + ) + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): + yield key + + async def _enumerate_with_aggregate( + self, + client: AsyncRedisClient, + index_name: str, + batch_size: int = 1000, + ) -> AsyncGenerator[str, None]: + """Async version: Enumerate keys using FT.AGGREGATE WITHCURSOR. + + Uses MAXIDLE to extend the server-side cursor idle timeout (default + ~5 min). If the cursor still expires, the ResponseError propagates + so the caller can fall back to SCAN. + """ + cursor_id: Optional[int] = None + + try: + # Initial aggregate call with LOAD 1 __key + result = await client.execute_command( + "FT.AGGREGATE", + index_name, + "*", + "LOAD", + "1", + "__key", + "WITHCURSOR", + "COUNT", + str(batch_size), + "MAXIDLE", + "300000", + ) + + while True: + results_data, cursor_id = result + + # Extract keys from results + for item in results_data[1:]: + if isinstance(item, (list, tuple)) and len(item) >= 2: + key = item[1] + yield key.decode() if isinstance(key, bytes) else str(key) + + if cursor_id == 0: + break + + result = await client.execute_command( + "FT.CURSOR", + "READ", + index_name, + str(cursor_id), + "COUNT", + str(batch_size), + ) + finally: + if cursor_id and cursor_id != 0: + try: + await client.execute_command( + "FT.CURSOR", "DEL", index_name, str(cursor_id) + ) + except Exception: + pass + + async def _enumerate_with_scan( + self, + client: AsyncRedisClient, + index_name: str, + batch_size: int = 1000, + key_separator: str = ":", + ) -> AsyncGenerator[str, None]: + """Async version: Enumerate keys using SCAN with prefix matching.""" + # Get prefix from index info + try: + info = await client.ft(index_name).info() + if isinstance(info, dict): + prefixes = info.get("index_definition", {}).get("prefixes", []) + else: + prefixes = [] + for i, item in enumerate(info): + if item == b"index_definition" or item == "index_definition": + defn = info[i + 1] + if isinstance(defn, dict): + prefixes = defn.get("prefixes", []) + elif isinstance(defn, list): + for j, d in enumerate(defn): + if d in (b"prefixes", "prefixes") and j + 1 < len(defn): + prefixes = defn[j + 1] + break + normalized_prefixes = [ + p.decode() if isinstance(p, bytes) else str(p) for p in prefixes + ] + except Exception as e: + logger.warning(f"Failed to get prefix from index info: {e}") + normalized_prefixes = [] + + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + normalized_prefixes, key_separator + ): + cursor: int = 0 + while True: + cursor, keys = await client.scan( + cursor=cursor, + match=match_pattern, + count=batch_size, + ) + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else str(key) + if key_str not in seen_keys: + seen_keys.add(key_str) + yield key_str + + if cursor == 0: + break + + async def _rename_keys( + self, + client: AsyncRedisClient, + keys: List[str], + old_prefix: str, + new_prefix: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Async version: Rename keys from old prefix to new prefix. + + Uses RENAMENX for standalone Redis. For Redis Cluster, falls back + to DUMP/RESTORE/DEL to avoid CROSSSLOT errors. + """ + is_cluster = isinstance(client, AsyncRedisCluster) + if is_cluster: + return await self._rename_keys_cluster( + client, keys, old_prefix, new_prefix, progress_callback + ) + return await self._rename_keys_standalone( + client, keys, old_prefix, new_prefix, progress_callback + ) + + async def _rename_keys_standalone( + self, + client: AsyncRedisClient, + keys: List[str], + old_prefix: str, + new_prefix: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Rename keys using pipelined RENAMENX (standalone Redis only).""" + renamed = 0 + total = len(keys) + pipeline_size = 100 + collisions: List[str] = [] + successfully_renamed: List[tuple] = [] + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + pipe = client.pipeline(transaction=False) + batch_key_pairs: List[tuple] = [] + + for key in batch: + if key.startswith(old_prefix): + new_key = new_prefix + key[len(old_prefix) :] + else: + logger.warning( + f"Key '{key}' does not start with prefix '{old_prefix}'" + ) + continue + pipe.renamenx(key, new_key) + batch_key_pairs.append((key, new_key)) + + try: + results = await pipe.execute() + for j, r in enumerate(results): + if r is True or r == 1: + renamed += 1 + successfully_renamed.append(batch_key_pairs[j]) + else: + collisions.append(batch_key_pairs[j][1]) + except Exception as e: + logger.warning(f"Error in rename batch: {e}") + raise + + if collisions: + raise RuntimeError( + f"Prefix rename aborted after {renamed} successful rename(s): " + f"{len(collisions)} destination key(s) already exist " + f"(first 5: {collisions[:5]}). This would overwrite existing data. " + f"Remove conflicting keys or choose a different prefix. " + f"Note: {renamed} key(s) were already renamed from " + f"'{old_prefix}*' to '{new_prefix}*' and must be reversed " + f"manually if you want to retry." + ) + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + async def _rename_keys_cluster( + self, + client: AsyncRedisClient, + keys: List[str], + old_prefix: str, + new_prefix: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Rename keys using batched DUMP/RESTORE/DEL for Redis Cluster. + + RENAME/RENAMENX raises CROSSSLOT errors when source and destination + hash to different slots. DUMP/RESTORE works across slots. + + Batches DUMP+PTTL reads and RESTORE+DEL writes in groups of + ``pipeline_size`` to reduce per-key round-trip overhead. + """ + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + # Build (key, new_key) pairs for this batch + pairs = [] + for key in batch: + if not key.startswith(old_prefix): + logger.warning( + "Key '%s' does not start with prefix '%s'", key, old_prefix + ) + continue + new_key = new_prefix + key[len(old_prefix) :] + pairs.append((key, new_key)) + + if not pairs: + continue + + # Phase 1: Check destination keys don't exist (batched) + check_pipe = client.pipeline(transaction=False) + for _, new_key in pairs: + check_pipe.exists(new_key) + exists_results = await check_pipe.execute() + for (_, new_key), exists in zip(pairs, exists_results): + if exists: + raise RuntimeError( + f"Prefix rename aborted after {renamed} successful rename(s): " + f"destination key '{new_key}' already exists. " + f"Remove conflicting keys or choose a different prefix." + ) + + # Phase 2: DUMP + PTTL all source keys (batched — 1 RTT) + dump_pipe = client.pipeline(transaction=False) + for key, _ in pairs: + dump_pipe.dump(key) + dump_pipe.pttl(key) + dump_results = await dump_pipe.execute() + + # Phase 3: RESTORE + DEL (batched — 1 RTT) + restore_pipe = client.pipeline(transaction=False) + valid_pairs = [] + for idx, (key, new_key) in enumerate(pairs): + dumped = dump_results[idx * 2] + ttl = dump_results[idx * 2 + 1] + if dumped is None: + logger.warning("Key '%s' does not exist, skipping", key) + continue + restore_ttl = max(ttl, 0) + restore_pipe.restore(new_key, restore_ttl, dumped, replace=False) + restore_pipe.delete(key) + valid_pairs.append((key, new_key)) + + if valid_pairs: + await restore_pipe.execute() + renamed += len(valid_pairs) + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + if progress_callback: + progress_callback(total, total) + + return renamed + + async def _rename_field_in_hash( + self, + client: AsyncRedisClient, + keys: List[str], + old_name: str, + new_name: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Async version: Rename a field in hash documents.""" + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + # Get old field values AND check if destination exists + pipe = client.pipeline(transaction=False) + for key in batch: + pipe.hget(key, old_name) + pipe.hexists(key, new_name) + raw_results = await pipe.execute() + # Interleaved: [hget_0, hexists_0, hget_1, hexists_1, ...] + values = raw_results[0::2] + dest_exists = raw_results[1::2] + + pipe = client.pipeline(transaction=False) + batch_ops = 0 + for key, value, exists in zip(batch, values, dest_exists): + if value is not None: + if exists: + logger.warning( + "Field '%s' already exists in key '%s'; " + "overwriting with value from '%s'", + new_name, + key, + old_name, + ) + pipe.hset(key, new_name, value) + pipe.hdel(key, old_name) + batch_ops += 1 + + try: + await pipe.execute() + # Count by number of keys that had old field values, + # not by HSET return (HSET returns 0 for existing field updates) + renamed += batch_ops + except Exception as e: + logger.warning(f"Error in field rename batch: {e}") + raise + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + async def _rename_field_in_json( + self, + client: AsyncRedisClient, + keys: List[str], + old_path: str, + new_path: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Async version: Rename a field in JSON documents.""" + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + pipe = client.pipeline(transaction=False) + for key in batch: + pipe.json().get(key, old_path) + values = await pipe.execute() + + # JSONPath GET returns results as a list; unwrap single-element + # results to preserve the original document shape. + # Missing paths return None or [] depending on Redis version. + pipe = client.pipeline(transaction=False) + batch_ops = 0 + for key, value in zip(batch, values): + if value is None or value == []: + continue + if isinstance(value, list) and len(value) == 1: + value = value[0] + pipe.json().set(key, new_path, value) + pipe.json().delete(key, old_path) + batch_ops += 1 + try: + await pipe.execute() + # Count by number of keys that had old field values, + # not by JSON.SET return value + renamed += batch_ops + except Exception as e: + logger.warning(f"Error in JSON field rename batch: {e}") + raise + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + async def apply( + self, + plan: MigrationPlan, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + query_check_file: Optional[str] = None, + progress_callback: Optional[Callable[[str, Optional[str]], None]] = None, + backup_dir: Optional[str] = None, + batch_size: int = 500, + num_workers: int = 1, + keep_backup: bool = False, + checkpoint_path: Optional[str] = None, # deprecated, use backup_dir + ) -> MigrationReport: + """Apply a migration plan asynchronously. + + Async counterpart of :meth:`MigrationExecutor.apply`. Uses + ``await`` for Redis I/O so the event loop remains responsive during + large quantization jobs. Multi-worker quantization uses + ``asyncio.gather`` with independent connections. + + Args: + plan: The migration plan to apply (from + ``AsyncMigrationPlanner.create_plan``). + redis_url: Redis connection URL (e.g. + ``"redis://localhost:6379"``). Required when + *num_workers* > 1. + redis_client: Optional existing async Redis client. + query_check_file: Optional YAML file with post-migration queries. + progress_callback: Optional ``callback(step, detail)``. + backup_dir: Directory for vector backup files. Enables crash-safe + resume and rollback. Required when *num_workers* > 1. + Disk usage ≈ ``num_docs × dims × bytes_per_element``. + batch_size: Keys per pipeline batch (default 500). Values + between 200 and 1000 are typical. + num_workers: Parallel quantization workers (default 1). For + low-dimensional vectors (≤ 256 dims) a single worker is + often fastest. Diminishing returns above 4–8 workers. + keep_backup: Retain backup files after success (default + ``False``). + """ + started_at = timestamp_utc() + started = time.perf_counter() + + report = MigrationReport( + source_index=plan.source.index_name, + target_index=plan.merged_target_schema["index"]["name"], + result="failed", + started_at=started_at, + finished_at=started_at, + warnings=list(plan.warnings), + ) + + if not plan.diff_classification.supported: + report.validation.errors.extend(plan.diff_classification.blocked_reasons) + report.manual_actions.append( + "This change requires document migration, which is not yet supported." + ) + report.finished_at = timestamp_utc() + return report + + # Handle deprecated checkpoint_path parameter + if checkpoint_path is not None: + import warnings + + warnings.warn( + "checkpoint_path is deprecated and will be removed in a future " + "version. Use backup_dir instead.", + DeprecationWarning, + stacklevel=2, + ) + if backup_dir is None: + backup_dir = checkpoint_path + + # Check if we are resuming from a backup file (post-crash). + from redisvl.migration.backup import VectorBackup + + resuming_from_backup = False + existing_backup: Optional[VectorBackup] = None + backup_path: Optional[str] = None + + if backup_dir: + safe_name = ( + plan.source.index_name.replace("/", "_") + .replace("\\", "_") + .replace(":", "_") + ) + name_hash = hashlib.sha256(plan.source.index_name.encode()).hexdigest()[:8] + backup_path = str( + Path(backup_dir) / f"migration_backup_{safe_name}_{name_hash}" + ) + existing_backup = VectorBackup.load(backup_path) + + # Fallback: probe for legacy backup filename (pre-hash naming) + if existing_backup is None: + legacy_path = str(Path(backup_dir) / f"migration_backup_{safe_name}") + legacy_backup = VectorBackup.load(legacy_path) + if legacy_backup is not None: + logger.info( + "Found legacy backup at %s (pre-hash naming), using it", + legacy_path, + ) + existing_backup = legacy_backup + backup_path = legacy_path + + if existing_backup is not None: + if existing_backup.header.index_name != plan.source.index_name: + existing_backup = None + elif existing_backup.header.phase == "completed": + resuming_from_backup = True + elif existing_backup.header.phase in ("active", "ready"): + resuming_from_backup = True + elif existing_backup.header.phase == "dump": + Path(backup_path + ".header").unlink(missing_ok=True) + Path(backup_path + ".data").unlink(missing_ok=True) + existing_backup = None + + resuming = resuming_from_backup + + if not resuming: + if not await self._async_current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ): + report.validation.errors.append( + "The current live source schema no longer matches the saved source snapshot." + ) + report.manual_actions.append( + "Re-run `rvl migrate plan` to refresh the migration plan before applying." + ) + report.finished_at = timestamp_utc() + return report + + source_index = await AsyncSearchIndex.from_existing( + plan.source.index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + else: + # Source index was dropped before crash; reconstruct from snapshot + # to get a valid AsyncSearchIndex with a Redis client attached. + source_index = AsyncSearchIndex.from_dict( + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ) + + target_index = AsyncSearchIndex.from_dict( + plan.merged_target_schema, + redis_url=redis_url, + redis_client=redis_client, + ) + + enumerate_duration = 0.0 + drop_duration = 0.0 + quantize_duration = 0.0 + field_rename_duration = 0.0 + key_rename_duration = 0.0 + recreate_duration = 0.0 + indexing_duration = 0.0 + target_info: Dict[str, Any] = {} + docs_quantized = 0 + keys_to_process: List[str] = [] + storage_type = plan.source.keyspace.storage_type + + datatype_changes = AsyncMigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, + plan.merged_target_schema, + rename_operations=plan.rename_operations, + ) + + # Check for rename operations + rename_ops = plan.rename_operations + has_prefix_change = rename_ops.change_prefix is not None + has_field_renames = bool(rename_ops.rename_fields) + needs_quantization = bool(datatype_changes) and storage_type != "json" + needs_enumeration = needs_quantization or has_prefix_change or has_field_renames + has_same_width_quantization = any( + is_same_width_dtype_conversion(change["source"], change["target"]) + for change in datatype_changes.values() + ) + + if backup_dir and has_same_width_quantization: + report.validation.errors.append( + "Crash-safe resume is not supported for same-width datatype " + "changes (float16<->bfloat16 or int8<->uint8)." + ) + report.manual_actions.append( + "Re-run without --backup-dir for same-width vector conversions, or " + "split the migration to avoid same-width datatype changes." + ) + report.finished_at = timestamp_utc() + return report + + def _notify(step: str, detail: Optional[str] = None) -> None: + if progress_callback: + progress_callback(step, detail) + + try: + client = await source_index._get_client() + if client is None: + raise ValueError("Failed to get Redis client from source index") + aof_enabled = await self._detect_aof_enabled(client) + disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + logger.info( + "Disk space estimate: RDB ~%d bytes, AOF ~%d bytes, total ~%d bytes", + disk_estimate.rdb_snapshot_disk_bytes, + disk_estimate.aof_growth_bytes, + disk_estimate.total_new_disk_bytes, + ) + report.disk_space_estimate = disk_estimate + + if resuming_from_backup and existing_backup is not None: + if existing_backup.header.phase == "completed": + _notify("enumerate", "skipped (resume from backup)") + _notify("drop", "skipped (already dropped)") + _notify("quantize", "skipped (already completed)") + elif existing_backup.header.phase in ("active", "ready"): + _notify("enumerate", "skipped (resume from backup)") + _notify("drop", "skipped (already dropped)") + effective_changes = datatype_changes + if has_field_renames: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } + _notify("quantize", "Resuming vector re-encoding from backup...") + quantize_started = time.perf_counter() + docs_quantized = await self._quantize_from_backup( + client=client, + backup=existing_backup, + datatype_changes=effective_changes, + progress_callback=lambda done, total: _notify( + "quantize", f"{done:,}/{total:,} docs" + ), + ) + quantize_duration = round(time.perf_counter() - quantize_started, 3) + _notify( + "quantize", + f"done ({docs_quantized:,} docs in {quantize_duration}s)", + ) + + # Key prefix renames may not have happened before the crash + # (they run after index drop in the normal path). Re-apply + # idempotently. + if has_prefix_change: + resume_keys = [] + for batch_keys, _ in existing_backup.iter_batches(): + resume_keys.extend(batch_keys) + if resume_keys: + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + assert new_prefix is not None + _notify("key_rename", "Renaming keys (resume)...") + key_rename_started = time.perf_counter() + renamed_count = await self._rename_keys( + client, + resume_keys, + old_prefix, + new_prefix, + progress_callback=lambda done, total: _notify( + "key_rename", f"{done:,}/{total:,} keys" + ), + ) + key_rename_duration = round( + time.perf_counter() - key_rename_started, 3 + ) + _notify( + "key_rename", + f"done ({renamed_count:,} keys in {key_rename_duration}s)", + ) + else: + # Normal (non-resume) path + if needs_enumeration: + _notify("enumerate", "Enumerating indexed documents...") + enumerate_started = time.perf_counter() + keys_to_process = [ + key + async for key in self._enumerate_indexed_keys( + client, + plan.source.index_name, + batch_size=1000, + key_separator=plan.source.keyspace.key_separator, + ) + ] + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) + + # Field renames + if has_field_renames and keys_to_process: + _notify("field_rename", "Renaming fields in documents...") + field_rename_started = time.perf_counter() + for field_rename in rename_ops.rename_fields: + if storage_type == "json": + old_path = get_schema_field_path( + plan.source.schema_snapshot, field_rename.old_name + ) + new_path = get_schema_field_path( + plan.merged_target_schema, field_rename.new_name + ) + if not old_path or not new_path or old_path == new_path: + continue + await self._rename_field_in_json( + client, + keys_to_process, + old_path, + new_path, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + else: + await self._rename_field_in_hash( + client, + keys_to_process, + field_rename.old_name, + field_rename.new_name, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + field_rename_duration = round( + time.perf_counter() - field_rename_started, 3 + ) + _notify("field_rename", f"done ({field_rename_duration}s)") + + # Dump original vectors to backup file (before drop) + active_backup = None + use_multi_worker = num_workers > 1 and backup_dir is not None + if ( + needs_quantization + and keys_to_process + and backup_path + and not use_multi_worker + ): + effective_changes = datatype_changes + if has_field_renames: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } + _notify("dump", "Backing up original vectors...") + dump_started = time.perf_counter() + active_backup = await self._dump_vectors( + client=client, + index_name=plan.source.index_name, + keys=keys_to_process, + datatype_changes=effective_changes, + backup_path=backup_path, + batch_size=batch_size, + progress_callback=lambda done, total: _notify( + "dump", f"{done:,}/{total:,} docs" + ), + ) + dump_duration = round(time.perf_counter() - dump_started, 3) + _notify("dump", f"done ({dump_duration}s)") + + # Drop the index + _notify("drop", "Dropping index definition...") + drop_started = time.perf_counter() + await source_index.delete(drop=False) + drop_duration = round(time.perf_counter() - drop_started, 3) + _notify("drop", f"done ({drop_duration}s)") + + # Key renames + if has_prefix_change and keys_to_process: + _notify("key_rename", "Renaming keys...") + key_rename_started = time.perf_counter() + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + assert new_prefix is not None + renamed_count = await self._rename_keys( + client, + keys_to_process, + old_prefix, + new_prefix, + progress_callback=lambda done, total: _notify( + "key_rename", f"{done:,}/{total:,} keys" + ), + ) + key_rename_duration = round( + time.perf_counter() - key_rename_started, 3 + ) + _notify( + "key_rename", + f"done ({renamed_count:,} keys in {key_rename_duration}s)", + ) + + # Quantize vectors + if needs_quantization and keys_to_process: + effective_changes = datatype_changes + if has_field_renames: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } + + # Update key references if prefix changed + if has_prefix_change and rename_ops.change_prefix: + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + keys_to_process = [ + ( + new_prefix + k[len(old_prefix) :] + if k.startswith(old_prefix) + else k + ) + for k in keys_to_process + ] + + if use_multi_worker: + from redisvl.migration.quantize import ( + async_multi_worker_quantize, + ) + + if backup_dir is None: + raise ValueError( + "--backup-dir is required when using --workers > 1" + ) + if redis_url is None: + raise ValueError( + "redis_url is required when using num_workers > 1" + ) + _notify( + "quantize", + f"Re-encoding vectors ({num_workers} workers)...", + ) + quantize_started = time.perf_counter() + mw_result = await async_multi_worker_quantize( + redis_url=redis_url, + keys=keys_to_process, + datatype_changes=effective_changes, + backup_dir=backup_dir, + index_name=plan.source.index_name, + num_workers=num_workers, + batch_size=batch_size, + ) + docs_quantized = mw_result.total_docs_quantized + elif active_backup: + _notify("quantize", "Re-encoding vectors from backup...") + quantize_started = time.perf_counter() + docs_quantized = await self._quantize_from_backup( + client=client, + backup=active_backup, + datatype_changes=effective_changes, + progress_callback=lambda done, total: _notify( + "quantize", f"{done:,}/{total:,} docs" + ), + ) + else: + # No backup dir — direct pipeline read + convert + write + from redisvl.migration.quantize import ( + convert_vectors, + pipeline_write_vectors, + ) + + _notify("quantize", "Re-encoding vectors...") + quantize_started = time.perf_counter() + docs_quantized = 0 + total = len(keys_to_process) + field_names = list(effective_changes.keys()) + for batch_start in range(0, total, batch_size): + batch_keys = keys_to_process[ + batch_start : batch_start + batch_size + ] + # Async pipelined read + pipe = client.pipeline(transaction=False) + call_order: list[tuple] = [] + for key in batch_keys: + for fn in field_names: + pipe.hget(key, fn) + call_order.append((key, fn)) + results = await pipe.execute() + originals: dict[str, dict[str, bytes]] = {} + for (key, fn), value in zip(call_order, results): + if value is not None: + originals.setdefault(key, {})[fn] = value + converted = convert_vectors(originals, effective_changes) + if converted: + wpipe = client.pipeline(transaction=False) + for key, fields in converted.items(): + for fn, data in fields.items(): + wpipe.hset(key, fn, data) + await wpipe.execute() + docs_quantized += len(converted) if converted else 0 + if progress_callback: + _notify( + "quantize", + f"{docs_quantized:,}/{total:,} docs", + ) + quantize_duration = round(time.perf_counter() - quantize_started, 3) + _notify( + "quantize", + f"done ({docs_quantized:,} docs in {quantize_duration}s)", + ) + report.warnings.append( + f"Re-encoded {docs_quantized} documents for vector quantization: " + f"{datatype_changes}" + ) + elif datatype_changes and storage_type == "json": + _notify( + "quantize", "skipped (JSON vectors are re-indexed on recreate)" + ) + + _notify("create", "Creating index with new schema...") + recreate_started = time.perf_counter() + await target_index.create() + recreate_duration = round(time.perf_counter() - recreate_started, 3) + _notify("create", f"done ({recreate_duration}s)") + + _notify("index", "Waiting for re-indexing...") + + def _index_progress(indexed: int, total: int, pct: float) -> None: + _notify("index", f"{indexed:,}/{total:,} docs ({pct:.0f}%)") + + target_info, indexing_duration = await self._async_wait_for_index_ready( + target_index, progress_callback=_index_progress + ) + _notify("index", f"done ({indexing_duration}s)") + + _notify("validate", "Validating migration...") + validation, target_info, validation_duration = ( + await self.validator.validate( + plan, + redis_url=redis_url, + redis_client=redis_client, + query_check_file=query_check_file, + ) + ) + _notify("validate", f"done ({validation_duration}s)") + report.validation = validation + total_duration = round(time.perf_counter() - started, 3) + report.timings = MigrationTimings( + total_migration_duration_seconds=total_duration, + drop_duration_seconds=drop_duration, + quantize_duration_seconds=( + quantize_duration if quantize_duration else None + ), + field_rename_duration_seconds=( + field_rename_duration if field_rename_duration else None + ), + key_rename_duration_seconds=( + key_rename_duration if key_rename_duration else None + ), + recreate_duration_seconds=recreate_duration, + initial_indexing_duration_seconds=indexing_duration, + validation_duration_seconds=validation_duration, + downtime_duration_seconds=round( + drop_duration + + field_rename_duration + + key_rename_duration + + quantize_duration + + recreate_duration + + indexing_duration, + 3, + ), + ) + report.benchmark_summary = self._build_benchmark_summary( + plan, + target_info, + report.timings, + ) + report.result = "succeeded" if not validation.errors else "failed" + if validation.errors: + report.manual_actions.append( + "Review validation errors before treating the migration as complete." + ) + except Exception as exc: + total_duration = round(time.perf_counter() - started, 3) + report.timings = MigrationTimings( + total_migration_duration_seconds=total_duration, + drop_duration_seconds=drop_duration or None, + quantize_duration_seconds=quantize_duration or None, + field_rename_duration_seconds=field_rename_duration or None, + key_rename_duration_seconds=key_rename_duration or None, + recreate_duration_seconds=recreate_duration or None, + initial_indexing_duration_seconds=indexing_duration or None, + downtime_duration_seconds=( + round( + drop_duration + + field_rename_duration + + key_rename_duration + + quantize_duration + + recreate_duration + + indexing_duration, + 3, + ) + if drop_duration + or field_rename_duration + or key_rename_duration + or quantize_duration + or recreate_duration + or indexing_duration + else None + ), + ) + report.validation = MigrationValidation( + errors=[f"Migration execution failed: {exc}"] + ) + report.manual_actions.extend( + [ + "Inspect the Redis index state before retrying.", + "If the source index was dropped, recreate it from the saved migration plan.", + ] + ) + finally: + report.finished_at = timestamp_utc() + + # Auto-cleanup backup files on success + if backup_dir and not keep_backup and report.result == "succeeded": + self._cleanup_backup_files(backup_dir, plan.source.index_name) + + return report + + def _cleanup_backup_files(self, backup_dir: str, index_name: str) -> None: + """Remove backup files after successful migration. + + Only removes files with the exact extensions produced by VectorBackup + (.header and .data), avoiding accidental deletion of unrelated files + that happen to share the same prefix. + """ + safe_name = index_name.replace("/", "_").replace("\\", "_").replace(":", "_") + name_hash = hashlib.sha256(index_name.encode()).hexdigest()[:8] + base_prefix = f"migration_backup_{safe_name}_{name_hash}" + known_suffixes = (".header", ".data") + backup_dir_path = Path(backup_dir) + + for entry in backup_dir_path.iterdir(): + if not entry.is_file(): + continue + name = entry.name + if not name.startswith(base_prefix): + continue + if not any(name.endswith(s) for s in known_suffixes): + continue + remainder = name[len(base_prefix) :] + if remainder and remainder[0] not in (".", "_"): + continue + try: + entry.unlink() + logger.debug("Removed backup file: %s", entry) + except OSError as e: + logger.warning("Failed to remove backup file %s: %s", entry, e) + + # ------------------------------------------------------------------ + # Two-phase quantization: dump originals → convert from backup + # ------------------------------------------------------------------ + + async def _dump_vectors( + self, + client: Any, + index_name: str, + keys: List[str], + datatype_changes: Dict[str, Dict[str, Any]], + backup_path: str, + batch_size: int = 500, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> "VectorBackup": + """Phase 1: Pipeline-read original vectors and write to backup file. + + Async version. Runs BEFORE index drop. + """ + from redisvl.migration.backup import VectorBackup + + backup = VectorBackup.create( + path=backup_path, + index_name=index_name, + fields=datatype_changes, + batch_size=batch_size, + ) + + total = len(keys) + field_names = list(datatype_changes.keys()) + + for batch_start in range(0, total, batch_size): + batch_keys = keys[batch_start : batch_start + batch_size] + + # Pipelined async reads + pipe = client.pipeline(transaction=False) + call_order: List[tuple] = [] + for key in batch_keys: + for field_name in field_names: + pipe.hget(key, field_name) + call_order.append((key, field_name)) + results = await pipe.execute() + + # Reassemble + originals: Dict[str, Dict[str, bytes]] = {} + for (key, field_name), value in zip(call_order, results): + if value is not None: + if key not in originals: + originals[key] = {} + originals[key][field_name] = value + + backup.write_batch(batch_start // batch_size, batch_keys, originals) + if progress_callback: + progress_callback(min(batch_start + batch_size, total), total) + + backup.mark_dump_complete() + return backup + + async def _quantize_from_backup( + self, + client: Any, + backup: "VectorBackup", + datatype_changes: Dict[str, Dict[str, Any]], + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Phase 2: Read originals from backup file, convert, pipeline-write. + + Async version. Runs AFTER index drop. + """ + from redisvl.migration.quantize import convert_vectors + + if backup.header.phase == "ready": + backup.start_quantize() + + docs_quantized = 0 + start_batch = backup.header.quantize_completed_batches + docs_done = start_batch * backup.header.batch_size + + for batch_idx, (batch_keys, originals) in enumerate( + backup.iter_remaining_batches() + ): + actual_batch_idx = start_batch + batch_idx + converted = convert_vectors(originals, datatype_changes) + if converted: + pipe = client.pipeline(transaction=False) + for key, fields in converted.items(): + for field_name, data in fields.items(): + pipe.hset(key, field_name, data) + await pipe.execute() + backup.mark_batch_quantized(actual_batch_idx) + docs_quantized += len(batch_keys) + docs_done += len(batch_keys) + if progress_callback: + total = backup.header.dump_completed_batches * backup.header.batch_size + progress_callback(docs_done, total) + + backup.mark_complete() + return docs_quantized + + async def _async_wait_for_index_ready( + self, + index: AsyncSearchIndex, + *, + timeout_seconds: int = 1800, + poll_interval_seconds: float = 0.5, + progress_callback: Optional[Callable[[int, int, float], None]] = None, + ) -> tuple[Dict[str, Any], float]: + """Wait for index to finish indexing all documents (async version).""" + start = time.perf_counter() + deadline = start + timeout_seconds + latest_info = await index.info() + + stable_ready_checks: Optional[int] = None + while time.perf_counter() < deadline: + ready = False + latest_info = await index.info() + indexing = latest_info.get("indexing") + percent_indexed = latest_info.get("percent_indexed") + + if percent_indexed is not None or indexing is not None: + pct = float(percent_indexed) if percent_indexed is not None else None + is_indexing = bool(indexing) + if pct is not None: + ready = pct >= 1.0 and not is_indexing + else: + # percent_indexed missing but indexing flag present: + # treat as ready when indexing flag is falsy (0 / False). + ready = not is_indexing + if progress_callback: + total_docs = int(latest_info.get("num_docs", 0)) + display_pct = pct if pct is not None else (1.0 if ready else 0.0) + indexed_docs = int(total_docs * display_pct) + progress_callback(indexed_docs, total_docs, display_pct * 100) + else: + current_docs = latest_info.get("num_docs") + if current_docs is None: + ready = True + else: + if stable_ready_checks is None: + stable_ready_checks = int(current_docs) + await asyncio.sleep(poll_interval_seconds) + continue + current = int(current_docs) + if current == stable_ready_checks: + ready = True + else: + # num_docs changed; update baseline and keep waiting + stable_ready_checks = current + + if ready: + return latest_info, round(time.perf_counter() - start, 3) + + await asyncio.sleep(poll_interval_seconds) + + raise TimeoutError( + f"Index {index.schema.index.name} did not become ready within {timeout_seconds} seconds" + ) + + async def _async_current_source_matches_snapshot( + self, + index_name: str, + expected_schema: Dict[str, Any], + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> bool: + """Check if current source schema matches the snapshot (async version).""" + from redisvl.migration.utils import schemas_equal + + try: + current_index = await AsyncSearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + except Exception: + # Index no longer exists (e.g. already dropped during migration) + return False + return schemas_equal(current_index.schema.to_dict(), expected_schema) + + def _build_benchmark_summary( + self, + plan: MigrationPlan, + target_info: dict, + timings: MigrationTimings, + ) -> MigrationBenchmarkSummary: + source_index_size = float( + plan.source.stats_snapshot.get("vector_index_sz_mb", 0) or 0 + ) + target_index_size = float(target_info.get("vector_index_sz_mb", 0) or 0) + source_num_docs = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + indexed_per_second = None + indexing_time = timings.initial_indexing_duration_seconds + if indexing_time and indexing_time > 0: + indexed_per_second = round(source_num_docs / indexing_time, 3) + + return MigrationBenchmarkSummary( + documents_indexed_per_second=indexed_per_second, + source_index_size_mb=round(source_index_size, 3), + target_index_size_mb=round(target_index_size, 3), + index_size_delta_mb=round(target_index_size - source_index_size, 3), + ) diff --git a/redisvl/migration/async_planner.py b/redisvl/migration/async_planner.py new file mode 100644 index 000000000..6c75efda2 --- /dev/null +++ b/redisvl/migration/async_planner.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +from typing import Any, List, Optional + +from redisvl.index import AsyncSearchIndex +from redisvl.migration.models import ( + KeyspaceSnapshot, + MigrationPlan, + SchemaPatch, + SourceSnapshot, +) +from redisvl.migration.planner import MigrationPlanner +from redisvl.redis.connection import supports_svs_async +from redisvl.schema.schema import IndexSchema +from redisvl.types import AsyncRedisClient + + +class AsyncMigrationPlanner: + """Async migration planner for document-preserving drop/recreate flows. + + This is the async version of MigrationPlanner. It uses AsyncSearchIndex + and async Redis operations for better performance on large indexes. + + The classification logic, schema merging, and diff analysis are delegated + to a sync MigrationPlanner instance (they are CPU-bound and don't need async). + """ + + def __init__(self, key_sample_limit: int = 10): + self.key_sample_limit = key_sample_limit + # Delegate to sync planner for CPU-bound operations + self._sync_planner = MigrationPlanner(key_sample_limit=key_sample_limit) + + # Expose static methods from MigrationPlanner for convenience + get_vector_datatype_changes = staticmethod( + MigrationPlanner.get_vector_datatype_changes + ) + + async def create_plan( + self, + index_name: str, + *, + redis_url: Optional[str] = None, + schema_patch_path: Optional[str] = None, + target_schema_path: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> MigrationPlan: + if not schema_patch_path and not target_schema_path: + raise ValueError( + "Must provide either --schema-patch or --target-schema for migration planning" + ) + if schema_patch_path and target_schema_path: + raise ValueError( + "Provide only one of --schema-patch or --target-schema for migration planning" + ) + + snapshot = await self.snapshot_source( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + + if schema_patch_path: + schema_patch = self._sync_planner.load_schema_patch(schema_patch_path) + else: + # target_schema_path is guaranteed to be not None here + assert target_schema_path is not None + schema_patch = self._sync_planner.normalize_target_schema_to_patch( + source_schema, target_schema_path + ) + + return await self.create_plan_from_patch( + index_name, + schema_patch=schema_patch, + redis_url=redis_url, + redis_client=redis_client, + _snapshot=snapshot, + ) + + async def create_plan_from_patch( + self, + index_name: str, + *, + schema_patch: SchemaPatch, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + _snapshot: Optional[Any] = None, + ) -> MigrationPlan: + if _snapshot is None: + _snapshot = await self.snapshot_source( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + snapshot = _snapshot + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + merged_target_schema = self._sync_planner.merge_patch( + source_schema, schema_patch + ) + + # Extract rename operations first + rename_operations, rename_warnings = ( + self._sync_planner._extract_rename_operations(source_schema, schema_patch) + ) + + # Classify diff with awareness of rename operations + diff_classification = self._sync_planner.classify_diff( + source_schema, schema_patch, merged_target_schema, rename_operations + ) + + # Build warnings list + warnings = ["Index downtime is required"] + warnings.extend(rename_warnings) + + # Warn if source index has hash indexing failures + source_failures = int( + snapshot.stats_snapshot.get("hash_indexing_failures", 0) or 0 + ) + if source_failures > 0: + warnings.append( + f"Source index has {source_failures:,} hash indexing failure(s). " + "Documents that previously failed to index may become indexable after " + "migration, causing the post-migration document count to differ from " + "the pre-migration count. This is expected and validation accounts for it." + ) + + # Check for SVS-VAMANA in target schema and add appropriate warnings + svs_warnings = await self._check_svs_vamana_requirements( + merged_target_schema, + redis_url=redis_url, + redis_client=redis_client, + ) + warnings.extend(svs_warnings) + + return MigrationPlan( + source=snapshot, + requested_changes=schema_patch.model_dump(exclude_none=True), + merged_target_schema=merged_target_schema.to_dict(), + diff_classification=diff_classification, + rename_operations=rename_operations, + warnings=warnings, + ) + + async def _check_svs_vamana_requirements( + self, + target_schema: IndexSchema, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> List[str]: + """Async version: Check SVS-VAMANA requirements and return warnings.""" + warnings: List[str] = [] + target_dict = target_schema.to_dict() + + # Check if any vector field uses SVS-VAMANA + uses_svs = False + uses_compression = False + compression_types: set = set() + + for field in target_dict.get("fields", []): + if field.get("type") != "vector": + continue + attrs = field.get("attrs", {}) + algo = attrs.get("algorithm", "").upper() + if algo == "SVS-VAMANA": + uses_svs = True + compression = attrs.get("compression", "") + if compression: + uses_compression = True + compression_types.add(compression) + + if not uses_svs: + return warnings + + # Check Redis version support + created_client = False + try: + if redis_client: + client = redis_client + elif redis_url: + from redis.asyncio import Redis + + client = Redis.from_url(redis_url) + created_client = True + else: + client = None + + if client and not await supports_svs_async(client): + warnings.append( + "SVS-VAMANA requires Redis >= 8.2.0 and Redis Search >= 2.8.10. " + "The target Redis instance may not support this algorithm. " + "Migration will fail at apply time if requirements are not met." + ) + except Exception: + warnings.append( + "SVS-VAMANA requires Redis >= 8.2.0 and Redis Search >= 2.8.10. " + "Verify your Redis instance supports this algorithm before applying." + ) + finally: + if created_client and client is not None: + await client.aclose() # type: ignore[union-attr] + + # Intel hardware warning for compression + if uses_compression: + compression_label = ", ".join(sorted(compression_types)) + warnings.append( + f"SVS-VAMANA with {compression_label} compression: " + "LVQ and LeanVec optimizations require Intel hardware with AVX-512 support. " + "On non-Intel platforms or Redis Open Source, these fall back to basic " + "8-bit scalar quantization with reduced performance benefits." + ) + else: + warnings.append( + "SVS-VAMANA: For optimal performance, Intel hardware with AVX-512 support " + "is recommended. LVQ/LeanVec compression options provide additional memory " + "savings on supported hardware." + ) + + return warnings + + async def snapshot_source( + self, + index_name: str, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> SourceSnapshot: + index = await AsyncSearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + schema_dict = index.schema.to_dict() + stats_snapshot = await index.info() + prefixes = index.schema.index.prefix + prefix_list = prefixes if isinstance(prefixes, list) else [prefixes] + + client = index.client + if client is None: + raise ValueError("Failed to get Redis client from index") + + return SourceSnapshot( + index_name=index_name, + schema_snapshot=schema_dict, + stats_snapshot=stats_snapshot, + keyspace=KeyspaceSnapshot( + storage_type=index.schema.index.storage_type.value, + prefixes=prefix_list, + key_separator=index.schema.index.key_separator, + key_sample=await self._async_sample_keys( + client=client, + prefixes=prefix_list, + key_separator=index.schema.index.key_separator, + ), + ), + ) + + async def _async_sample_keys( + self, *, client: AsyncRedisClient, prefixes: List[str], key_separator: str + ) -> List[str]: + """Async version of _sample_keys.""" + key_sample: List[str] = [] + if self.key_sample_limit <= 0: + return key_sample + + for prefix in prefixes: + if len(key_sample) >= self.key_sample_limit: + break + if prefix == "": + match_pattern = "*" + elif prefix.endswith(key_separator): + match_pattern = f"{prefix}*" + else: + match_pattern = f"{prefix}{key_separator}*" + cursor: int = 0 + while True: + cursor, keys = await client.scan( + cursor=cursor, + match=match_pattern, + count=max(self.key_sample_limit, 10), + ) + for key in keys: + decoded_key = key.decode() if isinstance(key, bytes) else str(key) + if decoded_key not in key_sample: + key_sample.append(decoded_key) + if len(key_sample) >= self.key_sample_limit: + return key_sample + if cursor == 0: + break + return key_sample + + def write_plan(self, plan: MigrationPlan, plan_out: str) -> None: + """Delegate to sync planner for file I/O.""" + self._sync_planner.write_plan(plan, plan_out) diff --git a/redisvl/migration/async_validation.py b/redisvl/migration/async_validation.py new file mode 100644 index 000000000..18331e4f8 --- /dev/null +++ b/redisvl/migration/async_validation.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +from redis.commands.search.query import Query + +from redisvl.index import AsyncSearchIndex +from redisvl.migration.models import ( + MigrationPlan, + MigrationValidation, + QueryCheckResult, +) +from redisvl.migration.utils import load_yaml, schemas_equal +from redisvl.types import AsyncRedisClient + + +class AsyncMigrationValidator: + """Async migration validator for post-migration checks. + + This is the async version of MigrationValidator. It uses AsyncSearchIndex + and async Redis operations for better performance. + """ + + async def validate( + self, + plan: MigrationPlan, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + query_check_file: Optional[str] = None, + ) -> tuple[MigrationValidation, Dict[str, Any], float]: + started = time.perf_counter() + target_index = await AsyncSearchIndex.from_existing( + plan.merged_target_schema["index"]["name"], + redis_url=redis_url, + redis_client=redis_client, + ) + target_info = await target_index.info() + validation = MigrationValidation() + + live_schema = target_index.schema.to_dict() + # Exclude query-time and creation-hint attributes (ef_runtime, epsilon, + # initial_cap, phonetic_matcher) that are not part of index structure + # validation. Confirmed by RediSearch team as not relevant for this check. + validation.schema_match = schemas_equal( + live_schema, plan.merged_target_schema, strip_excluded=True + ) + + source_num_docs = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + target_num_docs = int(target_info.get("num_docs", 0) or 0) + + source_failures = int( + plan.source.stats_snapshot.get("hash_indexing_failures", 0) or 0 + ) + target_failures = int(target_info.get("hash_indexing_failures", 0) or 0) + validation.indexing_failures_delta = target_failures - source_failures + + # Compare total keys (num_docs + hash_indexing_failures) instead of + # just num_docs. Migrations can resolve indexing failures (e.g. a + # vector datatype change may fix documents that previously failed to + # index), shifting counts between the two buckets while the total + # number of keys under the prefix stays the same. + source_total = source_num_docs + source_failures + target_total = target_num_docs + target_failures + validation.doc_count_match = source_total == target_total + + key_sample = plan.source.keyspace.key_sample + client = target_index.client + if not key_sample: + validation.key_sample_exists = True + elif client is None: + validation.key_sample_exists = False + validation.errors.append("Failed to get Redis client for key sample check") + else: + # Handle prefix change: transform key_sample to use new prefix. + # Must match the executor's RENAME logic exactly: + # new_key = new_prefix + key[len(old_prefix):] + keys_to_check = key_sample + if plan.rename_operations.change_prefix is not None: + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = plan.rename_operations.change_prefix + keys_to_check = [] + for k in key_sample: + if k.startswith(old_prefix): + keys_to_check.append(new_prefix + k[len(old_prefix) :]) + else: + keys_to_check.append(k) + existing_count = await client.exists(*keys_to_check) + validation.key_sample_exists = existing_count == len(keys_to_check) + + # Run automatic functional checks (always). + # Use source_total (num_docs + failures) as the expected count so that + # resolved indexing failures don't cause the wildcard check to fail. + functional_checks = await self._run_functional_checks( + target_index, source_total + ) + validation.query_checks.extend(functional_checks) + + # Run user-provided query checks (if file provided) + if query_check_file: + user_checks = await self._run_query_checks(target_index, query_check_file) + validation.query_checks.extend(user_checks) + + if not validation.schema_match and plan.validation.require_schema_match: + validation.errors.append("Live schema does not match merged_target_schema.") + if not validation.doc_count_match and plan.validation.require_doc_count_match: + validation.errors.append( + f"Total key count mismatch: source had {source_total} " + f"(num_docs={source_num_docs}, failures={source_failures}), " + f"target has {target_total} " + f"(num_docs={target_num_docs}, failures={target_failures})." + ) + if validation.indexing_failures_delta > 0: + validation.errors.append("Indexing failures increased during migration.") + if not validation.key_sample_exists: + validation.errors.append( + "One or more sampled source keys is missing after migration." + ) + if any(not query_check.passed for query_check in validation.query_checks): + validation.errors.append("One or more query checks failed.") + + return validation, target_info, round(time.perf_counter() - started, 3) + + async def _run_query_checks( + self, + target_index: AsyncSearchIndex, + query_check_file: str, + ) -> list[QueryCheckResult]: + query_checks = load_yaml(query_check_file) + results: list[QueryCheckResult] = [] + + for doc_id in query_checks.get("fetch_ids", []): + fetched = await target_index.fetch(doc_id) + results.append( + QueryCheckResult( + name=f"fetch:{doc_id}", + passed=fetched is not None, + details=( + "Document fetched successfully" + if fetched is not None + else "Document not found" + ), + ) + ) + + client = target_index.client + for key in query_checks.get("keys_exist", []): + if client is None: + results.append( + QueryCheckResult( + name=f"key:{key}", + passed=False, + details="Failed to get Redis client", + ) + ) + else: + exists = bool(await client.exists(key)) + results.append( + QueryCheckResult( + name=f"key:{key}", + passed=exists, + details="Key exists" if exists else "Key not found", + ) + ) + + return results + + async def _run_functional_checks( + self, target_index: AsyncSearchIndex, expected_doc_count: int + ) -> List[QueryCheckResult]: + """Run automatic functional checks to verify the index is operational. + + These checks run automatically after every migration to prove the index + actually works, not just that the schema looks correct. + """ + results: List[QueryCheckResult] = [] + + # Check 1: Wildcard search - proves the index responds and returns docs + try: + search_result = await target_index.search(Query("*").paging(0, 1)) + total_found = search_result.total + # When expected_doc_count is 0 (empty index), a successful + # search returning 0 docs is correct behaviour, not a failure. + if expected_doc_count == 0: + passed = total_found == 0 + else: + passed = total_found > 0 + if expected_doc_count == 0: + detail_expectation = "expected 0" + else: + detail_expectation = f"expected >0, source had {expected_doc_count}" + results.append( + QueryCheckResult( + name="functional:wildcard_search", + passed=passed, + details=( + f"Wildcard search returned {total_found} docs " + f"({detail_expectation})" + ), + ) + ) + except Exception as e: + results.append( + QueryCheckResult( + name="functional:wildcard_search", + passed=False, + details=f"Wildcard search failed: {str(e)}", + ) + ) + + return results diff --git a/redisvl/migration/backup.py b/redisvl/migration/backup.py new file mode 100644 index 000000000..aa3b316a9 --- /dev/null +++ b/redisvl/migration/backup.py @@ -0,0 +1,228 @@ +"""Vector backup file for crash-safe quantization. + +Stores original vector bytes on disk so that: +- Quantization can resume from where it left off after a crash +- Original vectors can be restored (rollback) at any time +- No BGSAVE or Redis-side checkpointing is needed + +File layout: + .header — JSON file with phase, progress counters, metadata + .data — Binary file with length-prefixed pickle blobs per batch +""" + +import json +import os +import pickle +import struct +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Tuple + + +@dataclass +class BackupHeader: + """Metadata and progress tracking for a vector backup.""" + + index_name: str + fields: Dict[str, Dict[str, Any]] + batch_size: int + phase: str = "dump" # dump → ready → active → completed + dump_completed_batches: int = 0 + quantize_completed_batches: int = 0 + + def to_dict(self) -> dict: + return { + "index_name": self.index_name, + "fields": self.fields, + "batch_size": self.batch_size, + "phase": self.phase, + "dump_completed_batches": self.dump_completed_batches, + "quantize_completed_batches": self.quantize_completed_batches, + } + + @classmethod + def from_dict(cls, d: dict) -> "BackupHeader": + return cls( + index_name=d["index_name"], + fields=d["fields"], + batch_size=d.get("batch_size", 500), + phase=d.get("phase", "dump"), + dump_completed_batches=d.get("dump_completed_batches", 0), + quantize_completed_batches=d.get("quantize_completed_batches", 0), + ) + + +class VectorBackup: + """Manages a vector backup file for crash-safe quantization. + + Two files on disk: + .header — small JSON, atomically updated after each batch + .data — append-only binary, one length-prefixed pickle blob per batch + """ + + def __init__(self, path: str, header: BackupHeader) -> None: + self._path = path + self._header_path = path + ".header" + self._data_path = path + ".data" + self.header = header + + # ------------------------------------------------------------------ + # Construction + # ------------------------------------------------------------------ + + @classmethod + def create( + cls, + path: str, + index_name: str, + fields: Dict[str, Dict[str, Any]], + batch_size: int = 500, + ) -> "VectorBackup": + """Create a new backup file. Raises FileExistsError if one already exists.""" + header_path = path + ".header" + if os.path.exists(header_path): + raise FileExistsError(f"Backup already exists at {header_path}") + + header = BackupHeader( + index_name=index_name, + fields=fields, + batch_size=batch_size, + ) + backup = cls(path, header) + backup._save_header() + return backup + + @classmethod + def load(cls, path: str) -> Optional["VectorBackup"]: + """Load an existing backup from disk. Returns None if not found.""" + header_path = path + ".header" + if not os.path.exists(header_path): + return None + with open(header_path, "r") as f: + header = BackupHeader.from_dict(json.load(f)) + return cls(path, header) + + # ------------------------------------------------------------------ + # Header persistence (atomic write via temp + rename) + # ------------------------------------------------------------------ + + def _save_header(self) -> None: + """Atomically write header to disk.""" + dir_path = os.path.dirname(self._header_path) or "." + fd, tmp = tempfile.mkstemp(dir=dir_path, suffix=".tmp") + try: + with os.fdopen(fd, "w") as f: + json.dump(self.header.to_dict(), f) + os.replace(tmp, self._header_path) + except BaseException: + try: + os.unlink(tmp) + except OSError: + pass + raise + + # ------------------------------------------------------------------ + # Dump phase: write batches of original vectors + # ------------------------------------------------------------------ + + def write_batch( + self, + batch_idx: int, + keys: List[str], + originals: Dict[str, Dict[str, bytes]], + ) -> None: + """Append a batch of original vectors to the data file. + + Args: + batch_idx: Sequential batch index (0, 1, 2, ...) + keys: Ordered list of Redis keys in this batch + originals: {key: {field_name: original_bytes}} + """ + if self.header.phase != "dump": + raise ValueError( + f"Cannot write batch in phase '{self.header.phase}'. " + "Only allowed during 'dump' phase." + ) + blob = pickle.dumps({"keys": keys, "vectors": originals}) + # Length-prefixed: 4 bytes big-endian length + blob + length_prefix = struct.pack(">I", len(blob)) + with open(self._data_path, "ab") as f: + f.write(length_prefix) + f.write(blob) + f.flush() + os.fsync(f.fileno()) + + self.header.dump_completed_batches = batch_idx + 1 + self._save_header() + + def mark_dump_complete(self) -> None: + """Transition from dump → ready.""" + if self.header.phase != "dump": + raise ValueError( + f"Cannot mark dump complete in phase '{self.header.phase}'" + ) + self.header.phase = "ready" + self._save_header() + + # ------------------------------------------------------------------ + # Quantize phase: track which batches have been written to Redis + # ------------------------------------------------------------------ + + def start_quantize(self) -> None: + """Transition from ready → active.""" + if self.header.phase not in ("ready", "active"): + raise ValueError(f"Cannot start quantize in phase '{self.header.phase}'") + self.header.phase = "active" + self._save_header() + + def mark_batch_quantized(self, batch_idx: int) -> None: + """Record that a batch has been successfully written to Redis. + + Called ONLY after pipeline_write succeeds. + """ + self.header.quantize_completed_batches = batch_idx + 1 + self._save_header() + + def mark_complete(self) -> None: + """Transition from active → completed.""" + self.header.phase = "completed" + self._save_header() + + # ------------------------------------------------------------------ + # Reading batches back + # ------------------------------------------------------------------ + + def iter_batches( + self, + ) -> Generator[Tuple[List[str], Dict[str, Dict[str, bytes]]], None, None]: + """Iterate ALL batches in the data file. + + Yields (keys, originals) for each batch. + """ + if not os.path.exists(self._data_path): + return + with open(self._data_path, "rb") as f: + for _ in range(self.header.dump_completed_batches): + length_bytes = f.read(4) + if len(length_bytes) < 4: + return + length = struct.unpack(">I", length_bytes)[0] + blob = f.read(length) + if len(blob) < length: + return + batch = pickle.loads(blob) + yield batch["keys"], batch["vectors"] + + def iter_remaining_batches( + self, + ) -> Generator[Tuple[List[str], Dict[str, Dict[str, bytes]]], None, None]: + """Iterate batches that have NOT been quantized yet. + + Skips the first `quantize_completed_batches` batches. + """ + skip = self.header.quantize_completed_batches + for idx, (keys, vectors) in enumerate(self.iter_batches()): + if idx < skip: + continue + yield keys, vectors diff --git a/redisvl/migration/batch_executor.py b/redisvl/migration/batch_executor.py new file mode 100644 index 000000000..038e0a2a3 --- /dev/null +++ b/redisvl/migration/batch_executor.py @@ -0,0 +1,390 @@ +"""Batch migration executor with checkpointing and resume support.""" + +from __future__ import annotations + +import time +from pathlib import Path +from typing import Any, Callable, Optional + +import yaml + +from redisvl.migration.executor import MigrationExecutor +from redisvl.migration.models import ( + BatchIndexReport, + BatchIndexState, + BatchPlan, + BatchReport, + BatchReportSummary, + BatchState, +) +from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.utils import timestamp_utc, write_yaml +from redisvl.redis.connection import RedisConnectionFactory + + +class BatchMigrationExecutor: + """Executor for batch migration of multiple indexes. + + Supports: + - Sequential execution (one index at a time) + - Checkpointing for resume after failure + - Configurable failure policies (fail_fast, continue_on_error) + """ + + def __init__(self, executor: Optional[MigrationExecutor] = None): + self._single_executor = executor or MigrationExecutor() + self._planner = MigrationPlanner() + + def apply( + self, + batch_plan: BatchPlan, + *, + batch_plan_path: Optional[str] = None, + state_path: str = "batch_state.yaml", + report_dir: str = "./reports", + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + progress_callback: Optional[Callable[[str, int, int, str], None]] = None, + ) -> BatchReport: + """Execute batch migration with checkpointing. + + Args: + batch_plan: The batch plan to execute. + batch_plan_path: Path to the batch plan file (stored in state for resume). + state_path: Path to checkpoint state file. + report_dir: Directory for per-index reports. + redis_url: Redis connection URL. + redis_client: Existing Redis client. + progress_callback: Optional callback(index_name, position, total, status). + + Returns: + BatchReport with results for all indexes. + """ + # Get Redis client + client = redis_client + if client is None: + if not redis_url: + raise ValueError("Must provide either redis_url or redis_client") + client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url) + + # Ensure report directory exists + report_path = Path(report_dir).resolve() + report_path.mkdir(parents=True, exist_ok=True) + + # Initialize or load state + state = self._init_or_load_state(batch_plan, state_path, batch_plan_path) + started_at = state.started_at + batch_start_time = time.perf_counter() + + # Get applicable indexes + applicable_indexes = [idx for idx in batch_plan.indexes if idx.applicable] + total = len(applicable_indexes) + + # Calculate the correct starting position for progress reporting + # (accounts for already-completed indexes during resume) + already_completed = len(state.completed) + + # Process each remaining index + for offset, index_name in enumerate(state.remaining[:]): + state.current_index = index_name + state.updated_at = timestamp_utc() + self._write_state(state, state_path) + + position = already_completed + offset + 1 + if progress_callback: + progress_callback(index_name, position, total, "starting") + + # Find the index entry + index_entry = next( + (idx for idx in batch_plan.indexes if idx.name == index_name), None + ) + if not index_entry or not index_entry.applicable: + # Skip non-applicable indexes + state.remaining.remove(index_name) + state.completed.append( + BatchIndexState( + name=index_name, + status="skipped", + completed_at=timestamp_utc(), + ) + ) + state.current_index = None + state.updated_at = timestamp_utc() + self._write_state(state, state_path) + if progress_callback: + progress_callback(index_name, position, total, "skipped") + continue + + # Execute migration for this index + index_state = self._migrate_single_index( + index_name=index_name, + batch_plan=batch_plan, + report_dir=report_path, + redis_client=client, + ) + + # Update state + state.remaining.remove(index_name) + state.completed.append(index_state) + state.current_index = None + state.updated_at = timestamp_utc() + self._write_state(state, state_path) + + if progress_callback: + progress_callback(index_name, position, total, index_state.status) + + # Check failure policy + if ( + index_state.status == "failed" + and batch_plan.failure_policy == "fail_fast" + ): + # Leave remaining indexes in state.remaining so that + # checkpoint resume can pick them up later. + break + + # Build final report + total_duration = time.perf_counter() - batch_start_time + return self._build_batch_report(batch_plan, state, started_at, total_duration) + + def resume( + self, + state_path: str, + *, + batch_plan_path: Optional[str] = None, + retry_failed: bool = False, + report_dir: str = "./reports", + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + progress_callback: Optional[Callable[[str, int, int, str], None]] = None, + ) -> BatchReport: + """Resume batch migration from checkpoint. + + Args: + state_path: Path to checkpoint state file. + batch_plan_path: Path to batch plan (uses state.plan_path if not provided). + retry_failed: If True, retry previously failed indexes. + report_dir: Directory for per-index reports. + redis_url: Redis connection URL. + redis_client: Existing Redis client. + progress_callback: Optional callback(index_name, position, total, status). + """ + state = self._load_state(state_path) + plan_path = batch_plan_path or state.plan_path + if not plan_path or not plan_path.strip(): + raise ValueError( + "No batch plan path available. Provide batch_plan_path explicitly, " + "or ensure the checkpoint state contains a valid plan_path." + ) + batch_plan = self._load_batch_plan(plan_path) + + # Optionally retry failed indexes + if retry_failed: + failed_names = [ + idx.name for idx in state.completed if idx.status == "failed" + ] + state.remaining = failed_names + state.remaining + state.completed = [idx for idx in state.completed if idx.status != "failed"] + # Write updated state back to file so apply() picks up the changes + self._write_state(state, state_path) + + # Re-run apply with the updated state + return self.apply( + batch_plan, + batch_plan_path=batch_plan_path, + state_path=state_path, + report_dir=report_dir, + redis_url=redis_url, + redis_client=redis_client, + progress_callback=progress_callback, + ) + + def _migrate_single_index( + self, + *, + index_name: str, + batch_plan: BatchPlan, + report_dir: Path, + redis_client: Any, + ) -> BatchIndexState: + """Execute migration for a single index.""" + try: + # Create migration plan for this index + plan = self._planner.create_plan_from_patch( + index_name, + schema_patch=batch_plan.shared_patch, + redis_client=redis_client, + ) + + # Execute migration + report = self._single_executor.apply( + plan, + redis_client=redis_client, + ) + + # Sanitize index_name to prevent path traversal and invalid filenames + safe_name = ( + index_name.replace("/", "_") + .replace("\\", "_") + .replace("..", "_") + .replace(":", "_") + ) + report_file = report_dir / f"{safe_name}_report.yaml" + write_yaml(report.model_dump(exclude_none=True), str(report_file)) + + return BatchIndexState( + name=index_name, + status="success" if report.result == "succeeded" else "failed", + completed_at=timestamp_utc(), + report_path=str(report_file), + error=report.validation.errors[0] if report.validation.errors else None, + ) + + except Exception as e: + return BatchIndexState( + name=index_name, + status="failed", + completed_at=timestamp_utc(), + error=str(e), + ) + + def _init_or_load_state( + self, + batch_plan: BatchPlan, + state_path: str, + batch_plan_path: Optional[str] = None, + ) -> BatchState: + """Initialize new state or load existing checkpoint.""" + path = Path(state_path).resolve() + if path.exists(): + loaded = self._load_state(state_path) + # Validate that loaded state matches the current batch plan + if loaded.batch_id and loaded.batch_id != batch_plan.batch_id: + raise ValueError( + f"Checkpoint state batch_id '{loaded.batch_id}' does not match " + f"current batch plan '{batch_plan.batch_id}'. " + "Remove the stale state file or use a different state_path." + ) + # Update plan_path if caller provided one (handles cases where + # the original path was empty or pointed to a deleted temp dir). + if batch_plan_path: + loaded.plan_path = str(Path(batch_plan_path).resolve()) + return loaded + + # Create new state with plan_path for resume support + applicable_names = [idx.name for idx in batch_plan.indexes if idx.applicable] + return BatchState( + batch_id=batch_plan.batch_id, + plan_path=str(Path(batch_plan_path).resolve()) if batch_plan_path else "", + started_at=timestamp_utc(), + updated_at=timestamp_utc(), + remaining=applicable_names, + completed=[], + current_index=None, + ) + + def _write_state(self, state: BatchState, state_path: str) -> None: + """Write checkpoint state to file atomically. + + Writes to a temporary file first, then renames to avoid corruption + if the process crashes mid-write. + """ + path = Path(state_path).resolve() + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(".tmp") + with open(tmp_path, "w") as f: + yaml.safe_dump(state.model_dump(exclude_none=True), f, sort_keys=False) + f.flush() + tmp_path.replace(path) + + def _load_state(self, state_path: str) -> BatchState: + """Load checkpoint state from file.""" + path = Path(state_path).resolve() + if not path.is_file(): + raise FileNotFoundError(f"State file not found: {state_path}") + with open(path, "r") as f: + data = yaml.safe_load(f) or {} + return BatchState.model_validate(data) + + def _load_batch_plan(self, plan_path: str) -> BatchPlan: + """Load batch plan from file.""" + path = Path(plan_path).resolve() + if not path.is_file(): + raise FileNotFoundError(f"Batch plan not found: {plan_path}") + with open(path, "r") as f: + data = yaml.safe_load(f) or {} + return BatchPlan.model_validate(data) + + def _build_batch_report( + self, + batch_plan: BatchPlan, + state: BatchState, + started_at: str, + total_duration: float, + ) -> BatchReport: + """Build final batch report from state.""" + index_reports = [] + succeeded = 0 + failed = 0 + skipped = 0 + + for idx_state in state.completed: + index_reports.append( + BatchIndexReport( + name=idx_state.name, + status=idx_state.status, + report_path=idx_state.report_path, + error=idx_state.error, + ) + ) + if idx_state.status == "success": + succeeded += 1 + elif idx_state.status == "failed": + failed += 1 + else: + skipped += 1 + + # Add remaining indexes (fail-fast left them pending) as skipped + for remaining_name in state.remaining: + index_reports.append( + BatchIndexReport( + name=remaining_name, + status="skipped", + error="Skipped due to fail_fast policy", + ) + ) + skipped += 1 + + # Add non-applicable indexes as skipped + for idx in batch_plan.indexes: + if not idx.applicable: + index_reports.append( + BatchIndexReport( + name=idx.name, + status="skipped", + error=idx.skip_reason, + ) + ) + skipped += 1 + + # Determine overall status + if failed == 0 and len(state.remaining) == 0: + status = "completed" + elif succeeded > 0: + status = "partial_failure" + else: + status = "failed" + + return BatchReport( + batch_id=batch_plan.batch_id, + status=status, + started_at=started_at, + completed_at=timestamp_utc(), + summary=BatchReportSummary( + total_indexes=len(batch_plan.indexes), + successful=succeeded, + failed=failed, + skipped=skipped, + total_duration_seconds=round(total_duration, 3), + ), + indexes=index_reports, + ) diff --git a/redisvl/migration/batch_planner.py b/redisvl/migration/batch_planner.py new file mode 100644 index 000000000..eab94c810 --- /dev/null +++ b/redisvl/migration/batch_planner.py @@ -0,0 +1,307 @@ +"""Batch migration planner for migrating multiple indexes with a shared patch.""" + +from __future__ import annotations + +import fnmatch +import uuid +from pathlib import Path +from typing import Any, List, Optional, Tuple + +import redis.exceptions +import yaml + +from redisvl.index import SearchIndex +from redisvl.migration.models import BatchIndexEntry, BatchPlan, SchemaPatch +from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.utils import list_indexes, timestamp_utc +from redisvl.redis.connection import RedisConnectionFactory + + +class BatchMigrationPlanner: + """Planner for batch migration of multiple indexes with a shared patch. + + The batch planner applies a single SchemaPatch to multiple indexes, + checking applicability for each index based on field name matching. + """ + + def __init__(self): + self._single_planner = MigrationPlanner() + + def create_batch_plan( + self, + *, + indexes: Optional[List[str]] = None, + pattern: Optional[str] = None, + indexes_file: Optional[str] = None, + schema_patch_path: str, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + failure_policy: str = "fail_fast", + ) -> BatchPlan: + # --- NEW: validate failure_policy early --- + """Create a batch migration plan for multiple indexes. + + Args: + indexes: Explicit list of index names. + pattern: Glob pattern to match index names (e.g., "*_idx"). + indexes_file: Path to file with index names (one per line). + schema_patch_path: Path to shared schema patch YAML file. + redis_url: Redis connection URL. + redis_client: Existing Redis client. + failure_policy: "fail_fast" or "continue_on_error". + + Returns: + BatchPlan with shared patch and per-index applicability. + """ + _VALID_FAILURE_POLICIES = {"fail_fast", "continue_on_error"} + if failure_policy not in _VALID_FAILURE_POLICIES: + raise ValueError( + f"Invalid failure_policy '{failure_policy}'. " + f"Must be one of: {sorted(_VALID_FAILURE_POLICIES)}" + ) + + # Get Redis client + client = redis_client + if client is None: + if not redis_url: + raise ValueError("Must provide either redis_url or redis_client") + client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url) + + # Resolve index list + index_names = self._resolve_index_names( + indexes=indexes, + pattern=pattern, + indexes_file=indexes_file, + redis_client=client, + ) + + if not index_names: + raise ValueError("No indexes found matching the specified criteria") + + # Load shared patch + shared_patch = self._single_planner.load_schema_patch(schema_patch_path) + + # Check applicability for each index + batch_entries: List[BatchIndexEntry] = [] + requires_quantization = False + + for index_name in index_names: + entry, has_quantization = self._check_index_applicability( + index_name=index_name, + shared_patch=shared_patch, + redis_client=client, + ) + batch_entries.append(entry) + if has_quantization: + requires_quantization = True + + batch_id = f"batch_{uuid.uuid4().hex[:12]}" + + return BatchPlan( + batch_id=batch_id, + mode="drop_recreate", + failure_policy=failure_policy, + requires_quantization=requires_quantization, + shared_patch=shared_patch, + indexes=batch_entries, + created_at=timestamp_utc(), + ) + + def _resolve_index_names( + self, + *, + indexes: Optional[List[str]], + pattern: Optional[str], + indexes_file: Optional[str], + redis_client: Any, + ) -> List[str]: + """Resolve index names from explicit list, pattern, or file.""" + sources = sum([bool(indexes), bool(pattern), bool(indexes_file)]) + if sources == 0: + raise ValueError("Must provide one of: indexes, pattern, or indexes_file") + if sources > 1: + raise ValueError("Provide only one of: indexes, pattern, or indexes_file") + + if indexes: + # Deduplicate while preserving order + return list(dict.fromkeys(indexes)) + + if indexes_file: + return self._load_indexes_from_file(indexes_file) + + # Pattern matching -- pattern is guaranteed non-None at this point + assert pattern is not None, "pattern must be set when reaching fnmatch" + all_indexes = list_indexes(redis_client=redis_client) + matched = [idx for idx in all_indexes if fnmatch.fnmatch(idx, pattern)] + return sorted(matched) + + def _load_indexes_from_file(self, file_path: str) -> List[str]: + """Load index names from a file (one per line).""" + path = Path(file_path).resolve() + if not path.exists(): + raise FileNotFoundError(f"Indexes file not found: {file_path}") + + with open(path, "r") as f: + lines = f.readlines() + + return [ + stripped + for line in lines + if (stripped := line.strip()) and not stripped.startswith("#") + ] + + def _check_index_applicability( + self, + *, + index_name: str, + shared_patch: SchemaPatch, + redis_client: Any, + ) -> Tuple[BatchIndexEntry, bool]: + """Check if the shared patch can be applied to a specific index. + + Returns: + Tuple of (BatchIndexEntry, requires_quantization). + """ + try: + index = SearchIndex.from_existing(index_name, redis_client=redis_client) + schema_dict = index.schema.to_dict() + field_names = {f["name"] for f in schema_dict.get("fields", [])} + + # Build a set of field names that includes rename targets so + # that update_fields referencing the NEW name of a renamed field + # are considered applicable. + rename_target_names = { + fr.new_name for fr in shared_patch.changes.rename_fields + } + effective_field_names = field_names | rename_target_names + + # Check that all update_fields exist in this index (or are rename targets) + missing_fields = [] + for field_update in shared_patch.changes.update_fields: + if field_update.name not in effective_field_names: + missing_fields.append(field_update.name) + + if missing_fields: + return ( + BatchIndexEntry( + name=index_name, + applicable=False, + skip_reason=f"Missing fields: {', '.join(missing_fields)}", + ), + False, + ) + + # Validate rename targets don't collide with each other or + # existing fields (after accounting for the source being renamed away) + if shared_patch.changes.rename_fields: + rename_targets = [ + fr.new_name for fr in shared_patch.changes.rename_fields + ] + rename_sources = { + fr.old_name for fr in shared_patch.changes.rename_fields + } + seen_targets: dict[str, int] = {} + for t in rename_targets: + seen_targets[t] = seen_targets.get(t, 0) + 1 + duplicates = [t for t, c in seen_targets.items() if c > 1] + if duplicates: + return ( + BatchIndexEntry( + name=index_name, + applicable=False, + skip_reason=f"Rename targets collide: {', '.join(duplicates)}", + ), + False, + ) + # Check if any rename target already exists and isn't itself being renamed away + collisions = [ + t + for t in rename_targets + if t in field_names and t not in rename_sources + ] + if collisions: + return ( + BatchIndexEntry( + name=index_name, + applicable=False, + skip_reason=f"Rename targets already exist: {', '.join(collisions)}", + ), + False, + ) + + # Check that add_fields don't already exist. + # Fields being renamed away free their name for new additions. + rename_sources = {fr.old_name for fr in shared_patch.changes.rename_fields} + post_rename_fields = (field_names - rename_sources) | rename_target_names + existing_adds: list[str] = [] + for field in shared_patch.changes.add_fields: + field_name = field.get("name") + if field_name and field_name in post_rename_fields: + existing_adds.append(field_name) + + if existing_adds: + return ( + BatchIndexEntry( + name=index_name, + applicable=False, + skip_reason=f"Fields already exist: {', '.join(existing_adds)}", + ), + False, + ) + + # Try creating a plan to check for blocked changes + plan = self._single_planner.create_plan_from_patch( + index_name, + schema_patch=shared_patch, + redis_client=redis_client, + ) + + if not plan.diff_classification.supported: + return ( + BatchIndexEntry( + name=index_name, + applicable=False, + skip_reason=( + plan.diff_classification.blocked_reasons[0] + if plan.diff_classification.blocked_reasons + else "Unsupported changes" + ), + ), + False, + ) + + # Detect quantization from the plan we already created + has_quantization = bool( + MigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, + plan.merged_target_schema, + rename_operations=plan.rename_operations, + ) + ) + + return BatchIndexEntry(name=index_name, applicable=True), has_quantization + + except ( + ConnectionError, + OSError, + TimeoutError, + redis.exceptions.ConnectionError, + ) as e: + # Infrastructure failures should propagate, not be silently + # treated as "not applicable". + raise + except Exception as e: + return ( + BatchIndexEntry( + name=index_name, + applicable=False, + skip_reason=str(e), + ), + False, + ) + + def write_batch_plan(self, batch_plan: BatchPlan, path: str) -> None: + """Write batch plan to YAML file.""" + plan_path = Path(path).resolve() + with open(plan_path, "w") as f: + yaml.safe_dump(batch_plan.model_dump(exclude_none=True), f, sort_keys=False) diff --git a/redisvl/migration/executor.py b/redisvl/migration/executor.py new file mode 100644 index 000000000..85e85d5ce --- /dev/null +++ b/redisvl/migration/executor.py @@ -0,0 +1,1381 @@ +from __future__ import annotations + +import hashlib +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional + +if TYPE_CHECKING: + from redisvl.migration.backup import VectorBackup + +from redis.cluster import RedisCluster +from redis.exceptions import ResponseError + +from redisvl.index import SearchIndex +from redisvl.migration.models import ( + MigrationBenchmarkSummary, + MigrationPlan, + MigrationReport, + MigrationTimings, + MigrationValidation, +) +from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.reliability import is_same_width_dtype_conversion +from redisvl.migration.utils import ( + build_scan_match_patterns, + current_source_matches_snapshot, + detect_aof_enabled, + estimate_disk_space, + get_schema_field_path, + normalize_keys, + timestamp_utc, + wait_for_index_ready, +) +from redisvl.migration.validation import MigrationValidator +from redisvl.types import SyncRedisClient +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + + +class MigrationExecutor: + def __init__(self, validator: Optional[MigrationValidator] = None): + self.validator = validator or MigrationValidator() + + def _enumerate_indexed_keys( + self, + client: SyncRedisClient, + index_name: str, + batch_size: int = 1000, + key_separator: str = ":", + ) -> Generator[str, None, None]: + """Enumerate document keys using FT.AGGREGATE with SCAN fallback. + + Uses FT.AGGREGATE WITHCURSOR for efficient enumeration when the index + has no indexing failures. Falls back to SCAN if: + - Index has hash_indexing_failures > 0 (would miss failed docs) + - FT.AGGREGATE command fails for any reason + + Args: + client: Redis client + index_name: Name of the index to enumerate + batch_size: Number of keys per batch + key_separator: Separator between prefix and key ID + + Yields: + Document keys as strings + """ + # Check for indexing failures - if any, fall back to SCAN + try: + info = client.ft(index_name).info() + failures = int(info.get("hash_indexing_failures", 0) or 0) + if failures > 0: + logger.warning( + f"Index '{index_name}' has {failures} indexing failures. " + "Using SCAN for complete enumeration." + ) + yield from self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ) + return + except Exception as e: + logger.warning(f"Failed to check index info: {e}. Using SCAN fallback.") + yield from self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ) + return + + # Try FT.AGGREGATE enumeration + try: + yield from self._enumerate_with_aggregate(client, index_name, batch_size) + except ResponseError as e: + logger.warning( + f"FT.AGGREGATE failed: {e}. Falling back to SCAN enumeration." + ) + yield from self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ) + + def _enumerate_with_aggregate( + self, + client: SyncRedisClient, + index_name: str, + batch_size: int = 1000, + ) -> Generator[str, None, None]: + """Enumerate keys using FT.AGGREGATE WITHCURSOR. + + More efficient than SCAN for sparse indexes (only returns indexed docs). + Requires LOAD 1 __key to retrieve document keys. + + Note: FT.AGGREGATE cursors expire after ~5 minutes of idle time on the + server side. If the caller processes a batch slowly (e.g. performing + heavy per-key work between reads), a subsequent FT.CURSOR READ will + fail with a ``Cursor not found`` error. This is caught and re-raised + so the caller (_enumerate_indexed_keys) can fall back to SCAN. + """ + cursor_id: Optional[int] = None + + try: + # Initial aggregate call with LOAD 1 __key (not LOAD 0!) + # Use MAXIDLE to extend the server-side cursor idle timeout. + # Default Redis cursor idle timeout is 300 000 ms (5 min); + # we request the maximum allowed (300 000 ms). + result = client.execute_command( + "FT.AGGREGATE", + index_name, + "*", + "LOAD", + "1", + "__key", + "WITHCURSOR", + "COUNT", + str(batch_size), + "MAXIDLE", + "300000", + ) + + while True: + results_data, cursor_id = result + + # Extract keys from results (skip first element which is count) + for item in results_data[1:]: + if isinstance(item, (list, tuple)) and len(item) >= 2: + key = item[1] + yield key.decode() if isinstance(key, bytes) else str(key) + + # Check if done (cursor_id == 0) + if cursor_id == 0: + break + + # Read next batch. The cursor may have expired if the caller + # took longer than MAXIDLE between reads — let the + # ResponseError propagate so the caller can fall back to SCAN. + result = client.execute_command( + "FT.CURSOR", + "READ", + index_name, + str(cursor_id), + "COUNT", + str(batch_size), + ) + finally: + # Clean up cursor if interrupted + if cursor_id and cursor_id != 0: + try: + client.execute_command( + "FT.CURSOR", "DEL", index_name, str(cursor_id) + ) + except Exception: + pass # Cursor may have expired + + def _enumerate_with_scan( + self, + client: SyncRedisClient, + index_name: str, + batch_size: int = 1000, + key_separator: str = ":", + ) -> Generator[str, None, None]: + """Enumerate keys using SCAN with prefix matching. + + Fallback method that scans all keys matching the index prefix. + Less efficient but more complete (includes failed-to-index docs). + """ + # Get prefix from index info + try: + info = client.ft(index_name).info() + # Handle both dict and list formats from FT.INFO + if isinstance(info, dict): + prefixes = info.get("index_definition", {}).get("prefixes", []) + else: + # List format - find index_definition + prefixes = [] + for i, item in enumerate(info): + if item == b"index_definition" or item == "index_definition": + defn = info[i + 1] + if isinstance(defn, dict): + prefixes = defn.get("prefixes", []) + elif isinstance(defn, list): + for j, d in enumerate(defn): + if d in (b"prefixes", "prefixes") and j + 1 < len(defn): + prefixes = defn[j + 1] + break + normalized_prefixes = [ + p.decode() if isinstance(p, bytes) else str(p) for p in prefixes + ] + except Exception as e: + logger.warning(f"Failed to get prefix from index info: {e}") + normalized_prefixes = [] + + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + normalized_prefixes, key_separator + ): + cursor = 0 + while True: + cursor, keys = client.scan( # type: ignore[misc] + cursor=cursor, + match=match_pattern, + count=batch_size, + ) + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else str(key) + if key_str not in seen_keys: + seen_keys.add(key_str) + yield key_str + + if cursor == 0: + break + + def _rename_keys( + self, + client: SyncRedisClient, + keys: List[str], + old_prefix: str, + new_prefix: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Rename keys from old prefix to new prefix. + + Uses RENAMENX to avoid overwriting existing destination keys. + Raises on collision to prevent silent data loss. + + For Redis Cluster, RENAME/RENAMENX fails with CROSSSLOT errors when + old and new keys hash to different slots. In that case we fall back + to DUMP/RESTORE/DEL per key, which works across slots. + + Args: + client: Redis client + keys: List of keys to rename + old_prefix: Current prefix (e.g., "doc:") + new_prefix: New prefix (e.g., "article:") + progress_callback: Optional callback(done, total) + + Returns: + Number of keys successfully renamed + """ + is_cluster = isinstance(client, RedisCluster) + if is_cluster: + return self._rename_keys_cluster( + client, keys, old_prefix, new_prefix, progress_callback + ) + return self._rename_keys_standalone( + client, keys, old_prefix, new_prefix, progress_callback + ) + + def _rename_keys_standalone( + self, + client: SyncRedisClient, + keys: List[str], + old_prefix: str, + new_prefix: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Rename keys using pipelined RENAMENX (standalone Redis only).""" + renamed = 0 + total = len(keys) + pipeline_size = 100 + collisions: List[str] = [] + successfully_renamed: List[tuple] = [] # (old_key, new_key) for recovery info + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + pipe = client.pipeline(transaction=False) + batch_key_pairs: List[tuple] = [] # (old_key, new_key) + + for key in batch: + if key.startswith(old_prefix): + new_key = new_prefix + key[len(old_prefix) :] + else: + logger.warning( + f"Key '{key}' does not start with prefix '{old_prefix}'" + ) + continue + pipe.renamenx(key, new_key) + batch_key_pairs.append((key, new_key)) + + try: + results = pipe.execute() + for j, r in enumerate(results): + if r is True or r == 1: + renamed += 1 + successfully_renamed.append(batch_key_pairs[j]) + else: + collisions.append(batch_key_pairs[j][1]) + except Exception as e: + logger.warning(f"Error in rename batch: {e}") + raise + + # Fail fast on collisions to avoid partial renames across batches. + if collisions: + raise RuntimeError( + f"Prefix rename aborted after {renamed} successful rename(s): " + f"{len(collisions)} destination key(s) already exist " + f"(first 5: {collisions[:5]}). This would overwrite existing data. " + f"Remove conflicting keys or choose a different prefix. " + f"Note: {renamed} key(s) were already renamed from " + f"'{old_prefix}*' to '{new_prefix}*' and must be reversed " + f"manually if you want to retry." + ) + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + def _rename_keys_cluster( + self, + client: SyncRedisClient, + keys: List[str], + old_prefix: str, + new_prefix: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Rename keys using batched DUMP/RESTORE/DEL for Redis Cluster. + + RENAME/RENAMENX raises CROSSSLOT errors when source and destination + hash to different slots. DUMP/RESTORE works across slots. + + Batches DUMP+PTTL reads and RESTORE+DEL writes in groups of + ``pipeline_size`` to reduce per-key round-trip overhead. + """ + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + # Build (key, new_key) pairs for this batch + pairs = [] + for key in batch: + if not key.startswith(old_prefix): + logger.warning( + "Key '%s' does not start with prefix '%s'", key, old_prefix + ) + continue + new_key = new_prefix + key[len(old_prefix) :] + pairs.append((key, new_key)) + + if not pairs: + continue + + # Phase 1: Check destination keys don't exist (batched) + check_pipe = client.pipeline(transaction=False) + for _, new_key in pairs: + check_pipe.exists(new_key) + exists_results = check_pipe.execute() + for (_, new_key), exists in zip(pairs, exists_results): + if exists: + raise RuntimeError( + f"Prefix rename aborted after {renamed} successful rename(s): " + f"destination key '{new_key}' already exists. " + f"Remove conflicting keys or choose a different prefix." + ) + + # Phase 2: DUMP + PTTL all source keys (batched — 1 RTT) + dump_pipe = client.pipeline(transaction=False) + for key, _ in pairs: + dump_pipe.dump(key) + dump_pipe.pttl(key) + dump_results = dump_pipe.execute() + + # Phase 3: RESTORE + DEL (batched — 1 RTT) + restore_pipe = client.pipeline(transaction=False) + valid_pairs = [] + for idx, (key, new_key) in enumerate(pairs): + dumped = dump_results[idx * 2] + ttl = int(dump_results[idx * 2 + 1]) # type: ignore[arg-type] + if dumped is None: + logger.warning("Key '%s' does not exist, skipping", key) + continue + restore_ttl = max(ttl, 0) + restore_pipe.restore(new_key, restore_ttl, dumped, replace=False) # type: ignore[arg-type] + restore_pipe.delete(key) + valid_pairs.append((key, new_key)) + + if valid_pairs: + restore_pipe.execute() + renamed += len(valid_pairs) + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + if progress_callback: + progress_callback(total, total) + + return renamed + + def _rename_field_in_hash( + self, + client: SyncRedisClient, + keys: List[str], + old_name: str, + new_name: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Rename a field in hash documents. + + For each document: + 1. HGET key old_name -> value + 2. HSET key new_name value + 3. HDEL key old_name + """ + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + # First, get old field values AND check if destination exists + pipe = client.pipeline(transaction=False) + for key in batch: + pipe.hget(key, old_name) + pipe.hexists(key, new_name) + raw_results = pipe.execute() + # Interleaved: [hget_0, hexists_0, hget_1, hexists_1, ...] + values = raw_results[0::2] + dest_exists = raw_results[1::2] + + # Now set new field and delete old + pipe = client.pipeline(transaction=False) + batch_ops = 0 + for key, value, exists in zip(batch, values, dest_exists): + if value is not None: + if exists: + logger.warning( + "Field '%s' already exists in key '%s'; " + "overwriting with value from '%s'", + new_name, + key, + old_name, + ) + pipe.hset(key, new_name, value) + pipe.hdel(key, old_name) + batch_ops += 1 + + try: + pipe.execute() + # Count by number of keys that had old field values, + # not by HSET return (HSET returns 0 for existing field updates) + renamed += batch_ops + except Exception as e: + logger.warning(f"Error in field rename batch: {e}") + raise + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + def _rename_field_in_json( + self, + client: SyncRedisClient, + keys: List[str], + old_path: str, + new_path: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Rename a field in JSON documents. + + For each document: + 1. JSON.GET key old_path -> value + 2. JSON.SET key new_path value + 3. JSON.DEL key old_path + """ + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + # First, get all old field values + pipe = client.pipeline(transaction=False) + for key in batch: + pipe.json().get(key, old_path) + values = pipe.execute() + + # Now set new field and delete old + # JSONPath GET returns results as a list; unwrap single-element + # results to preserve the original document shape. + # Missing paths return None or [] depending on Redis version. + pipe = client.pipeline(transaction=False) + batch_ops = 0 + for key, value in zip(batch, values): + if value is None or value == []: + continue + if isinstance(value, list) and len(value) == 1: + value = value[0] + pipe.json().set(key, new_path, value) + pipe.json().delete(key, old_path) + batch_ops += 1 + try: + pipe.execute() + # Count by number of keys that had old field values, + # not by JSON.SET return value + renamed += batch_ops + except Exception as e: + logger.warning(f"Error in JSON field rename batch: {e}") + raise + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + def apply( + self, + plan: MigrationPlan, + *, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + query_check_file: Optional[str] = None, + progress_callback: Optional[Callable[[str, Optional[str]], None]] = None, + backup_dir: Optional[str] = None, + batch_size: int = 500, + num_workers: int = 1, + keep_backup: bool = False, + checkpoint_path: Optional[str] = None, # deprecated, use backup_dir + ) -> MigrationReport: + """Apply a migration plan. + + Executes the migration phases in order: enumerate → dump → drop → + key-renames → quantize → create → index → validate. + + **Single-worker mode** (default): original vectors are read from Redis + and backed up to disk *before* the index is dropped, then converted + and written back after the drop. This provides the strongest + crash-safety: if the process dies after drop, the complete backup is + already on disk for manual rollback. + + **Multi-worker mode** (``num_workers > 1``): for performance, the dump + and quantize phases are fused — each worker reads its key shard, + writes the original to its backup shard, converts, and writes the + quantized vector back, all *after* the index drop. This avoids a + redundant full read pass but means the backup may be incomplete if + the process crashes mid-quantize. A re-run with the same + ``backup_dir`` will detect partial backups and resume from where it + left off. + + Args: + plan: The migration plan to apply (from ``MigrationPlanner.create_plan``). + redis_url: Redis connection URL (e.g. ``"redis://localhost:6379"``). + Required when *num_workers* > 1 so each worker can open its own + connection. Mutually exclusive with *redis_client* for the + multi-worker path. + redis_client: Optional existing Redis client. Ignored when + *num_workers* > 1. + query_check_file: Optional YAML file containing post-migration + queries to verify search results. + progress_callback: Optional ``callback(step, detail)`` invoked + during each migration phase. + + * *step*: phase name (``"enumerate"``, ``"dump"``, ``"drop"``, + ``"quantize"``, ``"create"``, ``"index"``, ``"validate"``) + * *detail*: human-readable progress string + (e.g. ``"1000/5000 docs"``) or ``None`` + backup_dir: Directory for vector backup files. When provided, + original vectors are saved to disk before mutation, enabling + crash-safe resume (re-run the same command) and manual rollback. + Required when *num_workers* > 1. Disk usage is approximately + ``num_docs × dims × bytes_per_element`` (e.g. ~2.9 GB for 1 M + 768-dim float32 vectors). + batch_size: Number of keys per Redis pipeline batch (default 500). + Controls the granularity of pipelined ``HGET``/``HSET`` calls. + Larger batches reduce round-trips but increase per-batch memory. + Values between 200 and 1000 are typical. + num_workers: Number of parallel quantization workers (default 1). + Each worker opens its own Redis connection and writes to its own + backup-file shard. Requires *backup_dir* and *redis_url*. + Parallelism improves throughput for high-dimensional vectors + where conversion is CPU-bound. For low-dimensional vectors + (≤ 256 dims), a single worker is often faster because the + per-worker overhead (process spawning, extra connections) + outweighs the parallelism benefit. Diminishing returns above + 4–8 workers on a single Redis instance. + keep_backup: If ``True``, retain backup files after a successful + migration. Default ``False`` (auto-cleanup on success). Useful + for post-migration auditing or manual rollback. + + Returns: + MigrationReport: Outcome including timing breakdown, validation + results, and any warnings or manual actions. + """ + # Handle deprecated checkpoint_path parameter + if checkpoint_path is not None: + import warnings + + warnings.warn( + "checkpoint_path is deprecated and will be removed in a future " + "version. Use backup_dir instead.", + DeprecationWarning, + stacklevel=2, + ) + if backup_dir is None: + backup_dir = checkpoint_path + + started_at = timestamp_utc() + started = time.perf_counter() + + report = MigrationReport( + source_index=plan.source.index_name, + target_index=plan.merged_target_schema["index"]["name"], + result="failed", + started_at=started_at, + finished_at=started_at, + warnings=list(plan.warnings), + ) + + if not plan.diff_classification.supported: + report.validation.errors.extend(plan.diff_classification.blocked_reasons) + report.manual_actions.append( + "This change requires document migration, which is not yet supported." + ) + report.finished_at = timestamp_utc() + return report + + # Check if we are resuming from a backup file (post-crash). + # New migration order: enumerate → field-renames → DUMP → DROP + # → key-renames → QUANTIZE → CREATE. + # The backup file stores original vectors and tracks progress. + # If a backup file exists, we can determine exactly where the + # previous run stopped and resume from there. + from redisvl.migration.backup import VectorBackup + + resuming_from_backup = False + existing_backup: Optional[VectorBackup] = None + backup_path: Optional[str] = None + + if backup_dir: + # Sanitize index name for filesystem with hash suffix to avoid + # collisions between distinct names that sanitize identically + # (e.g., "a/b" and "a:b" both become "a_b"). + safe_name = ( + plan.source.index_name.replace("/", "_") + .replace("\\", "_") + .replace(":", "_") + ) + name_hash = hashlib.sha256(plan.source.index_name.encode()).hexdigest()[:8] + backup_path = str( + Path(backup_dir) / f"migration_backup_{safe_name}_{name_hash}" + ) + existing_backup = VectorBackup.load(backup_path) + + # Fallback: probe for legacy backup filename (pre-hash naming) + if existing_backup is None: + legacy_path = str(Path(backup_dir) / f"migration_backup_{safe_name}") + legacy_backup = VectorBackup.load(legacy_path) + if legacy_backup is not None: + logger.info( + "Found legacy backup at %s (pre-hash naming), using it", + legacy_path, + ) + existing_backup = legacy_backup + backup_path = legacy_path + + if existing_backup is not None: + if existing_backup.header.index_name != plan.source.index_name: + logger.warning( + "Backup index '%s' does not match plan index '%s', ignoring", + existing_backup.header.index_name, + plan.source.index_name, + ) + existing_backup = None + elif existing_backup.header.phase == "completed": + # Previous run completed quantization. Index may need recreating. + resuming_from_backup = True + logger.info( + "Backup at %s is completed; skipping to index creation", + backup_path, + ) + elif existing_backup.header.phase in ("active", "ready"): + # Crash after dump (possibly after drop). Resume. + resuming_from_backup = True + logger.info( + "Backup at %s found (phase=%s), resuming migration", + backup_path, + existing_backup.header.phase, + ) + elif existing_backup.header.phase == "dump": + # Crash during dump — index should still be alive. + # For simplicity, remove partial backup and restart. + logger.info( + "Partial dump found at %s, restarting dump", + backup_path, + ) + Path(backup_path + ".header").unlink(missing_ok=True) + Path(backup_path + ".data").unlink(missing_ok=True) + existing_backup = None + + resuming = resuming_from_backup + + if not resuming: + if not current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ): + report.validation.errors.append( + "The current live source schema no longer matches the saved source snapshot." + ) + report.manual_actions.append( + "Re-run `rvl migrate plan` to refresh the migration plan before applying." + ) + report.finished_at = timestamp_utc() + return report + + source_index = SearchIndex.from_existing( + plan.source.index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + else: + # Source index was dropped before crash; reconstruct from snapshot + # to get a valid SearchIndex with a Redis client attached. + source_index = SearchIndex.from_dict( + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ) + + target_index = SearchIndex.from_dict( + plan.merged_target_schema, + redis_url=redis_url, + redis_client=redis_client, + ) + + enumerate_duration = 0.0 + drop_duration = 0.0 + quantize_duration = 0.0 + field_rename_duration = 0.0 + key_rename_duration = 0.0 + recreate_duration = 0.0 + indexing_duration = 0.0 + target_info: Dict[str, Any] = {} + docs_quantized = 0 + keys_to_process: List[str] = [] + storage_type = plan.source.keyspace.storage_type + + # Check if we need to re-encode vectors for datatype changes + datatype_changes = MigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, + plan.merged_target_schema, + rename_operations=plan.rename_operations, + ) + + # Check for rename operations + rename_ops = plan.rename_operations + has_prefix_change = rename_ops.change_prefix is not None + has_field_renames = bool(rename_ops.rename_fields) + needs_quantization = bool(datatype_changes) and storage_type != "json" + needs_enumeration = needs_quantization or has_prefix_change or has_field_renames + has_same_width_quantization = any( + is_same_width_dtype_conversion(change["source"], change["target"]) + for change in datatype_changes.values() + ) + + if backup_dir and has_same_width_quantization: + report.validation.errors.append( + "Crash-safe resume is not supported for same-width datatype " + "changes (float16<->bfloat16 or int8<->uint8)." + ) + report.manual_actions.append( + "Re-run without --backup-dir for same-width vector conversions, or " + "split the migration to avoid same-width datatype changes." + ) + report.finished_at = timestamp_utc() + return report + + def _notify(step: str, detail: Optional[str] = None) -> None: + if progress_callback: + progress_callback(step, detail) + + try: + client = source_index._redis_client + aof_enabled = detect_aof_enabled(client) + disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + logger.info( + "Disk space estimate: RDB ~%d bytes, AOF ~%d bytes, total ~%d bytes", + disk_estimate.rdb_snapshot_disk_bytes, + disk_estimate.aof_growth_bytes, + disk_estimate.total_new_disk_bytes, + ) + report.disk_space_estimate = disk_estimate + + if resuming_from_backup and existing_backup is not None: + # Resume from backup file. The backup has the key list + # and original vectors — no enumeration or SCAN needed. + if existing_backup.header.phase == "completed": + # Quantize already done, skip to CREATE + _notify("enumerate", "skipped (resume from backup)") + _notify("drop", "skipped (already dropped)") + _notify("quantize", "skipped (already completed)") + elif existing_backup.header.phase in ("active", "ready"): + _notify("enumerate", "skipped (resume from backup)") + _notify("drop", "skipped (already dropped)") + + # Remap datatype_changes if field renames happened + effective_changes = datatype_changes + if has_field_renames: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } + + _notify("quantize", "Resuming vector re-encoding from backup...") + quantize_started = time.perf_counter() + docs_quantized = self._quantize_from_backup( + client=client, + backup=existing_backup, + datatype_changes=effective_changes, + progress_callback=lambda done, total: _notify( + "quantize", f"{done:,}/{total:,} docs" + ), + ) + quantize_duration = round(time.perf_counter() - quantize_started, 3) + _notify( + "quantize", + f"done ({docs_quantized:,} docs in {quantize_duration}s)", + ) + + # Key prefix renames may not have happened before the crash + # (they run after index drop in the normal path). Re-apply + # idempotently — RENAME is a no-op if old == new or key + # was already renamed. + if has_prefix_change: + # Collect keys from backup to know what to rename + resume_keys = [] + for batch_keys, _ in existing_backup.iter_batches(): + resume_keys.extend(batch_keys) + if resume_keys: + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + assert new_prefix is not None + _notify("key_rename", "Renaming keys (resume)...") + key_rename_started = time.perf_counter() + renamed_count = self._rename_keys( + client, + resume_keys, + old_prefix, + new_prefix, + progress_callback=lambda done, total: _notify( + "key_rename", f"{done:,}/{total:,} keys" + ), + ) + key_rename_duration = round( + time.perf_counter() - key_rename_started, 3 + ) + _notify( + "key_rename", + f"done ({renamed_count:,} keys in {key_rename_duration}s)", + ) + else: + # Normal (non-resume) path + # STEP 1: Enumerate keys BEFORE any modifications + if needs_enumeration: + _notify("enumerate", "Enumerating indexed documents...") + enumerate_started = time.perf_counter() + keys_to_process = list( + self._enumerate_indexed_keys( + client, + plan.source.index_name, + batch_size=1000, + key_separator=plan.source.keyspace.key_separator, + ) + ) + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) + + # STEP 2: Field renames (before dropping index) + if has_field_renames and keys_to_process: + _notify("field_rename", "Renaming fields in documents...") + field_rename_started = time.perf_counter() + for field_rename in rename_ops.rename_fields: + if storage_type == "json": + old_path = get_schema_field_path( + plan.source.schema_snapshot, field_rename.old_name + ) + new_path = get_schema_field_path( + plan.merged_target_schema, field_rename.new_name + ) + if not old_path or not new_path or old_path == new_path: + continue + self._rename_field_in_json( + client, + keys_to_process, + old_path, + new_path, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + else: + self._rename_field_in_hash( + client, + keys_to_process, + field_rename.old_name, + field_rename.new_name, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + field_rename_duration = round( + time.perf_counter() - field_rename_started, 3 + ) + _notify("field_rename", f"done ({field_rename_duration}s)") + + # STEP 3: Dump original vectors to backup file (before drop) + # For multi-worker, dump happens inside multi_worker_quantize + # after the drop, so we skip the separate dump step. + dump_duration = 0.0 + active_backup = None + use_multi_worker = num_workers > 1 and backup_dir is not None + if ( + needs_quantization + and keys_to_process + and backup_path + and not use_multi_worker + ): + # Single-worker dump before drop + effective_changes = datatype_changes + if has_field_renames: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } + _notify("dump", "Backing up original vectors...") + dump_started = time.perf_counter() + active_backup = self._dump_vectors( + client=client, + index_name=plan.source.index_name, + keys=keys_to_process, + datatype_changes=effective_changes, + backup_path=backup_path, + batch_size=batch_size, + progress_callback=lambda done, total: _notify( + "dump", f"{done:,}/{total:,} docs" + ), + ) + dump_duration = round(time.perf_counter() - dump_started, 3) + _notify("dump", f"done ({dump_duration}s)") + + # STEP 4: Drop the index + _notify("drop", "Dropping index definition...") + drop_started = time.perf_counter() + source_index.delete(drop=False) + drop_duration = round(time.perf_counter() - drop_started, 3) + _notify("drop", f"done ({drop_duration}s)") + + # STEP 5: Key renames (after drop, before recreate) + if has_prefix_change and keys_to_process: + _notify("key_rename", "Renaming keys...") + key_rename_started = time.perf_counter() + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + assert new_prefix is not None + renamed_count = self._rename_keys( + client, + keys_to_process, + old_prefix, + new_prefix, + progress_callback=lambda done, total: _notify( + "key_rename", f"{done:,}/{total:,} keys" + ), + ) + key_rename_duration = round( + time.perf_counter() - key_rename_started, 3 + ) + _notify( + "key_rename", + f"done ({renamed_count:,} keys in {key_rename_duration}s)", + ) + + # STEP 6: Quantize vectors + if needs_quantization and keys_to_process: + effective_changes = datatype_changes + if has_field_renames: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } + + # Update key references if prefix changed + if has_prefix_change and rename_ops.change_prefix: + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + keys_to_process = [ + ( + new_prefix + k[len(old_prefix) :] + if k.startswith(old_prefix) + else k + ) + for k in keys_to_process + ] + + if use_multi_worker: + # Multi-worker path: dump + quantize in parallel + from redisvl.migration.quantize import multi_worker_quantize + + if backup_dir is None: + raise ValueError( + "--backup-dir is required when using --workers > 1" + ) + if redis_url is None: + raise ValueError( + "redis_url is required when using num_workers > 1" + ) + _notify( + "quantize", + f"Re-encoding vectors ({num_workers} workers)...", + ) + quantize_started = time.perf_counter() + mw_result = multi_worker_quantize( + redis_url=redis_url, + keys=keys_to_process, + datatype_changes=effective_changes, + backup_dir=backup_dir, + index_name=plan.source.index_name, + num_workers=num_workers, + batch_size=batch_size, + ) + docs_quantized = mw_result.total_docs_quantized + elif active_backup: + # Single-worker backup path + _notify("quantize", "Re-encoding vectors from backup...") + quantize_started = time.perf_counter() + docs_quantized = self._quantize_from_backup( + client=client, + backup=active_backup, + datatype_changes=effective_changes, + progress_callback=lambda done, total: _notify( + "quantize", f"{done:,}/{total:,} docs" + ), + ) + else: + # No backup dir — direct pipeline read + write + from redisvl.migration.quantize import ( + convert_vectors, + pipeline_read_vectors, + pipeline_write_vectors, + ) + + _notify("quantize", "Re-encoding vectors...") + quantize_started = time.perf_counter() + docs_quantized = 0 + total = len(keys_to_process) + for batch_start in range(0, total, batch_size): + batch_keys = keys_to_process[ + batch_start : batch_start + batch_size + ] + originals = pipeline_read_vectors( + client, batch_keys, effective_changes + ) + converted = convert_vectors(originals, effective_changes) + if converted: + pipeline_write_vectors(client, converted) + docs_quantized += len(converted) if converted else 0 + if progress_callback: + _notify( + "quantize", + f"{docs_quantized:,}/{total:,} docs", + ) + quantize_duration = round(time.perf_counter() - quantize_started, 3) + _notify( + "quantize", + f"done ({docs_quantized:,} docs in {quantize_duration}s)", + ) + report.warnings.append( + f"Re-encoded {docs_quantized} documents for vector quantization: " + f"{datatype_changes}" + ) + elif datatype_changes and storage_type == "json": + _notify( + "quantize", "skipped (JSON vectors are re-indexed on recreate)" + ) + + _notify("create", "Creating index with new schema...") + recreate_started = time.perf_counter() + target_index.create() + recreate_duration = round(time.perf_counter() - recreate_started, 3) + _notify("create", f"done ({recreate_duration}s)") + + _notify("index", "Waiting for re-indexing...") + + def _index_progress(indexed: int, total: int, pct: float) -> None: + _notify("index", f"{indexed:,}/{total:,} docs ({pct:.0f}%)") + + target_info, indexing_duration = wait_for_index_ready( + target_index, progress_callback=_index_progress + ) + _notify("index", f"done ({indexing_duration}s)") + + _notify("validate", "Validating migration...") + validation, target_info, validation_duration = self.validator.validate( + plan, + redis_url=redis_url, + redis_client=redis_client, + query_check_file=query_check_file, + ) + _notify("validate", f"done ({validation_duration}s)") + report.validation = validation + total_duration = round(time.perf_counter() - started, 3) + report.timings = MigrationTimings( + total_migration_duration_seconds=total_duration, + drop_duration_seconds=drop_duration, + quantize_duration_seconds=( + quantize_duration if quantize_duration else None + ), + field_rename_duration_seconds=( + field_rename_duration if field_rename_duration else None + ), + key_rename_duration_seconds=( + key_rename_duration if key_rename_duration else None + ), + recreate_duration_seconds=recreate_duration, + initial_indexing_duration_seconds=indexing_duration, + validation_duration_seconds=validation_duration, + downtime_duration_seconds=round( + drop_duration + + field_rename_duration + + key_rename_duration + + quantize_duration + + recreate_duration + + indexing_duration, + 3, + ), + ) + report.benchmark_summary = self._build_benchmark_summary( + plan, + target_info, + report.timings, + ) + report.result = "succeeded" if not validation.errors else "failed" + if validation.errors: + report.manual_actions.append( + "Review validation errors before treating the migration as complete." + ) + except Exception as exc: + total_duration = round(time.perf_counter() - started, 3) + report.timings = MigrationTimings( + total_migration_duration_seconds=total_duration, + drop_duration_seconds=drop_duration or None, + quantize_duration_seconds=quantize_duration or None, + field_rename_duration_seconds=field_rename_duration or None, + key_rename_duration_seconds=key_rename_duration or None, + recreate_duration_seconds=recreate_duration or None, + initial_indexing_duration_seconds=indexing_duration or None, + downtime_duration_seconds=( + round( + drop_duration + + field_rename_duration + + key_rename_duration + + quantize_duration + + recreate_duration + + indexing_duration, + 3, + ) + if drop_duration + or field_rename_duration + or key_rename_duration + or quantize_duration + or recreate_duration + or indexing_duration + else None + ), + ) + report.validation = MigrationValidation( + errors=[f"Migration execution failed: {exc}"] + ) + report.manual_actions.extend( + [ + "Inspect the Redis index state before retrying.", + "If the source index was dropped, recreate it from the saved migration plan.", + ] + ) + finally: + report.finished_at = timestamp_utc() + + # Auto-cleanup backup files on success + if backup_dir and not keep_backup and report.result == "succeeded": + self._cleanup_backup_files(backup_dir, plan.source.index_name) + + return report + + def _cleanup_backup_files(self, backup_dir: str, index_name: str) -> None: + """Remove backup files after successful migration. + + Only removes files with the exact extensions produced by VectorBackup + (.header and .data), avoiding accidental deletion of unrelated files + that happen to share the same prefix. + """ + safe_name = index_name.replace("/", "_").replace("\\", "_").replace(":", "_") + name_hash = hashlib.sha256(index_name.encode()).hexdigest()[:8] + base_prefix = f"migration_backup_{safe_name}_{name_hash}" + # Exact suffixes written by VectorBackup + known_suffixes = (".header", ".data") + backup_dir_path = Path(backup_dir) + + for entry in backup_dir_path.iterdir(): + if not entry.is_file(): + continue + name = entry.name + # Match: base_prefix exactly, or base_prefix + shard suffix + # e.g., migration_backup_myidx.header + # migration_backup_myidx_shard_0.header + if not name.startswith(base_prefix): + continue + # Check that the file ends with a known extension + if not any(name.endswith(s) for s in known_suffixes): + continue + # Verify the character after the prefix is either a dot or underscore + # (prevents matching migration_backup_myidx2.header) + remainder = name[len(base_prefix) :] + if remainder and remainder[0] not in (".", "_"): + continue + try: + entry.unlink() + logger.debug("Removed backup file: %s", entry) + except OSError as e: + logger.warning("Failed to remove backup file %s: %s", entry, e) + + # ------------------------------------------------------------------ + # Two-phase quantization: dump originals → convert from backup + # ------------------------------------------------------------------ + + def _dump_vectors( + self, + client: Any, + index_name: str, + keys: List[str], + datatype_changes: Dict[str, Dict[str, Any]], + backup_path: str, + batch_size: int = 500, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> "VectorBackup": + """Phase 1: Pipeline-read original vectors and write to backup file. + + Runs BEFORE index drop — the index is still alive. + No Redis state is modified. + + Args: + client: Redis client + index_name: Name of the source index + keys: Pre-enumerated list of document keys + datatype_changes: {field_name: {"source", "target", "dims"}} + backup_path: Path prefix for backup files + batch_size: Keys per pipeline batch + progress_callback: Optional callback(docs_done, total_docs) + + Returns: + VectorBackup in "ready" phase (dump complete) + """ + from redisvl.migration.backup import VectorBackup + from redisvl.migration.quantize import pipeline_read_vectors + + backup = VectorBackup.create( + path=backup_path, + index_name=index_name, + fields=datatype_changes, + batch_size=batch_size, + ) + + total = len(keys) + for batch_idx in range(0, total, batch_size): + batch_keys = keys[batch_idx : batch_idx + batch_size] + originals = pipeline_read_vectors(client, batch_keys, datatype_changes) + backup.write_batch(batch_idx // batch_size, batch_keys, originals) + if progress_callback: + progress_callback(min(batch_idx + batch_size, total), total) + + backup.mark_dump_complete() + return backup + + def _quantize_from_backup( + self, + client: Any, + backup: "VectorBackup", + datatype_changes: Dict[str, Dict[str, Any]], + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Phase 2: Read originals from backup file, convert, pipeline-write. + + Runs AFTER index drop. Reads from local disk, not Redis. + Tracks progress via backup header for crash-safe resume. + + Args: + client: Redis client + backup: VectorBackup in "ready" or "active" phase + datatype_changes: {field_name: {"source", "target", "dims"}} + progress_callback: Optional callback(docs_done, total_docs) + + Returns: + Number of documents quantized + """ + from redisvl.migration.quantize import convert_vectors, pipeline_write_vectors + + if backup.header.phase == "ready": + backup.start_quantize() + + docs_quantized = 0 + start_batch = backup.header.quantize_completed_batches + docs_done = start_batch * backup.header.batch_size + + for batch_idx, (batch_keys, originals) in enumerate( + backup.iter_remaining_batches() + ): + actual_batch_idx = start_batch + batch_idx + converted = convert_vectors(originals, datatype_changes) + if converted: + pipeline_write_vectors(client, converted) + backup.mark_batch_quantized(actual_batch_idx) + docs_quantized += len(batch_keys) + docs_done += len(batch_keys) + if progress_callback: + total = backup.header.dump_completed_batches * backup.header.batch_size + progress_callback(docs_done, total) + + backup.mark_complete() + return docs_quantized + + def _build_benchmark_summary( + self, + plan: MigrationPlan, + target_info: dict, + timings: MigrationTimings, + ) -> MigrationBenchmarkSummary: + source_index_size = float( + plan.source.stats_snapshot.get("vector_index_sz_mb", 0) or 0 + ) + target_index_size = float(target_info.get("vector_index_sz_mb", 0) or 0) + source_num_docs = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + indexed_per_second = None + indexing_time = timings.initial_indexing_duration_seconds + if indexing_time and indexing_time > 0: + indexed_per_second = round(source_num_docs / indexing_time, 3) + + return MigrationBenchmarkSummary( + documents_indexed_per_second=indexed_per_second, + source_index_size_mb=round(source_index_size, 3), + target_index_size_mb=round(target_index_size, 3), + index_size_delta_mb=round(target_index_size - source_index_size, 3), + ) diff --git a/redisvl/migration/models.py b/redisvl/migration/models.py new file mode 100644 index 000000000..8cffd2a4f --- /dev/null +++ b/redisvl/migration/models.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, model_validator + + +class FieldUpdate(BaseModel): + """Partial field update for schema patch inputs.""" + + name: str + type: Optional[str] = None + path: Optional[str] = None + attrs: Dict[str, Any] = Field(default_factory=dict) + options: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def merge_options_into_attrs(self) -> "FieldUpdate": + if self.options: + merged_attrs = dict(self.attrs) + merged_attrs.update(self.options) + self.attrs = merged_attrs + self.options = {} + return self + + +class FieldRename(BaseModel): + """Field rename specification for schema patch inputs.""" + + old_name: str + new_name: str + + +class SchemaPatchChanges(BaseModel): + add_fields: List[Dict[str, Any]] = Field(default_factory=list) + remove_fields: List[str] = Field(default_factory=list) + update_fields: List[FieldUpdate] = Field(default_factory=list) + rename_fields: List[FieldRename] = Field(default_factory=list) + index: Dict[str, Any] = Field(default_factory=dict) + + +class SchemaPatch(BaseModel): + version: int = 1 + changes: SchemaPatchChanges = Field(default_factory=SchemaPatchChanges) + + +class KeyspaceSnapshot(BaseModel): + storage_type: str + prefixes: List[str] + key_separator: str + key_sample: List[str] = Field(default_factory=list) + + +class SourceSnapshot(BaseModel): + index_name: str + schema_snapshot: Dict[str, Any] + stats_snapshot: Dict[str, Any] + keyspace: KeyspaceSnapshot + + +class DiffClassification(BaseModel): + supported: bool + blocked_reasons: List[str] = Field(default_factory=list) + + +class ValidationPolicy(BaseModel): + require_doc_count_match: bool = True + require_schema_match: bool = True + + +class RenameOperations(BaseModel): + """Tracks which rename operations are required for a migration.""" + + rename_index: Optional[str] = None # New index name if renaming + change_prefix: Optional[str] = None # New prefix if changing + rename_fields: List[FieldRename] = Field(default_factory=list) + + @property + def has_operations(self) -> bool: + return bool( + self.rename_index is not None + or self.change_prefix is not None + or self.rename_fields + ) + + +class MigrationPlan(BaseModel): + version: int = 1 + mode: str = "drop_recreate" + source: SourceSnapshot + requested_changes: Dict[str, Any] + merged_target_schema: Dict[str, Any] + diff_classification: DiffClassification + rename_operations: RenameOperations = Field(default_factory=RenameOperations) + warnings: List[str] = Field(default_factory=list) + validation: ValidationPolicy = Field(default_factory=ValidationPolicy) + + +class QueryCheckResult(BaseModel): + name: str + passed: bool + details: Optional[str] = None + + +class MigrationValidation(BaseModel): + schema_match: bool = False + doc_count_match: bool = False + key_sample_exists: bool = False + indexing_failures_delta: int = 0 + query_checks: List[QueryCheckResult] = Field(default_factory=list) + errors: List[str] = Field(default_factory=list) + + +class MigrationTimings(BaseModel): + total_migration_duration_seconds: Optional[float] = None + drop_duration_seconds: Optional[float] = None + quantize_duration_seconds: Optional[float] = None + field_rename_duration_seconds: Optional[float] = None + key_rename_duration_seconds: Optional[float] = None + recreate_duration_seconds: Optional[float] = None + initial_indexing_duration_seconds: Optional[float] = None + validation_duration_seconds: Optional[float] = None + downtime_duration_seconds: Optional[float] = None + + +class MigrationBenchmarkSummary(BaseModel): + documents_indexed_per_second: Optional[float] = None + source_index_size_mb: Optional[float] = None + target_index_size_mb: Optional[float] = None + index_size_delta_mb: Optional[float] = None + + +class MigrationReport(BaseModel): + version: int = 1 + mode: str = "drop_recreate" + source_index: str + target_index: str + result: str + started_at: str + finished_at: str + timings: MigrationTimings = Field(default_factory=MigrationTimings) + validation: MigrationValidation = Field(default_factory=MigrationValidation) + benchmark_summary: MigrationBenchmarkSummary = Field( + default_factory=MigrationBenchmarkSummary + ) + disk_space_estimate: Optional["DiskSpaceEstimate"] = None + warnings: List[str] = Field(default_factory=list) + manual_actions: List[str] = Field(default_factory=list) + + +# ----------------------------------------------------------------------------- +# Disk Space Estimation +# ----------------------------------------------------------------------------- + +# Bytes per element for each vector datatype +DTYPE_BYTES: Dict[str, int] = { + "float64": 8, + "float32": 4, + "float16": 2, + "bfloat16": 2, + "int8": 1, + "uint8": 1, +} + +# AOF protocol overhead per HSET command (RESP framing) +AOF_HSET_OVERHEAD_BYTES = 114 +# JSON.SET has slightly larger RESP framing +AOF_JSON_SET_OVERHEAD_BYTES = 140 +# RDB compression ratio for pseudo-random vector data (compresses poorly) +RDB_COMPRESSION_RATIO = 0.95 + + +class VectorFieldEstimate(BaseModel): + """Per-field disk space breakdown for a single vector field.""" + + field_name: str + dims: int + source_dtype: str + target_dtype: str + source_bytes_per_doc: int + target_bytes_per_doc: int + + +class DiskSpaceEstimate(BaseModel): + """Pre-migration estimate of disk and memory costs. + + Produced by estimate_disk_space() as a pure calculation from the migration + plan. No Redis mutations are performed. + """ + + # Index metadata + index_name: str + doc_count: int + storage_type: str = "hash" + + # Per-field breakdowns + vector_fields: List[VectorFieldEstimate] = Field(default_factory=list) + + # Aggregate vector data sizes + total_source_vector_bytes: int = 0 + total_target_vector_bytes: int = 0 + + # RDB snapshot cost (BGSAVE before migration) + rdb_snapshot_disk_bytes: int = 0 + rdb_cow_memory_if_concurrent_bytes: int = 0 + + # AOF growth cost (only if aof_enabled is True) + aof_enabled: bool = False + aof_growth_bytes: int = 0 + + # Totals + total_new_disk_bytes: int = 0 + memory_savings_after_bytes: int = 0 + + @property + def has_quantization(self) -> bool: + return len(self.vector_fields) > 0 + + def summary(self) -> str: + """Human-readable summary for CLI output.""" + if not self.has_quantization: + return "No vector quantization in this migration. No additional disk space required." + + lines = [ + "Pre-migration disk space estimate:", + f" Index: {self.index_name} ({self.doc_count:,} documents)", + ] + for vf in self.vector_fields: + lines.append( + f" Vector field '{vf.field_name}': {vf.dims} dims, " + f"{vf.source_dtype} -> {vf.target_dtype}" + ) + + lines.append("") + lines.append( + f" RDB snapshot (BGSAVE): ~{_format_bytes(self.rdb_snapshot_disk_bytes)}" + ) + if self.aof_enabled: + lines.append( + f" AOF growth (appendonly=yes): ~{_format_bytes(self.aof_growth_bytes)}" + ) + else: + lines.append( + " AOF growth: not estimated (pass aof_enabled=True if AOF is on)" + ) + lines.append( + f" Total new disk required: ~{_format_bytes(self.total_new_disk_bytes)}" + ) + lines.append("") + lines.append( + f" Post-migration memory delta: ~{_format_bytes(abs(self.memory_savings_after_bytes))} " + f"({'reduction' if self.memory_savings_after_bytes >= 0 else 'increase'}, " + f"{abs(self._savings_pct())}%)" + ) + return "\n".join(lines) + + def _savings_pct(self) -> int: + if self.total_source_vector_bytes == 0: + return 0 + return round( + 100 * self.memory_savings_after_bytes / self.total_source_vector_bytes + ) + + +def _format_bytes(n: int) -> str: + """Format byte count as human-readable string.""" + if n >= 1_073_741_824: + return f"{n / 1_073_741_824:.2f} GB" + if n >= 1_048_576: + return f"{n / 1_048_576:.1f} MB" + if n >= 1024: + return f"{n / 1024:.1f} KB" + return f"{n} bytes" + + +# ----------------------------------------------------------------------------- +# Batch Migration Models +# ----------------------------------------------------------------------------- + + +class BatchIndexEntry(BaseModel): + """Entry for a single index in a batch migration plan.""" + + name: str + applicable: bool = True + skip_reason: Optional[str] = None + + +class BatchPlan(BaseModel): + """Plan for migrating multiple indexes with a shared patch.""" + + version: int = 1 + batch_id: str + mode: str = "drop_recreate" + failure_policy: str = "fail_fast" # or "continue_on_error" + requires_quantization: bool = False + shared_patch: SchemaPatch + indexes: List[BatchIndexEntry] = Field(default_factory=list) + created_at: str + + @property + def applicable_count(self) -> int: + return sum(1 for idx in self.indexes if idx.applicable) + + @property + def skipped_count(self) -> int: + return sum(1 for idx in self.indexes if not idx.applicable) + + +class BatchIndexState(BaseModel): + """State of a single index in batch execution.""" + + name: str + status: str # pending, in_progress, success, failed, skipped + started_at: Optional[str] = None + completed_at: Optional[str] = None + failed_at: Optional[str] = None + error: Optional[str] = None + report_path: Optional[str] = None + + +class BatchState(BaseModel): + """Checkpoint state for batch migration execution.""" + + batch_id: str + plan_path: str + started_at: str + updated_at: str + completed: List[BatchIndexState] = Field(default_factory=list) + current_index: Optional[str] = None + remaining: List[str] = Field(default_factory=list) + + @property + def success_count(self) -> int: + return sum(1 for idx in self.completed if idx.status == "success") + + @property + def failed_count(self) -> int: + return sum(1 for idx in self.completed if idx.status == "failed") + + @property + def skipped_count(self) -> int: + return sum(1 for idx in self.completed if idx.status == "skipped") + + @property + def is_complete(self) -> bool: + return len(self.remaining) == 0 and self.current_index is None + + +class BatchReportSummary(BaseModel): + """Summary statistics for batch migration.""" + + total_indexes: int = 0 + successful: int = 0 + failed: int = 0 + skipped: int = 0 + total_duration_seconds: float = 0.0 + + +class BatchIndexReport(BaseModel): + """Report for a single index in batch execution.""" + + name: str + status: str # success, failed, skipped + duration_seconds: Optional[float] = None + docs_migrated: Optional[int] = None + report_path: Optional[str] = None + error: Optional[str] = None + + +class BatchReport(BaseModel): + """Final report for batch migration execution.""" + + version: int = 1 + batch_id: str + status: str # completed, partial_failure, failed + summary: BatchReportSummary = Field(default_factory=BatchReportSummary) + indexes: List[BatchIndexReport] = Field(default_factory=list) + started_at: str + completed_at: str diff --git a/redisvl/migration/planner.py b/redisvl/migration/planner.py new file mode 100644 index 000000000..e5d40dc10 --- /dev/null +++ b/redisvl/migration/planner.py @@ -0,0 +1,786 @@ +from __future__ import annotations + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml + +from redisvl.index import SearchIndex +from redisvl.migration.models import ( + DiffClassification, + FieldRename, + KeyspaceSnapshot, + MigrationPlan, + RenameOperations, + SchemaPatch, + SourceSnapshot, +) +from redisvl.redis.connection import supports_svs +from redisvl.schema.schema import IndexSchema + + +class MigrationPlanner: + """Migration planner for drop/recreate-based index migrations. + + The `drop_recreate` mode drops the index definition and recreates it with + a new schema. By default, documents are preserved in Redis. When possible, + the planner/executor can apply transformations so the preserved documents + remain compatible with the new index schema. + + This means: + - Index-only changes are always safe (algorithm, distance metric, tuning + params, quantization, etc.) + - Some document-dependent changes are supported via explicit migration + operations in the migration plan + + Supported document-dependent changes: + - Prefix/keyspace changes: keys are renamed via RENAME command + - Field renames: documents are updated to use new field names + - Index renaming: the new index is created with a different name + + Document-dependent changes that remain unsupported: + - Vector dimensions: stored vectors have wrong number of dimensions + - Storage type: documents are in hash format but index expects JSON + """ + + def __init__(self, key_sample_limit: int = 10): + self.key_sample_limit = key_sample_limit + + def create_plan( + self, + index_name: str, + *, + redis_url: Optional[str] = None, + schema_patch_path: Optional[str] = None, + target_schema_path: Optional[str] = None, + redis_client: Optional[Any] = None, + ) -> MigrationPlan: + """Generate a migration plan by comparing the live index to a desired schema. + + Snapshots the current index metadata from Redis, loads the requested + changes from either a *schema_patch_path* or *target_schema_path*, and + produces a :class:`MigrationPlan` that describes every step required to + reach the target schema. + + No data is modified — this is a read-only planning step. The resulting + plan should be reviewed before passing to + :meth:`MigrationExecutor.apply`. + + Args: + index_name: Name of the existing Redis Search index. + redis_url: Redis connection URL + (e.g. ``"redis://localhost:6379"``). + schema_patch_path: Path to a YAML schema-patch file describing + incremental changes (add/remove/update fields, change + algorithm, rename fields, etc.). + target_schema_path: Path to a full target-schema YAML file. + The planner diffs the live schema against this target. + redis_client: Optional pre-existing Redis client instance. + + Returns: + MigrationPlan: An immutable plan object containing the source + snapshot, diff classification, target schema, and any warnings. + + Raises: + ValueError: If neither or both of *schema_patch_path* and + *target_schema_path* are provided. + """ + if not schema_patch_path and not target_schema_path: + raise ValueError( + "Must provide either --schema-patch or --target-schema for migration planning" + ) + if schema_patch_path and target_schema_path: + raise ValueError( + "Provide only one of --schema-patch or --target-schema for migration planning" + ) + + snapshot = self.snapshot_source( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + + if schema_patch_path: + schema_patch = self.load_schema_patch(schema_patch_path) + else: + # target_schema_path is guaranteed non-None here due to validation above + assert target_schema_path is not None + schema_patch = self.normalize_target_schema_to_patch( + source_schema, target_schema_path + ) + + return self.create_plan_from_patch( + index_name, + schema_patch=schema_patch, + redis_url=redis_url, + redis_client=redis_client, + _snapshot=snapshot, + ) + + def create_plan_from_patch( + self, + index_name: str, + *, + schema_patch: SchemaPatch, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + _snapshot: Optional[Any] = None, + ) -> MigrationPlan: + if _snapshot is None: + _snapshot = self.snapshot_source( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + snapshot = _snapshot + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + merged_target_schema = self.merge_patch(source_schema, schema_patch) + + # Extract rename operations first + rename_operations, rename_warnings = self._extract_rename_operations( + source_schema, schema_patch + ) + + # Classify diff with awareness of rename operations + diff_classification = self.classify_diff( + source_schema, schema_patch, merged_target_schema, rename_operations + ) + + # Build warnings list + warnings = ["Index downtime is required"] + warnings.extend(rename_warnings) + + # Warn if source index has hash indexing failures + source_failures = int( + snapshot.stats_snapshot.get("hash_indexing_failures", 0) or 0 + ) + if source_failures > 0: + warnings.append( + f"Source index has {source_failures:,} hash indexing failure(s). " + "Documents that previously failed to index may become indexable after " + "migration, causing the post-migration document count to differ from " + "the pre-migration count. This is expected and validation accounts for it." + ) + + # Check for SVS-VAMANA in target schema and add appropriate warnings + svs_warnings = self._check_svs_vamana_requirements( + merged_target_schema, + redis_url=redis_url, + redis_client=redis_client, + ) + warnings.extend(svs_warnings) + + return MigrationPlan( + source=snapshot, + requested_changes=schema_patch.model_dump(exclude_none=True), + merged_target_schema=merged_target_schema.to_dict(), + diff_classification=diff_classification, + rename_operations=rename_operations, + warnings=warnings, + ) + + def snapshot_source( + self, + index_name: str, + *, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + ) -> SourceSnapshot: + index = SearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + schema_dict = index.schema.to_dict() + stats_snapshot = index.info() + prefixes = index.schema.index.prefix + prefix_list = prefixes if isinstance(prefixes, list) else [prefixes] + + return SourceSnapshot( + index_name=index_name, + schema_snapshot=schema_dict, + stats_snapshot=stats_snapshot, + keyspace=KeyspaceSnapshot( + storage_type=index.schema.index.storage_type.value, + prefixes=prefix_list, + key_separator=index.schema.index.key_separator, + key_sample=self._sample_keys( + client=index.client, + prefixes=prefix_list, + key_separator=index.schema.index.key_separator, + ), + ), + ) + + def load_schema_patch(self, schema_patch_path: str) -> SchemaPatch: + patch_path = Path(schema_patch_path).resolve() + if not patch_path.exists(): + raise FileNotFoundError( + f"Schema patch file {schema_patch_path} does not exist" + ) + + with open(patch_path, "r") as f: + patch_data = yaml.safe_load(f) or {} + return SchemaPatch.model_validate(patch_data) + + def normalize_target_schema_to_patch( + self, source_schema: IndexSchema, target_schema_path: str + ) -> SchemaPatch: + target_schema = IndexSchema.from_yaml(target_schema_path) + source_dict = source_schema.to_dict() + target_dict = target_schema.to_dict() + + changes: Dict[str, Any] = { + "add_fields": [], + "remove_fields": [], + "update_fields": [], + "index": {}, + } + + source_fields = {field["name"]: field for field in source_dict["fields"]} + target_fields = {field["name"]: field for field in target_dict["fields"]} + + for field_name, target_field in target_fields.items(): + if field_name not in source_fields: + changes["add_fields"].append(target_field) + elif source_fields[field_name] != target_field: + changes["update_fields"].append(target_field) + + for field_name in source_fields: + if field_name not in target_fields: + changes["remove_fields"].append(field_name) + + for index_key, target_value in target_dict["index"].items(): + source_value = source_dict["index"].get(index_key) + # Normalize single-element list prefixes for comparison so that + # e.g. source ["docs"] and target "docs" are treated as equal. + sv, tv = source_value, target_value + if index_key == "prefix": + if isinstance(sv, list) and len(sv) == 1: + sv = sv[0] + if isinstance(tv, list) and len(tv) == 1: + tv = tv[0] + if sv != tv: + changes["index"][index_key] = target_value + + return SchemaPatch.model_validate({"version": 1, "changes": changes}) + + def merge_patch( + self, source_schema: IndexSchema, schema_patch: SchemaPatch + ) -> IndexSchema: + schema_dict = deepcopy(source_schema.to_dict()) + changes = schema_patch.changes + fields_by_name = { + field["name"]: deepcopy(field) for field in schema_dict["fields"] + } + + # Apply field renames first (before other modifications) + # This ensures the merged schema's field names match the executor's renamed fields + for rename in changes.rename_fields: + if rename.old_name not in fields_by_name: + raise ValueError( + f"Cannot rename field '{rename.old_name}' because it does not exist in the source schema" + ) + if rename.new_name in fields_by_name and rename.new_name != rename.old_name: + raise ValueError( + f"Cannot rename field '{rename.old_name}' to '{rename.new_name}' because a field with the new name already exists" + ) + if rename.new_name == rename.old_name: + continue # No-op rename + field_def = fields_by_name.pop(rename.old_name) + field_def["name"] = rename.new_name + fields_by_name[rename.new_name] = field_def + + for field_name in changes.remove_fields: + fields_by_name.pop(field_name, None) + + # Build a mapping from old field names to new names so that + # update_fields entries referencing pre-rename names still resolve. + rename_map = { + rename.old_name: rename.new_name + for rename in changes.rename_fields + if rename.old_name != rename.new_name + } + + for field_update in changes.update_fields: + # Resolve through renames: if the update references the old name, + # look up the field under its new name. + resolved_name = rename_map.get(field_update.name, field_update.name) + if resolved_name not in fields_by_name: + raise ValueError( + f"Cannot update field '{field_update.name}' because it does not exist in the source schema" + ) + existing_field = fields_by_name[resolved_name] + if field_update.type is not None: + existing_field["type"] = field_update.type + if field_update.path is not None: + existing_field["path"] = field_update.path + if field_update.attrs: + merged_attrs = dict(existing_field.get("attrs", {})) + merged_attrs.update(field_update.attrs) + existing_field["attrs"] = merged_attrs + + for field in changes.add_fields: + field_name = field["name"] + if field_name in fields_by_name: + raise ValueError( + f"Cannot add field '{field_name}' because it already exists in the source schema" + ) + fields_by_name[field_name] = deepcopy(field) + + schema_dict["fields"] = list(fields_by_name.values()) + schema_dict["index"].update(changes.index) + return IndexSchema.from_dict(schema_dict) + + def _extract_rename_operations( + self, + source_schema: IndexSchema, + schema_patch: SchemaPatch, + ) -> Tuple[RenameOperations, List[str]]: + """Extract rename operations from the patch and generate warnings. + + Returns: + Tuple of (RenameOperations, warnings list) + """ + source_dict = source_schema.to_dict() + changes = schema_patch.changes + warnings: List[str] = [] + + # Index rename + rename_index: Optional[str] = None + if "name" in changes.index: + new_name = changes.index["name"] + old_name = source_dict["index"].get("name") + if new_name != old_name: + rename_index = new_name + warnings.append( + f"Index rename: '{old_name}' -> '{new_name}' (index-only change, no document migration needed)" + ) + + # Prefix change + change_prefix: Optional[str] = None + if "prefix" in changes.index: + new_prefix = changes.index["prefix"] + # Normalize list-type prefix to a single string (local copy only) + if isinstance(new_prefix, list): + if len(new_prefix) != 1: + raise ValueError( + f"Target prefix must be a single string, got list: {new_prefix}. " + f"Multi-prefix migrations are not supported." + ) + new_prefix = new_prefix[0] + old_prefix = source_dict["index"].get("prefix") + # Normalize single-element list to string for comparison + if isinstance(old_prefix, list) and len(old_prefix) == 1: + old_prefix = old_prefix[0] + if new_prefix != old_prefix: + # Block multi-prefix migrations - we only support single prefix + if isinstance(old_prefix, list) and len(old_prefix) > 1: + raise ValueError( + f"Cannot change prefix for multi-prefix indexes. " + f"Source index has multiple prefixes: {old_prefix}. " + f"Multi-prefix migrations are not supported." + ) + change_prefix = new_prefix + warnings.append( + f"Prefix change: '{old_prefix}' -> '{new_prefix}' " + "(requires RENAME for all keys, may be slow for large datasets)" + ) + + # Field renames from explicit rename_fields + rename_fields: List[FieldRename] = list(changes.rename_fields) + for field_rename in rename_fields: + warnings.append( + f"Field rename: '{field_rename.old_name}' -> '{field_rename.new_name}' " + "(requires read/write for all documents, may be slow for large datasets)" + ) + + return ( + RenameOperations( + rename_index=rename_index, + change_prefix=change_prefix, + rename_fields=rename_fields, + ), + warnings, + ) + + def _check_svs_vamana_requirements( + self, + target_schema: IndexSchema, + *, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + ) -> List[str]: + """Check SVS-VAMANA requirements and return warnings. + + Checks: + 1. If target uses SVS-VAMANA, verify Redis version supports it + 2. Add Intel hardware warning for LVQ/LeanVec optimizations + """ + warnings: List[str] = [] + target_dict = target_schema.to_dict() + + # Check if any vector field uses SVS-VAMANA + uses_svs = False + uses_compression = False + compression_types: set = set() + + for field in target_dict.get("fields", []): + if field.get("type") != "vector": + continue + attrs = field.get("attrs", {}) + algo = attrs.get("algorithm", "").upper() + if algo == "SVS-VAMANA": + uses_svs = True + compression = attrs.get("compression", "") + if compression: + uses_compression = True + compression_types.add(compression) + + if not uses_svs: + return warnings + + # Check Redis version support + created_client = None + try: + if redis_client: + client = redis_client + elif redis_url: + from redis import Redis + + client = Redis.from_url(redis_url) + created_client = client + else: + client = None + + if client and not supports_svs(client): + warnings.append( + "SVS-VAMANA requires Redis >= 8.2.0 and Redis Search >= 2.8.10. " + "The target Redis instance may not support this algorithm. " + "Migration will fail at apply time if requirements are not met." + ) + except Exception: + # If we can't check, add a general warning + warnings.append( + "SVS-VAMANA requires Redis >= 8.2.0 and Redis Search >= 2.8.10. " + "Verify your Redis instance supports this algorithm before applying." + ) + finally: + if created_client: + created_client.close() + + # Intel hardware warning for compression + if uses_compression: + compression_label = ", ".join(sorted(compression_types)) + warnings.append( + f"SVS-VAMANA with {compression_label} compression: " + "LVQ and LeanVec optimizations require Intel hardware with AVX-512 support. " + "On non-Intel platforms or Redis Open Source, these fall back to basic " + "8-bit scalar quantization with reduced performance benefits." + ) + else: + warnings.append( + "SVS-VAMANA: For optimal performance, Intel hardware with AVX-512 support " + "is recommended. LVQ/LeanVec compression options provide additional memory " + "savings on supported hardware." + ) + + return warnings + + def classify_diff( + self, + source_schema: IndexSchema, + schema_patch: SchemaPatch, + merged_target_schema: IndexSchema, + rename_operations: Optional[RenameOperations] = None, + ) -> DiffClassification: + blocked_reasons: List[str] = [] + changes = schema_patch.changes + source_dict = source_schema.to_dict() + target_dict = merged_target_schema.to_dict() + + # Check which rename operations are being handled + has_index_rename = rename_operations and rename_operations.rename_index + has_prefix_change = ( + rename_operations and rename_operations.change_prefix is not None + ) + has_field_renames = ( + rename_operations and len(rename_operations.rename_fields) > 0 + ) + + for index_key, target_value in changes.index.items(): + source_value = source_dict["index"].get(index_key) + # Normalize single-element list prefixes for comparison so that + # e.g. source ``["docs"]`` and target ``"docs"`` are treated as equal. + sv, tv = source_value, target_value + if index_key == "prefix": + if isinstance(sv, list) and len(sv) == 1: + sv = sv[0] + if isinstance(tv, list) and len(tv) == 1: + tv = tv[0] + if sv == tv: + continue + if index_key == "name": + # Index rename is now supported - skip blocking if we have rename_operations + if not has_index_rename: + blocked_reasons.append( + "Changing the index name requires document migration (not yet supported)." + ) + elif index_key == "prefix": + # Prefix change is now supported + if not has_prefix_change: + blocked_reasons.append( + "Changing index prefixes requires document migration (not yet supported)." + ) + elif index_key == "key_separator": + blocked_reasons.append( + "Changing the key separator requires document migration (not yet supported)." + ) + elif index_key == "storage_type": + blocked_reasons.append( + "Changing the storage type requires document migration (not yet supported)." + ) + + source_fields = {field["name"]: field for field in source_dict["fields"]} + target_fields = {field["name"]: field for field in target_dict["fields"]} + + for field in changes.add_fields: + if field["type"] == "vector": + blocked_reasons.append( + f"Adding vector field '{field['name']}' requires document migration (not yet supported)." + ) + + # Build rename mappings: old->new and new->old so update_fields + # can reference either the pre-rename or post-rename name + classify_rename_map = { + rename.old_name: rename.new_name + for rename in changes.rename_fields + if rename.old_name != rename.new_name + } + reverse_rename_map = {v: k for k, v in classify_rename_map.items()} + + for field_update in changes.update_fields: + # Resolve through renames: update_fields may use old or new name + if field_update.name in classify_rename_map: + # update references old name -> look up source by old, target by new + source_name = field_update.name + target_name = classify_rename_map[field_update.name] + elif field_update.name in reverse_rename_map: + # update references new name -> look up source by old, target by new + source_name = reverse_rename_map[field_update.name] + target_name = field_update.name + else: + # no rename involved + source_name = field_update.name + target_name = field_update.name + source_field = source_fields.get(source_name) + target_field = target_fields.get(target_name) + if source_field is None or target_field is None: + # Field not found in source or target; skip classification + continue + source_type = source_field["type"] + target_type = target_field["type"] + + if source_type != target_type: + blocked_reasons.append( + f"Changing field '{field_update.name}' type from {source_type} to {target_type} is not supported by drop_recreate." + ) + continue + + source_path = source_field.get("path") + target_path = target_field.get("path") + if source_path != target_path: + blocked_reasons.append( + f"Changing field '{field_update.name}' path from {source_path} to {target_path} is not supported by drop_recreate." + ) + continue + + if target_type == "vector" and source_field != target_field: + # Check for document-dependent changes that are not yet supported + vector_blocked = self._classify_vector_field_change( + source_field, target_field + ) + blocked_reasons.extend(vector_blocked) + + # Detect possible undeclared field renames. When explicit renames + # exist, exclude those fields from heuristic detection so we still + # catch additional add/remove pairs that look like renames. + detect_source = dict(source_fields) + detect_target = dict(target_fields) + if has_field_renames and rename_operations: + for fr in rename_operations.rename_fields: + detect_source.pop(fr.old_name, None) + detect_target.pop(fr.new_name, None) + blocked_reasons.extend( + self._detect_possible_field_renames(detect_source, detect_target) + ) + + return DiffClassification( + supported=len(blocked_reasons) == 0, + blocked_reasons=self._dedupe(blocked_reasons), + ) + + def write_plan(self, plan: MigrationPlan, plan_out: str) -> None: + plan_path = Path(plan_out).resolve() + with open(plan_path, "w") as f: + yaml.safe_dump(plan.model_dump(exclude_none=True), f, sort_keys=False) + + def _sample_keys( + self, *, client: Any, prefixes: List[str], key_separator: str + ) -> List[str]: + key_sample: List[str] = [] + if client is None or self.key_sample_limit <= 0: + return key_sample + + for prefix in prefixes: + if len(key_sample) >= self.key_sample_limit: + break + if prefix == "": + match_pattern = "*" + elif prefix.endswith(key_separator): + match_pattern = f"{prefix}*" + else: + match_pattern = f"{prefix}{key_separator}*" + cursor = 0 + while True: + cursor, keys = client.scan( + cursor=cursor, + match=match_pattern, + count=max(self.key_sample_limit, 10), + ) + for key in keys: + decoded_key = key.decode() if isinstance(key, bytes) else str(key) + if decoded_key not in key_sample: + key_sample.append(decoded_key) + if len(key_sample) >= self.key_sample_limit: + return key_sample + if cursor == 0: + break + return key_sample + + def _detect_possible_field_renames( + self, + source_fields: Dict[str, Dict[str, Any]], + target_fields: Dict[str, Dict[str, Any]], + ) -> List[str]: + blocked_reasons: List[str] = [] + added_fields = [ + field for name, field in target_fields.items() if name not in source_fields + ] + removed_fields = [ + field for name, field in source_fields.items() if name not in target_fields + ] + + for removed_field in removed_fields: + for added_field in added_fields: + if self._fields_match_except_name(removed_field, added_field): + blocked_reasons.append( + f"Possible field rename from '{removed_field['name']}' to '{added_field['name']}' is not supported by drop_recreate." + ) + return blocked_reasons + + @staticmethod + def _classify_vector_field_change( + source_field: Dict[str, Any], target_field: Dict[str, Any] + ) -> List[str]: + """Classify vector field changes as supported or blocked for drop_recreate. + + Index-only changes (allowed with drop_recreate): + - algorithm (FLAT -> HNSW -> SVS-VAMANA) + - distance_metric (COSINE, L2, IP) + - initial_cap + - Algorithm tuning: m, ef_construction, ef_runtime, epsilon, block_size, + graph_max_degree, construction_window_size, search_window_size, etc. + + Quantization changes (allowed with drop_recreate, requires vector re-encoding): + - datatype (float32 -> float16, etc.) - executor will re-encode vectors + + Document-dependent changes (blocked, not yet supported): + - dims (vectors stored with wrong number of dimensions) + """ + blocked_reasons: List[str] = [] + field_name = source_field.get("name", "unknown") + source_attrs = source_field.get("attrs", {}) + target_attrs = target_field.get("attrs", {}) + + # Document-dependent properties (not yet supported) + if source_attrs.get("dims") != target_attrs.get("dims"): + blocked_reasons.append( + f"Changing vector field '{field_name}' dims from {source_attrs.get('dims')} " + f"to {target_attrs.get('dims')} requires document migration (not yet supported). " + "Vectors are stored with incompatible dimensions." + ) + + # Datatype changes are now ALLOWED - executor will re-encode vectors + # before recreating the index + + # All other vector changes are index-only and allowed + return blocked_reasons + + @staticmethod + def get_vector_datatype_changes( + source_schema: Dict[str, Any], + target_schema: Dict[str, Any], + rename_operations: Optional[Any] = None, + ) -> Dict[str, Dict[str, Any]]: + """Identify vector fields that need datatype conversion (quantization). + + Handles renamed vector fields by using rename_operations to map + source field names to their target counterparts. + + Returns: + Dict mapping source_field_name -> { + "source": source_dtype, + "target": target_dtype, + "dims": int # vector dimensions for idempotent detection + } + """ + changes: Dict[str, Dict[str, Any]] = {} + source_fields = {f["name"]: f for f in source_schema.get("fields", [])} + target_fields = {f["name"]: f for f in target_schema.get("fields", [])} + + # Build rename map: source_name -> target_name + field_rename_map: Dict[str, str] = {} + if rename_operations and hasattr(rename_operations, "rename_fields"): + for fr in rename_operations.rename_fields: + field_rename_map[fr.old_name] = fr.new_name + + for name, source_field in source_fields.items(): + if source_field.get("type") != "vector": + continue + # Look up target by renamed name if applicable + target_name = field_rename_map.get(name, name) + target_field = target_fields.get(target_name) + if not target_field or target_field.get("type") != "vector": + continue + + source_dtype = source_field.get("attrs", {}).get("datatype", "float32") + target_dtype = target_field.get("attrs", {}).get("datatype", "float32") + dims = source_field.get("attrs", {}).get("dims", 0) + + if source_dtype != target_dtype: + changes[name] = { + "source": source_dtype, + "target": target_dtype, + "dims": dims, + } + + return changes + + @staticmethod + def _fields_match_except_name( + source_field: Dict[str, Any], target_field: Dict[str, Any] + ) -> bool: + comparable_source = {k: v for k, v in source_field.items() if k != "name"} + comparable_target = {k: v for k, v in target_field.items() if k != "name"} + return comparable_source == comparable_target + + @staticmethod + def _dedupe(values: List[str]) -> List[str]: + deduped: List[str] = [] + for value in values: + if value not in deduped: + deduped.append(value) + return deduped diff --git a/redisvl/migration/quantize.py b/redisvl/migration/quantize.py new file mode 100644 index 000000000..3fd9f9fda --- /dev/null +++ b/redisvl/migration/quantize.py @@ -0,0 +1,474 @@ +"""Pipelined vector quantization helpers. + +Provides pipeline-read, convert, and pipeline-write functions that replace +the per-key HGET loop with batched pipeline operations. + +Also provides multi-worker orchestration for parallel quantization +using ThreadPoolExecutor (sync) or asyncio.gather (async). +""" + +import hashlib +import logging +import math +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +from redisvl.redis.utils import array_to_buffer, buffer_to_array + + +def pipeline_read_vectors( + client: Any, + keys: List[str], + datatype_changes: Dict[str, Dict[str, Any]], +) -> Dict[str, Dict[str, bytes]]: + """Pipeline-read vector fields from Redis for a batch of keys. + + Instead of N individual HGET calls (N round trips), uses a single + pipeline with N*F HGET calls (1 round trip). + + Args: + client: Redis client + keys: List of Redis keys to read + datatype_changes: {field_name: {"source", "target", "dims"}} + + Returns: + {key: {field_name: original_bytes}} — only includes keys/fields + that returned non-None data. + """ + if not keys: + return {} + + pipe = client.pipeline(transaction=False) + # Track the order of pipelined calls: (key, field_name) + call_order: List[tuple] = [] + field_names = list(datatype_changes.keys()) + + for key in keys: + for field_name in field_names: + pipe.hget(key, field_name) + call_order.append((key, field_name)) + + results = pipe.execute() + + # Reassemble into {key: {field: bytes}} + output: Dict[str, Dict[str, bytes]] = {} + for (key, field_name), value in zip(call_order, results): + if value is not None: + if key not in output: + output[key] = {} + output[key][field_name] = value + + return output + + +def pipeline_write_vectors( + client: Any, + converted: Dict[str, Dict[str, bytes]], +) -> None: + """Pipeline-write converted vectors to Redis. + + Args: + client: Redis client + converted: {key: {field_name: new_bytes}} + """ + if not converted: + return + + pipe = client.pipeline(transaction=False) + for key, fields in converted.items(): + for field_name, data in fields.items(): + pipe.hset(key, field_name, data) + pipe.execute() + + +def convert_vectors( + originals: Dict[str, Dict[str, bytes]], + datatype_changes: Dict[str, Dict[str, Any]], +) -> Dict[str, Dict[str, bytes]]: + """Convert vector bytes from source dtype to target dtype. + + Args: + originals: {key: {field_name: original_bytes}} + datatype_changes: {field_name: {"source", "target", "dims"}} + + Returns: + {key: {field_name: converted_bytes}} + """ + converted: Dict[str, Dict[str, bytes]] = {} + for key, fields in originals.items(): + converted[key] = {} + for field_name, data in fields.items(): + change = datatype_changes.get(field_name) + if not change: + continue + array = buffer_to_array(data, change["source"]) + new_bytes = array_to_buffer(array, change["target"]) + converted[key][field_name] = new_bytes + return converted + + +logger = logging.getLogger(__name__) + + +@dataclass +class MultiWorkerResult: + """Result from multi-worker quantization.""" + + total_docs_quantized: int + num_workers: int + worker_results: List[Dict[str, Any]] = field(default_factory=list) + + +def split_keys(keys: List[str], num_workers: int) -> List[List[str]]: + """Split keys into N contiguous slices for parallel processing. + + Args: + keys: Full list of Redis keys + num_workers: Number of workers + + Returns: + List of key slices (some may be empty if keys < workers) + """ + if num_workers < 1: + raise ValueError(f"num_workers must be >= 1, got {num_workers}") + if not keys: + return [] + n = len(keys) + chunk_size = math.ceil(n / num_workers) + return [keys[i : i + chunk_size] for i in range(0, n, chunk_size)] + + +def _worker_quantize( + worker_id: int, + redis_url: str, + keys: List[str], + datatype_changes: Dict[str, Dict[str, Any]], + backup_path: str, + index_name: str, + batch_size: int, + progress_callback: Optional[Callable[[str, int, int], None]] = None, +) -> Dict[str, Any]: + """Single worker: dump originals + convert + write back. + + Each worker gets its own Redis connection and backup file shard. + """ + from redisvl.migration.backup import VectorBackup + from redisvl.redis.connection import RedisConnectionFactory + + client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url) + try: + # Try to resume from existing backup shard first + backup = VectorBackup.load(backup_path) + if backup is not None: + logger.info( + "Worker %d: resuming from existing backup (phase=%s, " + "dump_batches=%d, quantize_batches=%d)", + worker_id, + backup.header.phase, + backup.header.dump_completed_batches, + backup.header.quantize_completed_batches, + ) + else: + backup = VectorBackup.create( + path=backup_path, + index_name=index_name, + fields=datatype_changes, + batch_size=batch_size, + ) + + total = len(keys) + + # Phase 1: Dump originals to backup shard (skip if already complete) + if backup.header.phase == "dump": + start_batch = backup.header.dump_completed_batches + for batch_start in range(start_batch * batch_size, total, batch_size): + batch_keys = keys[batch_start : batch_start + batch_size] + originals = pipeline_read_vectors(client, batch_keys, datatype_changes) + backup.write_batch(batch_start // batch_size, batch_keys, originals) + if progress_callback: + progress_callback( + "dump", worker_id, min(batch_start + batch_size, total) + ) + backup.mark_dump_complete() + + # Phase 2: Convert + write from backup (skip completed batches) + if backup.header.phase in ("ready", "active"): + backup.start_quantize() + docs_quantized = 0 + + for batch_idx, (batch_keys, originals) in enumerate(backup.iter_batches()): + if batch_idx < backup.header.quantize_completed_batches: + docs_quantized += len(batch_keys) + continue + converted = convert_vectors(originals, datatype_changes) + if converted: + pipeline_write_vectors(client, converted) + backup.mark_batch_quantized(batch_idx) + docs_quantized += len(batch_keys) + if progress_callback: + progress_callback("quantize", worker_id, docs_quantized) + + backup.mark_complete() + elif backup.header.phase == "completed": + # Already done from previous run + docs_quantized = sum( + 1 for _ in range(0, total, batch_size) for _ in keys[:batch_size] + ) + docs_quantized = total + + return {"worker_id": worker_id, "docs": docs_quantized} + finally: + try: + client.close() + except Exception: + pass + + +def multi_worker_quantize( + redis_url: str, + keys: List[str], + datatype_changes: Dict[str, Dict[str, Any]], + backup_dir: str, + index_name: str, + num_workers: int = 1, + batch_size: int = 500, + progress_callback: Optional[Callable[[str, int, int], None]] = None, +) -> MultiWorkerResult: + """Orchestrate multi-worker quantization. + + Splits keys across N workers, each with its own Redis connection + and backup file shard. Uses ThreadPoolExecutor for parallelism. + + Args: + redis_url: Redis connection URL + keys: Full list of document keys to quantize + datatype_changes: {field_name: {"source", "target", "dims"}} + backup_dir: Directory for backup file shards + index_name: Source index name + num_workers: Number of parallel workers (default 1) + batch_size: Keys per pipeline batch + progress_callback: Optional callback(phase, worker_id, docs_done) + + Returns: + MultiWorkerResult with total docs quantized and per-worker results + """ + from pathlib import Path + + slices = split_keys(keys, num_workers) + actual_workers = len(slices) + + if actual_workers == 0: + return MultiWorkerResult( + total_docs_quantized=0, num_workers=0, worker_results=[] + ) + + # Generate backup paths per worker + safe_name = index_name.replace("/", "_").replace("\\", "_").replace(":", "_") + name_hash = hashlib.sha256(index_name.encode()).hexdigest()[:8] + worker_backup_paths = [ + str(Path(backup_dir) / f"migration_backup_{safe_name}_{name_hash}_worker{i}") + for i in range(actual_workers) + ] + + if actual_workers == 1: + # Single worker — run directly, no ThreadPoolExecutor overhead + result = _worker_quantize( + worker_id=0, + redis_url=redis_url, + keys=slices[0], + datatype_changes=datatype_changes, + backup_path=worker_backup_paths[0], + index_name=index_name, + batch_size=batch_size, + progress_callback=progress_callback, + ) + return MultiWorkerResult( + total_docs_quantized=result["docs"], + num_workers=1, + worker_results=[result], + ) + + # Multi-worker — ThreadPoolExecutor + worker_results: List[Dict[str, Any]] = [] + with ThreadPoolExecutor(max_workers=actual_workers) as executor: + futures = {} + for i, key_slice in enumerate(slices): + future = executor.submit( + _worker_quantize, + worker_id=i, + redis_url=redis_url, + keys=key_slice, + datatype_changes=datatype_changes, + backup_path=worker_backup_paths[i], + index_name=index_name, + batch_size=batch_size, + progress_callback=progress_callback, + ) + futures[future] = i + + for future in as_completed(futures): + result = future.result() # raises if worker failed + worker_results.append(result) + + total_docs = sum(r["docs"] for r in worker_results) + return MultiWorkerResult( + total_docs_quantized=total_docs, + num_workers=actual_workers, + worker_results=worker_results, + ) + + +async def _async_worker_quantize( + worker_id: int, + redis_url: str, + keys: List[str], + datatype_changes: Dict[str, Dict[str, Any]], + backup_path: str, + index_name: str, + batch_size: int, + progress_callback: Optional[Callable[[str, int, int], None]] = None, +) -> Dict[str, Any]: + """Async single worker: dump originals + convert + write back.""" + import redis.asyncio as aioredis + + from redisvl.migration.backup import VectorBackup + + client = aioredis.from_url(redis_url) + try: + # Try to resume from existing backup shard first + backup = VectorBackup.load(backup_path) + if backup is not None: + logger.info( + "Async worker %d: resuming from existing backup (phase=%s, " + "dump_batches=%d, quantize_batches=%d)", + worker_id, + backup.header.phase, + backup.header.dump_completed_batches, + backup.header.quantize_completed_batches, + ) + else: + backup = VectorBackup.create( + path=backup_path, + index_name=index_name, + fields=datatype_changes, + batch_size=batch_size, + ) + + total = len(keys) + field_names = list(datatype_changes.keys()) + + # Phase 1: Dump originals (skip if already complete) + if backup.header.phase == "dump": + start_batch = backup.header.dump_completed_batches + for batch_start in range(start_batch * batch_size, total, batch_size): + batch_keys = keys[batch_start : batch_start + batch_size] + pipe = client.pipeline(transaction=False) + call_order: List[tuple] = [] + for key in batch_keys: + for field_name in field_names: + pipe.hget(key, field_name) + call_order.append((key, field_name)) + results = await pipe.execute() + + originals: Dict[str, Dict[str, bytes]] = {} + for (key, field_name), value in zip(call_order, results): + if value is not None: + if key not in originals: + originals[key] = {} + originals[key][field_name] = value + + backup.write_batch(batch_start // batch_size, batch_keys, originals) + if progress_callback: + progress_callback( + "dump", worker_id, min(batch_start + batch_size, total) + ) + backup.mark_dump_complete() + + # Phase 2: Convert + write from backup (skip completed batches) + if backup.header.phase in ("ready", "active"): + backup.start_quantize() + docs_quantized = 0 + + for batch_idx, (batch_keys, batch_originals) in enumerate( + backup.iter_batches() + ): + if batch_idx < backup.header.quantize_completed_batches: + docs_quantized += len(batch_keys) + continue + converted = convert_vectors(batch_originals, datatype_changes) + if converted: + pipe = client.pipeline(transaction=False) + for key, fields in converted.items(): + for field_name, data in fields.items(): + pipe.hset(key, field_name, data) + await pipe.execute() + backup.mark_batch_quantized(batch_idx) + docs_quantized += len(batch_keys) + if progress_callback: + progress_callback("quantize", worker_id, docs_quantized) + + backup.mark_complete() + elif backup.header.phase == "completed": + docs_quantized = total + + return {"worker_id": worker_id, "docs": docs_quantized} + finally: + await client.aclose() + + +async def async_multi_worker_quantize( + redis_url: str, + keys: List[str], + datatype_changes: Dict[str, Dict[str, Any]], + backup_dir: str, + index_name: str, + num_workers: int = 1, + batch_size: int = 500, + progress_callback: Optional[Callable[[str, int, int], None]] = None, +) -> MultiWorkerResult: + """Orchestrate async multi-worker quantization via asyncio.gather. + + Each worker gets its own async Redis connection and backup file shard. + """ + import asyncio + from pathlib import Path + + slices = split_keys(keys, num_workers) + actual_workers = len(slices) + + if actual_workers == 0: + return MultiWorkerResult( + total_docs_quantized=0, num_workers=0, worker_results=[] + ) + + safe_name = index_name.replace("/", "_").replace("\\", "_").replace(":", "_") + name_hash = hashlib.sha256(index_name.encode()).hexdigest()[:8] + worker_backup_paths = [ + str(Path(backup_dir) / f"migration_backup_{safe_name}_{name_hash}_worker{i}") + for i in range(actual_workers) + ] + + coroutines = [ + _async_worker_quantize( + worker_id=i, + redis_url=redis_url, + keys=slices[i], + datatype_changes=datatype_changes, + backup_path=worker_backup_paths[i], + index_name=index_name, + batch_size=batch_size, + progress_callback=progress_callback, + ) + for i in range(actual_workers) + ] + + results = await asyncio.gather(*coroutines) + worker_results = list(results) + total_docs = sum(r["docs"] for r in worker_results) + + return MultiWorkerResult( + total_docs_quantized=total_docs, + num_workers=actual_workers, + worker_results=worker_results, + ) diff --git a/redisvl/migration/reliability.py b/redisvl/migration/reliability.py new file mode 100644 index 000000000..71f6e672e --- /dev/null +++ b/redisvl/migration/reliability.py @@ -0,0 +1,340 @@ +"""Crash-safe quantization utilities for index migration. + +Provides idempotent dtype detection, checkpointing, BGSAVE safety, +and bounded undo buffering for reliable vector re-encoding. +""" + +import asyncio +import os +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml +from pydantic import BaseModel, Field + +from redisvl.migration.models import DTYPE_BYTES +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + +# Dtypes that share byte widths and are functionally interchangeable +# for idempotent detection purposes (same byte length per element). +_DTYPE_FAMILY: Dict[str, str] = { + "float64": "8byte", + "float32": "4byte", + "float16": "2byte", + "bfloat16": "2byte", + "int8": "1byte", + "uint8": "1byte", +} + + +def is_same_width_dtype_conversion(source_dtype: str, target_dtype: str) -> bool: + """Return True when two dtypes share byte width but differ in encoding.""" + if source_dtype == target_dtype: + return False + source_family = _DTYPE_FAMILY.get(source_dtype) + target_family = _DTYPE_FAMILY.get(target_dtype) + if source_family is None or target_family is None: + return False + return source_family == target_family + + +# --------------------------------------------------------------------------- +# Idempotent Dtype Detection +# --------------------------------------------------------------------------- + + +def detect_vector_dtype(data: bytes, expected_dims: int) -> Optional[str]: + """Inspect raw vector bytes and infer the storage dtype. + + Uses byte length and expected dimensions to determine which dtype + the vector is currently stored as. Returns the canonical representative + for each byte-width family (float16 for 2-byte, int8 for 1-byte), + since dtypes within a family cannot be distinguished by length alone. + + Args: + data: Raw vector bytes from Redis. + expected_dims: Number of dimensions expected for this vector field. + + Returns: + Detected dtype string (e.g. "float32", "float16", "int8") or None + if the size does not match any known dtype. + """ + if not data or expected_dims <= 0: + return None + + nbytes = len(data) + + # Check each dtype in decreasing element size to avoid ambiguity. + # Only canonical representatives are checked (float16 covers bfloat16, + # int8 covers uint8) since they share byte widths. + for dtype in ("float64", "float32", "float16", "int8"): + if nbytes == expected_dims * DTYPE_BYTES[dtype]: + return dtype + + return None + + +def is_already_quantized( + data: bytes, + expected_dims: int, + source_dtype: str, + target_dtype: str, +) -> bool: + """Check whether a vector has already been converted to the target dtype. + + Uses byte-width families to handle ambiguous dtypes. For example, + if source is float32 and target is float16, a vector detected as + 2-bytes-per-element is considered already quantized (the byte width + shrank from 4 to 2, so conversion already happened). + + However, same-width conversions (e.g. float16 -> bfloat16 or + int8 -> uint8) are NOT skipped because the encoding semantics + differ even though the byte length is identical. We cannot + distinguish these by length, so we must always re-encode. + + Args: + data: Raw vector bytes. + expected_dims: Number of dimensions. + source_dtype: The dtype the vector was originally stored as. + target_dtype: The dtype we want to convert to. + + Returns: + True if the vector already matches the target dtype (skip conversion). + """ + detected = detect_vector_dtype(data, expected_dims) + if detected is None: + return False + + detected_family = _DTYPE_FAMILY.get(detected) + target_family = _DTYPE_FAMILY.get(target_dtype) + source_family = _DTYPE_FAMILY.get(source_dtype) + + # If detected byte-width matches target family, the vector looks converted. + # But if source and target share the same byte-width family (e.g. + # float16 -> bfloat16), we cannot tell whether conversion happened, + # so we must NOT skip -- always re-encode for same-width migrations. + if source_family == target_family: + return False + + return detected_family == target_family + + +# --------------------------------------------------------------------------- +# Quantization Checkpoint +# --------------------------------------------------------------------------- + + +class QuantizationCheckpoint(BaseModel): + """Tracks migration progress for crash-safe resume.""" + + index_name: str + total_keys: int + completed_keys: int = 0 + completed_batches: int = 0 + last_batch_keys: List[str] = Field(default_factory=list) + # Retained for backward compatibility with older checkpoint files. + # New checkpoints rely on completed_keys with deterministic key ordering + # instead of rewriting an ever-growing processed key list on every batch. + processed_keys: List[str] = Field(default_factory=list) + status: str = "in_progress" + checkpoint_path: str = "" + + def record_batch(self, keys: List[str]) -> None: + """Record a successfully processed batch. + + Does not auto-save to disk. Call save() after record_batch() + to persist the checkpoint for crash recovery. + """ + self.completed_keys += len(keys) + self.completed_batches += 1 + self.last_batch_keys = list(keys) + if self.processed_keys: + self.processed_keys.extend(keys) + + def mark_complete(self) -> None: + """Mark the migration as completed.""" + self.status = "completed" + + def save(self) -> None: + """Persist checkpoint to disk atomically. + + Writes to a temporary file first, then renames. This ensures a + crash mid-write does not corrupt the checkpoint file. + """ + path = Path(self.checkpoint_path) + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + dir=path.parent, suffix=".tmp", prefix=".checkpoint_" + ) + try: + exclude = set() + if not self.processed_keys: + exclude.add("processed_keys") + with os.fdopen(fd, "w") as f: + yaml.safe_dump( + self.model_dump(exclude=exclude), + f, + sort_keys=False, + ) + os.replace(tmp_path, str(path)) + except BaseException: + # Clean up temp file on any failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + @classmethod + def load(cls, path: str) -> Optional["QuantizationCheckpoint"]: + """Load a checkpoint from disk. Returns None if file does not exist. + + Always sets checkpoint_path to the path used to load, not the + value stored in the file. This ensures subsequent save() calls + write to the correct location even if the file was moved. + """ + p = Path(path) + if not p.exists(): + return None + with open(p, "r") as f: + data = yaml.safe_load(f) + if not data: + return None + checkpoint = cls.model_validate(data) + if checkpoint.processed_keys and checkpoint.completed_keys < len( + checkpoint.processed_keys + ): + checkpoint.completed_keys = len(checkpoint.processed_keys) + checkpoint.checkpoint_path = str(p) + return checkpoint + + def get_remaining_keys(self, all_keys: List[str]) -> List[str]: + """Return keys that have not yet been processed.""" + if self.processed_keys: + done = set(self.processed_keys) + return [k for k in all_keys if k not in done] + + if self.completed_keys <= 0: + return list(all_keys) + + return all_keys[self.completed_keys :] + + +# --------------------------------------------------------------------------- +# BGSAVE Safety Net +# --------------------------------------------------------------------------- + + +def trigger_bgsave_and_wait( + client: Any, + *, + timeout_seconds: int = 300, + poll_interval: float = 1.0, +) -> bool: + """Trigger a Redis BGSAVE and wait for it to complete. + + If a BGSAVE is already in progress, waits for it instead. + + Args: + client: Sync Redis client. + timeout_seconds: Max seconds to wait for BGSAVE to finish. + poll_interval: Seconds between status polls. + + Returns: + True if BGSAVE completed successfully. + """ + try: + client.bgsave() + except Exception as exc: + if "already in progress" not in str(exc).lower(): + raise + logger.info("BGSAVE already in progress, waiting for it to finish.") + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + info = client.info("persistence") + if isinstance(info, dict) and not info.get("rdb_bgsave_in_progress", 0): + status = info.get("rdb_last_bgsave_status", "ok") + if status != "ok": + logger.warning("BGSAVE completed with status: %s", status) + return False + return True + time.sleep(poll_interval) + + raise TimeoutError(f"BGSAVE did not complete within {timeout_seconds}s") + + +async def async_trigger_bgsave_and_wait( + client: Any, + *, + timeout_seconds: int = 300, + poll_interval: float = 1.0, +) -> bool: + """Async version of trigger_bgsave_and_wait.""" + try: + await client.bgsave() + except Exception as exc: + if "already in progress" not in str(exc).lower(): + raise + logger.info("BGSAVE already in progress, waiting for it to finish.") + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + info = await client.info("persistence") + if isinstance(info, dict) and not info.get("rdb_bgsave_in_progress", 0): + status = info.get("rdb_last_bgsave_status", "ok") + if status != "ok": + logger.warning("BGSAVE completed with status: %s", status) + return False + return True + await asyncio.sleep(poll_interval) + + raise TimeoutError(f"BGSAVE did not complete within {timeout_seconds}s") + + +# --------------------------------------------------------------------------- +# Bounded Undo Buffer +# --------------------------------------------------------------------------- + + +class BatchUndoBuffer: + """Stores original vector values for the current batch to allow rollback. + + Memory-bounded: only holds data for one batch at a time. Call clear() + after each successful batch commit. + """ + + def __init__(self) -> None: + self._entries: List[Tuple[str, str, bytes]] = [] + + @property + def size(self) -> int: + return len(self._entries) + + def store(self, key: str, field: str, original_value: bytes) -> None: + """Record the original value of a field before mutation.""" + self._entries.append((key, field, original_value)) + + def rollback(self, pipe: Any) -> None: + """Restore all stored originals via the given pipeline (sync).""" + if not self._entries: + return + for key, field, value in self._entries: + pipe.hset(key, field, value) + pipe.execute() + + async def async_rollback(self, pipe: Any) -> None: + """Restore all stored originals via the given pipeline (async).""" + if not self._entries: + return + for key, field, value in self._entries: + pipe.hset(key, field, value) + await pipe.execute() + + def clear(self) -> None: + """Discard all stored entries.""" + self._entries.clear() diff --git a/redisvl/migration/utils.py b/redisvl/migration/utils.py new file mode 100644 index 000000000..11dc98664 --- /dev/null +++ b/redisvl/migration/utils.py @@ -0,0 +1,465 @@ +from __future__ import annotations + +import json +import time +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import yaml + +from redisvl.index import SearchIndex +from redisvl.migration.models import ( + AOF_HSET_OVERHEAD_BYTES, + AOF_JSON_SET_OVERHEAD_BYTES, + DTYPE_BYTES, + RDB_COMPRESSION_RATIO, + DiskSpaceEstimate, + MigrationPlan, + MigrationReport, + VectorFieldEstimate, +) +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.schema.schema import IndexSchema +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + + +def list_indexes( + *, redis_url: Optional[str] = None, redis_client: Optional[Any] = None +): + if redis_client is None: + if not redis_url: + raise ValueError("Must provide either redis_url or redis_client") + redis_client = RedisConnectionFactory.get_redis_connection(redis_url=redis_url) + index = SearchIndex.from_dict( + {"index": {"name": "__redisvl_migration_helper__"}, "fields": []}, + redis_client=redis_client, + ) + return index.listall() + + +def load_yaml(path: str) -> Dict[str, Any]: + resolved = Path(path).resolve() + with open(resolved, "r") as f: + return yaml.safe_load(f) or {} + + +def write_yaml(data: Dict[str, Any], path: str) -> None: + resolved = Path(path).resolve() + with open(resolved, "w") as f: + yaml.safe_dump(data, f, sort_keys=False) + + +def load_migration_plan(path: str) -> MigrationPlan: + return MigrationPlan.model_validate(load_yaml(path)) + + +def write_migration_report(report: MigrationReport, path: str) -> None: + write_yaml(report.model_dump(exclude_none=True), path) + + +def write_benchmark_report(report: MigrationReport, path: str) -> None: + benchmark_report = { + "version": report.version, + "mode": report.mode, + "source_index": report.source_index, + "target_index": report.target_index, + "result": report.result, + "timings": report.timings.model_dump(exclude_none=True), + "benchmark_summary": report.benchmark_summary.model_dump(exclude_none=True), + "validation": { + "schema_match": report.validation.schema_match, + "doc_count_match": report.validation.doc_count_match, + "indexing_failures_delta": report.validation.indexing_failures_delta, + "key_sample_exists": report.validation.key_sample_exists, + }, + } + write_yaml(benchmark_report, path) + + +def normalize_keys(keys: List[str]) -> List[str]: + """Deduplicate and sort keys for deterministic resume behavior.""" + return sorted(set(keys)) + + +def build_scan_match_patterns(prefixes: List[str], key_separator: str) -> List[str]: + """Build SCAN patterns for all configured prefixes.""" + if not prefixes: + logger.warning( + "No prefixes provided for SCAN pattern. " + "Using '*' which will scan the entire keyspace." + ) + return ["*"] + + patterns = set() + for prefix in prefixes: + if not prefix: + logger.warning( + "Empty prefix in prefix list. " + "Using '*' which will scan the entire keyspace." + ) + return ["*"] + if key_separator and not prefix.endswith(key_separator): + patterns.add(f"{prefix}{key_separator}*") + else: + patterns.add(f"{prefix}*") + return sorted(patterns) + + +def detect_aof_enabled(client: Any) -> bool: + """Best-effort detection of whether AOF is enabled on the live Redis.""" + try: + info = client.info("persistence") + if isinstance(info, dict) and "aof_enabled" in info: + return bool(int(info["aof_enabled"])) + except Exception: + pass + + try: + config = client.config_get("appendonly") + if isinstance(config, dict): + value = config.get("appendonly") + if value is not None: + return str(value).lower() in {"yes", "1", "true", "on"} + except Exception: + pass + + return False + + +def get_schema_field_path(schema: Dict[str, Any], field_name: str) -> Optional[str]: + """Return the JSON path configured for a field, if present.""" + for field in schema.get("fields", []): + if field.get("name") != field_name: + continue + path = field.get("path") + if path is None: + path = field.get("attrs", {}).get("path") + return str(path) if path is not None else None + return None + + +# Attributes excluded from schema validation comparison. +# These are query-time or creation-hint parameters that FT.INFO does not return +# and are not relevant for index structure validation (confirmed by RediSearch team). +# - ef_runtime, epsilon: query-time tuning knobs, not index definition attributes +# - initial_cap: creation-time memory pre-allocation hint, diverges after indexing +EXCLUDED_VECTOR_ATTRS = {"ef_runtime", "epsilon", "initial_cap"} +# phonetic_matcher: the matcher string (e.g. "dm:en") is not stored server-side, +# only a boolean flag is kept, so it cannot be read back. +# withsuffixtrie: returned as a flag in FT.INFO but not as a KV attribute, +# so RedisVL's parser does not capture it yet. +EXCLUDED_TEXT_ATTRS = {"phonetic_matcher", "withsuffixtrie"} +EXCLUDED_TAG_ATTRS = {"withsuffixtrie"} + + +def _strip_excluded_attrs(field: Dict[str, Any]) -> Dict[str, Any]: + """Remove attributes not relevant for index validation comparison. + + These are either query-time parameters, creation-time hints, or attributes + whose server-side representation differs from the schema definition. + + Also normalizes attributes that have implicit behavior: + - For NUMERIC + SORTABLE, Redis auto-applies UNF, so we normalize to unf=True + """ + field = field.copy() + attrs = field.get("attrs", {}) + if not attrs: + return field + + attrs = attrs.copy() + field_type = field.get("type", "").lower() + + if field_type == "vector": + for attr in EXCLUDED_VECTOR_ATTRS: + attrs.pop(attr, None) + elif field_type == "text": + for attr in EXCLUDED_TEXT_ATTRS: + attrs.pop(attr, None) + # Normalize weight to int for comparison (FT.INFO may return float) + if "weight" in attrs and isinstance(attrs["weight"], float): + if attrs["weight"] == int(attrs["weight"]): + attrs["weight"] = int(attrs["weight"]) + elif field_type == "tag": + for attr in EXCLUDED_TAG_ATTRS: + attrs.pop(attr, None) + elif field_type == "numeric": + # Redis auto-applies UNF when SORTABLE is set on NUMERIC fields. + # Normalize unf to True when sortable is True to avoid false mismatches. + if attrs.get("sortable"): + attrs["unf"] = True + + field["attrs"] = attrs + return field + + +def canonicalize_schema( + schema_dict: Dict[str, Any], + *, + strip_unreliable: bool = False, + strip_excluded: bool = False, +) -> Dict[str, Any]: + """Canonicalize schema for comparison. + + Args: + schema_dict: The schema dictionary to canonicalize. + strip_unreliable: Deprecated alias for strip_excluded. Kept for + backward compatibility. + strip_excluded: If True, remove query-time and creation-hint attributes + that are not part of index structure validation. + """ + schema = IndexSchema.from_dict(schema_dict).to_dict() + + should_strip = strip_excluded or strip_unreliable + fields = schema.get("fields", []) + if should_strip: + fields = [_strip_excluded_attrs(f) for f in fields] + + schema["fields"] = sorted(fields, key=lambda field: field["name"]) + prefixes = schema["index"].get("prefix") + if isinstance(prefixes, list): + schema["index"]["prefix"] = sorted(prefixes) + stopwords = schema["index"].get("stopwords") + if isinstance(stopwords, list): + schema["index"]["stopwords"] = sorted(stopwords) + return schema + + +def schemas_equal( + left: Dict[str, Any], + right: Dict[str, Any], + *, + strip_unreliable: bool = False, + strip_excluded: bool = False, +) -> bool: + """Compare two schemas for equality. + + Args: + left: First schema dictionary. + right: Second schema dictionary. + strip_unreliable: Deprecated alias for strip_excluded. Kept for + backward compatibility. + strip_excluded: If True, exclude query-time and creation-hint attributes + (ef_runtime, epsilon, initial_cap, phonetic_matcher) from comparison. + """ + should_strip = strip_excluded or strip_unreliable + return json.dumps( + canonicalize_schema(left, strip_excluded=should_strip), sort_keys=True + ) == json.dumps( + canonicalize_schema(right, strip_excluded=should_strip), sort_keys=True + ) + + +def wait_for_index_ready( + index: SearchIndex, + *, + timeout_seconds: int = 1800, + poll_interval_seconds: float = 0.5, + progress_callback: Optional[Callable[[int, int, float], None]] = None, +) -> Tuple[Dict[str, Any], float]: + """Wait for index to finish indexing all documents. + + Args: + index: The SearchIndex to monitor. + timeout_seconds: Maximum time to wait. + poll_interval_seconds: How often to check status. + progress_callback: Optional callback(indexed_docs, total_docs, percent). + """ + start = time.perf_counter() + deadline = start + timeout_seconds + latest_info = index.info() + + stable_ready_checks: Optional[int] = None + while time.perf_counter() < deadline: + ready = False + latest_info = index.info() + indexing = latest_info.get("indexing") + percent_indexed = latest_info.get("percent_indexed") + + if percent_indexed is not None or indexing is not None: + pct = float(percent_indexed) if percent_indexed is not None else None + is_indexing = bool(indexing) + if pct is not None: + ready = pct >= 1.0 and not is_indexing + else: + # percent_indexed missing but indexing flag present: + # treat as ready when indexing flag is falsy (0 / False). + ready = not is_indexing + if progress_callback: + total_docs = int(latest_info.get("num_docs", 0)) + display_pct = pct if pct is not None else (1.0 if ready else 0.0) + indexed_docs = int(total_docs * display_pct) + progress_callback(indexed_docs, total_docs, display_pct * 100) + else: + current_docs = latest_info.get("num_docs") + if current_docs is None: + ready = True + else: + if stable_ready_checks is None: + stable_ready_checks = int(current_docs) + time.sleep(poll_interval_seconds) + continue + current = int(current_docs) + if current == stable_ready_checks: + ready = True + else: + # num_docs changed; update baseline and keep waiting + stable_ready_checks = current + + if ready: + return latest_info, round(time.perf_counter() - start, 3) + + time.sleep(poll_interval_seconds) + + raise TimeoutError( + f"Index {index.schema.index.name} did not become ready within {timeout_seconds} seconds" + ) + + +def current_source_matches_snapshot( + index_name: str, + expected_schema: Dict[str, Any], + *, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, +) -> bool: + try: + current_index = SearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + except Exception: + # Index no longer exists (e.g. already dropped during migration) + return False + return schemas_equal(current_index.schema.to_dict(), expected_schema) + + +def timestamp_utc() -> str: + return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + +def estimate_disk_space( + plan: MigrationPlan, + *, + aof_enabled: bool = False, +) -> DiskSpaceEstimate: + """Estimate disk space required for a migration with quantization. + + This is a pure calculation based on the migration plan. No Redis + operations are performed. + + Args: + plan: The migration plan containing source/target schemas. + aof_enabled: Whether AOF persistence is active on the Redis instance. + + Returns: + DiskSpaceEstimate with projected costs. + """ + doc_count = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + storage_type = plan.source.keyspace.storage_type + index_name = plan.source.index_name + + # Find vector fields with datatype changes + source_fields = { + f["name"]: f for f in plan.source.schema_snapshot.get("fields", []) + } + target_fields = {f["name"]: f for f in plan.merged_target_schema.get("fields", [])} + + # Build rename map: source_name -> target_name + field_rename_map: Dict[str, str] = {} + rename_ops = plan.rename_operations + if rename_ops and rename_ops.rename_fields: + for fr in rename_ops.rename_fields: + field_rename_map[fr.old_name] = fr.new_name + + vector_field_estimates: list[VectorFieldEstimate] = [] + total_source_bytes = 0 + total_target_bytes = 0 + total_aof_growth = 0 + + aof_overhead = ( + AOF_JSON_SET_OVERHEAD_BYTES + if storage_type == "json" + else AOF_HSET_OVERHEAD_BYTES + ) + + for name, source_field in source_fields.items(): + if source_field.get("type") != "vector": + continue + # Look up target by renamed name if applicable + target_name = field_rename_map.get(name, name) + target_field = target_fields.get(target_name) + if not target_field or target_field.get("type") != "vector": + continue + + source_attrs = source_field.get("attrs", {}) + target_attrs = target_field.get("attrs", {}) + source_dtype = source_attrs.get("datatype", "float32").lower() + target_dtype = target_attrs.get("datatype", "float32").lower() + + if source_dtype == target_dtype: + continue + + if source_dtype not in DTYPE_BYTES: + raise ValueError( + f"Unknown source vector datatype '{source_dtype}' for field '{name}'. " + f"Supported datatypes: {', '.join(sorted(DTYPE_BYTES.keys()))}" + ) + if target_dtype not in DTYPE_BYTES: + raise ValueError( + f"Unknown target vector datatype '{target_dtype}' for field '{name}'. " + f"Supported datatypes: {', '.join(sorted(DTYPE_BYTES.keys()))}" + ) + + if storage_type == "json": + # JSON-backed migrations do not rewrite per-document vector payloads + # during apply(); they rely on recreate + re-index instead. + continue + + dims = int(source_attrs.get("dims", 0)) + source_bpe = DTYPE_BYTES[source_dtype] + target_bpe = DTYPE_BYTES[target_dtype] + + source_vec_size = dims * source_bpe + target_vec_size = dims * target_bpe + + vector_field_estimates.append( + VectorFieldEstimate( + field_name=name, + dims=dims, + source_dtype=source_dtype, + target_dtype=target_dtype, + source_bytes_per_doc=source_vec_size, + target_bytes_per_doc=target_vec_size, + ) + ) + + field_source_total = doc_count * source_vec_size + field_target_total = doc_count * target_vec_size + total_source_bytes += field_source_total + total_target_bytes += field_target_total + + if aof_enabled: + total_aof_growth += doc_count * (target_vec_size + aof_overhead) + + rdb_snapshot_disk = int(total_source_bytes * RDB_COMPRESSION_RATIO) + rdb_cow_memory = total_source_bytes + total_new_disk = rdb_snapshot_disk + total_aof_growth + memory_savings = total_source_bytes - total_target_bytes + + return DiskSpaceEstimate( + index_name=index_name, + doc_count=doc_count, + storage_type=storage_type, + vector_fields=vector_field_estimates, + total_source_vector_bytes=total_source_bytes, + total_target_vector_bytes=total_target_bytes, + rdb_snapshot_disk_bytes=rdb_snapshot_disk, + rdb_cow_memory_if_concurrent_bytes=rdb_cow_memory, + aof_enabled=aof_enabled, + aof_growth_bytes=total_aof_growth, + total_new_disk_bytes=total_new_disk, + memory_savings_after_bytes=memory_savings, + ) diff --git a/redisvl/migration/validation.py b/redisvl/migration/validation.py new file mode 100644 index 000000000..83123c314 --- /dev/null +++ b/redisvl/migration/validation.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +from redis.commands.search.query import Query + +from redisvl.index import SearchIndex +from redisvl.migration.models import ( + MigrationPlan, + MigrationValidation, + QueryCheckResult, +) +from redisvl.migration.utils import load_yaml, schemas_equal + + +class MigrationValidator: + def validate( + self, + plan: MigrationPlan, + *, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + query_check_file: Optional[str] = None, + ) -> tuple[MigrationValidation, Dict[str, Any], float]: + started = time.perf_counter() + target_index = SearchIndex.from_existing( + plan.merged_target_schema["index"]["name"], + redis_url=redis_url, + redis_client=redis_client, + ) + target_info = target_index.info() + validation = MigrationValidation() + + live_schema = target_index.schema.to_dict() + # Exclude query-time and creation-hint attributes (ef_runtime, epsilon, + # initial_cap, phonetic_matcher) that are not part of index structure + # validation. Confirmed by RediSearch team as not relevant for this check. + validation.schema_match = schemas_equal( + live_schema, plan.merged_target_schema, strip_excluded=True + ) + + source_num_docs = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + target_num_docs = int(target_info.get("num_docs", 0) or 0) + + source_failures = int( + plan.source.stats_snapshot.get("hash_indexing_failures", 0) or 0 + ) + target_failures = int(target_info.get("hash_indexing_failures", 0) or 0) + validation.indexing_failures_delta = target_failures - source_failures + + # Compare total keys (num_docs + hash_indexing_failures) instead of + # just num_docs. Migrations can resolve indexing failures (e.g. a + # vector datatype change may fix documents that previously failed to + # index), shifting counts between the two buckets while the total + # number of keys under the prefix stays the same. + source_total = source_num_docs + source_failures + target_total = target_num_docs + target_failures + validation.doc_count_match = source_total == target_total + + key_sample = plan.source.keyspace.key_sample + if not key_sample: + validation.key_sample_exists = True + else: + # Handle prefix change: transform key_sample to use new prefix. + # Must match the executor's RENAME logic exactly: + # new_key = new_prefix + key[len(old_prefix):] + keys_to_check = key_sample + if plan.rename_operations.change_prefix is not None: + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = plan.rename_operations.change_prefix + # Mirror executor logic exactly: + # new_key = new_prefix + key[len(old_prefix):] + keys_to_check = [] + for k in key_sample: + if k.startswith(old_prefix): + keys_to_check.append(new_prefix + k[len(old_prefix) :]) + else: + keys_to_check.append(k) + existing_count = target_index.client.exists(*keys_to_check) + validation.key_sample_exists = existing_count == len(keys_to_check) + + # Run automatic functional checks (always). + # Use source_total (num_docs + failures) as the expected count so that + # resolved indexing failures don't cause the wildcard check to fail. + functional_checks = self._run_functional_checks(target_index, source_total) + validation.query_checks.extend(functional_checks) + + # Run user-provided query checks (if file provided) + if query_check_file: + user_checks = self._run_query_checks(target_index, query_check_file) + validation.query_checks.extend(user_checks) + + if not validation.schema_match and plan.validation.require_schema_match: + validation.errors.append("Live schema does not match merged_target_schema.") + if not validation.doc_count_match and plan.validation.require_doc_count_match: + validation.errors.append( + f"Total key count mismatch: source had {source_total} " + f"(num_docs={source_num_docs}, failures={source_failures}), " + f"target has {target_total} " + f"(num_docs={target_num_docs}, failures={target_failures})." + ) + if validation.indexing_failures_delta > 0: + validation.errors.append("Indexing failures increased during migration.") + if not validation.key_sample_exists: + validation.errors.append( + "One or more sampled source keys is missing after migration." + ) + if any(not query_check.passed for query_check in validation.query_checks): + validation.errors.append("One or more query checks failed.") + + return validation, target_info, round(time.perf_counter() - started, 3) + + def _run_query_checks( + self, + target_index: SearchIndex, + query_check_file: str, + ) -> list[QueryCheckResult]: + query_checks = load_yaml(query_check_file) + results: list[QueryCheckResult] = [] + + for doc_id in query_checks.get("fetch_ids", []): + fetched = target_index.fetch(doc_id) + results.append( + QueryCheckResult( + name=f"fetch:{doc_id}", + passed=fetched is not None, + details=( + "Document fetched successfully" + if fetched is not None + else "Document not found" + ), + ) + ) + + for key in query_checks.get("keys_exist", []): + client = target_index.client + if client is None: + raise ValueError("Redis client not connected") + exists = bool(client.exists(key)) + results.append( + QueryCheckResult( + name=f"key:{key}", + passed=exists, + details="Key exists" if exists else "Key not found", + ) + ) + + return results + + def _run_functional_checks( + self, target_index: SearchIndex, expected_doc_count: int + ) -> List[QueryCheckResult]: + """Run automatic functional checks to verify the index is operational. + + These checks run automatically after every migration to prove the index + actually works, not just that the schema looks correct. + """ + results: List[QueryCheckResult] = [] + + # Check 1: Wildcard search - proves the index responds and returns docs + try: + search_result = target_index.search(Query("*").paging(0, 1)) + total_found = search_result.total + # When expected_doc_count is 0 (empty index), a successful + # search returning 0 docs is correct behaviour, not a failure. + if expected_doc_count == 0: + passed = total_found == 0 + else: + passed = total_found > 0 + if expected_doc_count == 0: + detail_expectation = "expected 0" + else: + detail_expectation = f"expected >0, source had {expected_doc_count}" + results.append( + QueryCheckResult( + name="functional:wildcard_search", + passed=passed, + details=( + f"Wildcard search returned {total_found} docs " + f"({detail_expectation})" + ), + ) + ) + except Exception as e: + results.append( + QueryCheckResult( + name="functional:wildcard_search", + passed=False, + details=f"Wildcard search failed: {str(e)}", + ) + ) + + return results diff --git a/redisvl/migration/wizard.py b/redisvl/migration/wizard.py new file mode 100644 index 000000000..fc2607359 --- /dev/null +++ b/redisvl/migration/wizard.py @@ -0,0 +1,895 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import yaml + +from redisvl.migration.models import ( + FieldRename, + FieldUpdate, + SchemaPatch, + SchemaPatchChanges, +) +from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.utils import list_indexes, write_yaml +from redisvl.schema.schema import IndexSchema + +SUPPORTED_FIELD_TYPES = ["text", "tag", "numeric", "geo"] +UPDATABLE_FIELD_TYPES = ["text", "tag", "numeric", "geo", "vector"] + + +class MigrationWizard: + def __init__(self, planner: Optional[MigrationPlanner] = None): + self.planner = planner or MigrationPlanner() + self._existing_sortable: bool = False + + def run( + self, + *, + index_name: Optional[str] = None, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + existing_patch_path: Optional[str] = None, + plan_out: str = "migration_plan.yaml", + patch_out: Optional[str] = None, + target_schema_out: Optional[str] = None, + ): + resolved_index_name = self._resolve_index_name( + index_name=index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + snapshot = self.planner.snapshot_source( + resolved_index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + + # Guard: the wizard does not support indexes with multiple prefixes. + prefixes = source_schema.index.prefix + if isinstance(prefixes, list) and len(prefixes) > 1: + raise ValueError( + f"Index '{resolved_index_name}' has multiple prefixes " + f"({prefixes}). The migration wizard only supports single-prefix " + "indexes. Use the planner API directly for multi-prefix indexes." + ) + + print(f"Building a migration plan for index '{resolved_index_name}'") + self._print_source_schema(source_schema.to_dict()) + + # Load existing patch if provided + existing_changes = None + if existing_patch_path: + existing_changes = self._load_existing_patch(existing_patch_path) + + schema_patch = self._build_patch( + source_schema.to_dict(), existing_changes=existing_changes + ) + plan = self.planner.create_plan_from_patch( + resolved_index_name, + schema_patch=schema_patch, + redis_url=redis_url, + redis_client=redis_client, + ) + self.planner.write_plan(plan, plan_out) + + if patch_out: + write_yaml(schema_patch.model_dump(exclude_none=True), patch_out) + if target_schema_out: + write_yaml(plan.merged_target_schema, target_schema_out) + + return plan + + def _load_existing_patch(self, patch_path: str) -> SchemaPatchChanges: + from redisvl.migration.utils import load_yaml + + data = load_yaml(patch_path) + patch = SchemaPatch.model_validate(data) + print(f"Loaded existing patch from {patch_path}") + print(f" Add fields: {len(patch.changes.add_fields)}") + print(f" Update fields: {len(patch.changes.update_fields)}") + print(f" Remove fields: {len(patch.changes.remove_fields)}") + print(f" Rename fields: {len(patch.changes.rename_fields)}") + if patch.changes.index: + print(f" Index changes: {list(patch.changes.index.keys())}") + return patch.changes + + def _resolve_index_name( + self, + *, + index_name: Optional[str], + redis_url: Optional[str], + redis_client: Optional[Any], + ) -> str: + if index_name: + return index_name + + indexes = list_indexes(redis_url=redis_url, redis_client=redis_client) + if not indexes: + raise ValueError("No indexes found in Redis") + + print("Available indexes:") + for position, name in enumerate(indexes, start=1): + print(f"{position}. {name}") + + while True: + choice = input("Select an index by number or name: ").strip() + if choice in indexes: + return choice + if choice.isdigit(): + offset = int(choice) - 1 + if 0 <= offset < len(indexes): + return indexes[offset] + print("Invalid selection. Please try again.") + + @staticmethod + def _filter_staged_adds( + working_schema: Dict[str, Any], staged_add_names: set + ) -> Dict[str, Any]: + """Return a copy of working_schema with staged-add fields removed. + + This prevents staged additions from appearing in update/rename + candidate lists. + """ + import copy + + filtered = copy.deepcopy(working_schema) + filtered["fields"] = [ + f for f in filtered["fields"] if f["name"] not in staged_add_names + ] + return filtered + + def _apply_staged_changes( + self, + source_schema: Dict[str, Any], + changes: SchemaPatchChanges, + ) -> Dict[str, Any]: + """Build a working copy of source_schema reflecting staged changes. + + This ensures subsequent prompts show the current state of the schema + after renames, removes, and adds have been queued. + """ + import copy + + working = copy.deepcopy(source_schema) + + # Apply removes + removed_names = set(changes.remove_fields) + working["fields"] = [ + f for f in working["fields"] if f["name"] not in removed_names + ] + + # Apply renames. Apply each rename sequentially so that chained + # renames (A→B, B→C) are handled correctly even if they weren't + # collapsed at input time. + rename_map = {r.old_name: r.new_name for r in changes.rename_fields} + for r in changes.rename_fields: + for field in working["fields"]: + if field["name"] == r.old_name: + field["name"] = r.new_name + break + + # Apply updates (reflect attribute changes in working schema). + # Resolve update names through the rename map so that updates staged + # before a rename (referencing the old name) still match. + update_map = {} + for u in changes.update_fields: + resolved = rename_map.get(u.name, u.name) + update_map[resolved] = u + for field in working["fields"]: + if field["name"] in update_map: + upd = update_map[field["name"]] + if upd.attrs: + field.setdefault("attrs", {}).update(upd.attrs) + if upd.type: + field["type"] = upd.type + + # Apply adds + for added in changes.add_fields: + working["fields"].append(added) + + # Apply index-level changes (name, prefix) so preview reflects them + if changes.index: + for key, value in changes.index.items(): + working["index"][key] = value + + return working + + def _build_patch( + self, + source_schema: Dict[str, Any], + existing_changes: Optional[SchemaPatchChanges] = None, + ) -> SchemaPatch: + if existing_changes: + changes = existing_changes + else: + changes = SchemaPatchChanges() + done = False + while not done: + # Refresh working schema to reflect staged changes + working_schema = self._apply_staged_changes(source_schema, changes) + + print("\nChoose an action:") + print("1. Add field (text, tag, numeric, geo)") + print("2. Update field (sortable, weight, separator, vector config)") + print("3. Remove field") + print("4. Rename field (rename field in all documents)") + print("5. Rename index (change index name)") + print("6. Change prefix (rename all keys)") + print("7. Preview patch (show pending changes as YAML)") + print("8. Finish") + action = input("Enter a number: ").strip() + + if action == "1": + field = self._prompt_add_field(working_schema) + if field: + staged_names = {f["name"] for f in changes.add_fields} + if field["name"] in staged_names: + print( + f"Field '{field['name']}' is already staged for addition." + ) + else: + changes.add_fields.append(field) + elif action == "2": + # Filter out staged additions from update candidates + staged_add_names = {f["name"] for f in changes.add_fields} + update_schema = self._filter_staged_adds( + working_schema, staged_add_names + ) + update = self._prompt_update_field(update_schema) + if update: + # Merge with existing update for same field if present + existing = next( + (u for u in changes.update_fields if u.name == update.name), + None, + ) + if existing: + if update.attrs: + existing.attrs = {**(existing.attrs or {}), **update.attrs} + if update.type: + existing.type = update.type + else: + changes.update_fields.append(update) + elif action == "3": + field_name = self._prompt_remove_field(working_schema) + if field_name: + # If removing a staged-add, cancel the add instead of + # appending to remove_fields + staged_add_names = {f["name"] for f in changes.add_fields} + if field_name in staged_add_names: + changes.add_fields = [ + f for f in changes.add_fields if f["name"] != field_name + ] + print(f"Cancelled staged addition of '{field_name}'.") + else: + changes.remove_fields.append(field_name) + # Also remove any queued updates or renames for this field. + # Check both old_name and new_name so that: + # - renames FROM this field are dropped (old_name match) + # - renames TO this field are dropped (new_name match) + # Also drop updates referencing either the field itself or + # any pre-rename name that mapped to it. + rename_aliases = {field_name} + for r in changes.rename_fields: + if r.new_name == field_name: + rename_aliases.add(r.old_name) + if r.old_name == field_name: + rename_aliases.add(r.new_name) + changes.update_fields = [ + u + for u in changes.update_fields + if u.name not in rename_aliases + ] + changes.rename_fields = [ + r + for r in changes.rename_fields + if r.old_name != field_name and r.new_name != field_name + ] + elif action == "4": + # Filter out staged additions from rename candidates + staged_add_names = {f["name"] for f in changes.add_fields} + rename_schema = self._filter_staged_adds( + working_schema, staged_add_names + ) + field_rename = self._prompt_rename_field(rename_schema) + if field_rename: + # Check rename target doesn't collide with staged additions + if field_rename.new_name in staged_add_names: + print( + f"Cannot rename to '{field_rename.new_name}': " + "a field with that name is already staged for addition." + ) + else: + # Collapse chained renames: if there's an existing + # rename X→Y and the user now renames Y→Z, collapse + # into a single X→Z rename. + collapsed = False + for ridx, prev_rename in enumerate(changes.rename_fields): + if prev_rename.new_name == field_rename.old_name: + changes.rename_fields[ridx] = FieldRename( + old_name=prev_rename.old_name, + new_name=field_rename.new_name, + ) + collapsed = True + break + if not collapsed: + changes.rename_fields.append(field_rename) + elif action == "5": + new_name = self._prompt_rename_index(working_schema) + if new_name: + changes.index["name"] = new_name + elif action == "6": + new_prefix = self._prompt_change_prefix(working_schema) + if new_prefix: + changes.index["prefix"] = new_prefix + elif action == "7": + print( + yaml.safe_dump( + { + "version": 1, + "changes": changes.model_dump(exclude_none=True), + }, + sort_keys=False, + ) + ) + elif action == "8": + done = True + else: + print("Invalid action. Please choose 1-8.") + + return SchemaPatch(version=1, changes=changes) + + def _prompt_add_field( + self, source_schema: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + field_name = input("Field name: ").strip() + existing_names = {field["name"] for field in source_schema["fields"]} + if not field_name: + print("Field name is required.") + return None + if field_name in existing_names: + print(f"Field '{field_name}' already exists in the source schema.") + return None + + field_type = self._prompt_from_choices( + "Field type", + SUPPORTED_FIELD_TYPES, + block_message="Vector fields cannot be added (requires embedding all documents). Only text, tag, numeric, and geo are supported.", + ) + if not field_type: + return None + + field: Dict[str, Any] = {"name": field_name, "type": field_type} + storage_type = source_schema["index"]["storage_type"] + if storage_type == "json": + print(" JSON path: location in document where this field is stored") + path = ( + input(f"JSON path [default $.{field_name}]: ").strip() + or f"$.{field_name}" + ) + field["path"] = path + + attrs = self._prompt_common_attrs(field_type) + if attrs: + field["attrs"] = attrs + return field + + def _prompt_update_field( + self, source_schema: Dict[str, Any] + ) -> Optional[FieldUpdate]: + fields = [ + field + for field in source_schema["fields"] + if field["type"] in UPDATABLE_FIELD_TYPES + ] + if not fields: + print("No updatable fields are available.") + return None + + print("Updatable fields:") + for position, field in enumerate(fields, start=1): + print(f"{position}. {field['name']} ({field['type']})") + + choice = input("Select a field to update by number or name: ").strip() + selected: Optional[Dict[str, Any]] = None + for position, field in enumerate(fields, start=1): + if choice == str(position) or choice == field["name"]: + selected = field + break + if not selected: + print("Invalid field selection.") + return None + + if selected["type"] == "vector": + attrs = self._prompt_vector_attrs(selected) + else: + attrs = self._prompt_common_attrs( + selected["type"], + allow_blank=True, + existing_attrs=selected.get("attrs"), + ) + if not attrs: + print("No changes collected.") + return None + return FieldUpdate(name=selected["name"], attrs=attrs) + + def _prompt_remove_field(self, source_schema: Dict[str, Any]) -> Optional[str]: + removable_fields = [field["name"] for field in source_schema["fields"]] + if not removable_fields: + print("No fields available to remove.") + return None + + print("Removable fields:") + for position, field in enumerate(source_schema["fields"], start=1): + field_type = field["type"] + warning = " [WARNING: vector field]" if field_type == "vector" else "" + print(f"{position}. {field['name']} ({field_type}){warning}") + + choice = input("Select a field to remove by number or name: ").strip() + selected_name: Optional[str] = None + if choice in removable_fields: + selected_name = choice + elif choice.isdigit(): + offset = int(choice) - 1 + if 0 <= offset < len(removable_fields): + selected_name = removable_fields[offset] + + if not selected_name: + print("Invalid field selection.") + return None + + # Check if it's a vector field and require confirmation + selected_field = next( + (f for f in source_schema["fields"] if f["name"] == selected_name), None + ) + if selected_field and selected_field["type"] == "vector": + print( + f"\n WARNING: Removing vector field '{selected_name}' will:\n" + " - Remove it from the search index\n" + " - Leave vector data in documents (wasted storage)\n" + " - Require re-embedding if you want to restore it later" + ) + confirm = input("Type 'yes' to confirm removal: ").strip().lower() + if confirm != "yes": + print("Cancelled.") + return None + + return selected_name + + def _prompt_rename_field( + self, source_schema: Dict[str, Any] + ) -> Optional[FieldRename]: + """Prompt user to rename a field in all documents.""" + fields = source_schema["fields"] + if not fields: + print("No fields available to rename.") + return None + + print("Fields available for renaming:") + for position, field in enumerate(fields, start=1): + print(f"{position}. {field['name']} ({field['type']})") + + choice = input("Select a field to rename by number or name: ").strip() + selected: Optional[Dict[str, Any]] = None + for position, field in enumerate(fields, start=1): + if choice == str(position) or choice == field["name"]: + selected = field + break + if not selected: + print("Invalid field selection.") + return None + + old_name = selected["name"] + print(f"Renaming field '{old_name}'") + print( + " Warning: This will modify all documents to rename the field. " + "This is an expensive operation for large datasets." + ) + new_name = input("New field name: ").strip() + if not new_name: + print("New field name is required.") + return None + if new_name == old_name: + print("New name is the same as the old name.") + return None + + existing_names = {f["name"] for f in fields} + if new_name in existing_names: + print(f"Field '{new_name}' already exists.") + return None + + return FieldRename(old_name=old_name, new_name=new_name) + + def _prompt_rename_index(self, source_schema: Dict[str, Any]) -> Optional[str]: + """Prompt user to rename the index.""" + current_name = source_schema["index"]["name"] + print(f"Current index name: {current_name}") + print( + " Note: This only changes the index name. " + "Documents and keys are unchanged." + ) + new_name = input("New index name: ").strip() + if not new_name: + print("New index name is required.") + return None + if new_name == current_name: + print("New name is the same as the current name.") + return None + return new_name + + def _prompt_change_prefix(self, source_schema: Dict[str, Any]) -> Optional[str]: + """Prompt user to change the key prefix.""" + current_prefix = source_schema["index"]["prefix"] + print(f"Current prefix: {current_prefix}") + print( + " Warning: This will RENAME all keys from the old prefix to the new prefix. " + "This is an expensive operation for large datasets." + ) + new_prefix = input("New prefix: ").strip() + if not new_prefix: + print("New prefix is required.") + return None + if new_prefix == current_prefix: + print("New prefix is the same as the current prefix.") + return None + return new_prefix + + def _prompt_common_attrs( + self, + field_type: str, + allow_blank: bool = False, + existing_attrs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + attrs: Dict[str, Any] = {} + + # Sortable - available for all non-vector types + print(" Sortable: enables sorting and aggregation on this field") + sortable = self._prompt_bool("Sortable", allow_blank=allow_blank) + if sortable is not None: + attrs["sortable"] = sortable + + # Index missing - available for all types (requires Redis Search 2.10+) + print( + " Index missing: enables ismissing() queries for documents without this field" + ) + index_missing = self._prompt_bool("Index missing", allow_blank=allow_blank) + if index_missing is not None: + attrs["index_missing"] = index_missing + + # Index empty - index documents where field value is empty string + print( + " Index empty: enables isempty() queries for documents with empty string values" + ) + index_empty = self._prompt_bool("Index empty", allow_blank=allow_blank) + if index_empty is not None: + attrs["index_empty"] = index_empty + + # Track whether the field was already sortable so that type-specific + # prompt helpers (text UNF, numeric UNF) can offer dependent prompts + # even when the user leaves sortable blank during an update. + self._existing_sortable = (existing_attrs or {}).get("sortable", False) + + # Type-specific attributes + if field_type == "text": + self._prompt_text_attrs(attrs, allow_blank) + elif field_type == "tag": + self._prompt_tag_attrs(attrs, allow_blank) + elif field_type == "numeric": + self._prompt_numeric_attrs(attrs, allow_blank, sortable) + + # No index - only meaningful with sortable. + # When updating (allow_blank), also check the existing field's sortable + # state so we offer dependent prompts even if the user left sortable blank. + # But if sortable was explicitly set to False, skip dependent prompts. + _existing_sortable = self._existing_sortable + if sortable or ( + sortable is None + and allow_blank + and (_existing_sortable or attrs.get("sortable")) + ): + print(" No index: store field for sorting only, not searchable") + no_index = self._prompt_bool("No index", allow_blank=allow_blank) + if no_index is not None: + attrs["no_index"] = no_index + + # When explicitly disabling sortable on a previously-sortable field, + # clear sortable-dependent attributes that are no longer meaningful. + # UNF and no_index are only used with sortable; leaving them set would + # be confusing even though Redis technically allows it. + if sortable is False and _existing_sortable: + if "unf" not in attrs: + attrs["unf"] = False + if "no_index" not in attrs: + attrs["no_index"] = False + + return attrs + + def _prompt_text_attrs(self, attrs: Dict[str, Any], allow_blank: bool) -> None: + """Prompt for text field specific attributes.""" + # No stem + print( + " Disable stemming: prevents word variations (running/runs) from matching" + ) + no_stem = self._prompt_bool("Disable stemming", allow_blank=allow_blank) + if no_stem is not None: + attrs["no_stem"] = no_stem + + # Weight + print(" Weight: relevance multiplier for full-text search (default: 1.0)") + weight_input = input("Weight [leave blank for default]: ").strip() + if weight_input: + try: + weight = float(weight_input) + if weight > 0: + attrs["weight"] = weight + else: + print("Weight must be positive.") + except ValueError: + print("Invalid weight value.") + + # Phonetic matcher + print( + " Phonetic matcher: enables phonetic matching (e.g., 'dm:en' for Metaphone)" + ) + phonetic = input("Phonetic matcher [leave blank for none]: ").strip() + if phonetic: + attrs["phonetic_matcher"] = phonetic + + # UNF (only if sortable – skip if sortable was explicitly set to False) + if attrs.get("sortable") or ( + attrs.get("sortable") is not False and self._existing_sortable + ): + print(" UNF: preserve original form (no lowercasing) for sorting") + unf = self._prompt_bool("UNF (un-normalized form)", allow_blank=allow_blank) + if unf is not None: + attrs["unf"] = unf + + def _prompt_tag_attrs(self, attrs: Dict[str, Any], allow_blank: bool) -> None: + """Prompt for tag field specific attributes.""" + # Separator + print(" Separator: character that splits multiple values (default: comma)") + separator = input("Separator [leave blank to keep existing/default]: ").strip() + if separator: + attrs["separator"] = separator + + # Case sensitive + print(" Case sensitive: match tags with exact case (default: false)") + case_sensitive = self._prompt_bool("Case sensitive", allow_blank=allow_blank) + if case_sensitive is not None: + attrs["case_sensitive"] = case_sensitive + + def _prompt_numeric_attrs( + self, attrs: Dict[str, Any], allow_blank: bool, sortable: Optional[bool] + ) -> None: + """Prompt for numeric field specific attributes.""" + # UNF (only if sortable – skip if sortable was explicitly set to False) + if sortable or ( + sortable is not False and (attrs.get("sortable") or self._existing_sortable) + ): + print(" UNF: preserve exact numeric representation for sorting") + unf = self._prompt_bool("UNF (un-normalized form)", allow_blank=allow_blank) + if unf is not None: + attrs["unf"] = unf + + def _prompt_vector_attrs(self, field: Dict[str, Any]) -> Dict[str, Any]: + attrs: Dict[str, Any] = {} + current = field.get("attrs", {}) + field_name = field["name"] + + print(f"Current vector config for '{field_name}':") + current_algo = current.get("algorithm", "hnsw").upper() + print(f" algorithm: {current_algo}") + print(f" datatype: {current.get('datatype', 'float32')}") + print(f" distance_metric: {current.get('distance_metric', 'cosine')}") + print(f" dims: {current.get('dims')} (cannot be changed)") + if current_algo == "HNSW": + print(f" m: {current.get('m', 16)}") + print(f" ef_construction: {current.get('ef_construction', 200)}") + + print("\nLeave blank to keep current value.") + + # Algorithm + print( + " Algorithm: vector search method (FLAT=brute force, HNSW=graph, SVS-VAMANA=compressed graph)" + ) + algo = ( + input(f"Algorithm [current: {current_algo}]: ") + .strip() + .upper() + .replace("_", "-") # Normalize SVS_VAMANA to SVS-VAMANA + ) + if algo and algo in ("FLAT", "HNSW", "SVS-VAMANA") and algo != current_algo: + attrs["algorithm"] = algo + + # Datatype (quantization) - show algorithm-specific options + effective_algo = attrs.get("algorithm", current_algo) + valid_datatypes: tuple[str, ...] + if effective_algo == "SVS-VAMANA": + # SVS-VAMANA only supports float16, float32 + print( + " Datatype for SVS-VAMANA: float16, float32 " + "(float16 reduces memory by ~50%)" + ) + valid_datatypes = ("float16", "float32") + else: + # FLAT/HNSW support: float16, float32, bfloat16, float64, int8, uint8 + print( + " Datatype: float16, float32, bfloat16, float64, int8, uint8\n" + " (float16 reduces memory ~50%, int8/uint8 reduce ~75%)" + ) + valid_datatypes = ( + "float16", + "float32", + "bfloat16", + "float64", + "int8", + "uint8", + ) + current_datatype = current.get("datatype", "float32") + # If switching to SVS-VAMANA and current datatype is incompatible, + # require the user to pick a valid one. + force_datatype = ( + effective_algo == "SVS-VAMANA" and current_datatype not in valid_datatypes + ) + if force_datatype: + print( + f" Current datatype '{current_datatype}' is not compatible with SVS-VAMANA. " + "You must select a valid datatype." + ) + datatype = input(f"Datatype [current: {current_datatype}]: ").strip().lower() + if datatype and datatype in valid_datatypes: + attrs["datatype"] = datatype + elif force_datatype: + # Default to float32 when user skips but current dtype is incompatible + print(" Defaulting to float32 for SVS-VAMANA compatibility.") + attrs["datatype"] = "float32" + + # Distance metric + print(" Distance metric: how similarity is measured (cosine, l2, ip)") + metric = ( + input( + f"Distance metric [current: {current.get('distance_metric', 'cosine')}]: " + ) + .strip() + .lower() + ) + if metric and metric in ("cosine", "l2", "ip"): + attrs["distance_metric"] = metric + + # Algorithm-specific params (effective_algo already computed above) + if effective_algo == "HNSW": + print( + " M: number of connections per node (higher=better recall, more memory)" + ) + m_input = input(f"M [current: {current.get('m', 16)}]: ").strip() + if m_input and m_input.isdigit(): + attrs["m"] = int(m_input) + + print( + " EF_CONSTRUCTION: build-time search depth (higher=better recall, slower build)" + ) + ef_input = input( + f"EF_CONSTRUCTION [current: {current.get('ef_construction', 200)}]: " + ).strip() + if ef_input and ef_input.isdigit(): + attrs["ef_construction"] = int(ef_input) + + print( + " EF_RUNTIME: query-time search depth (higher=better recall, slower queries)" + ) + ef_runtime_input = input( + f"EF_RUNTIME [current: {current.get('ef_runtime', 10)}]: " + ).strip() + if ef_runtime_input and ef_runtime_input.isdigit(): + ef_runtime_val = int(ef_runtime_input) + if ef_runtime_val > 0: + attrs["ef_runtime"] = ef_runtime_val + + print( + " EPSILON: relative factor for range queries (0.0-1.0, lower=more accurate)" + ) + epsilon_input = input( + f"EPSILON [current: {current.get('epsilon', 0.01)}]: " + ).strip() + if epsilon_input: + try: + epsilon_val = float(epsilon_input) + if 0.0 <= epsilon_val <= 1.0: + attrs["epsilon"] = epsilon_val + else: + print(" Epsilon must be between 0.0 and 1.0, ignoring.") + except ValueError: + print(" Invalid epsilon value, ignoring.") + + elif effective_algo == "SVS-VAMANA": + print( + " GRAPH_MAX_DEGREE: max edges per node (higher=better recall, more memory)" + ) + gmd_input = input( + f"GRAPH_MAX_DEGREE [current: {current.get('graph_max_degree', 40)}]: " + ).strip() + if gmd_input and gmd_input.isdigit(): + attrs["graph_max_degree"] = int(gmd_input) + + print(" COMPRESSION: optional vector compression for memory savings") + print(" Options: LVQ4, LVQ8, LVQ4x4, LVQ4x8, LeanVec4x8, LeanVec8x8") + print( + " Note: LVQ/LeanVec optimizations require Intel hardware with AVX-512" + ) + compression_input = ( + input("COMPRESSION [leave blank for none]: ").strip().upper() + ) + # Map input to correct enum case (CompressionType expects exact case) + compression_map = { + "LVQ4": "LVQ4", + "LVQ8": "LVQ8", + "LVQ4X4": "LVQ4x4", + "LVQ4X8": "LVQ4x8", + "LEANVEC4X8": "LeanVec4x8", + "LEANVEC8X8": "LeanVec8x8", + } + compression = compression_map.get(compression_input) + if compression: + attrs["compression"] = compression + + # Prompt for REDUCE if LeanVec compression is selected + if compression.startswith("LeanVec"): + dims = current.get("dims", 0) + recommended = dims // 2 if dims > 0 else None + print( + f" REDUCE: dimensionality reduction for LeanVec (must be < {dims})" + ) + if recommended: + print( + f" Recommended: {recommended} (dims/2 for balanced performance)" + ) + reduce_input = input(f"REDUCE [leave blank to skip]: ").strip() + if reduce_input and reduce_input.isdigit(): + reduce_val = int(reduce_input) + if reduce_val > 0 and reduce_val < dims: + attrs["reduce"] = reduce_val + else: + print( + f" Invalid: reduce must be > 0 and < {dims}, ignoring." + ) + + return attrs + + def _prompt_bool(self, label: str, allow_blank: bool = False) -> Optional[bool]: + suffix = " [y/n]" if not allow_blank else " [y/n/skip]" + while True: + value = input(f"{label}{suffix}: ").strip().lower() + if value in ("y", "yes"): + return True + if value in ("n", "no"): + return False + if allow_blank and value in ("", "skip", "s"): + return None + if not allow_blank and value == "": + return False + hint = "y, n, or skip" if allow_blank else "y or n" + print(f"Please answer {hint}.") + + def _prompt_from_choices( + self, + label: str, + choices: List[str], + *, + block_message: str, + ) -> Optional[str]: + print(f"{label} options: {', '.join(choices)}") + value = input(f"{label}: ").strip().lower() + if value not in choices: + print(block_message) + return None + return value + + def _print_source_schema(self, schema_dict: Dict[str, Any]) -> None: + print("Current schema:") + print(f"- Index name: {schema_dict['index']['name']}") + print(f"- Storage type: {schema_dict['index']['storage_type']}") + for field in schema_dict["fields"]: + path = field.get("path") + suffix = f" path={path}" if path else "" + print(f" - {field['name']} ({field['type']}){suffix}") diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 6c435c547..cf5a38193 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -304,6 +304,19 @@ def parse_vector_attrs(attrs): # Default to float32 if missing normalized["datatype"] = "float32" + # Handle HNSW-specific parameters + if "m" in vector_attrs: + try: + normalized["m"] = int(vector_attrs["m"]) + except (ValueError, TypeError): + pass + + if "ef_construction" in vector_attrs: + try: + normalized["ef_construction"] = int(vector_attrs["ef_construction"]) + except (ValueError, TypeError): + pass + # Handle SVS-VAMANA specific parameters # Compression - Redis uses different internal names, so we need to map them if "compression" in vector_attrs: diff --git a/scripts/test_crash_resume_e2e.py b/scripts/test_crash_resume_e2e.py new file mode 100644 index 000000000..8c84f184d --- /dev/null +++ b/scripts/test_crash_resume_e2e.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +"""E2E crash-resume test: 10,000 docs, float32→float16, 4 simulated crashes. + +Strategy: + - Single-worker with backup_dir for deterministic checkpoint tracking + - Monkey-patch pipeline_write_vectors to raise after N batches + - 10,000 docs / batch_size=500 = 20 batches total + - Crash at batches: 5 (25%), 10 (50%), 15 (75%), 18 (90%) + - Each resume verifies partial progress, then continues + - Final resume completes and verifies all 10,000 docs are float16 +""" +import json +import os +import shutil +import sys +import tempfile +import time + +import numpy as np +import redis +import yaml + +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379") +INDEX_NAME = "e2e_crash_test_idx" +PREFIX = "e2e_crash:" +NUM_DOCS = 10_000 +DIMS = 128 +BATCH_SIZE = 500 +TOTAL_BATCHES = NUM_DOCS // BATCH_SIZE # 20 + +# Crash after these many TOTAL batches have been quantized +CRASH_AFTER_BATCHES = [3, 7, 11, 16, 19] + + +def log(msg: str): + print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) + + +def cleanup_index(r): + try: + r.execute_command("FT.DROPINDEX", INDEX_NAME) + except Exception: + pass + keys = list(r.scan_iter(match=f"{PREFIX}*", count=1000)) + while keys: + r.delete(*keys[:500]) + keys = keys[500:] + if not keys: + keys = list(r.scan_iter(match=f"{PREFIX}*", count=1000)) + + +def create_index_and_load(r): + log(f"Creating index '{INDEX_NAME}' with {NUM_DOCS:,} docs ({DIMS}-dim float32)...") + r.execute_command( + "FT.CREATE", INDEX_NAME, "ON", "HASH", "PREFIX", "1", PREFIX, + "SCHEMA", "title", "TEXT", + "embedding", "VECTOR", "FLAT", "6", + "TYPE", "FLOAT32", "DIM", str(DIMS), "DISTANCE_METRIC", "COSINE", + ) + pipe = r.pipeline(transaction=False) + for i in range(NUM_DOCS): + vec = np.random.randn(DIMS).astype(np.float32).tobytes() + pipe.hset(f"{PREFIX}{i}", mapping={"title": f"Doc {i}", "embedding": vec}) + if (i + 1) % 500 == 0: + pipe.execute() + pipe = r.pipeline(transaction=False) + pipe.execute() + # Wait for indexing + for _ in range(60): + info = r.execute_command("FT.INFO", INDEX_NAME) + info_dict = dict(zip(info[::2], info[1::2])) + num_indexed = int(info_dict.get(b"num_docs", info_dict.get("num_docs", 0))) + if num_indexed >= NUM_DOCS: + break + time.sleep(0.5) + log(f"Index ready: {num_indexed:,} docs indexed") + return num_indexed + + +def verify_vectors(r, expected_bytes, label=""): + """Count docs by vector size. Returns (correct_count, wrong_count).""" + pipe = r.pipeline(transaction=False) + for i in range(NUM_DOCS): + pipe.hget(f"{PREFIX}{i}", "embedding") + results = pipe.execute() + correct = sum(1 for d in results if d and len(d) == expected_bytes) + wrong = NUM_DOCS - correct + if label: + log(f" {label}: {correct:,} correct ({expected_bytes}B), {wrong:,} other") + return correct, wrong + + +def count_quantized_docs(r, float16_bytes=256, float32_bytes=512): + """Count how many docs are already float16 vs float32.""" + pipe = r.pipeline(transaction=False) + for i in range(NUM_DOCS): + pipe.hget(f"{PREFIX}{i}", "embedding") + results = pipe.execute() + f16 = sum(1 for d in results if d and len(d) == float16_bytes) + f32 = sum(1 for d in results if d and len(d) == float32_bytes) + return f16, f32 + + +def make_plan(backup_dir): + from redisvl.migration.planner import MigrationPlanner + schema_patch = { + "version": 1, + "changes": { + "update_fields": [{ + "name": "embedding", + "attrs": {"algorithm": "flat", "datatype": "float16", "distance_metric": "cosine"}, + }] + }, + } + patch_path = os.path.join(backup_dir, "schema_patch.yaml") + with open(patch_path, "w") as f: + yaml.dump(schema_patch, f) + planner = MigrationPlanner() + plan = planner.create_plan(index_name=INDEX_NAME, redis_url=REDIS_URL, schema_patch_path=patch_path) + # Save plan for resume + plan_path = os.path.join(backup_dir, "plan.yaml") + with open(plan_path, "w") as f: + yaml.dump(plan.model_dump(), f, sort_keys=False) + return plan, plan_path + + +class SimulatedCrash(Exception): + """Raised to simulate a process crash during quantization.""" + pass + + + +def run_attempt(plan, backup_dir, crash_after=None, attempt_num=0): + """Run apply(). If crash_after is set, crash after that many total quantize batches. + + Uses direct monkey-patching of the module attribute to ensure the + executor's local `from ... import` picks up the patched version. + """ + from redisvl.migration.executor import MigrationExecutor + import redisvl.migration.quantize as quantize_mod + + original_write = quantize_mod.pipeline_write_vectors + executor = MigrationExecutor() + events = [] + + def progress_cb(step, detail=None): + msg = f" [{step}] {detail}" if detail else f" [{step}]" + events.append(msg) + log(msg) + + if crash_after is not None: + # Read backup to see how many batches already done + from redisvl.migration.backup import VectorBackup + safe = INDEX_NAME.replace("/", "_").replace("\\", "_").replace(":", "_") + bp = str(os.path.join(backup_dir, f"migration_backup_{safe}")) + existing = VectorBackup.load(bp) + already_done = existing.header.quantize_completed_batches if existing else 0 + new_batches_allowed = crash_after - already_done + call_counter = [0] + + log(f" [attempt {attempt_num}] Crash after {crash_after} total batches " + f"({already_done} already done, {new_batches_allowed} new allowed)") + + def crashing_write(client, converted): + call_counter[0] += 1 + if call_counter[0] > new_batches_allowed: + raise SimulatedCrash( + f"💥 Simulated crash at write call {call_counter[0]} " + f"(allowed {new_batches_allowed})!" + ) + return original_write(client, converted) + + # Monkey-patch at module level + quantize_mod.pipeline_write_vectors = crashing_write + try: + report = executor.apply( + plan, redis_url=REDIS_URL, progress_callback=progress_cb, + backup_dir=backup_dir, batch_size=BATCH_SIZE, + num_workers=1, keep_backup=True, + ) + finally: + quantize_mod.pipeline_write_vectors = original_write + log(f" [attempt {attempt_num}] Write calls made: {call_counter[0]}") + return report, events + else: + log(f" [attempt {attempt_num}] Final run — no crash limit") + report = executor.apply( + plan, redis_url=REDIS_URL, progress_callback=progress_cb, + backup_dir=backup_dir, batch_size=BATCH_SIZE, + num_workers=1, keep_backup=True, + ) + return report, events + + +def inspect_backup(backup_dir): + """Read backup header and report state.""" + from redisvl.migration.backup import VectorBackup + safe = INDEX_NAME.replace("/", "_").replace("\\", "_").replace(":", "_") + bp = str(os.path.join(backup_dir, f"migration_backup_{safe}")) + backup = VectorBackup.load(bp) + if backup: + h = backup.header + log(f" Backup: phase={h.phase}, dump_batches={h.dump_completed_batches}, " + f"quantize_batches={h.quantize_completed_batches}") + return h + else: + log(" Backup: not found") + return None + + +def main(): + log("=" * 70) + log(f"CRASH-RESUME E2E: {NUM_DOCS:,} docs, {DIMS}d, float32→float16") + log(f" Batch size: {BATCH_SIZE}, Total batches: {TOTAL_BATCHES}") + log(f" Crash points: {CRASH_AFTER_BATCHES} batches") + log("=" * 70) + + r = redis.from_url(REDIS_URL) + cleanup_index(r) + + num_docs = create_index_and_load(r) + assert num_docs >= NUM_DOCS, f"Only {num_docs} indexed!" + + correct, _ = verify_vectors(r, 512, "Pre-migration float32") + assert correct == NUM_DOCS + + backup_dir = tempfile.mkdtemp(prefix="crash_resume_backup_") + log(f"\nBackup dir: {backup_dir}") + + plan, plan_path = make_plan(backup_dir) + log(f"Plan: mode={plan.mode}, changes detected: " + f"{len(plan.requested_changes.get('changes', {}).get('update_fields', []))}") + + try: + # ── CRASH 1-4: Simulate crashes during quantization ── + for crash_num, crash_at in enumerate(CRASH_AFTER_BATCHES): + log(f"\n{'─'*60}") + log(f"CRASH {crash_num + 1}/{len(CRASH_AFTER_BATCHES)}: " + f"Crashing after batch {crash_at}/{TOTAL_BATCHES} " + f"({crash_at * BATCH_SIZE:,} docs)") + log(f"{'─'*60}") + + report, events = run_attempt( + plan, backup_dir, crash_after=crash_at, attempt_num=crash_num + 1 + ) + log(f" Result: {report.result}") + + # Verify backup state + header = inspect_backup(backup_dir) + assert header is not None, "Backup should exist after crash!" + assert header.quantize_completed_batches == crash_at, ( + f"Expected {crash_at} batches quantized, got {header.quantize_completed_batches}" + ) + assert header.phase in ("active", "ready"), ( + f"Expected phase 'active' or 'ready', got '{header.phase}'" + ) + + # Verify partial progress: some docs should be float16 + f16, f32 = count_quantized_docs(r) + expected_f16 = crash_at * BATCH_SIZE + log(f" Partial progress: {f16:,} float16, {f32:,} float32") + assert f16 == expected_f16, ( + f"Expected {expected_f16} float16 docs, got {f16}" + ) + assert f32 == NUM_DOCS - expected_f16, ( + f"Expected {NUM_DOCS - expected_f16} float32 docs, got {f32}" + ) + log(f" ✅ Crash {crash_num + 1} verified: {f16:,} quantized, " + f"{f32:,} remaining") + + # ── FINAL RESUME: Complete the migration ── + log(f"\n{'─'*60}") + log(f"FINAL RESUME: Completing remaining " + f"{TOTAL_BATCHES - CRASH_AFTER_BATCHES[-1]} batches") + log(f"{'─'*60}") + + report, events = run_attempt(plan, backup_dir, crash_after=None, attempt_num=5) + log(f" Result: {report.result}") + assert report.result == "succeeded", f"Final resume failed: {report.result}" + + # Verify ALL docs are float16 + correct, wrong = verify_vectors(r, 256, "Post-migration float16") + assert correct == NUM_DOCS, f"Only {correct}/{NUM_DOCS} docs are float16!" + assert wrong == 0 + + log(f"\n✅ ALL {NUM_DOCS:,} docs verified as float16!") + + # Verify backup is completed + header = inspect_backup(backup_dir) + assert header is not None + assert header.phase == "completed" + assert header.quantize_completed_batches == TOTAL_BATCHES + + log(f"\n{'='*70}") + log("RESULTS") + log(f"{'='*70}") + log(f" {NUM_DOCS:,} docs migrated float32→float16") + log(f" Crashes simulated: {len(CRASH_AFTER_BATCHES)}") + for i, cb in enumerate(CRASH_AFTER_BATCHES): + log(f" Crash {i+1}: after batch {cb}/{TOTAL_BATCHES} " + f"({cb*BATCH_SIZE:,}/{NUM_DOCS:,} docs)") + log(f" Final resume completed remaining {TOTAL_BATCHES - CRASH_AFTER_BATCHES[-1]} batches") + log(f" All {NUM_DOCS:,} vectors verified ✅") + log(f"{'='*70}") + + finally: + cleanup_index(r) + shutil.rmtree(backup_dir, ignore_errors=True) + r.close() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/scripts/test_migration_e2e.py b/scripts/test_migration_e2e.py new file mode 100644 index 000000000..e78c82ea6 --- /dev/null +++ b/scripts/test_migration_e2e.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +"""End-to-end migration benchmark: realistic KM index, HNSW float32 → FLAT float16. + +Mirrors a real production knowledge-management index with 16 fields: + tags, text, numeric, and a high-dimensional HNSW vector. + +Usage: + python scripts/test_migration_e2e.py # defaults + NUM_DOCS=50000 python scripts/test_migration_e2e.py # override doc count +""" +import glob +import os +import random +import shutil +import string +import struct +import sys +import tempfile +import time +import uuid + +import numpy as np +import redis +import yaml + +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379") +INDEX_NAME = "KM_benchmark_idx" +PREFIX = "KM:benchmark:" +NUM_DOCS = int(os.environ.get("NUM_DOCS", 10_000)) +DIMS = int(os.environ.get("DIMS", 1536)) +NUM_WORKERS = int(os.environ.get("NUM_WORKERS", 4)) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 500)) + + +def log(msg: str): + print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) + + +def cleanup_index(r): + try: + r.execute_command("FT.DROPINDEX", INDEX_NAME) + except Exception: + pass + # Batched delete for large key counts + deleted = 0 + while True: + keys = list(r.scan_iter(match=f"{PREFIX}*", count=5000)) + if not keys: + break + pipe = r.pipeline(transaction=False) + for k in keys: + pipe.delete(k) + pipe.execute() + deleted += len(keys) + if deleted % 50000 == 0: + log(f" cleanup: {deleted:,} keys deleted...") + if deleted: + log(f" cleanup: {deleted:,} keys deleted total") + + +SAMPLE_NAMES = ["Q4 Earnings Report", "Investment Memo", "Risk Assessment", + "Portfolio Summary", "Market Analysis", "Due Diligence", "Credit Review", + "Bond Prospectus", "Fund Factsheet", "Regulatory Filing"] +SAMPLE_AUTHORS = ["alice@corp.com", "bob@corp.com", "carol@corp.com", "dave@corp.com"] + + +def _random_text(n=200): + words = ["the", "fund", "portfolio", "return", "risk", "asset", "bond", + "equity", "market", "yield", "rate", "credit", "cash", "flow", + "price", "value", "growth", "income", "dividend", "capital"] + return " ".join(random.choice(words) for _ in range(n)) + + +def create_index_and_load(r): + log(f"Creating {NUM_DOCS:,} docs ({DIMS}-dim float32, 16 fields)...") + log(f" Step 1: Load data into Redis (no index yet for max speed)...") + + load_start = time.perf_counter() + + # Pre-generate reusable data to avoid per-doc overhead + doc_ids = [str(uuid.uuid4()) for _ in range(max(1, NUM_DOCS // 50))] + file_ids = [str(uuid.uuid4()) for _ in range(max(1, NUM_DOCS // 10))] + text_pool = [_random_text(200) for _ in range(100)] + desc_pool = [_random_text(50) for _ in range(50)] + cusip_pool = [ + f"{random.randint(0,999999):06d}{random.choice(string.ascii_uppercase)}" + f"{random.choice(string.ascii_uppercase)}{random.randint(0,9)}" + for _ in range(200) + ] + now_base = int(time.time()) + + # Stream vectors in small batches — never hold more than LOAD_BATCH in memory + LOAD_BATCH = 1000 + insert_start = time.perf_counter() + pipe = r.pipeline(transaction=False) + + for batch_start in range(0, NUM_DOCS, LOAD_BATCH): + batch_end = min(batch_start + LOAD_BATCH, NUM_DOCS) + batch_size = batch_end - batch_start + vecs = np.random.randn(batch_size, DIMS).astype(np.float32) + + for j in range(batch_size): + i = batch_start + j + mapping = { + "doc_base_id": doc_ids[i % len(doc_ids)], + "file_id": file_ids[i % len(file_ids)], + "page_text": text_pool[i % len(text_pool)], + "chunk_number": i % 50, + "start_page": (i % 50) + 1, + "end_page": (i % 50) + 2, + "created_by": SAMPLE_AUTHORS[i % len(SAMPLE_AUTHORS)], + "file_name": f"{SAMPLE_NAMES[i % len(SAMPLE_NAMES)]}_{i}.pdf", + "created_time": now_base - (i * 31), + "last_updated_by": SAMPLE_AUTHORS[(i + 1) % len(SAMPLE_AUTHORS)], + "last_updated_time": now_base - (i * 31) + 3600, + "embedding": vecs[j].tobytes(), + } + if i % 3 == 0: + mapping["CUSIP"] = cusip_pool[i % len(cusip_pool)] + mapping["description"] = desc_pool[i % len(desc_pool)] + mapping["name"] = SAMPLE_NAMES[i % len(SAMPLE_NAMES)] + mapping["price"] = round(10.0 + (i % 49000) * 0.01, 2) + pipe.hset(f"{PREFIX}{i}", mapping=mapping) + + pipe.execute() + pipe = r.pipeline(transaction=False) + + if batch_end % 10_000 == 0: + elapsed_so_far = time.perf_counter() - insert_start + rate = batch_end / elapsed_so_far + eta = (NUM_DOCS - batch_end) / rate if rate > 0 else 0 + log(f" inserted {batch_end:,}/{NUM_DOCS:,} docs " + f"({rate:,.0f}/s, ETA {eta:.0f}s)...") + pipe.execute() + load_elapsed = time.perf_counter() - insert_start + log(f" Data inserted in {load_elapsed:.1f}s " + f"({NUM_DOCS/load_elapsed:,.0f} docs/s)") + + # Step 2: Create HNSW index on existing data (background indexing) + log(f" Step 2: Creating HNSW index (background indexing {NUM_DOCS:,} docs)...") + r.execute_command( + "FT.CREATE", INDEX_NAME, "ON", "HASH", "PREFIX", "1", PREFIX, + "SCHEMA", + "doc_base_id", "TAG", "SEPARATOR", ",", + "file_id", "TAG", "SEPARATOR", ",", + "page_text", "TEXT", "WEIGHT", "1", + "chunk_number", "NUMERIC", + "start_page", "NUMERIC", + "end_page", "NUMERIC", + "created_by", "TAG", "SEPARATOR", ",", + "file_name", "TEXT", "WEIGHT", "1", + "created_time", "NUMERIC", + "last_updated_by", "TEXT", "WEIGHT", "1", + "last_updated_time", "NUMERIC", + "embedding", "VECTOR", "HNSW", "10", + "TYPE", "FLOAT32", "DIM", str(DIMS), + "DISTANCE_METRIC", "COSINE", "M", "16", "EF_CONSTRUCTION", "200", + "CUSIP", "TAG", "SEPARATOR", ",", "INDEXMISSING", + "description", "TEXT", "WEIGHT", "1", "INDEXMISSING", + "name", "TEXT", "WEIGHT", "1", "INDEXMISSING", + "price", "NUMERIC", "INDEXMISSING", + ) + + # Wait for HNSW indexing + idx_start = time.perf_counter() + for attempt in range(7200): + info = r.execute_command("FT.INFO", INDEX_NAME) + info_dict = dict(zip(info[::2], info[1::2])) + num_indexed = int(info_dict.get(b"num_docs", info_dict.get("num_docs", 0))) + pct = float(info_dict.get(b"percent_indexed", + info_dict.get("percent_indexed", "0"))) + if pct >= 1.0: + break + if attempt % 15 == 0: + elapsed_idx = time.perf_counter() - idx_start + log(f" indexing: {num_indexed:,}/{NUM_DOCS:,} docs " + f"({pct*100:.1f}%, {elapsed_idx:.0f}s elapsed)...") + time.sleep(1) + idx_elapsed = time.perf_counter() - idx_start + log(f" Index ready: {num_indexed:,} docs indexed in {idx_elapsed:.1f}s") + return num_indexed + + +def verify_vectors(r, expected_dtype, bytes_per_element, sample_size=10000): + expected_bytes = bytes_per_element * DIMS + check_count = min(NUM_DOCS, sample_size) + log(f"Verifying {expected_dtype} vectors (sampling {check_count:,}/{NUM_DOCS:,})...") + errors = 0 + # Sample evenly across the key space + step = max(1, NUM_DOCS // check_count) + indices = list(range(0, NUM_DOCS, step))[:check_count] + pipe = r.pipeline(transaction=False) + for i in indices: + pipe.hget(f"{PREFIX}{i}", "embedding") + results = pipe.execute() + for idx, data in zip(indices, results): + if data is None: + errors += 1 + elif len(data) != expected_bytes: + if errors < 5: + log(f" ERROR: doc {idx}: {len(data)} bytes, expected {expected_bytes}") + errors += 1 + if errors == 0: + log(f" ✅ All {check_count:,} sampled docs correct ({expected_bytes} bytes each)") + else: + log(f" ❌ {errors}/{check_count:,} docs have incorrect vectors!") + return errors + + +def run_migration(backup_dir): + from redisvl.migration.executor import MigrationExecutor + from redisvl.migration.planner import MigrationPlanner + + schema_patch = { + "version": 1, + "changes": { + "update_fields": [ + { + "name": "embedding", + "attrs": { + "algorithm": "flat", + "datatype": "float16", + "distance_metric": "cosine", + }, + } + ] + }, + } + patch_path = os.path.join(backup_dir, "schema_patch.yaml") + with open(patch_path, "w") as f: + yaml.dump(schema_patch, f) + + log("Planning migration: float32 → float16...") + planner = MigrationPlanner() + plan = planner.create_plan(index_name=INDEX_NAME, redis_url=REDIS_URL, schema_patch_path=patch_path) + log(f"Plan: mode={plan.mode}") + log(f" Changes: {plan.requested_changes}") + log(f" Supported: {plan.diff_classification.supported}") + + executor = MigrationExecutor() + phase_times = {} # step -> [start, end] + current_phase = [None] + + def progress_cb(step, detail=None): + now = time.perf_counter() + # Track phase transitions + if step != current_phase[0]: + if current_phase[0] and current_phase[0] in phase_times: + phase_times[current_phase[0]][1] = now + if step not in phase_times: + phase_times[step] = [now, now] + current_phase[0] = step + else: + phase_times[step][1] = now + msg = f" [{step}] {detail}" if detail else f" [{step}]" + log(msg) + + log(f"\nApplying: {NUM_WORKERS} workers, batch_size={BATCH_SIZE}...") + started = time.perf_counter() + report = executor.apply( + plan, redis_url=REDIS_URL, progress_callback=progress_cb, + backup_dir=backup_dir, batch_size=BATCH_SIZE, + num_workers=NUM_WORKERS, keep_backup=True, + ) + elapsed = time.perf_counter() - started + # Close last phase + if current_phase[0] and current_phase[0] in phase_times: + phase_times[current_phase[0]][1] = started + elapsed + + log(f"\nMigration completed in {elapsed:.3f}s") + log(f" Result: {report.result}") + if report.validation: + log(f" Schema match: {report.validation.schema_match}") + log(f" Doc count: {report.validation.doc_count_match}") + return report, phase_times, elapsed + + +def main(): + global NUM_DOCS + log("=" * 60) + log(f"E2E Migration Test: {NUM_DOCS:,} docs, {DIMS}d, float32→float16, {NUM_WORKERS} workers") + log("=" * 60) + r = redis.from_url(REDIS_URL) + cleanup_index(r) + + num_docs = create_index_and_load(r) + if num_docs < NUM_DOCS: + log(f" ⚠️ Only {num_docs:,}/{NUM_DOCS:,} docs indexed " + f"(HNSW memory limit). Benchmarking with {num_docs:,}.") + NUM_DOCS = num_docs + + errors = verify_vectors(r, "float32", 4) + assert errors == 0 + + backup_dir = tempfile.mkdtemp(prefix="migration_backup_") + log(f"\nBackup dir: {backup_dir}") + + try: + report, phase_times, elapsed = run_migration(backup_dir) + # When switching HNSW→FLAT, FLAT may index MORE docs than HNSW could + # (HNSW has memory overhead that limits capacity). Treat this as success. + if report.result == "failed" and report.validation: + if report.validation.schema_match and not report.validation.doc_count_match: + log("\n⚠️ Doc count mismatch (expected with HNSW→FLAT: " + "FLAT indexes all docs HNSW couldn't fit).") + log(" Treating as success — schema matched, all data preserved.") + else: + assert False, f"FAILED: {report.result} — {report.validation}" + elif report.result != "succeeded": + assert False, f"FAILED: {report.result}" + log("\n✅ Migration completed!") + + errors = verify_vectors(r, "float16", 2) + assert errors == 0, "Float16 verification failed!" + + # Cleanup backup + from redisvl.migration.executor import MigrationExecutor + executor = MigrationExecutor() + safe = INDEX_NAME.replace("/", "_").replace("\\", "_").replace(":", "_") + pattern = os.path.join(backup_dir, f"migration_backup_{safe}*") + backup_files = glob.glob(pattern) + total_backup_mb = sum(os.path.getsize(f) for f in backup_files) / (1024 * 1024) + executor._cleanup_backup_files(backup_dir, INDEX_NAME) + + # ── Benchmark results ── + data_mb = (NUM_DOCS * DIMS * 4) / (1024 * 1024) + + log("\n" + "=" * 74) + log(" MIGRATION BENCHMARK") + log("=" * 74) + log(f" Schema: HNSW float32 → FLAT float16") + log(f" Documents: {NUM_DOCS:,}") + log(f" Dimensions: {DIMS}") + log(f" Workers: {NUM_WORKERS}") + log(f" Batch size: {BATCH_SIZE:,}") + log(f" Vector data: {data_mb:,.1f} MB → {data_mb/2:,.1f} MB " + f"({data_mb/2:,.1f} MB saved)") + log(f" Backup size: {total_backup_mb:,.1f} MB ({len(backup_files)} files)") + log("") + log(" Phase breakdown:") + log(f" {'Phase':<16} {'Time':>10} {'Docs/sec':>12} Notes") + log(f" {'─'*16} {'─'*10} {'─'*12} {'─'*25}") + for phase in ["enumerate", "dump", "drop", "quantize", "create", "index", "validate"]: + if phase in phase_times: + dt = phase_times[phase][1] - phase_times[phase][0] + dps = f"{NUM_DOCS / dt:,.0f}" if dt > 0.001 else "—" + notes = "" + if phase == "quantize": + notes = f"read+convert+write ({NUM_WORKERS} workers)" + elif phase == "dump": + notes = f"pipeline read → backup file" + elif phase == "index": + notes = "Redis FLAT re-index" + elif phase == "enumerate": + notes = "FT.SEARCH scan" + log(f" {phase:<16} {dt:>9.3f}s {dps:>12} {notes}") + log(f" {'─'*16} {'─'*10}") + log(f" {'TOTAL':<16} {elapsed:>9.3f}s " + f"{NUM_DOCS / elapsed:>11,.0f}/s") + log("") + + # Quantize-only throughput (the work we actually do) + if "quantize" in phase_times: + qt = phase_times["quantize"][1] - phase_times["quantize"][0] + log(f" ⚡ Quantize throughput: {NUM_DOCS/qt:,.0f} docs/sec " + f"({data_mb/qt:,.1f} MB/sec) [{qt:.3f}s]") + log(f" ✅ All {NUM_DOCS:,} vectors verified as float16") + log("=" * 74) + + finally: + cleanup_index(r) + shutil.rmtree(backup_dir, ignore_errors=True) + r.close() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/verify_data_correctness.py b/scripts/verify_data_correctness.py new file mode 100644 index 000000000..13e432ac5 --- /dev/null +++ b/scripts/verify_data_correctness.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +"""Verify migration actually produces correct float16 conversions of original float32 data.""" +import shutil +import tempfile +import time + +import numpy as np +import redis + +DIMS = 256 +N = 1000 +PREFIX = "verify_test:" + +r = redis.from_url("redis://localhost:6379") + +# 1. Create known vectors +print(f"Creating {N} docs with known float32 vectors ({DIMS}d)...") +original_vectors = {} +pipe = r.pipeline(transaction=False) +for i in range(N): + vec = np.random.randn(DIMS).astype(np.float32) + original_vectors[i] = vec.copy() + pipe.hset(f"{PREFIX}{i}", mapping={"text": f"doc {i}", "embedding": vec.tobytes()}) +pipe.execute() + +# 2. Create HNSW index +r.execute_command( + "FT.CREATE", "verify_idx", "ON", "HASH", "PREFIX", "1", PREFIX, + "SCHEMA", "text", "TEXT", + "embedding", "VECTOR", "HNSW", "10", + "TYPE", "FLOAT32", "DIM", str(DIMS), + "DISTANCE_METRIC", "COSINE", "M", "16", "EF_CONSTRUCTION", "200", +) +time.sleep(3) + +# 3. Verify float32 stored correctly +pipe = r.pipeline(transaction=False) +for i in range(N): + pipe.hget(f"{PREFIX}{i}", "embedding") +pre = pipe.execute() +f32_ok = all(np.array_equal(np.frombuffer(pre[i], dtype=np.float32), original_vectors[i]) for i in range(N)) +print(f"Float32 pre-migration: {'PASS' if f32_ok else 'FAIL'}") + +# 4. Run migration +print("\nRunning migration: HNSW float32 -> FLAT float16...") +import os +import yaml +from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.executor import MigrationExecutor + +backup_dir = tempfile.mkdtemp() +schema_patch = { + "version": 1, + "changes": { + "update_fields": [{ + "name": "embedding", + "attrs": {"algorithm": "flat", "datatype": "float16", "distance_metric": "cosine"}, + }], + }, +} +patch_path = os.path.join(backup_dir, "patch.yaml") +with open(patch_path, "w") as f: + yaml.dump(schema_patch, f) + +plan = MigrationPlanner().create_plan( + index_name="verify_idx", redis_url="redis://localhost:6379", + schema_patch_path=patch_path, +) +report = MigrationExecutor().apply( + plan, redis_url="redis://localhost:6379", + backup_dir=backup_dir, batch_size=200, num_workers=1, +) +print(f" Result: {report.result}") +print(f" Doc count: {report.validation.doc_count_match}") +print(f" Schema: {report.validation.schema_match}") + +# 5. THE REAL CHECK +print("\n=== DATA CORRECTNESS CHECK ===") +pipe = r.pipeline(transaction=False) +for i in range(N): + pipe.hget(f"{PREFIX}{i}", "embedding") +post = pipe.execute() + +missing = wrong_size = value_errors = 0 +max_abs = 0.0 +total_abs = 0.0 +max_rel = 0.0 + +for i in range(N): + data = post[i] + if data is None: + missing += 1 + continue + if len(data) != DIMS * 2: + wrong_size += 1 + continue + + actual_f16 = np.frombuffer(data, dtype=np.float16) + expected_f16 = original_vectors[i].astype(np.float16) + + if not np.array_equal(actual_f16, expected_f16): + value_errors += 1 + if value_errors <= 3: + diff = np.abs(actual_f16.astype(np.float32) - expected_f16.astype(np.float32)) + print(f" doc {i}: max_diff={diff.max():.8f}") + print(f" expected[:5] = {expected_f16[:5]}") + print(f" actual[:5] = {actual_f16[:5]}") + + abs_err = np.abs(actual_f16.astype(np.float32) - original_vectors[i]) + max_abs = max(max_abs, abs_err.max()) + total_abs += abs_err.mean() + + nz = np.abs(original_vectors[i]) > 1e-10 + if nz.any(): + rel = abs_err[nz] / np.abs(original_vectors[i][nz]) + max_rel = max(max_rel, rel.max()) + +print(f"\nMissing docs: {missing}") +print(f"Wrong size: {wrong_size}") +print(f"Value mismatches: {value_errors} (actual != expected float16)") +print(f"Max abs error: {max_abs:.8f} (vs original float32)") +print(f"Avg abs error: {total_abs/N:.8f}") +print(f"Max relative error: {max_rel:.6f} ({max_rel*100:.4f}%)") + +if missing == 0 and wrong_size == 0 and value_errors == 0: + print("\n✅ ALL DATA CORRECT: every vector is the exact float16 conversion of its original float32") +else: + print(f"\n❌ ISSUES FOUND") + +# Cleanup +try: + r.execute_command("FT.DROPINDEX", "verify_idx") +except Exception: + pass +pipe = r.pipeline(transaction=False) +for i in range(N): + pipe.delete(f"{PREFIX}{i}") +pipe.execute() +shutil.rmtree(backup_dir, ignore_errors=True) +r.close() diff --git a/tests/benchmarks/index_migrator_real_benchmark.py b/tests/benchmarks/index_migrator_real_benchmark.py new file mode 100644 index 000000000..c2a28bd1a --- /dev/null +++ b/tests/benchmarks/index_migrator_real_benchmark.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +import argparse +import csv +import json +import statistics +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence + +import numpy as np +import yaml +from datasets import load_dataset +from redis import Redis +from sentence_transformers import SentenceTransformer + +from redisvl.index import SearchIndex +from redisvl.migration import MigrationPlanner +from redisvl.query import VectorQuery +from redisvl.redis.utils import array_to_buffer + +AG_NEWS_LABELS = { + 0: "world", + 1: "sports", + 2: "business", + 3: "sci_tech", +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run a real local benchmark for migrating from HNSW/FP32 to FLAT/FP16 " + "with a real internet dataset and sentence-transformers embeddings." + ) + ) + parser.add_argument( + "--redis-url", + default="redis://localhost:6379", + help="Redis URL for the local benchmark target.", + ) + parser.add_argument( + "--sizes", + nargs="+", + type=int, + default=[1000, 10000, 100000], + help="Dataset sizes to benchmark.", + ) + parser.add_argument( + "--query-count", + type=int, + default=25, + help="Number of held-out query documents to benchmark search latency.", + ) + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Number of nearest neighbors to fetch for overlap checks.", + ) + parser.add_argument( + "--embedding-batch-size", + type=int, + default=256, + help="Batch size for sentence-transformers encoding.", + ) + parser.add_argument( + "--load-batch-size", + type=int, + default=500, + help="Batch size for SearchIndex.load calls.", + ) + parser.add_argument( + "--model", + default="sentence-transformers/all-MiniLM-L6-v2", + help="Sentence-transformers model name.", + ) + parser.add_argument( + "--dataset-csv", + default="", + help=( + "Optional path to a local AG News CSV file with label,title,description columns. " + "If provided, the benchmark skips Hugging Face dataset downloads." + ), + ) + parser.add_argument( + "--output", + default="index_migrator_benchmark_results.json", + help="Where to write the benchmark report.", + ) + return parser.parse_args() + + +def build_schema( + *, + index_name: str, + prefix: str, + dims: int, + algorithm: str, + datatype: str, +) -> Dict[str, Any]: + return { + "index": { + "name": index_name, + "prefix": prefix, + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "label", "type": "tag"}, + {"name": "text", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "dims": dims, + "distance_metric": "cosine", + "algorithm": algorithm, + "datatype": datatype, + }, + }, + ], + } + + +def load_ag_news_records(num_docs: int, query_count: int) -> List[Dict[str, Any]]: + dataset = load_dataset("ag_news", split=f"train[:{num_docs + query_count}]") + records: List[Dict[str, Any]] = [] + for idx, row in enumerate(dataset): + records.append( + { + "doc_id": f"ag-news-{idx}", + "text": row["text"], + "label": AG_NEWS_LABELS[int(row["label"])], + } + ) + return records + + +def load_ag_news_records_from_csv( + csv_path: str, + *, + required_docs: int, +) -> List[Dict[str, Any]]: + records: List[Dict[str, Any]] = [] + with open(csv_path, "r", newline="", encoding="utf-8") as f: + reader = csv.reader(f) + for idx, row in enumerate(reader): + if len(row) < 3: + continue + # Skip header row if present (label column should be a digit) + if idx == 0 and not row[0].strip().isdigit(): + continue + if len(records) >= required_docs: + break + label, title, description = row + text = f"{title}. {description}".strip() + records.append( + { + "doc_id": f"ag-news-{len(records)}", + "text": text, + "label": AG_NEWS_LABELS[int(label) - 1], + } + ) + + if len(records) < required_docs: + raise ValueError( + f"Expected at least {required_docs} records in {csv_path}, found {len(records)}" + ) + return records + + +def encode_texts( + model_name: str, + texts: Sequence[str], + batch_size: int, +) -> tuple[np.ndarray, float]: + try: + encoder = SentenceTransformer(model_name, local_files_only=True) + except OSError: + # Model not cached locally yet; download it + print(f"Model '{model_name}' not found locally, downloading...") + encoder = SentenceTransformer(model_name) + start = time.perf_counter() + embeddings = encoder.encode( + list(texts), + batch_size=batch_size, + show_progress_bar=True, + convert_to_numpy=True, + normalize_embeddings=True, + ) + duration = time.perf_counter() - start + return np.asarray(embeddings, dtype=np.float32), duration + + +def iter_documents( + records: Sequence[Dict[str, Any]], + embeddings: np.ndarray, + *, + dtype: str, +) -> Iterable[Dict[str, Any]]: + for record, embedding in zip(records, embeddings): + yield { + "doc_id": record["doc_id"], + "label": record["label"], + "text": record["text"], + "embedding": array_to_buffer(embedding, dtype), + } + + +def wait_for_index_ready( + index: SearchIndex, + *, + timeout_seconds: int = 1800, + poll_interval_seconds: float = 0.5, +) -> Dict[str, Any]: + deadline = time.perf_counter() + timeout_seconds + latest_info = index.info() + while time.perf_counter() < deadline: + latest_info = index.info() + percent_indexed = float(latest_info.get("percent_indexed", 1)) + indexing = latest_info.get("indexing", 0) + if percent_indexed >= 1.0 and not indexing: + return latest_info + time.sleep(poll_interval_seconds) + raise TimeoutError( + f"Index {index.schema.index.name} did not finish indexing within {timeout_seconds} seconds" + ) + + +def get_memory_snapshot(client: Redis) -> Dict[str, Any]: + info = client.info("memory") + used_memory_bytes = int(info.get("used_memory", 0)) + return { + "used_memory_bytes": used_memory_bytes, + "used_memory_mb": round(used_memory_bytes / (1024 * 1024), 3), + "used_memory_human": info.get("used_memory_human"), + } + + +def summarize_index_info(index_info: Dict[str, Any]) -> Dict[str, Any]: + return { + "num_docs": int(index_info.get("num_docs", 0) or 0), + "percent_indexed": float(index_info.get("percent_indexed", 0) or 0), + "hash_indexing_failures": int(index_info.get("hash_indexing_failures", 0) or 0), + "vector_index_sz_mb": float(index_info.get("vector_index_sz_mb", 0) or 0), + "total_indexing_time": float(index_info.get("total_indexing_time", 0) or 0), + } + + +def percentile(values: Sequence[float], pct: float) -> float: + if not values: + return 0.0 + return round(float(np.percentile(np.asarray(values), pct)), 3) + + +def run_query_benchmark( + index: SearchIndex, + query_embeddings: np.ndarray, + *, + dtype: str, + top_k: int, +) -> Dict[str, Any]: + latencies_ms: List[float] = [] + result_sets: List[List[str]] = [] + + for query_embedding in query_embeddings: + query = VectorQuery( + vector=query_embedding.tolist(), + vector_field_name="embedding", + return_fields=["doc_id", "label"], + num_results=top_k, + dtype=dtype, + ) + start = time.perf_counter() + results = index.query(query) + latencies_ms.append((time.perf_counter() - start) * 1000) + result_sets.append( + [result.get("doc_id") or result.get("id") for result in results if result] + ) + + return { + "count": len(latencies_ms), + "p50_ms": percentile(latencies_ms, 50), + "p95_ms": percentile(latencies_ms, 95), + "p99_ms": percentile(latencies_ms, 99), + "mean_ms": round(statistics.mean(latencies_ms), 3), + "result_sets": result_sets, + } + + +def compute_overlap( + source_result_sets: Sequence[Sequence[str]], + target_result_sets: Sequence[Sequence[str]], + *, + top_k: int, +) -> Dict[str, Any]: + overlap_ratios: List[float] = [] + for source_results, target_results in zip(source_result_sets, target_result_sets): + source_set = set(source_results[:top_k]) + target_set = set(target_results[:top_k]) + overlap_ratios.append(len(source_set.intersection(target_set)) / max(top_k, 1)) + return { + "mean_overlap_at_k": round(statistics.mean(overlap_ratios), 4), + "min_overlap_at_k": round(min(overlap_ratios), 4), + "max_overlap_at_k": round(max(overlap_ratios), 4), + } + + +def run_quantization_migration( + planner: MigrationPlanner, + client: Redis, + source_index_name: str, + source_schema: Dict[str, Any], + dims: int, +) -> Dict[str, Any]: + """Run full HNSW/FP32 -> FLAT/FP16 migration with quantization.""" + from redisvl.migration import MigrationExecutor + + target_schema = build_schema( + index_name=source_schema["index"]["name"], + prefix=source_schema["index"]["prefix"], + dims=dims, + algorithm="flat", # Change algorithm + datatype="float16", # Change datatype (quantization) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + target_schema_path = Path(tmpdir) / "target_schema.yaml" + plan_path = Path(tmpdir) / "migration_plan.yaml" + with open(target_schema_path, "w") as f: + yaml.safe_dump(target_schema, f, sort_keys=False) + + plan_start = time.perf_counter() + plan = planner.create_plan( + source_index_name, + redis_client=client, + target_schema_path=str(target_schema_path), + ) + planner.write_plan(plan, str(plan_path)) + plan_duration = time.perf_counter() - plan_start + + if not plan.diff_classification.supported: + raise AssertionError( + f"Expected planner to ALLOW quantization migration, " + f"but it blocked with: {plan.diff_classification.blocked_reasons}" + ) + + # Check datatype changes detected + datatype_changes = MigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, plan.merged_target_schema + ) + + # Execute migration + executor = MigrationExecutor() + migrate_start = time.perf_counter() + report = executor.apply(plan, redis_client=client) + migrate_duration = time.perf_counter() - migrate_start + + if report.result != "succeeded": + raise AssertionError(f"Migration failed: {report.validation.errors}") + + return { + "test": "quantization_migration", + "plan_duration_seconds": round(plan_duration, 3), + "migration_duration_seconds": round(migrate_duration, 3), + "quantize_duration_seconds": report.timings.quantize_duration_seconds, + "supported": plan.diff_classification.supported, + "datatype_changes": datatype_changes, + "result": report.result, + } + + +def assert_planner_allows_algorithm_change( + planner: MigrationPlanner, + client: Redis, + source_index_name: str, + source_schema: Dict[str, Any], + dims: int, +) -> Dict[str, Any]: + """Test that algorithm-only changes (HNSW -> FLAT) are allowed.""" + target_schema = build_schema( + index_name=source_schema["index"]["name"], + prefix=source_schema["index"]["prefix"], + dims=dims, + algorithm="flat", # Different algorithm - should be allowed + datatype="float32", # Same datatype + ) + + with tempfile.TemporaryDirectory() as tmpdir: + target_schema_path = Path(tmpdir) / "target_schema.yaml" + plan_path = Path(tmpdir) / "migration_plan.yaml" + with open(target_schema_path, "w") as f: + yaml.safe_dump(target_schema, f, sort_keys=False) + + start = time.perf_counter() + plan = planner.create_plan( + source_index_name, + redis_client=client, + target_schema_path=str(target_schema_path), + ) + planner.write_plan(plan, str(plan_path)) + duration = time.perf_counter() - start + + if not plan.diff_classification.supported: + raise AssertionError( + f"Expected planner to ALLOW algorithm change (HNSW -> FLAT), " + f"but it blocked with: {plan.diff_classification.blocked_reasons}" + ) + + return { + "test": "algorithm_change_allowed", + "plan_duration_seconds": round(duration, 3), + "supported": plan.diff_classification.supported, + "blocked_reasons": plan.diff_classification.blocked_reasons, + } + + +def benchmark_scale( + *, + client: Redis, + all_records: Sequence[Dict[str, Any]], + all_embeddings: np.ndarray, + size: int, + query_count: int, + top_k: int, + load_batch_size: int, +) -> Dict[str, Any]: + records = list(all_records[:size]) + query_records = list(all_records[size : size + query_count]) + doc_embeddings = all_embeddings[:size] + query_embeddings = all_embeddings[size : size + query_count] + dims = int(all_embeddings.shape[1]) + + client.flushdb() + + baseline_memory = get_memory_snapshot(client) + planner = MigrationPlanner(key_sample_limit=5) + source_schema = build_schema( + index_name=f"benchmark_source_{size}", + prefix=f"benchmark:source:{size}", + dims=dims, + algorithm="hnsw", + datatype="float32", + ) + + source_index = SearchIndex.from_dict(source_schema, redis_client=client) + migrated_index = None # Will be set after migration + + try: + source_index.create(overwrite=True) + source_load_start = time.perf_counter() + source_index.load( + iter_documents(records, doc_embeddings, dtype="float32"), + id_field="doc_id", + batch_size=load_batch_size, + ) + source_info = wait_for_index_ready(source_index) + source_setup_duration = time.perf_counter() - source_load_start + source_memory = get_memory_snapshot(client) + + # Query source index before migration + source_query_metrics = run_query_benchmark( + source_index, + query_embeddings, + dtype="float32", + top_k=top_k, + ) + + # Run full quantization migration: HNSW/FP32 -> FLAT/FP16 + quantization_result = run_quantization_migration( + planner=planner, + client=client, + source_index_name=source_schema["index"]["name"], + source_schema=source_schema, + dims=dims, + ) + + # Get migrated index info and memory + migrated_index = SearchIndex.from_existing( + source_schema["index"]["name"], redis_client=client + ) + target_info = wait_for_index_ready(migrated_index) + overlap_memory = get_memory_snapshot(client) + + # Query migrated index + target_query_metrics = run_query_benchmark( + migrated_index, + query_embeddings.astype(np.float16), + dtype="float16", + top_k=top_k, + ) + + overlap_metrics = compute_overlap( + source_query_metrics["result_sets"], + target_query_metrics["result_sets"], + top_k=top_k, + ) + + post_cutover_memory = get_memory_snapshot(client) + + return { + "size": size, + "query_count": len(query_records), + "vector_dims": dims, + "source": { + "algorithm": "hnsw", + "datatype": "float32", + "setup_duration_seconds": round(source_setup_duration, 3), + "index_info": summarize_index_info(source_info), + "query_metrics": { + k: v for k, v in source_query_metrics.items() if k != "result_sets" + }, + }, + "migration": { + "quantization": quantization_result, + }, + "target": { + "algorithm": "flat", + "datatype": "float16", + "migration_duration_seconds": quantization_result[ + "migration_duration_seconds" + ], + "quantize_duration_seconds": quantization_result[ + "quantize_duration_seconds" + ], + "index_info": summarize_index_info(target_info), + "query_metrics": { + k: v for k, v in target_query_metrics.items() if k != "result_sets" + }, + }, + "memory": { + "baseline": baseline_memory, + "after_source": source_memory, + "during_overlap": overlap_memory, + "after_cutover": post_cutover_memory, + "overlap_increase_mb": round( + overlap_memory["used_memory_mb"] - source_memory["used_memory_mb"], + 3, + ), + "net_change_after_cutover_mb": round( + post_cutover_memory["used_memory_mb"] + - source_memory["used_memory_mb"], + 3, + ), + }, + "correctness": { + "source_num_docs": int(source_info.get("num_docs", 0) or 0), + "target_num_docs": int(target_info.get("num_docs", 0) or 0), + "doc_count_match": int(source_info.get("num_docs", 0) or 0) + == int(target_info.get("num_docs", 0) or 0), + "migration_succeeded": quantization_result["result"] == "succeeded", + **overlap_metrics, + }, + } + finally: + for idx in (source_index, migrated_index): + try: + if idx is not None: + idx.delete(drop=True) + except Exception: + pass + + +def main() -> None: + args = parse_args() + sizes = sorted(args.sizes) + max_size = max(sizes) + required_docs = max_size + args.query_count + + if args.dataset_csv: + print( + f"Loading AG News CSV from {args.dataset_csv} with {required_docs} records" + ) + records = load_ag_news_records_from_csv( + args.dataset_csv, + required_docs=required_docs, + ) + else: + print(f"Loading AG News dataset with {required_docs} records") + records = load_ag_news_records( + required_docs - args.query_count, + args.query_count, + ) + print(f"Encoding {len(records)} texts with {args.model}") + embeddings, embedding_duration = encode_texts( + args.model, + [record["text"] for record in records], + args.embedding_batch_size, + ) + + client = Redis.from_url(args.redis_url, decode_responses=False) + client.ping() + + report = { + "dataset": "ag_news", + "model": args.model, + "sizes": sizes, + "query_count": args.query_count, + "top_k": args.top_k, + "embedding_duration_seconds": round(embedding_duration, 3), + "results": [], + } + + for size in sizes: + print(f"\nRunning benchmark for {size} documents") + result = benchmark_scale( + client=client, + all_records=records, + all_embeddings=embeddings, + size=size, + query_count=args.query_count, + top_k=args.top_k, + load_batch_size=args.load_batch_size, + ) + report["results"].append(result) + print( + json.dumps( + { + "size": size, + "source_setup_duration_seconds": result["source"][ + "setup_duration_seconds" + ], + "migration_duration_seconds": result["target"][ + "migration_duration_seconds" + ], + "quantize_duration_seconds": result["target"][ + "quantize_duration_seconds" + ], + "migration_succeeded": result["correctness"]["migration_succeeded"], + "mean_overlap_at_k": result["correctness"]["mean_overlap_at_k"], + "memory_change_mb": result["memory"]["net_change_after_cutover_mb"], + }, + indent=2, + ) + ) + + output_path = Path(args.output).resolve() + with open(output_path, "w") as f: + json.dump(report, f, indent=2) + + print(f"\nBenchmark report written to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmarks/migration_benchmark.py b/tests/benchmarks/migration_benchmark.py new file mode 100644 index 000000000..d2ef0a085 --- /dev/null +++ b/tests/benchmarks/migration_benchmark.py @@ -0,0 +1,642 @@ +"""Migration Benchmark: Measure end-to-end migration time at scale. + +Populates a realistic 16-field index (matching the KM production schema) +at 1K, 10K, 100K, and 1M vectors, then migrates: + - Sub-1M: HNSW FP32 -> FLAT FP16 + - 1M: HNSW FP32 -> HNSW FP16 + +Collects full MigrationTimings from MigrationExecutor.apply(). + +Usage: + python tests/benchmarks/migration_benchmark.py \\ + --redis-url redis://localhost:6379 \\ + --sizes 1000 10000 100000 \\ + --trials 3 \\ + --output tests/benchmarks/results_migration.json +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import random +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +from redis import Redis + +from redisvl.index import SearchIndex +from redisvl.migration import ( + AsyncMigrationExecutor, + AsyncMigrationPlanner, + MigrationExecutor, + MigrationPlanner, +) +from redisvl.migration.models import FieldUpdate, SchemaPatch, SchemaPatchChanges +from redisvl.migration.utils import wait_for_index_ready +from redisvl.redis.utils import array_to_buffer + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +VECTOR_DIMS = 3072 +INDEX_PREFIX = "KM:benchmark:" +HNSW_M = 16 +HNSW_EF_CONSTRUCTION = 200 +BATCH_SIZE = 500 + +# Vocabularies for synthetic data +TAG_VOCABS = { + "doc_base_id": [f"base_{i}" for i in range(50)], + "file_id": [f"file_{i:06d}" for i in range(200)], + "created_by": ["alice", "bob", "carol", "dave", "eve"], + "CUSIP": [f"{random.randint(100000000, 999999999)}" for _ in range(100)], +} + +TEXT_WORDS = [ + "financial", + "report", + "quarterly", + "analysis", + "revenue", + "growth", + "market", + "portfolio", + "investment", + "dividend", + "equity", + "bond", + "asset", + "liability", + "balance", + "income", + "statement", + "forecast", + "risk", + "compliance", +] + + +# --------------------------------------------------------------------------- +# Schema helpers +# --------------------------------------------------------------------------- + + +def make_source_schema(index_name: str) -> Dict[str, Any]: + """Build the 16-field HNSW FP32 source schema dict.""" + return { + "index": { + "name": index_name, + "prefix": INDEX_PREFIX, + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_base_id", "type": "tag", "attrs": {"separator": ","}}, + {"name": "file_id", "type": "tag", "attrs": {"separator": ","}}, + {"name": "page_text", "type": "text", "attrs": {"weight": 1}}, + {"name": "chunk_number", "type": "numeric"}, + {"name": "start_page", "type": "numeric"}, + {"name": "end_page", "type": "numeric"}, + {"name": "created_by", "type": "tag", "attrs": {"separator": ","}}, + {"name": "file_name", "type": "text", "attrs": {"weight": 1}}, + {"name": "created_time", "type": "numeric"}, + {"name": "last_updated_by", "type": "text", "attrs": {"weight": 1}}, + {"name": "last_updated_time", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "datatype": "float32", + "dims": VECTOR_DIMS, + "distance_metric": "COSINE", + "m": HNSW_M, + "ef_construction": HNSW_EF_CONSTRUCTION, + }, + }, + { + "name": "CUSIP", + "type": "tag", + "attrs": {"separator": ",", "index_missing": True}, + }, + { + "name": "description", + "type": "text", + "attrs": {"weight": 1, "index_missing": True}, + }, + { + "name": "name", + "type": "text", + "attrs": {"weight": 1, "index_missing": True}, + }, + {"name": "price", "type": "numeric", "attrs": {"index_missing": True}}, + ], + } + + +def make_migration_patch(target_algo: str) -> SchemaPatch: + """Build a SchemaPatch to change embedding from FP32 to FP16 (and optionally HNSW to FLAT).""" + attrs = {"datatype": "float16"} + if target_algo == "FLAT": + attrs["algorithm"] = "flat" + return SchemaPatch( + version=1, + changes=SchemaPatchChanges( + update_fields=[ + FieldUpdate(name="embedding", attrs=attrs), + ] + ), + ) + + +# --------------------------------------------------------------------------- +# Data generation +# --------------------------------------------------------------------------- + + +def generate_random_text(min_words: int = 10, max_words: int = 50) -> str: + """Generate a random sentence from the vocabulary.""" + n = random.randint(min_words, max_words) + return " ".join(random.choice(TEXT_WORDS) for _ in range(n)) + + +def generate_document(doc_id: int, vector: np.ndarray) -> Dict[str, Any]: + """Generate a single document with all 16 fields.""" + doc: Dict[str, Any] = { + "doc_base_id": random.choice(TAG_VOCABS["doc_base_id"]), + "file_id": random.choice(TAG_VOCABS["file_id"]), + "page_text": generate_random_text(), + "chunk_number": random.randint(0, 100), + "start_page": random.randint(1, 500), + "end_page": random.randint(1, 500), + "created_by": random.choice(TAG_VOCABS["created_by"]), + "file_name": f"document_{doc_id}.pdf", + "created_time": int(time.time()) - random.randint(0, 86400 * 365), + "last_updated_by": random.choice(TAG_VOCABS["created_by"]), + "last_updated_time": int(time.time()) - random.randint(0, 86400 * 30), + "embedding": array_to_buffer(vector, dtype="float32"), + } + # INDEXMISSING fields: populate ~80% of docs + if random.random() < 0.8: + doc["CUSIP"] = random.choice(TAG_VOCABS["CUSIP"]) + if random.random() < 0.8: + doc["description"] = generate_random_text(5, 20) + if random.random() < 0.8: + doc["name"] = f"Entity {doc_id}" + if random.random() < 0.8: + doc["price"] = round(random.uniform(1.0, 10000.0), 2) + return doc + + +# --------------------------------------------------------------------------- +# Population +# --------------------------------------------------------------------------- + + +def populate_index( + redis_url: str, + index_name: str, + num_docs: int, +) -> float: + """Create the source index and populate it with synthetic data. + + Returns the time taken in seconds. + """ + schema_dict = make_source_schema(index_name) + index = SearchIndex.from_dict(schema_dict, redis_url=redis_url) + + # Drop existing index if any + try: + index.delete(drop=True) + except Exception: + pass + + # Clean up any leftover keys from previous runs + client = Redis.from_url(redis_url) + cursor = 0 + while True: + cursor, keys = client.scan(cursor, match=f"{INDEX_PREFIX}*", count=5000) + if keys: + client.delete(*keys) + if cursor == 0: + break + client.close() + + index.create(overwrite=True) + + print(f" Populating {num_docs:,} documents...") + start = time.perf_counter() + + # Generate vectors in batches to manage memory + rng = np.random.default_rng(seed=42) + client = Redis.from_url(redis_url) + + for batch_start in range(0, num_docs, BATCH_SIZE): + batch_end = min(batch_start + BATCH_SIZE, num_docs) + batch_count = batch_end - batch_start + + # Generate batch of random unit-normalized vectors + vectors = rng.standard_normal((batch_count, VECTOR_DIMS)).astype(np.float32) + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + vectors = vectors / norms + + pipe = client.pipeline(transaction=False) + for i in range(batch_count): + doc_id = batch_start + i + key = f"{INDEX_PREFIX}{doc_id}" + doc = generate_document(doc_id, vectors[i]) + pipe.hset(key, mapping=doc) + + pipe.execute() + + if (batch_end % 10000 == 0) or batch_end == num_docs: + elapsed = time.perf_counter() - start + rate = batch_end / elapsed if elapsed > 0 else 0 + print(f" {batch_end:,}/{num_docs:,} docs ({rate:,.0f} docs/sec)") + + populate_duration = time.perf_counter() - start + client.close() + + # Wait for indexing to complete + print(" Waiting for index to be ready...") + idx = SearchIndex.from_existing(index_name, redis_url=redis_url) + _, indexing_wait = wait_for_index_ready(idx) + print( + f" Index ready (waited {indexing_wait:.1f}s after {populate_duration:.1f}s populate)" + ) + + return populate_duration + indexing_wait + + +# --------------------------------------------------------------------------- +# Migration execution +# --------------------------------------------------------------------------- + + +def run_migration( + redis_url: str, + index_name: str, + target_algo: str, +) -> Dict[str, Any]: + """Run a single migration and return the full report as a dict. + + Returns a dict with 'report' (model_dump) and 'enumerate_method' + indicating whether FT.AGGREGATE or SCAN was used for key discovery. + """ + import logging + + patch = make_migration_patch(target_algo) + planner = MigrationPlanner() + plan = planner.create_plan_from_patch( + index_name, + schema_patch=patch, + redis_url=redis_url, + ) + + if not plan.diff_classification.supported: + raise RuntimeError( + f"Migration not supported: {plan.diff_classification.blocked_reasons}" + ) + + executor = MigrationExecutor() + + # Capture enumerate method by intercepting executor logger warnings + enumerate_method = "FT.AGGREGATE" # default (happy path) + _orig_logger = logging.getLogger("redisvl.migration.executor") + _orig_level = _orig_logger.level + + class _EnumMethodHandler(logging.Handler): + def emit(self, record): + nonlocal enumerate_method + msg = record.getMessage() + if "Using SCAN" in msg or "Falling back to SCAN" in msg: + enumerate_method = "SCAN" + + handler = _EnumMethodHandler() + _orig_logger.addHandler(handler) + _orig_logger.setLevel(logging.WARNING) + + def progress(step: str, detail: Optional[str] = None) -> None: + if detail: + print(f" [{step}] {detail}") + + try: + report = executor.apply( + plan, + redis_url=redis_url, + progress_callback=progress, + ) + finally: + _orig_logger.removeHandler(handler) + _orig_logger.setLevel(_orig_level) + + return {"report": report.model_dump(), "enumerate_method": enumerate_method} + + +async def async_run_migration( + redis_url: str, + index_name: str, + target_algo: str, +) -> Dict[str, Any]: + """Run a single migration using AsyncMigrationExecutor. + + Returns a dict with 'report' (model_dump) and 'enumerate_method' + indicating whether FT.AGGREGATE or SCAN was used for key discovery. + """ + import logging + + patch = make_migration_patch(target_algo) + planner = AsyncMigrationPlanner() + plan = await planner.create_plan_from_patch( + index_name, + schema_patch=patch, + redis_url=redis_url, + ) + + if not plan.diff_classification.supported: + raise RuntimeError( + f"Migration not supported: {plan.diff_classification.blocked_reasons}" + ) + + executor = AsyncMigrationExecutor() + + # Capture enumerate method by intercepting executor logger warnings + enumerate_method = "FT.AGGREGATE" # default (happy path) + _orig_logger = logging.getLogger("redisvl.migration.async_executor") + _orig_level = _orig_logger.level + + class _EnumMethodHandler(logging.Handler): + def emit(self, record): + nonlocal enumerate_method + msg = record.getMessage() + if "Using SCAN" in msg or "Falling back to SCAN" in msg: + enumerate_method = "SCAN" + + handler = _EnumMethodHandler() + _orig_logger.addHandler(handler) + _orig_logger.setLevel(logging.WARNING) + + def progress(step: str, detail: Optional[str] = None) -> None: + if detail: + print(f" [{step}] {detail}") + + try: + report = await executor.apply( + plan, + redis_url=redis_url, + progress_callback=progress, + ) + finally: + _orig_logger.removeHandler(handler) + _orig_logger.setLevel(_orig_level) + + return {"report": report.model_dump(), "enumerate_method": enumerate_method} + + +# --------------------------------------------------------------------------- +# Benchmark driver +# --------------------------------------------------------------------------- + + +def run_benchmark( + redis_url: str, + sizes: List[int], + trials: int, + output_path: Optional[str], + use_async: bool = False, +) -> Dict[str, Any]: + """Run the full migration benchmark across all sizes and trials.""" + executor_label = "async" if use_async else "sync" + results: Dict[str, Any] = { + "benchmark": "migration_timing", + "executor": executor_label, + "schema_field_count": 16, + "vector_dims": VECTOR_DIMS, + "trials_per_size": trials, + "results": [], + } + + for size in sizes: + target_algo = "HNSW" if size >= 1_000_000 else "FLAT" + index_name = f"bench_migration_{size}" + print(f"\n{'='*60}") + print( + f"Size: {size:,} | Migration: HNSW FP32 -> {target_algo} FP16 | Executor: {executor_label}" + ) + print(f"{'='*60}") + + size_result = { + "size": size, + "source_algo": "HNSW", + "source_dtype": "FLOAT32", + "target_algo": target_algo, + "target_dtype": "FLOAT16", + "trials": [], + } + + for trial_num in range(1, trials + 1): + print(f"\n Trial {trial_num}/{trials}") + + # Step 1: Populate + populate_time = populate_index(redis_url, index_name, size) + + # Capture source memory + client = Redis.from_url(redis_url) + try: + info_raw = client.execute_command("FT.INFO", index_name) + # Parse the flat list into a dict + info_dict = {} + for i in range(0, len(info_raw), 2): + key = info_raw[i] + if isinstance(key, bytes): + key = key.decode() + info_dict[key] = info_raw[i + 1] + source_mem_mb = float(info_dict.get("vector_index_sz_mb", 0)) + source_total_mb = float(info_dict.get("total_index_memory_sz_mb", 0)) + source_num_docs = int(info_dict.get("num_docs", 0)) + except Exception as e: + print(f" Warning: could not read source FT.INFO: {e}") + source_mem_mb = 0.0 + source_total_mb = 0.0 + source_num_docs = 0 + finally: + client.close() + + print( + f" Source: {source_num_docs:,} docs, " + f"vector_idx={source_mem_mb:.1f}MB, " + f"total_idx={source_total_mb:.1f}MB" + ) + + # Step 2: Migrate + print(f" Running migration ({executor_label})...") + if use_async: + migration_result = asyncio.run( + async_run_migration(redis_url, index_name, target_algo) + ) + else: + migration_result = run_migration(redis_url, index_name, target_algo) + report_dict = migration_result["report"] + enumerate_method = migration_result["enumerate_method"] + + # Capture target memory + target_index_name = report_dict.get("target_index", index_name) + client = Redis.from_url(redis_url) + try: + info_raw = client.execute_command("FT.INFO", target_index_name) + info_dict = {} + for i in range(0, len(info_raw), 2): + key = info_raw[i] + if isinstance(key, bytes): + key = key.decode() + info_dict[key] = info_raw[i + 1] + target_mem_mb = float(info_dict.get("vector_index_sz_mb", 0)) + target_total_mb = float(info_dict.get("total_index_memory_sz_mb", 0)) + except Exception as e: + print(f" Warning: could not read target FT.INFO: {e}") + target_mem_mb = 0.0 + target_total_mb = 0.0 + finally: + client.close() + + timings = report_dict.get("timings", {}) + migrate_s = timings.get("total_migration_duration_seconds", 0) or 0 + total_s = round(populate_time + migrate_s, 3) + + # Vector memory savings (the real savings from FP32 -> FP16) + vec_savings_pct = ( + round((1 - target_mem_mb / source_mem_mb) * 100, 1) + if source_mem_mb > 0 + else 0 + ) + + trial_result = { + "trial": trial_num, + "load_time_seconds": round(populate_time, 3), + "migrate_time_seconds": round(migrate_s, 3), + "total_time_seconds": total_s, + "enumerate_method": enumerate_method, + "timings": timings, + "benchmark_summary": report_dict.get("benchmark_summary", {}), + "source_vector_index_mb": round(source_mem_mb, 3), + "source_total_index_mb": round(source_total_mb, 3), + "target_vector_index_mb": round(target_mem_mb, 3), + "target_total_index_mb": round(target_total_mb, 3), + "vector_memory_savings_pct": vec_savings_pct, + "validation_passed": report_dict.get("result") == "succeeded", + "num_docs": source_num_docs, + } + + # Print isolated timings + _enum_s = timings.get("drop_duration_seconds", 0) or 0 # noqa: F841 + quant_s = timings.get("quantize_duration_seconds") or 0 + index_s = timings.get("initial_indexing_duration_seconds") or 0 + down_s = timings.get("downtime_duration_seconds") or 0 + print( + f""" Results + load = {populate_time:.1f}s + migrate = {migrate_s:.1f}s (enumerate + drop + quantize + create + reindex + validate) + total = {total_s:.1f}s + enumerate = {enumerate_method} + quantize = {quant_s:.1f}s + reindex = {index_s:.1f}s + downtime = {down_s:.1f}s + vec memory = {source_mem_mb:.1f}MB -> {target_mem_mb:.1f}MB ({vec_savings_pct:.1f}% saved) + passed = {trial_result['validation_passed']}""" + ) + + size_result["trials"].append(trial_result) + + # Clean up for next trial (drop index + keys) + client = Redis.from_url(redis_url) + try: + try: + client.execute_command("FT.DROPINDEX", target_index_name) + except Exception: + pass + # Delete document keys + cursor = 0 + while True: + cursor, keys = client.scan( + cursor, match=f"{INDEX_PREFIX}*", count=5000 + ) + if keys: + client.delete(*keys) + if cursor == 0: + break + finally: + client.close() + + results["results"].append(size_result) + + # Save results + if output_path: + output = Path(output_path) + output.parent.mkdir(parents=True, exist_ok=True) + with open(output, "w") as f: + json.dump(results, f, indent=2, default=str) + print(f"\nResults saved to {output}") + + return results + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Migration timing benchmark") + parser.add_argument( + "--redis-url", default="redis://localhost:6379", help="Redis connection URL" + ) + parser.add_argument( + "--sizes", + nargs="+", + type=int, + default=[1000, 10000, 100000], + help="Corpus sizes to benchmark", + ) + parser.add_argument( + "--trials", type=int, default=3, help="Number of trials per size" + ) + parser.add_argument( + "--output", + default="tests/benchmarks/results_migration.json", + help="Output JSON file", + ) + parser.add_argument( + "--async", + dest="use_async", + action="store_true", + default=False, + help="Use AsyncMigrationExecutor instead of sync MigrationExecutor", + ) + args = parser.parse_args() + + executor_label = "AsyncMigrationExecutor" if args.use_async else "MigrationExecutor" + print( + f"""Migration Benchmark + Redis: {args.redis_url} + Sizes: {args.sizes} + Trials: {args.trials} + Vector dims: {VECTOR_DIMS} + Fields: 16 + Executor: {executor_label}""" + ) + + run_benchmark( + redis_url=args.redis_url, + sizes=args.sizes, + trials=args.trials, + output_path=args.output, + use_async=args.use_async, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/benchmarks/retrieval_benchmark.py b/tests/benchmarks/retrieval_benchmark.py new file mode 100644 index 000000000..48b663e91 --- /dev/null +++ b/tests/benchmarks/retrieval_benchmark.py @@ -0,0 +1,680 @@ +"""Retrieval Benchmark: FP32 vs FP16 x HNSW vs FLAT + +Replicates the methodology from the Redis SVS-VAMANA study using +pre-embedded datasets from HuggingFace (no embedding step required). + +Comparison matrix (4 configurations): + - HNSW / FLOAT32 (approximate, full precision) + - HNSW / FLOAT16 (approximate, quantized) + - FLAT / FLOAT32 (exact, full precision -- ground truth) + - FLAT / FLOAT16 (exact, quantized) + +Datasets: + - dbpedia: 1536-dim OpenAI embeddings (KShivendu/dbpedia-entities-openai-1M) + - cohere: 768-dim Cohere embeddings (Cohere/wikipedia-22-12-en-embeddings) + +Metrics: + - Overlap@K (precision vs FLAT/FP32 ground truth) + - Query latency: p50, p95, p99, mean + - QPS (queries per second) + - Memory footprint per configuration + - Index build / load time + +Usage: + python tests/benchmarks/retrieval_benchmark.py \\ + --redis-url redis://localhost:6379 \\ + --dataset dbpedia \\ + --sizes 1000 10000 \\ + --top-k 10 \\ + --query-count 100 \\ + --output retrieval_benchmark_results.json +""" + +from __future__ import annotations + +import argparse +import json +import statistics +import time +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence, Tuple + +import numpy as np +from redis import Redis + +from redisvl.index import SearchIndex +from redisvl.query import VectorQuery +from redisvl.redis.utils import array_to_buffer + +# --------------------------------------------------------------------------- +# Dataset registry +# --------------------------------------------------------------------------- + +DATASETS = { + "dbpedia": { + "hf_name": "KShivendu/dbpedia-entities-openai-1M", + "embedding_column": "openai", + "dims": 1536, + "distance_metric": "cosine", + "description": "DBpedia entities, OpenAI text-embedding-ada-002, 1536d", + }, + "cohere": { + "hf_name": "Cohere/wikipedia-22-12-en-embeddings", + "embedding_column": "emb", + "dims": 768, + "distance_metric": "cosine", + "description": "Wikipedia EN, Cohere multilingual encoder, 768d", + }, + "random768": { + "hf_name": None, + "embedding_column": None, + "dims": 768, + "distance_metric": "cosine", + "description": "Synthetic random unit vectors, 768d (Cohere-scale proxy)", + }, +} + +# Index configurations to benchmark +INDEX_CONFIGS = [ + {"algorithm": "flat", "datatype": "float32", "label": "FLAT_FP32"}, + {"algorithm": "flat", "datatype": "float16", "label": "FLAT_FP16"}, + {"algorithm": "hnsw", "datatype": "float32", "label": "HNSW_FP32"}, + {"algorithm": "hnsw", "datatype": "float16", "label": "HNSW_FP16"}, +] + +# HNSW parameters matching SVS-VAMANA study +HNSW_M = 16 +HNSW_EF_CONSTRUCTION = 200 +HNSW_EF_RUNTIME = 10 + +# Recall K values to compute recall curves +RECALL_K_VALUES = [1, 5, 10, 20, 50, 100] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Retrieval benchmark: FP32 vs FP16 x HNSW vs FLAT." + ) + parser.add_argument("--redis-url", default="redis://localhost:6379") + parser.add_argument( + "--dataset", + choices=list(DATASETS.keys()), + default="dbpedia", + ) + parser.add_argument( + "--sizes", + nargs="+", + type=int, + default=[1000, 10000], + ) + parser.add_argument("--query-count", type=int, default=100) + parser.add_argument("--top-k", type=int, default=10) + parser.add_argument("--ef-runtime", type=int, default=10) + parser.add_argument("--load-batch-size", type=int, default=500) + parser.add_argument( + "--recall-k-max", + type=int, + default=100, + help="Max K for recall curve (queries will fetch this many results).", + ) + parser.add_argument( + "--output", + default="retrieval_benchmark_results.json", + ) + return parser.parse_args() + + +# --------------------------------------------------------------------------- +# Dataset loading +# --------------------------------------------------------------------------- + + +def load_dataset_vectors( + dataset_key: str, + num_vectors: int, +) -> Tuple[np.ndarray, int]: + """Load pre-embedded vectors from HuggingFace or generate synthetic.""" + ds_info = DATASETS[dataset_key] + dims = ds_info["dims"] + + if ds_info["hf_name"] is None: + # Synthetic random unit vectors + print(f"Generating {num_vectors} random unit vectors ({dims}d) ...") + rng = np.random.default_rng(42) + vectors = rng.standard_normal((num_vectors, dims)).astype(np.float32) + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + vectors = vectors / norms + print(f" Generated shape: {vectors.shape}") + return vectors, dims + + # Local import to avoid requiring datasets for synthetic mode + from datasets import load_dataset + + hf_name = ds_info["hf_name"] + emb_col = ds_info["embedding_column"] + + print(f"Loading {num_vectors} vectors from {hf_name} ...") + ds = load_dataset(hf_name, split=f"train[:{num_vectors}]") + vectors = np.array(ds[emb_col], dtype=np.float32) + print(f" Loaded shape: {vectors.shape}") + return vectors, dims + + +# --------------------------------------------------------------------------- +# Schema helpers +# --------------------------------------------------------------------------- + + +def build_schema( + *, + index_name: str, + prefix: str, + dims: int, + algorithm: str, + datatype: str, + distance_metric: str, + ef_runtime: int = HNSW_EF_RUNTIME, +) -> Dict[str, Any]: + """Build an index schema dict for a given config.""" + vector_attrs: Dict[str, Any] = { + "dims": dims, + "distance_metric": distance_metric, + "algorithm": algorithm, + "datatype": datatype, + } + if algorithm == "hnsw": + vector_attrs["m"] = HNSW_M + vector_attrs["ef_construction"] = HNSW_EF_CONSTRUCTION + vector_attrs["ef_runtime"] = ef_runtime + + return { + "index": { + "name": index_name, + "prefix": prefix, + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + { + "name": "embedding", + "type": "vector", + "attrs": vector_attrs, + }, + ], + } + + +# --------------------------------------------------------------------------- +# Data loading into Redis +# --------------------------------------------------------------------------- + + +def iter_documents( + vectors: np.ndarray, + *, + dtype: str, +) -> Iterable[Dict[str, Any]]: + """Yield documents ready for SearchIndex.load().""" + for i, vec in enumerate(vectors): + yield { + "doc_id": f"doc-{i}", + "embedding": array_to_buffer(vec, dtype), + } + + +def wait_for_index_ready( + index: SearchIndex, + *, + timeout_seconds: int = 3600, + poll_interval: float = 0.5, +) -> Dict[str, Any]: + """Block until the index reports 100% indexed.""" + deadline = time.perf_counter() + timeout_seconds + info = index.info() + while time.perf_counter() < deadline: + info = index.info() + pct = float(info.get("percent_indexed", 0)) + indexing = info.get("indexing", 0) + if pct >= 1.0 and not indexing: + return info + time.sleep(poll_interval) + raise TimeoutError( + f"Index {index.schema.index.name} not ready within {timeout_seconds}s" + ) + + +# --------------------------------------------------------------------------- +# Memory helpers +# --------------------------------------------------------------------------- + + +def get_memory_mb(client: Redis) -> float: + info = client.info("memory") + return round(int(info.get("used_memory", 0)) / (1024 * 1024), 3) + + +# --------------------------------------------------------------------------- +# Query execution & overlap +# --------------------------------------------------------------------------- + + +def percentile(values: Sequence[float], pct: float) -> float: + if not values: + return 0.0 + return round(float(np.percentile(np.asarray(values), pct)), 6) + + +def run_queries( + index: SearchIndex, + query_vectors: np.ndarray, + *, + dtype: str, + top_k: int, +) -> Dict[str, Any]: + """Run query vectors; return latency stats and result doc-id lists.""" + latencies_ms: List[float] = [] + result_sets: List[List[str]] = [] + + for qvec in query_vectors: + q = VectorQuery( + vector=qvec.tolist(), + vector_field_name="embedding", + return_fields=["doc_id"], + num_results=top_k, + dtype=dtype, + ) + t0 = time.perf_counter() + results = index.query(q) + latencies_ms.append((time.perf_counter() - t0) * 1000) + result_sets.append([r.get("doc_id") or r.get("id", "") for r in results if r]) + + total_s = sum(latencies_ms) / 1000 + qps = len(latencies_ms) / total_s if total_s > 0 else 0 + + return { + "count": len(latencies_ms), + "p50_ms": percentile(latencies_ms, 50), + "p95_ms": percentile(latencies_ms, 95), + "p99_ms": percentile(latencies_ms, 99), + "mean_ms": round(statistics.mean(latencies_ms), 3), + "qps": round(qps, 2), + "result_sets": result_sets, + } + + +def compute_overlap( + ground_truth: List[List[str]], + candidate: List[List[str]], + *, + top_k: int, +) -> Dict[str, Any]: + """Compute Overlap@K (precision) of candidate vs ground truth.""" + ratios: List[float] = [] + for gt, cand in zip(ground_truth, candidate): + gt_set = set(gt[:top_k]) + cand_set = set(cand[:top_k]) + ratios.append(len(gt_set & cand_set) / max(top_k, 1)) + return { + "mean_overlap_at_k": round(statistics.mean(ratios), 4), + "min_overlap_at_k": round(min(ratios), 4), + "max_overlap_at_k": round(max(ratios), 4), + "std_overlap_at_k": ( + round(statistics.stdev(ratios), 4) if len(ratios) > 1 else 0.0 + ), + } + + +def compute_recall( + ground_truth: List[List[str]], + candidate: List[List[str]], + *, + k_values: Sequence[int], + ground_truth_depth: int, +) -> Dict[str, Any]: + """Compute Recall@K at multiple K values. + + For each K, recall is defined as: + |candidate_top_K intersection ground_truth_top_GT_DEPTH| / GT_DEPTH + + The ground truth set is FIXED at ground_truth_depth (e.g., top-100 from + FLAT FP32). As K increases from 1 to ground_truth_depth, recall should + climb from low to 1.0 (for exact search) or near-1.0 (for approximate). + + This is the standard recall metric from ANN benchmarks -- it answers + "what fraction of the true nearest neighbors did we find?" + """ + recall_at_k: Dict[str, float] = {} + recall_detail: Dict[str, Dict[str, float]] = {} + for k in k_values: + ratios: List[float] = [] + for gt, cand in zip(ground_truth, candidate): + gt_set = set(gt[:ground_truth_depth]) + cand_set = set(cand[:k]) + denom = min(ground_truth_depth, len(gt_set)) + if denom == 0: + # Empty ground truth means nothing to recall; use 0.0 + ratios.append(0.0) + else: + ratios.append(len(gt_set & cand_set) / denom) + mean_recall = round(statistics.mean(ratios), 4) + recall_at_k[f"recall@{k}"] = mean_recall + recall_detail[f"recall@{k}"] = { + "mean": mean_recall, + "min": round(min(ratios), 4), + "max": round(max(ratios), 4), + "std": round(statistics.stdev(ratios), 4) if len(ratios) > 1 else 0.0, + } + return { + "recall_at_k": recall_at_k, + "recall_detail": recall_detail, + "ground_truth_depth": ground_truth_depth, + } + + +# --------------------------------------------------------------------------- +# Single-config benchmark +# --------------------------------------------------------------------------- + + +def benchmark_single_config( + *, + client: Redis, + doc_vectors: np.ndarray, + query_vectors: np.ndarray, + config: Dict[str, str], + dims: int, + distance_metric: str, + size: int, + top_k: int, + ef_runtime: int, + load_batch_size: int, +) -> Dict[str, Any]: + """Build one index config, load data, query, and return metrics.""" + label = config["label"] + algo = config["algorithm"] + dtype = config["datatype"] + + index_name = f"bench_{label}_{size}" + prefix = f"bench:{label}:{size}" + + schema = build_schema( + index_name=index_name, + prefix=prefix, + dims=dims, + algorithm=algo, + datatype=dtype, + distance_metric=distance_metric, + ef_runtime=ef_runtime, + ) + + idx = SearchIndex.from_dict(schema, redis_client=client) + try: + idx.create(overwrite=True) + + # Load data + load_start = time.perf_counter() + idx.load( + iter_documents(doc_vectors, dtype=dtype), + id_field="doc_id", + batch_size=load_batch_size, + ) + info = wait_for_index_ready(idx) + load_duration = time.perf_counter() - load_start + + memory_mb = get_memory_mb(client) + + # Query + query_metrics = run_queries( + idx, + query_vectors, + dtype=dtype, + top_k=top_k, + ) + + return { + "label": label, + "algorithm": algo, + "datatype": dtype, + "load_duration_seconds": round(load_duration, 3), + "num_docs": int(info.get("num_docs", 0) or 0), + "vector_index_sz_mb": float(info.get("vector_index_sz_mb", 0) or 0), + "memory_mb": memory_mb, + "latency": { + "queried_top_k": top_k, + **{k: v for k, v in query_metrics.items() if k != "result_sets"}, + }, + "result_sets": query_metrics["result_sets"], + } + finally: + try: + idx.delete(drop=True) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Scale-level benchmark (runs all 4 configs for one size) +# --------------------------------------------------------------------------- + + +def benchmark_scale( + *, + client: Redis, + all_vectors: np.ndarray, + size: int, + query_count: int, + dims: int, + distance_metric: str, + top_k: int, + ef_runtime: int, + load_batch_size: int, + recall_k_max: int = 100, +) -> Dict[str, Any]: + """Run all 4 index configs for a given dataset size.""" + doc_vectors = all_vectors[:size] + query_vectors = all_vectors[size : size + query_count].copy() + + # Use the larger of top_k and recall_k_max for querying + # so we have enough results for recall curve computation + effective_top_k = max(top_k, recall_k_max) + + baseline_memory = get_memory_mb(client) + + config_results: Dict[str, Any] = {} + ground_truth_results: List[List[str]] = [] + + # Run FLAT_FP32 first to establish ground truth + gt_config = INDEX_CONFIGS[0] # FLAT_FP32 + assert gt_config["label"] == "FLAT_FP32" + + for config in INDEX_CONFIGS: + label = config["label"] + print(f" [{label}] Building and querying ...") + + result = benchmark_single_config( + client=client, + doc_vectors=doc_vectors, + query_vectors=query_vectors, + config=config, + dims=dims, + distance_metric=distance_metric, + size=size, + top_k=effective_top_k, + ef_runtime=ef_runtime, + load_batch_size=load_batch_size, + ) + + if label == "FLAT_FP32": + ground_truth_results = result["result_sets"] + + config_results[label] = result + + # Compute overlap vs ground truth for every config (at original top_k) + overlap_results: Dict[str, Any] = {} + for label, result in config_results.items(): + overlap = compute_overlap( + ground_truth_results, + result["result_sets"], + top_k=top_k, + ) + overlap_results[label] = overlap + + # Compute recall at multiple K values. + # Ground truth depth is fixed at top_k (e.g., 10). We measure what + # fraction of those top_k true results appear in candidate top-K as + # K varies from 1 up to effective_top_k. + valid_k_values = [k for k in RECALL_K_VALUES if k <= effective_top_k] + recall_results: Dict[str, Any] = {} + for label, result in config_results.items(): + recall = compute_recall( + ground_truth_results, + result["result_sets"], + k_values=valid_k_values, + ground_truth_depth=top_k, + ) + recall_results[label] = recall + + # Strip raw result_sets from output (too large for JSON) + for label in config_results: + del config_results[label]["result_sets"] + + return { + "size": size, + "query_count": query_count, + "dims": dims, + "distance_metric": distance_metric, + "top_k": top_k, + "recall_k_max": recall_k_max, + "ef_runtime": ef_runtime, + "hnsw_m": HNSW_M, + "hnsw_ef_construction": HNSW_EF_CONSTRUCTION, + "baseline_memory_mb": baseline_memory, + "configs": config_results, + "overlap_vs_ground_truth": overlap_results, + "recall_vs_ground_truth": recall_results, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + args = parse_args() + sizes = sorted(args.sizes) + max_needed = max(sizes) + args.query_count + ds_info = DATASETS[args.dataset] + + print( + f"""Retrieval Benchmark + Dataset: {args.dataset} ({ds_info['description']}) + Dims: {ds_info['dims']} + Sizes: {sizes} + Query count: {args.query_count} + Top-K: {args.top_k} + Recall K max: {args.recall_k_max} + EF runtime: {args.ef_runtime} + HNSW M: {HNSW_M} + EF construct: {HNSW_EF_CONSTRUCTION} + Redis URL: {args.redis_url} + Configs: {[c['label'] for c in INDEX_CONFIGS]}""" + ) + + # Load vectors once + all_vectors, dims = load_dataset_vectors(args.dataset, max_needed) + if all_vectors.shape[0] < max_needed: + raise ValueError( + f"Dataset has {all_vectors.shape[0]} vectors but need {max_needed} " + f"(max_size={max(sizes)} + query_count={args.query_count})" + ) + + client = Redis.from_url(args.redis_url, decode_responses=False) + client.ping() + print("Connected to Redis") + + report = { + "benchmark": "retrieval_fp32_vs_fp16", + "dataset": args.dataset, + "dataset_description": ds_info["description"], + "dims": dims, + "distance_metric": ds_info["distance_metric"], + "hnsw_m": HNSW_M, + "hnsw_ef_construction": HNSW_EF_CONSTRUCTION, + "ef_runtime": args.ef_runtime, + "top_k": args.top_k, + "recall_k_max": args.recall_k_max, + "recall_k_values": [ + k for k in RECALL_K_VALUES if k <= max(args.top_k, args.recall_k_max) + ], + "query_count": args.query_count, + "configs": [c["label"] for c in INDEX_CONFIGS], + "results": [], + } + + for size in sizes: + print(f"\n{'='*60}") + print(f" Size: {size:,} documents") + print(f"{'='*60}") + + client.flushdb() + + result = benchmark_scale( + client=client, + all_vectors=all_vectors, + size=size, + query_count=args.query_count, + dims=dims, + distance_metric=ds_info["distance_metric"], + top_k=args.top_k, + ef_runtime=args.ef_runtime, + load_batch_size=args.load_batch_size, + recall_k_max=args.recall_k_max, + ) + report["results"].append(result) + + # Print summary table for this size + print( + f"\n {'Config':<12} {'Load(s)':>8} {'Memory(MB)':>11} " + f"{'p50(ms)':>8} {'p95(ms)':>8} {'QPS':>7} {'Overlap@K':>10}" + ) + print(f" {'-'*12} {'-'*8} {'-'*11} {'-'*8} {'-'*8} {'-'*7} {'-'*10}") + for label, cfg in result["configs"].items(): + overlap = result["overlap_vs_ground_truth"][label] + print( + f" {label:<12} " + f"{cfg['load_duration_seconds']:>8.1f} " + f"{cfg['memory_mb']:>11.1f} " + f"{cfg['latency']['p50_ms']:>8.2f} " + f"{cfg['latency']['p95_ms']:>8.2f} " + f"{cfg['latency']['qps']:>7.1f} " + f"{overlap['mean_overlap_at_k']:>10.4f}" + ) + + # Print recall curve summary + recall_data = result.get("recall_vs_ground_truth", {}) + if recall_data: + first_label = next(iter(recall_data)) + k_keys = sorted( + recall_data[first_label].get("recall_at_k", {}).keys(), + key=lambda x: int(x.split("@")[1]), + ) + header = f" {'Config':<12} " + " ".join(f"{k:>10}" for k in k_keys) + print(f"\n Recall Curve:") + print(header) + print(f" {'-'*12} " + " ".join(f"{'-'*10}" for _ in k_keys)) + for label, rdata in recall_data.items(): + vals = " ".join( + f"{rdata['recall_at_k'].get(k, 0):>10.4f}" for k in k_keys + ) + print(f" {label:<12} {vals}") + + # Write report + output_path = Path(args.output).resolve() + with open(output_path, "w") as f: + json.dump(report, f, indent=2) + print(f"\nReport written to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmarks/visualize_results.py b/tests/benchmarks/visualize_results.py new file mode 100644 index 000000000..8b282743a --- /dev/null +++ b/tests/benchmarks/visualize_results.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python3 +""" +Visualization script for retrieval benchmark results. + +Generates charts replicating the style of the Redis SVS-VAMANA blog post: + 1. Memory footprint comparison (FP32 vs FP16, bar chart) + 2. Precision (Overlap@K) comparison (grouped bar chart) + 3. QPS comparison (grouped bar chart) + 4. Latency comparison (p50/p95, grouped bar chart) + 5. QPS vs Overlap@K curve (line chart) + +Usage: + python tests/benchmarks/visualize_results.py \ + --input tests/benchmarks/results_dbpedia.json \ + --output-dir tests/benchmarks/charts/ +""" + +import argparse +import json +import os +from typing import Any, Dict, List + +try: + import matplotlib.pyplot as plt + import matplotlib.ticker as mticker +except ImportError: + raise ImportError( + "matplotlib is required by this visualization script. " + "Install it with: pip install matplotlib" + ) +import numpy as np + +# Redis-inspired color palette +COLORS = { + "FLAT_FP32": "#1E3A5F", # dark navy + "FLAT_FP16": "#3B82F6", # bright blue + "HNSW_FP32": "#DC2626", # Redis red + "HNSW_FP16": "#F97316", # orange +} + +LABELS = { + "FLAT_FP32": "FLAT FP32", + "FLAT_FP16": "FLAT FP16", + "HNSW_FP32": "HNSW FP32", + "HNSW_FP16": "HNSW FP16", +} + + +def load_results(path: str) -> Dict[str, Any]: + with open(path) as f: + return json.load(f) + + +def setup_style(): + """Apply a clean, modern chart style.""" + plt.rcParams.update( + { + "figure.facecolor": "white", + "axes.facecolor": "#F8F9FA", + "axes.edgecolor": "#DEE2E6", + "axes.grid": True, + "grid.color": "#E9ECEF", + "grid.alpha": 0.7, + "font.family": "sans-serif", + "font.size": 11, + "axes.titlesize": 14, + "axes.titleweight": "bold", + "axes.labelsize": 12, + } + ) + + +def chart_memory(results: List[Dict], dataset: str, output_dir: str): + """Chart 1: Memory footprint comparison per size (grouped bar chart).""" + fig, ax = plt.subplots(figsize=(10, 6)) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + sizes = [r["size"] for r in results] + x = np.arange(len(sizes)) + width = 0.18 + + for i, cfg in enumerate(configs): + mem = [r["configs"][cfg]["memory_mb"] for r in results] + bars = ax.bar( + x + i * width, + mem, + width, + label=LABELS[cfg], + color=COLORS[cfg], + edgecolor="white", + linewidth=0.5, + ) + for bar, val in zip(bars, mem): + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 1, + f"{val:.0f}", + ha="center", + va="bottom", + fontsize=8, + ) + + ax.set_xlabel("Corpus Size") + ax.set_ylabel("Total Memory (MB)") + ax.set_title(f"Memory Footprint: FP32 vs FP16 -- {dataset}") + ax.set_xticks(x + width * 1.5) + ax.set_xticklabels([f"{s:,}" for s in sizes]) + ax.legend(loc="upper left") + ax.set_ylim(bottom=0) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_memory.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_memory.png") + + +def chart_overlap(results: List[Dict], dataset: str, output_dir: str): + """Chart 2: Overlap@K (precision) comparison per size.""" + fig, ax = plt.subplots(figsize=(10, 6)) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + sizes = [r["size"] for r in results] + x = np.arange(len(sizes)) + width = 0.18 + + for i, cfg in enumerate(configs): + overlap = [ + r["overlap_vs_ground_truth"][cfg]["mean_overlap_at_k"] for r in results + ] + bars = ax.bar( + x + i * width, + overlap, + width, + label=LABELS[cfg], + color=COLORS[cfg], + edgecolor="white", + linewidth=0.5, + ) + for bar, val in zip(bars, overlap): + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.005, + f"{val:.3f}", + ha="center", + va="bottom", + fontsize=8, + ) + + ax.set_xlabel("Corpus Size") + ax.set_ylabel("Overlap@K (Precision vs FLAT FP32)") + ax.set_title(f"Search Precision: FP32 vs FP16 -- {dataset}") + ax.set_xticks(x + width * 1.5) + ax.set_xticklabels([f"{s:,}" for s in sizes]) + ax.legend(loc="lower left") + ax.set_ylim(0, 1.1) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_overlap.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_overlap.png") + + +def chart_qps(results: List[Dict], dataset: str, output_dir: str): + """Chart 3: QPS comparison per size.""" + fig, ax = plt.subplots(figsize=(10, 6)) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + sizes = [r["size"] for r in results] + x = np.arange(len(sizes)) + width = 0.18 + + for i, cfg in enumerate(configs): + qps = [r["configs"][cfg]["latency"]["qps"] for r in results] + bars = ax.bar( + x + i * width, + qps, + width, + label=LABELS[cfg], + color=COLORS[cfg], + edgecolor="white", + linewidth=0.5, + ) + for bar, val in zip(bars, qps): + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 10, + f"{val:.0f}", + ha="center", + va="bottom", + fontsize=7, + rotation=45, + ) + + ax.set_xlabel("Corpus Size") + ax.set_ylabel("Queries Per Second (QPS)") + ax.set_title(f"Query Throughput: FP32 vs FP16 -- {dataset}") + ax.set_xticks(x + width * 1.5) + ax.set_xticklabels([f"{s:,}" for s in sizes]) + ax.legend(loc="upper right") + ax.set_ylim(bottom=0) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_qps.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_qps.png") + + +def chart_latency(results: List[Dict], dataset: str, output_dir: str): + """Chart 4: p50 and p95 latency comparison per size.""" + fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + sizes = [r["size"] for r in results] + x = np.arange(len(sizes)) + width = 0.18 + + for ax, metric, title in zip( + axes, ["p50_ms", "p95_ms"], ["p50 Latency", "p95 Latency"] + ): + for i, cfg in enumerate(configs): + vals = [r["configs"][cfg]["latency"][metric] for r in results] + bars = ax.bar( + x + i * width, + vals, + width, + label=LABELS[cfg], + color=COLORS[cfg], + edgecolor="white", + linewidth=0.5, + ) + for bar, val in zip(bars, vals): + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.02, + f"{val:.2f}", + ha="center", + va="bottom", + fontsize=7, + ) + ax.set_xlabel("Corpus Size") + ax.set_ylabel("Latency (ms)") + ax.set_title(f"{title} -- {dataset}") + ax.set_xticks(x + width * 1.5) + ax.set_xticklabels([f"{s:,}" for s in sizes]) + ax.legend(loc="upper left", fontsize=9) + ax.set_ylim(bottom=0) + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_latency.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_latency.png") + + +def chart_qps_vs_overlap(results: List[Dict], dataset: str, output_dir: str): + """Chart 5: QPS vs Overlap@K curve (Redis blog Chart 2 style).""" + fig, ax = plt.subplots(figsize=(10, 6)) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + markers = {"FLAT_FP32": "s", "FLAT_FP16": "D", "HNSW_FP32": "o", "HNSW_FP16": "^"} + + for cfg in configs: + overlaps = [] + qps_vals = [] + for r in results: + overlaps.append(r["overlap_vs_ground_truth"][cfg]["mean_overlap_at_k"]) + qps_vals.append(r["configs"][cfg]["latency"]["qps"]) + + ax.plot( + overlaps, + qps_vals, + marker=markers[cfg], + markersize=8, + linewidth=2, + label=LABELS[cfg], + color=COLORS[cfg], + ) + # Annotate points with size + for ov, qps, r in zip(overlaps, qps_vals, results): + ax.annotate( + f'{r["size"]//1000}K', + (ov, qps), + textcoords="offset points", + xytext=(5, 5), + fontsize=7, + color=COLORS[cfg], + ) + + ax.set_xlabel("Overlap@K (Precision)") + ax.set_ylabel("Queries Per Second (QPS)") + ax.set_title(f"Precision vs Throughput -- {dataset}") + ax.legend(loc="best") + ax.set_xlim(0, 1.05) + ax.set_ylim(bottom=0) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_qps_vs_overlap.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_qps_vs_overlap.png") + + +def chart_memory_savings(results: List[Dict], dataset: str, output_dir: str): + """Chart 6: Memory savings percentage (Redis blog Chart 1 style).""" + fig, ax = plt.subplots(figsize=(10, 6)) + sizes = [r["size"] for r in results] + + # Calculate savings: FP16 vs FP32 for both FLAT and HNSW + pairs = [ + ("FLAT", "FLAT_FP32", "FLAT_FP16", "#3B82F6"), + ("HNSW", "HNSW_FP32", "HNSW_FP16", "#F97316"), + ] + + x = np.arange(len(sizes)) + width = 0.3 + + for i, (label, fp32, fp16, color) in enumerate(pairs): + savings = [] + for r in results: + m32 = r["configs"][fp32]["memory_mb"] + m16 = r["configs"][fp16]["memory_mb"] + pct = (1 - m16 / m32) * 100 if m32 > 0 else 0.0 + savings.append(pct) + + bars = ax.bar( + x + i * width, + savings, + width, + label=f"{label} FP16 savings", + color=color, + edgecolor="white", + linewidth=0.5, + ) + for bar, val in zip(bars, savings): + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.5, + f"{val:.1f}%", + ha="center", + va="bottom", + fontsize=9, + fontweight="bold", + ) + + ax.set_xlabel("Corpus Size") + ax.set_ylabel("Memory Savings (%)") + ax.set_title(f"FP16 Memory Savings vs FP32 -- {dataset}") + ax.set_xticks(x + width * 0.5) + ax.set_xticklabels([f"{s:,}" for s in sizes]) + ax.legend(loc="lower right") + ax.set_ylim(0, 60) + ax.yaxis.set_major_formatter(mticker.PercentFormatter()) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_memory_savings.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_memory_savings.png") + + +def chart_build_time(results: List[Dict], dataset: str, output_dir: str): + """Chart 7: Index build/load time comparison.""" + fig, ax = plt.subplots(figsize=(10, 6)) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + sizes = [r["size"] for r in results] + x = np.arange(len(sizes)) + width = 0.18 + + for i, cfg in enumerate(configs): + times = [r["configs"][cfg]["load_duration_seconds"] for r in results] + bars = ax.bar( + x + i * width, + times, + width, + label=LABELS[cfg], + color=COLORS[cfg], + edgecolor="white", + linewidth=0.5, + ) + for bar, val in zip(bars, times): + if val > 0.1: + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.2, + f"{val:.1f}s", + ha="center", + va="bottom", + fontsize=7, + ) + + ax.set_xlabel("Corpus Size") + ax.set_ylabel("Build Time (seconds)") + ax.set_title(f"Index Build Time -- {dataset}") + ax.set_xticks(x + width * 1.5) + ax.set_xticklabels([f"{s:,}" for s in sizes]) + ax.legend(loc="upper left") + ax.set_ylim(bottom=0) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_build_time.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_build_time.png") + + +def chart_recall_curve(results: List[Dict], dataset: str, output_dir: str): + """Chart 8: Recall@K curve -- recall at multiple K values for the largest size.""" + # Use the largest corpus size for the recall curve + r = results[-1] + recall_data = r.get("recall_vs_ground_truth") + if not recall_data: + print(f" Skipping recall curve (no recall data in results)") + return + + fig, ax = plt.subplots(figsize=(10, 6)) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + markers = {"FLAT_FP32": "s", "FLAT_FP16": "D", "HNSW_FP32": "o", "HNSW_FP16": "^"} + linestyles = { + "FLAT_FP32": "-", + "FLAT_FP16": "--", + "HNSW_FP32": "-", + "HNSW_FP16": "--", + } + + for cfg in configs: + if cfg not in recall_data: + continue + recall_at_k = recall_data[cfg].get("recall_at_k", {}) + if not recall_at_k: + continue + k_vals = sorted([int(k.split("@")[1]) for k in recall_at_k.keys()]) + recalls = [recall_at_k[f"recall@{k}"] for k in k_vals] + + ax.plot( + k_vals, + recalls, + marker=markers[cfg], + markersize=7, + linewidth=2, + linestyle=linestyles[cfg], + label=LABELS[cfg], + color=COLORS[cfg], + ) + + ax.set_xlabel("K (number of results)") + ax.set_ylabel("Recall@K") + ax.set_title(f"Recall@K Curve at {r['size']:,} documents -- {dataset}") + ax.legend(loc="lower right") + ax.set_ylim(0, 1.05) + ax.set_xlim(left=0) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_recall_curve.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_recall_curve.png") + + +def chart_recall_by_size(results: List[Dict], dataset: str, output_dir: str): + """Chart 9: Recall@10 comparison across corpus sizes (grouped bar chart).""" + # Check if recall data exists + if not results[0].get("recall_vs_ground_truth"): + print(f" Skipping recall by size (no recall data)") + return + + fig, ax = plt.subplots(figsize=(10, 6)) + configs = ["FLAT_FP32", "FLAT_FP16", "HNSW_FP32", "HNSW_FP16"] + sizes = [r["size"] for r in results] + x = np.arange(len(sizes)) + width = 0.18 + + for i, cfg in enumerate(configs): + recalls = [] + for r in results: + recall_data = r.get("recall_vs_ground_truth", {}).get(cfg, {}) + recall_at_k = recall_data.get("recall_at_k", {}) + recalls.append(recall_at_k.get("recall@10", 0)) + bars = ax.bar( + x + i * width, + recalls, + width, + label=LABELS[cfg], + color=COLORS[cfg], + edgecolor="white", + linewidth=0.5, + ) + for bar, val in zip(bars, recalls): + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.005, + f"{val:.3f}", + ha="center", + va="bottom", + fontsize=8, + ) + + ax.set_xlabel("Corpus Size") + ax.set_ylabel("Recall@10") + ax.set_title(f"Recall@10: FP32 vs FP16 -- {dataset}") + ax.set_xticks(x + width * 1.5) + ax.set_xticklabels([f"{s:,}" for s in sizes]) + ax.legend(loc="lower left") + ax.set_ylim(0, 1.1) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"{dataset}_recall.png"), dpi=150) + plt.close(fig) + print(f" Saved {dataset}_recall.png") + + +def main(): + parser = argparse.ArgumentParser(description="Visualize benchmark results.") + parser.add_argument( + "--input", nargs="+", required=True, help="One or more result JSON files." + ) + parser.add_argument( + "--output-dir", + default="tests/benchmarks/charts/", + help="Directory to save chart images.", + ) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + setup_style() + + for path in args.input: + data = load_results(path) + dataset = data["dataset"] + results = data["results"] + print(f"\nGenerating charts for {dataset} ({len(results)} sizes) ...") + + chart_memory(results, dataset, args.output_dir) + chart_overlap(results, dataset, args.output_dir) + chart_qps(results, dataset, args.output_dir) + chart_latency(results, dataset, args.output_dir) + chart_qps_vs_overlap(results, dataset, args.output_dir) + chart_memory_savings(results, dataset, args.output_dir) + chart_build_time(results, dataset, args.output_dir) + chart_recall_curve(results, dataset, args.output_dir) + chart_recall_by_size(results, dataset, args.output_dir) + + print(f"\nAll charts saved to {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/tests/integration/test_async_migration_v1.py b/tests/integration/test_async_migration_v1.py new file mode 100644 index 000000000..c50fdaf84 --- /dev/null +++ b/tests/integration/test_async_migration_v1.py @@ -0,0 +1,150 @@ +"""Integration tests for async migration (Phase 1.5). + +These tests verify the async migration components work correctly with a real +Redis instance, mirroring the sync tests in test_migration_v1.py. +""" + +import uuid + +import pytest +import yaml + +from redisvl.index import AsyncSearchIndex +from redisvl.migration import ( + AsyncMigrationExecutor, + AsyncMigrationPlanner, + AsyncMigrationValidator, +) +from redisvl.migration.utils import load_migration_plan, schemas_equal +from redisvl.redis.utils import array_to_buffer + + +@pytest.mark.asyncio +async def test_async_drop_recreate_plan_apply_validate_flow( + redis_url, worker_id, tmp_path +): + """Test full async migration flow: plan -> apply -> validate.""" + unique_id = str(uuid.uuid4())[:8] + index_name = f"async_migration_v1_{worker_id}_{unique_id}" + prefix = f"async_migration_v1:{worker_id}:{unique_id}" + + source_index = AsyncSearchIndex.from_dict( + { + "index": { + "name": index_name, + "prefix": prefix, + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "title", "type": "text"}, + {"name": "price", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + redis_url=redis_url, + ) + + docs = [ + { + "doc_id": "1", + "title": "alpha", + "price": 1, + "category": "news", + "embedding": array_to_buffer([0.1, 0.2, 0.3], "float32"), + }, + { + "doc_id": "2", + "title": "beta", + "price": 2, + "category": "sports", + "embedding": array_to_buffer([0.2, 0.1, 0.4], "float32"), + }, + ] + + await source_index.create(overwrite=True) + await source_index.load(docs, id_field="doc_id") + + # Create schema patch + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "category", + "type": "tag", + "attrs": {"separator": ","}, + } + ], + "remove_fields": ["price"], + "update_fields": [{"name": "title", "attrs": {"sortable": True}}], + }, + }, + sort_keys=False, + ) + ) + + # Create plan using async planner + plan_path = tmp_path / "migration_plan.yaml" + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + index_name, + redis_url=redis_url, + schema_patch_path=str(patch_path), + ) + assert plan.diff_classification.supported is True + planner.write_plan(plan, str(plan_path)) + + # Create query checks + query_check_path = tmp_path / "query_checks.yaml" + query_check_path.write_text( + yaml.safe_dump({"fetch_ids": ["1", "2"]}, sort_keys=False) + ) + + # Apply migration using async executor + executor = AsyncMigrationExecutor() + report = await executor.apply( + load_migration_plan(str(plan_path)), + redis_url=redis_url, + query_check_file=str(query_check_path), + ) + + # Verify migration succeeded + assert report.result == "succeeded" + assert report.validation.schema_match is True + assert report.validation.doc_count_match is True + assert report.validation.key_sample_exists is True + assert report.validation.indexing_failures_delta == 0 + assert not report.validation.errors + assert report.benchmark_summary.documents_indexed_per_second is not None + + # Verify schema matches target + live_index = await AsyncSearchIndex.from_existing(index_name, redis_url=redis_url) + assert schemas_equal(live_index.schema.to_dict(), plan.merged_target_schema) + + # Test standalone async validator + validator = AsyncMigrationValidator() + validation, _target_info, _duration = await validator.validate( + load_migration_plan(str(plan_path)), + redis_url=redis_url, + query_check_file=str(query_check_path), + ) + assert validation.schema_match is True + assert validation.doc_count_match is True + assert validation.key_sample_exists is True + assert not validation.errors + + # Cleanup + await live_index.delete(drop=True) diff --git a/tests/integration/test_batch_migration_integration.py b/tests/integration/test_batch_migration_integration.py new file mode 100644 index 000000000..92ea7b94d --- /dev/null +++ b/tests/integration/test_batch_migration_integration.py @@ -0,0 +1,485 @@ +""" +Integration tests for batch migration. + +Tests the full batch migration flow with real Redis: +- Batch planning with patterns and explicit lists +- Batch apply with checkpointing +- Resume after interruption +- Failure policies (fail_fast, continue_on_error) +""" + +import uuid + +import pytest +import yaml + +from redisvl.index import SearchIndex +from redisvl.migration import BatchMigrationExecutor, BatchMigrationPlanner +from redisvl.redis.utils import array_to_buffer + + +def create_test_index(name: str, prefix: str, redis_url: str) -> SearchIndex: + """Helper to create a test index with standard schema.""" + index = SearchIndex.from_dict( + { + "index": { + "name": name, + "prefix": prefix, + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + redis_url=redis_url, + ) + return index + + +def load_test_data(index: SearchIndex) -> None: + """Load sample documents into an index.""" + docs = [ + { + "doc_id": "1", + "title": "alpha", + "embedding": array_to_buffer([0.1, 0.2, 0.3], "float32"), + }, + { + "doc_id": "2", + "title": "beta", + "embedding": array_to_buffer([0.2, 0.1, 0.4], "float32"), + }, + ] + index.load(docs, id_field="doc_id") + + +class TestBatchMigrationPlanIntegration: + """Test batch plan creation with real Redis.""" + + def test_batch_plan_with_pattern(self, redis_url, worker_id, tmp_path): + """Test creating a batch plan using pattern matching.""" + unique_id = str(uuid.uuid4())[:8] + prefix = f"batch_test:{worker_id}:{unique_id}" + indexes = [] + + # Create multiple indexes matching pattern + for i in range(3): + name = f"batch_{unique_id}_idx_{i}" + index = create_test_index(name, f"{prefix}_{i}", redis_url) + index.create(overwrite=True) + load_test_data(index) + indexes.append(index) + + # Create shared patch (add sortable to title) + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"sortable": True}} + ] + }, + }, + sort_keys=False, + ) + ) + + # Create batch plan + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + pattern=f"batch_{unique_id}_idx_*", + schema_patch_path=str(patch_path), + redis_url=redis_url, + ) + + # Verify batch plan + assert batch_plan.batch_id is not None + assert len(batch_plan.indexes) == 3 + for entry in batch_plan.indexes: + assert entry.applicable is True + assert entry.skip_reason is None + + # Cleanup + for index in indexes: + index.delete(drop=True) + + def test_batch_plan_with_explicit_list(self, redis_url, worker_id, tmp_path): + """Test creating a batch plan with explicit index list.""" + unique_id = str(uuid.uuid4())[:8] + prefix = f"batch_list_test:{worker_id}:{unique_id}" + index_names = [] + indexes = [] + + # Create indexes + for i in range(2): + name = f"list_batch_{unique_id}_{i}" + index = create_test_index(name, f"{prefix}_{i}", redis_url) + index.create(overwrite=True) + load_test_data(index) + indexes.append(index) + index_names.append(name) + + # Create shared patch + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"sortable": True}} + ] + }, + }, + sort_keys=False, + ) + ) + + # Create batch plan with explicit list + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=index_names, + schema_patch_path=str(patch_path), + redis_url=redis_url, + ) + + assert len(batch_plan.indexes) == 2 + assert all(idx.applicable for idx in batch_plan.indexes) + + # Cleanup + for index in indexes: + index.delete(drop=True) + + +class TestBatchMigrationApplyIntegration: + """Test batch apply with real Redis.""" + + def test_batch_apply_full_flow(self, redis_url, worker_id, tmp_path): + """Test complete batch apply flow: plan -> apply -> verify.""" + unique_id = str(uuid.uuid4())[:8] + prefix = f"batch_apply:{worker_id}:{unique_id}" + indexes = [] + index_names = [] + + # Create multiple indexes + for i in range(3): + name = f"apply_batch_{unique_id}_{i}" + index = create_test_index(name, f"{prefix}_{i}", redis_url) + index.create(overwrite=True) + load_test_data(index) + indexes.append(index) + index_names.append(name) + + # Create shared patch (make title sortable) + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"sortable": True}} + ] + }, + }, + sort_keys=False, + ) + ) + + # Create batch plan + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=index_names, + schema_patch_path=str(patch_path), + redis_url=redis_url, + ) + + # Save batch plan + plan_path = tmp_path / "batch_plan.yaml" + planner.write_batch_plan(batch_plan, str(plan_path)) + + # Apply batch migration + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + executor = BatchMigrationExecutor() + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_url=redis_url, + ) + + # Verify report + assert report.status == "completed" + assert report.summary.total_indexes == 3 + assert report.summary.successful == 3 + assert report.summary.failed == 0 + + # Verify all indexes were migrated (title is now sortable) + for name in index_names: + migrated = SearchIndex.from_existing(name, redis_url=redis_url) + title_field = migrated.schema.fields.get("title") + assert title_field is not None + assert title_field.attrs.sortable is True + + # Cleanup + for name in index_names: + idx = SearchIndex.from_existing(name, redis_url=redis_url) + idx.delete(drop=True) + + def test_batch_apply_with_inapplicable_indexes( + self, redis_url, worker_id, tmp_path + ): + """Test batch apply skips indexes that don't have matching fields.""" + unique_id = str(uuid.uuid4())[:8] + prefix = f"batch_skip:{worker_id}:{unique_id}" + indexes_to_cleanup = [] + + # Create an index WITH embedding field + with_embedding = f"with_emb_{unique_id}" + idx1 = create_test_index(with_embedding, f"{prefix}_1", redis_url) + idx1.create(overwrite=True) + load_test_data(idx1) + indexes_to_cleanup.append(with_embedding) + + # Create an index WITHOUT embedding field + without_embedding = f"no_emb_{unique_id}" + idx2 = SearchIndex.from_dict( + { + "index": { + "name": without_embedding, + "prefix": f"{prefix}_2", + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "content", "type": "text"}, + ], + }, + redis_url=redis_url, + ) + idx2.create(overwrite=True) + idx2.load([{"doc_id": "1", "content": "test"}], id_field="doc_id") + indexes_to_cleanup.append(without_embedding) + + # Create patch targeting embedding field (won't apply to idx2) + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"datatype": "float16"}} + ] + }, + }, + sort_keys=False, + ) + ) + + # Create batch plan + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=[with_embedding, without_embedding], + schema_patch_path=str(patch_path), + redis_url=redis_url, + ) + + # One should be applicable, one not + applicable = [idx for idx in batch_plan.indexes if idx.applicable] + not_applicable = [idx for idx in batch_plan.indexes if not idx.applicable] + assert len(applicable) == 1 + assert len(not_applicable) == 1 + assert "embedding" in not_applicable[0].skip_reason.lower() + + # Apply + executor = BatchMigrationExecutor() + report = executor.apply( + batch_plan, + state_path=str(tmp_path / "state.yaml"), + report_dir=str(tmp_path / "reports"), + redis_url=redis_url, + ) + + assert report.summary.successful == 1 + assert report.summary.skipped == 1 + + # Cleanup + for name in indexes_to_cleanup: + idx = SearchIndex.from_existing(name, redis_url=redis_url) + idx.delete(drop=True) + + +class TestBatchMigrationResumeIntegration: + """Test batch resume functionality with real Redis.""" + + def test_resume_from_checkpoint(self, redis_url, worker_id, tmp_path): + """Test resuming a batch migration from checkpoint state.""" + unique_id = str(uuid.uuid4())[:8] + prefix = f"batch_resume:{worker_id}:{unique_id}" + index_names = [] + indexes = [] + + # Create indexes + for i in range(3): + name = f"resume_batch_{unique_id}_{i}" + index = create_test_index(name, f"{prefix}_{i}", redis_url) + index.create(overwrite=True) + load_test_data(index) + indexes.append(index) + index_names.append(name) + + # Create patch + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"sortable": True}} + ] + }, + }, + sort_keys=False, + ) + ) + + # Create batch plan + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=index_names, + schema_patch_path=str(patch_path), + redis_url=redis_url, + ) + + # Save batch plan (needed for resume) + plan_path = tmp_path / "batch_plan.yaml" + planner.write_batch_plan(batch_plan, str(plan_path)) + + # Create a checkpoint state simulating partial completion + state_path = tmp_path / "batch_state.yaml" + partial_state = { + "batch_id": batch_plan.batch_id, + "plan_path": str(plan_path), + "started_at": "2026-03-20T10:00:00Z", + "updated_at": "2026-03-20T10:01:00Z", + "completed": [ + { + "name": index_names[0], + "status": "success", + "completed_at": "2026-03-20T10:00:30Z", + } + ], + "remaining": index_names[1:], # Still need to process idx 1 and 2 + "current_index": None, + } + state_path.write_text(yaml.safe_dump(partial_state, sort_keys=False)) + + # Resume from checkpoint + executor = BatchMigrationExecutor() + report = executor.resume( + state_path=str(state_path), + batch_plan_path=str(plan_path), + report_dir=str(tmp_path / "reports"), + redis_url=redis_url, + ) + + # Should complete remaining 2 indexes + # Note: The first index was marked as succeeded in checkpoint but not actually + # migrated, so the report will show 2 successful (the ones actually processed) + assert report.summary.successful >= 2 + assert report.status == "completed" + + # Verify at least the resumed indexes were migrated + for name in index_names[1:]: + migrated = SearchIndex.from_existing(name, redis_url=redis_url) + title_field = migrated.schema.fields.get("title") + assert title_field is not None + assert title_field.attrs.sortable is True + + # Cleanup + for name in index_names: + idx = SearchIndex.from_existing(name, redis_url=redis_url) + idx.delete(drop=True) + + def test_progress_callback_called(self, redis_url, worker_id, tmp_path): + """Test that progress callback is invoked during batch apply.""" + unique_id = str(uuid.uuid4())[:8] + prefix = f"batch_progress:{worker_id}:{unique_id}" + index_names = [] + indexes = [] + + # Create indexes + for i in range(2): + name = f"progress_batch_{unique_id}_{i}" + index = create_test_index(name, f"{prefix}_{i}", redis_url) + index.create(overwrite=True) + load_test_data(index) + indexes.append(index) + index_names.append(name) + + # Create patch + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"sortable": True}} + ] + }, + }, + sort_keys=False, + ) + ) + + # Create batch plan + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=index_names, + schema_patch_path=str(patch_path), + redis_url=redis_url, + ) + + # Track progress callbacks + progress_calls = [] + + def progress_cb(name, pos, total, status): + progress_calls.append((name, pos, total, status)) + + # Apply with progress callback + executor = BatchMigrationExecutor() + executor.apply( + batch_plan, + state_path=str(tmp_path / "state.yaml"), + report_dir=str(tmp_path / "reports"), + redis_url=redis_url, + progress_callback=progress_cb, + ) + + # Verify progress was reported for each index + assert len(progress_calls) >= 2 # At least one call per index + reported_names = {call[0] for call in progress_calls} + for name in index_names: + assert name in reported_names + + # Cleanup + for name in index_names: + idx = SearchIndex.from_existing(name, redis_url=redis_url) + idx.delete(drop=True) diff --git a/tests/integration/test_field_modifier_ordering_integration.py b/tests/integration/test_field_modifier_ordering_integration.py index b26463df0..b9d609674 100644 --- a/tests/integration/test_field_modifier_ordering_integration.py +++ b/tests/integration/test_field_modifier_ordering_integration.py @@ -399,6 +399,241 @@ def test_indexmissing_enables_ismissing_query(self, client, redis_url, worker_id index.delete(drop=True) +class TestIndexEmptyIntegration: + """Integration tests for INDEXEMPTY functionality.""" + + def test_text_field_index_empty_creates_successfully( + self, client, redis_url, worker_id + ): + """Test that INDEXEMPTY on text field allows index creation.""" + skip_if_search_version_below_for_indexmissing(client) + schema_dict = { + "index": { + "name": f"test_text_empty_{worker_id}", + "prefix": f"textempty_{worker_id}:", + "storage_type": "hash", + }, + "fields": [ + { + "name": "description", + "type": "text", + "attrs": {"index_empty": True}, + } + ], + } + + schema = IndexSchema.from_dict(schema_dict) + index = SearchIndex(schema=schema, redis_url=redis_url) + index.create(overwrite=True) + + # Verify index was created + info = client.execute_command("FT.INFO", f"test_text_empty_{worker_id}") + assert info is not None + + # Create documents with empty and non-empty values + client.hset(f"textempty_{worker_id}:1", "description", "has content") + client.hset(f"textempty_{worker_id}:2", "description", "") + client.hset(f"textempty_{worker_id}:3", "description", "more content") + + # Search should work, empty string doc should be indexed + result = client.execute_command( + "FT.SEARCH", + f"test_text_empty_{worker_id}", + "*", + ) + # All 3 docs should be found + assert result[0] == 3 + + # Cleanup + client.delete( + f"textempty_{worker_id}:1", + f"textempty_{worker_id}:2", + f"textempty_{worker_id}:3", + ) + index.delete(drop=True) + + def test_tag_field_index_empty_creates_successfully( + self, client, redis_url, worker_id + ): + """Test that INDEXEMPTY on tag field allows index creation.""" + skip_if_search_version_below_for_indexmissing(client) + schema_dict = { + "index": { + "name": f"test_tag_empty_{worker_id}", + "prefix": f"tagempty_{worker_id}:", + "storage_type": "hash", + }, + "fields": [ + { + "name": "category", + "type": "tag", + "attrs": {"index_empty": True}, + } + ], + } + + schema = IndexSchema.from_dict(schema_dict) + index = SearchIndex(schema=schema, redis_url=redis_url) + index.create(overwrite=True) + + # Verify index was created + info = client.execute_command("FT.INFO", f"test_tag_empty_{worker_id}") + assert info is not None + + # Create documents with empty and non-empty values + client.hset(f"tagempty_{worker_id}:1", "category", "electronics") + client.hset(f"tagempty_{worker_id}:2", "category", "") + client.hset(f"tagempty_{worker_id}:3", "category", "books") + + # Search should work + result = client.execute_command( + "FT.SEARCH", + f"test_tag_empty_{worker_id}", + "*", + ) + # All 3 docs should be found + assert result[0] == 3 + + # Cleanup + client.delete( + f"tagempty_{worker_id}:1", + f"tagempty_{worker_id}:2", + f"tagempty_{worker_id}:3", + ) + index.delete(drop=True) + + +class TestUnfModifierIntegration: + """Integration tests for UNF (un-normalized form) modifier.""" + + def test_text_field_unf_requires_sortable(self, client, redis_url, worker_id): + """Test that UNF on text field works only when sortable is also True.""" + skip_if_search_version_below_for_indexmissing(client) + schema_dict = { + "index": { + "name": f"test_text_unf_{worker_id}", + "prefix": f"textunf_{worker_id}:", + "storage_type": "hash", + }, + "fields": [ + { + "name": "title", + "type": "text", + "attrs": {"sortable": True, "unf": True}, + } + ], + } + + schema = IndexSchema.from_dict(schema_dict) + index = SearchIndex(schema=schema, redis_url=redis_url) + + # Should create successfully + index.create(overwrite=True) + + info = client.execute_command("FT.INFO", f"test_text_unf_{worker_id}") + assert info is not None + + index.delete(drop=True) + + def test_numeric_field_unf_with_sortable(self, client, redis_url, worker_id): + """Test that UNF on numeric field works when sortable is True.""" + skip_if_search_version_below_for_indexmissing(client) + schema_dict = { + "index": { + "name": f"test_num_unf_{worker_id}", + "prefix": f"numunf_{worker_id}:", + "storage_type": "hash", + }, + "fields": [ + { + "name": "price", + "type": "numeric", + "attrs": {"sortable": True, "unf": True}, + } + ], + } + + schema = IndexSchema.from_dict(schema_dict) + index = SearchIndex(schema=schema, redis_url=redis_url) + + # Should create successfully + index.create(overwrite=True) + + info = client.execute_command("FT.INFO", f"test_num_unf_{worker_id}") + assert info is not None + + index.delete(drop=True) + + +class TestNoIndexModifierIntegration: + """Integration tests for NOINDEX modifier.""" + + def test_noindex_with_sortable_allows_sorting_not_searching( + self, client, redis_url, worker_id + ): + """Test that NOINDEX field can be sorted but not searched.""" + schema_dict = { + "index": { + "name": f"test_noindex_{worker_id}", + "prefix": f"noindex_{worker_id}:", + "storage_type": "hash", + }, + "fields": [ + { + "name": "searchable", + "type": "text", + }, + { + "name": "sort_only", + "type": "numeric", + "attrs": {"sortable": True, "no_index": True}, + }, + ], + } + + schema = IndexSchema.from_dict(schema_dict) + index = SearchIndex(schema=schema, redis_url=redis_url) + index.create(overwrite=True) + + # Add test documents + client.hset( + f"noindex_{worker_id}:1", mapping={"searchable": "hello", "sort_only": 10} + ) + client.hset( + f"noindex_{worker_id}:2", mapping={"searchable": "world", "sort_only": 5} + ) + client.hset( + f"noindex_{worker_id}:3", mapping={"searchable": "test", "sort_only": 15} + ) + + # Sorting by no_index field should work + result = client.execute_command( + "FT.SEARCH", + f"test_noindex_{worker_id}", + "*", + "SORTBY", + "sort_only", + "ASC", + ) + assert result[0] == 3 + + # Filtering by NOINDEX field should return no results + filter_result = client.execute_command( + "FT.SEARCH", + f"test_noindex_{worker_id}", + "@sort_only:[5 10]", + ) + assert filter_result[0] == 0 + + # Cleanup + client.delete( + f"noindex_{worker_id}:1", + f"noindex_{worker_id}:2", + f"noindex_{worker_id}:3", + ) + index.delete(drop=True) + + class TestFieldTypeModifierSupport: """Test that field types only support their documented modifiers.""" diff --git a/tests/integration/test_migration_comprehensive.py b/tests/integration/test_migration_comprehensive.py new file mode 100644 index 000000000..1a9d9fca8 --- /dev/null +++ b/tests/integration/test_migration_comprehensive.py @@ -0,0 +1,1689 @@ +""" +Comprehensive integration tests for all 38 supported migration operations. + +This test suite validates migrations against real Redis with a tiered validation approach: +- L1: Execution (plan.supported == True) +- L2: Data Integrity (doc_count_match == True) +- L3: Key Existence (key_sample_exists == True) +- L4: Schema Match (schema_match == True) + +Test Categories: +1. Index-Level (2): rename index, change prefix +2. Field Add (4): text, tag, numeric, geo +3. Field Remove (5): text, tag, numeric, geo, vector +4. Field Rename (5): text, tag, numeric, geo, vector +5. Base Attrs (3): sortable, no_index, index_missing +6. Text Attrs (5): weight, no_stem, phonetic_matcher, index_empty, unf +7. Tag Attrs (3): separator, case_sensitive, index_empty +8. Numeric Attrs (1): unf +9. Vector Attrs (8): algorithm, distance_metric, initial_cap, m, ef_construction, + ef_runtime, epsilon, datatype +10. JSON Storage (2): add field, rename field + +Some tests use L2-only validation due to Redis FT.INFO limitations: +- prefix change (keys renamed), HNSW params, initial_cap, phonetic_matcher, numeric unf + +Run: pytest tests/integration/test_migration_comprehensive.py -v +Spec: local_docs/index_migrator/32_integration_test_spec.md +""" + +import uuid +from typing import Any, Dict, List + +import pytest +import yaml + +from redisvl.index import SearchIndex +from redisvl.migration import MigrationExecutor, MigrationPlanner +from redisvl.migration.utils import load_migration_plan, schemas_equal +from redisvl.redis.utils import array_to_buffer + +# ============================================================================== +# Fixtures +# ============================================================================== + + +@pytest.fixture +def unique_ids(worker_id): + """Generate unique identifiers for test isolation.""" + uid = str(uuid.uuid4())[:8] + return { + "name": f"mig_test_{worker_id}_{uid}", + "prefix": f"mig_test:{worker_id}:{uid}", + } + + +@pytest.fixture +def base_schema(unique_ids): + """Base schema with all field types for testing.""" + return { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "title", "type": "text"}, + {"name": "description", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "price", "type": "numeric"}, + {"name": "location", "type": "geo"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": 4, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + +@pytest.fixture +def sample_docs(): + """Sample documents with all field types.""" + return [ + { + "doc_id": "1", + "title": "Alpha Product", + "description": "First product description", + "category": "electronics", + "price": 99.99, + "location": "-122.4194,37.7749", # SF coordinates (lon,lat) + "embedding": array_to_buffer([0.1, 0.2, 0.3, 0.4], "float32"), + }, + { + "doc_id": "2", + "title": "Beta Service", + "description": "Second service description", + "category": "software", + "price": 149.99, + "location": "-73.9857,40.7484", # NYC coordinates (lon,lat) + "embedding": array_to_buffer([0.2, 0.3, 0.4, 0.5], "float32"), + }, + { + "doc_id": "3", + "title": "Gamma Item", + "description": "", # Empty for index_empty tests + "category": "", # Empty for index_empty tests + "price": 0, + "location": "-118.2437,34.0522", # LA coordinates (lon,lat) + "embedding": array_to_buffer([0.3, 0.4, 0.5, 0.6], "float32"), + }, + ] + + +def run_migration( + redis_url: str, + tmp_path, + index_name: str, + patch: Dict[str, Any], +) -> Dict[str, Any]: + """Helper to run a migration and return results.""" + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(patch, sort_keys=False)) + + plan_path = tmp_path / "plan.yaml" + planner = MigrationPlanner() + plan = planner.create_plan( + index_name, + redis_url=redis_url, + schema_patch_path=str(patch_path), + ) + planner.write_plan(plan, str(plan_path)) + + executor = MigrationExecutor() + report = executor.apply( + load_migration_plan(str(plan_path)), + redis_url=redis_url, + ) + + return { + "plan": plan, + "report": report, + "supported": plan.diff_classification.supported, + "succeeded": report.result == "succeeded", + # Additional validation fields for granular checks + "doc_count_match": report.validation.doc_count_match, + "schema_match": report.validation.schema_match, + "key_sample_exists": report.validation.key_sample_exists, + "validation_errors": report.validation.errors, + } + + +def setup_index(redis_url: str, schema: Dict, docs: List[Dict]) -> SearchIndex: + """Create index and load documents.""" + index = SearchIndex.from_dict(schema, redis_url=redis_url) + index.create(overwrite=True) + index.load(docs, id_field="doc_id") + return index + + +def cleanup_index(index: SearchIndex): + """Clean up index after test.""" + try: + index.delete(drop=True) + except Exception: + pass + + +# ============================================================================== +# 1. Index-Level Changes +# ============================================================================== + + +class TestIndexLevelChanges: + """Tests for index-level migration operations.""" + + def test_rename_index(self, redis_url, tmp_path, base_schema, sample_docs): + """Test renaming an index.""" + index = setup_index(redis_url, base_schema, sample_docs) + old_name = base_schema["index"]["name"] + new_name = f"{old_name}_renamed" + + try: + result = run_migration( + redis_url, + tmp_path, + old_name, + {"version": 1, "changes": {"index": {"name": new_name}}}, + ) + + assert result["supported"], "Rename index should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + + # Verify new index exists + live_index = SearchIndex.from_existing(new_name, redis_url=redis_url) + assert live_index.schema.index.name == new_name + cleanup_index(live_index) + except Exception: + cleanup_index(index) + raise + + def test_change_prefix(self, redis_url, tmp_path, base_schema, sample_docs): + """Test changing the key prefix (requires key renames).""" + index = setup_index(redis_url, base_schema, sample_docs) + old_prefix = base_schema["index"]["prefix"] + new_prefix = f"{old_prefix}_newprefix" + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + {"version": 1, "changes": {"index": {"prefix": new_prefix}}}, + ) + + assert result["supported"], "Change prefix should be supported" + # Validation now handles prefix change by transforming key_sample to new prefix + assert result["succeeded"], f"Migration failed: {result['report']}" + + # Verify keys were renamed + live_index = SearchIndex.from_existing( + base_schema["index"]["name"], redis_url=redis_url + ) + assert live_index.schema.index.prefix == new_prefix + cleanup_index(live_index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 2. Field Operations - Add Fields +# ============================================================================== + + +class TestAddFields: + """Tests for adding fields of different types.""" + + def test_add_text_field(self, redis_url, tmp_path, unique_ids, sample_docs): + """Test adding a text field.""" + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [{"name": "doc_id", "type": "tag"}], + } + index = setup_index(redis_url, schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "add_fields": [{"name": "title", "type": "text"}], + }, + }, + ) + + assert result["supported"], "Add text field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_tag_field(self, redis_url, tmp_path, unique_ids, sample_docs): + """Test adding a tag field.""" + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [{"name": "doc_id", "type": "tag"}], + } + index = setup_index(redis_url, schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "category", + "type": "tag", + "attrs": {"separator": ","}, + } + ], + }, + }, + ) + + assert result["supported"], "Add tag field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_numeric_field(self, redis_url, tmp_path, unique_ids, sample_docs): + """Test adding a numeric field.""" + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [{"name": "doc_id", "type": "tag"}], + } + index = setup_index(redis_url, schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "add_fields": [{"name": "price", "type": "numeric"}], + }, + }, + ) + + assert result["supported"], "Add numeric field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_geo_field(self, redis_url, tmp_path, unique_ids, sample_docs): + """Test adding a geo field.""" + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [{"name": "doc_id", "type": "tag"}], + } + index = setup_index(redis_url, schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "add_fields": [{"name": "location", "type": "geo"}], + }, + }, + ) + + assert result["supported"], "Add geo field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 2. Field Operations - Remove Fields +# ============================================================================== + + +class TestRemoveFields: + """Tests for removing fields of different types.""" + + def test_remove_text_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test removing a text field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + {"version": 1, "changes": {"remove_fields": ["description"]}}, + ) + + assert result["supported"], "Remove text field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_remove_tag_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test removing a tag field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + {"version": 1, "changes": {"remove_fields": ["category"]}}, + ) + + assert result["supported"], "Remove tag field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_remove_numeric_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test removing a numeric field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + {"version": 1, "changes": {"remove_fields": ["price"]}}, + ) + + assert result["supported"], "Remove numeric field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_remove_geo_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test removing a geo field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + {"version": 1, "changes": {"remove_fields": ["location"]}}, + ) + + assert result["supported"], "Remove geo field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_remove_vector_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test removing a vector field (allowed but warned).""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + {"version": 1, "changes": {"remove_fields": ["embedding"]}}, + ) + + assert result["supported"], "Remove vector field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 2. Field Operations - Rename Fields +# ============================================================================== + + +class TestRenameFields: + """Tests for renaming fields of different types.""" + + def test_rename_text_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test renaming a text field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "rename_fields": [ + {"old_name": "title", "new_name": "headline"} + ], + }, + }, + ) + + assert result["supported"], "Rename text field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_rename_tag_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test renaming a tag field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "rename_fields": [{"old_name": "category", "new_name": "tags"}], + }, + }, + ) + + assert result["supported"], "Rename tag field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_rename_numeric_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test renaming a numeric field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "rename_fields": [{"old_name": "price", "new_name": "cost"}], + }, + }, + ) + + assert result["supported"], "Rename numeric field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_rename_geo_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test renaming a geo field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "rename_fields": [ + {"old_name": "location", "new_name": "coordinates"} + ], + }, + }, + ) + + assert result["supported"], "Rename geo field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_rename_vector_field(self, redis_url, tmp_path, base_schema, sample_docs): + """Test renaming a vector field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "rename_fields": [ + {"old_name": "embedding", "new_name": "vector"} + ], + }, + }, + ) + + assert result["supported"], "Rename vector field should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 3. Base Attributes (All Non-Vector Types) +# ============================================================================== + + +class TestBaseAttributes: + """Tests for base attributes: sortable, no_index, index_missing.""" + + def test_add_sortable(self, redis_url, tmp_path, base_schema, sample_docs): + """Test adding sortable attribute to a field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"sortable": True}} + ], + }, + }, + ) + + assert result["supported"], "Add sortable should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_no_index(self, redis_url, tmp_path, unique_ids, sample_docs): + """Test adding no_index attribute (store only, no searching).""" + # Need a sortable field first + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "title", "type": "text", "attrs": {"sortable": True}}, + ], + } + index = setup_index(redis_url, schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"no_index": True}} + ], + }, + }, + ) + + assert result["supported"], "Add no_index should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_index_missing(self, redis_url, tmp_path, base_schema, sample_docs): + """Test adding index_missing attribute.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"index_missing": True}} + ], + }, + }, + ) + + assert result["supported"], "Add index_missing should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 4. Text Field Attributes +# ============================================================================== + + +class TestTextAttributes: + """Tests for text field attributes: weight, no_stem, phonetic_matcher, etc.""" + + def test_change_weight(self, redis_url, tmp_path, base_schema, sample_docs): + """Test changing text field weight.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [{"name": "title", "attrs": {"weight": 2.0}}], + }, + }, + ) + + assert result["supported"], "Change weight should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_no_stem(self, redis_url, tmp_path, base_schema, sample_docs): + """Test adding no_stem attribute.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"no_stem": True}} + ], + }, + }, + ) + + assert result["supported"], "Add no_stem should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_phonetic_matcher(self, redis_url, tmp_path, base_schema, sample_docs): + """Test adding phonetic_matcher attribute.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"phonetic_matcher": "dm:en"}} + ], + }, + }, + ) + + assert result["supported"], "Add phonetic_matcher should be supported" + # phonetic_matcher is stripped from schema comparison (FT.INFO doesn't return it) + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_index_empty_text(self, redis_url, tmp_path, base_schema, sample_docs): + """Test adding index_empty to text field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "title", "attrs": {"index_empty": True}} + ], + }, + }, + ) + + assert result["supported"], "Add index_empty should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_unf_text(self, redis_url, tmp_path, unique_ids, sample_docs): + """Test adding unf (un-normalized form) to text field.""" + # UNF requires sortable + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "title", "type": "text", "attrs": {"sortable": True}}, + ], + } + index = setup_index(redis_url, schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "update_fields": [{"name": "title", "attrs": {"unf": True}}], + }, + }, + ) + + assert result["supported"], "Add UNF should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 5. Tag Field Attributes +# ============================================================================== + + +class TestTagAttributes: + """Tests for tag field attributes: separator, case_sensitive, withsuffixtrie, etc.""" + + def test_change_separator(self, redis_url, tmp_path, base_schema, sample_docs): + """Test changing tag separator.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "category", "attrs": {"separator": "|"}} + ], + }, + }, + ) + + assert result["supported"], "Change separator should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_case_sensitive(self, redis_url, tmp_path, base_schema, sample_docs): + """Test adding case_sensitive attribute.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "category", "attrs": {"case_sensitive": True}} + ], + }, + }, + ) + + assert result["supported"], "Add case_sensitive should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_add_index_empty_tag(self, redis_url, tmp_path, base_schema, sample_docs): + """Test adding index_empty to tag field.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "category", "attrs": {"index_empty": True}} + ], + }, + }, + ) + + assert result["supported"], "Add index_empty should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 6. Numeric Field Attributes +# ============================================================================== + + +class TestNumericAttributes: + """Tests for numeric field attributes: unf.""" + + def test_add_unf_numeric(self, redis_url, tmp_path, unique_ids, sample_docs): + """Test adding unf (un-normalized form) to numeric field.""" + # UNF requires sortable + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "price", "type": "numeric", "attrs": {"sortable": True}}, + ], + } + index = setup_index(redis_url, schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "update_fields": [{"name": "price", "attrs": {"unf": True}}], + }, + }, + ) + + assert result["supported"], "Add UNF to numeric should be supported" + # Redis auto-applies UNF with SORTABLE on numeric fields, so both should match + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 7. Vector Field Attributes (Index-Only Changes) +# ============================================================================== + + +class TestVectorAttributes: + """Tests for vector field attributes: algorithm, distance_metric, HNSW params, etc.""" + + def test_change_algorithm_hnsw_to_flat( + self, redis_url, tmp_path, base_schema, sample_docs + ): + """Test changing vector algorithm from HNSW to FLAT.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"algorithm": "flat"}} + ], + }, + }, + ) + + assert result["supported"], "Change algorithm should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_change_distance_metric( + self, redis_url, tmp_path, base_schema, sample_docs + ): + """Test changing distance metric.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"distance_metric": "l2"}} + ], + }, + }, + ) + + assert result["supported"], "Change distance_metric should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_change_initial_cap(self, redis_url, tmp_path, base_schema, sample_docs): + """Test changing initial_cap.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"initial_cap": 1000}} + ], + }, + }, + ) + + assert result["supported"], "Change initial_cap should be supported" + # Redis may not return initial_cap accurately in FT.INFO. + # Check doc_count_match to confirm the migration executed successfully. + assert result[ + "doc_count_match" + ], f"Migration failed: {result['validation_errors']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_change_hnsw_m(self, redis_url, tmp_path, base_schema, sample_docs): + """Test changing HNSW m parameter.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [{"name": "embedding", "attrs": {"m": 32}}], + }, + }, + ) + + assert result["supported"], "Change HNSW m should be supported" + # Redis may not return m accurately in FT.INFO. + # Check doc_count_match to confirm the migration executed successfully. + assert result[ + "doc_count_match" + ], f"Migration failed: {result['validation_errors']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_change_hnsw_ef_construction( + self, redis_url, tmp_path, base_schema, sample_docs + ): + """Test changing HNSW ef_construction parameter.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"ef_construction": 400}} + ], + }, + }, + ) + + assert result["supported"], "Change ef_construction should be supported" + # Redis may not return ef_construction accurately in FT.INFO. + # Check doc_count_match to confirm the migration executed successfully. + assert result[ + "doc_count_match" + ], f"Migration failed: {result['validation_errors']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_change_hnsw_ef_runtime( + self, redis_url, tmp_path, base_schema, sample_docs + ): + """Test changing HNSW ef_runtime parameter.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"ef_runtime": 20}} + ], + }, + }, + ) + + assert result["supported"], "Change ef_runtime should be supported" + # Redis may not return ef_runtime accurately in FT.INFO (often returns defaults). + # Check doc_count_match to confirm the migration executed successfully. + assert result[ + "doc_count_match" + ], f"Migration failed: {result['validation_errors']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_change_hnsw_epsilon(self, redis_url, tmp_path, base_schema, sample_docs): + """Test changing HNSW epsilon parameter.""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"epsilon": 0.05}} + ], + }, + }, + ) + + assert result["supported"], "Change epsilon should be supported" + # Redis may not return epsilon accurately in FT.INFO (often returns defaults). + # Check doc_count_match to confirm the migration executed successfully. + assert result[ + "doc_count_match" + ], f"Migration failed: {result['validation_errors']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_change_datatype_quantization( + self, redis_url, tmp_path, base_schema, sample_docs + ): + """Test changing vector datatype (quantization).""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"datatype": "float16"}} + ], + }, + }, + ) + + assert result["supported"], "Change datatype should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 8. JSON Storage Type Tests +# ============================================================================== + + +class TestJsonStorageType: + """Tests for migrations with JSON storage type.""" + + @pytest.fixture + def json_schema(self, unique_ids): + """Schema using JSON storage type.""" + return { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "json", + }, + "fields": [ + {"name": "doc_id", "type": "tag", "path": "$.doc_id"}, + {"name": "title", "type": "text", "path": "$.title"}, + {"name": "category", "type": "tag", "path": "$.category"}, + {"name": "price", "type": "numeric", "path": "$.price"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "hnsw", + "dims": 4, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + @pytest.fixture + def json_sample_docs(self): + """Sample JSON documents (as dicts for RedisJSON).""" + return [ + { + "doc_id": "1", + "title": "Alpha Product", + "category": "electronics", + "price": 99.99, + "embedding": [0.1, 0.2, 0.3, 0.4], + }, + { + "doc_id": "2", + "title": "Beta Service", + "category": "software", + "price": 149.99, + "embedding": [0.2, 0.3, 0.4, 0.5], + }, + ] + + def test_json_add_field( + self, redis_url, tmp_path, unique_ids, json_schema, json_sample_docs, client + ): + """Test adding a field with JSON storage.""" + index = SearchIndex.from_dict(json_schema, redis_url=redis_url) + index.create(overwrite=True) + + # Load JSON docs directly + for i, doc in enumerate(json_sample_docs): + key = f"{unique_ids['prefix']}:{i+1}" + client.json().set(key, "$", json_sample_docs[i]) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "status", + "type": "tag", + "path": "$.status", + } + ], + }, + }, + ) + + assert result["supported"], "Add field with JSON should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_json_rename_field( + self, redis_url, tmp_path, unique_ids, json_schema, json_sample_docs, client + ): + """Test renaming a field with JSON storage.""" + index = SearchIndex.from_dict(json_schema, redis_url=redis_url) + index.create(overwrite=True) + + # Load JSON docs + for i, doc in enumerate(json_sample_docs): + key = f"{unique_ids['prefix']}:{i+1}" + client.json().set(key, "$", doc) + + try: + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "rename_fields": [ + {"old_name": "title", "new_name": "headline"} + ], + }, + }, + ) + + assert result["supported"], "Rename field with JSON should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 9. Hash Indexing Failures Validation Tests +# ============================================================================== + + +class TestHashIndexingFailuresValidation: + """Tests for validation when source index has hash_indexing_failures. + + These tests verify that the migrator correctly handles indexes where some + documents fail to index (e.g., due to wrong vector dimensions). The + validation logic should compare total keys (num_docs + failures) instead + of just num_docs, so that resolved failures don't trigger false negatives. + """ + + def test_migration_with_indexing_failures_passes_validation( + self, redis_url, tmp_path, unique_ids, client + ): + """Migration should pass validation when source has hash_indexing_failures. + + Scenario: Create index with dims=4, load 3 correct docs + 2 docs with + wrong-dimension vectors. The 2 bad docs cause hash_indexing_failures. + Run a simple migration (add a text field). After migration, validation + should pass because total keys (num_docs + failures) are conserved. + """ + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": 4, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + index = setup_index( + redis_url, + schema, + [ + { + "doc_id": "1", + "title": "Good doc one", + "embedding": array_to_buffer([0.1, 0.2, 0.3, 0.4], "float32"), + }, + { + "doc_id": "2", + "title": "Good doc two", + "embedding": array_to_buffer([0.2, 0.3, 0.4, 0.5], "float32"), + }, + { + "doc_id": "3", + "title": "Good doc three", + "embedding": array_to_buffer([0.3, 0.4, 0.5, 0.6], "float32"), + }, + ], + ) + + try: + # Manually add 2 keys with wrong-dimension vectors (8-dim instead of 4) + # These will cause hash_indexing_failures + bad_vec = array_to_buffer( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], "float32" + ) + client.hset( + f"{unique_ids['prefix']}:bad1", + mapping={"title": "Bad doc one", "embedding": bad_vec}, + ) + client.hset( + f"{unique_ids['prefix']}:bad2", + mapping={"title": "Bad doc two", "embedding": bad_vec}, + ) + + # Wait briefly for indexing to settle + import time + + time.sleep(0.5) + + # Verify we have indexing failures + info = index.info() + num_docs = int(info.get("num_docs", 0)) + failures = int(info.get("hash_indexing_failures", 0)) + assert num_docs == 3, f"Expected 3 indexed docs, got {num_docs}" + assert failures == 2, f"Expected 2 indexing failures, got {failures}" + + # Run migration: add a text field (simple, non-destructive) + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "add_fields": [{"name": "category", "type": "tag"}], + }, + }, + ) + + assert result["supported"], "Add field should be supported" + assert result[ + "succeeded" + ], f"Migration failed: {result['validation_errors']}" + assert result["doc_count_match"], ( + f"Doc count should match (total keys conserved). " + f"Errors: {result['validation_errors']}" + ) + + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_quantization_resolves_failures_passes_validation( + self, redis_url, tmp_path, unique_ids, client + ): + """Quantization migration that resolves indexing failures should pass. + + Scenario: Create index with dims=4 float32, load 3 docs with float32 + vectors. Then add 2 docs with float16 vectors (same dims but wrong + byte size for float32). These cause hash_indexing_failures. Migrate to + float16 — now the previously failed docs become indexable and the + previously good docs get re-encoded. Total keys are conserved. + """ + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": 4, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + index = setup_index( + redis_url, + schema, + [ + { + "doc_id": "1", + "title": "Float32 doc one", + "embedding": array_to_buffer([0.1, 0.2, 0.3, 0.4], "float32"), + }, + { + "doc_id": "2", + "title": "Float32 doc two", + "embedding": array_to_buffer([0.2, 0.3, 0.4, 0.5], "float32"), + }, + { + "doc_id": "3", + "title": "Float32 doc three", + "embedding": array_to_buffer([0.3, 0.4, 0.5, 0.6], "float32"), + }, + ], + ) + + try: + # Add 2 docs with float16 vectors (8 bytes for 4 dims vs 16 bytes) + # These will fail to index under float32 schema due to wrong byte size + f16_vec = array_to_buffer([0.4, 0.5, 0.6, 0.7], "float16") + client.hset( + f"{unique_ids['prefix']}:f16_1", + mapping={"title": "Float16 doc one", "embedding": f16_vec}, + ) + client.hset( + f"{unique_ids['prefix']}:f16_2", + mapping={"title": "Float16 doc two", "embedding": f16_vec}, + ) + + import time + + time.sleep(0.5) + + # Verify initial state: 3 indexed + 2 failures + info = index.info() + num_docs = int(info.get("num_docs", 0)) + failures = int(info.get("hash_indexing_failures", 0)) + assert num_docs == 3, f"Expected 3 indexed docs, got {num_docs}" + assert failures == 2, f"Expected 2 indexing failures, got {failures}" + + # Run quantization migration: float32 -> float16 + # The executor re-encodes the 3 float32 docs to float16. + # After re-indexing, the 2 previously-failed float16 docs should now + # index successfully. Total keys: 5 before and 5 after. + result = run_migration( + redis_url, + tmp_path, + unique_ids["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"datatype": "float16"}} + ], + }, + }, + ) + + assert result["supported"], "Quantization should be supported" + assert result[ + "succeeded" + ], f"Migration failed: {result['validation_errors']}" + assert result["doc_count_match"], ( + f"Doc count should match (total keys conserved). " + f"Errors: {result['validation_errors']}" + ) + + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + def test_planner_warns_about_indexing_failures( + self, redis_url, tmp_path, unique_ids, client + ): + """Planner should emit a warning when source has hash_indexing_failures.""" + schema = { + "index": { + "name": unique_ids["name"], + "prefix": unique_ids["prefix"], + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 4, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + index = setup_index( + redis_url, + schema, + [ + { + "doc_id": "1", + "title": "Good doc", + "embedding": array_to_buffer([0.1, 0.2, 0.3, 0.4], "float32"), + }, + ], + ) + + try: + # Add a doc with wrong-dimension vector + bad_vec = array_to_buffer([0.1, 0.2], "float32") # 2-dim instead of 4 + client.hset( + f"{unique_ids['prefix']}:bad1", + mapping={"title": "Bad doc", "embedding": bad_vec}, + ) + + import time + + time.sleep(0.5) + + # Verify we have failures + info = index.info() + failures = int(info.get("hash_indexing_failures", 0)) + assert failures > 0, "Expected at least 1 indexing failure" + + # Create plan and check for warning + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [{"name": "status", "type": "tag"}], + }, + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + unique_ids["name"], + redis_url=redis_url, + schema_patch_path=str(patch_path), + ) + + failure_warnings = [ + w for w in plan.warnings if "hash indexing failure" in w + ] + assert len(failure_warnings) == 1, ( + f"Expected 1 indexing failure warning, got {len(failure_warnings)}. " + f"All warnings: {plan.warnings}" + ) + + cleanup_index(index) + except Exception: + cleanup_index(index) + raise diff --git a/tests/integration/test_migration_routes.py b/tests/integration/test_migration_routes.py new file mode 100644 index 000000000..5d897d010 --- /dev/null +++ b/tests/integration/test_migration_routes.py @@ -0,0 +1,331 @@ +""" +Integration tests for migration routes. + +Tests the full Apply+Validate flow for all supported migration operations. +Requires Redis 8.0+ for INT8/UINT8 datatype tests. +""" + +import uuid + +import pytest +from redis import Redis + +from redisvl.index import SearchIndex +from redisvl.migration import MigrationExecutor, MigrationPlanner +from redisvl.migration.models import FieldUpdate, SchemaPatch +from tests.conftest import skip_if_redis_version_below + + +def create_source_index(redis_url, worker_id, source_attrs): + """Helper to create a source index with specified vector attributes.""" + unique_id = str(uuid.uuid4())[:8] + index_name = f"mig_route_{worker_id}_{unique_id}" + prefix = f"mig_route:{worker_id}:{unique_id}" + + base_attrs = { + "dims": 128, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + } + base_attrs.update(source_attrs) + + index = SearchIndex.from_dict( + { + "index": {"name": index_name, "prefix": prefix, "storage_type": "json"}, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": base_attrs, + }, + ], + }, + redis_url=redis_url, + ) + index.create(overwrite=True) + return index, index_name + + +def run_migration(redis_url, index_name, patch_attrs): + """Helper to run a migration with the given patch attributes.""" + patch = SchemaPatch( + version=1, + changes={ + "add_fields": [], + "remove_fields": [], + "update_fields": [FieldUpdate(name="embedding", attrs=patch_attrs)], + "rename_fields": [], + "index": {}, + }, + ) + + planner = MigrationPlanner() + plan = planner.create_plan_from_patch( + index_name, schema_patch=patch, redis_url=redis_url + ) + + executor = MigrationExecutor() + report = executor.apply(plan, redis_url=redis_url) + return report, plan + + +class TestAlgorithmChanges: + """Test algorithm migration routes.""" + + def test_hnsw_to_flat(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"algorithm": "flat"}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + assert str(live.schema.fields["embedding"].attrs.algorithm).endswith("FLAT") + finally: + index.delete(drop=True) + + def test_flat_to_hnsw_with_params(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"algorithm": "hnsw", "m": 32, "ef_construction": 200}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + attrs = live.schema.fields["embedding"].attrs + assert str(attrs.algorithm).endswith("HNSW") + assert attrs.m == 32 + assert attrs.ef_construction == 200 + finally: + index.delete(drop=True) + + +class TestDatatypeChanges: + """Test datatype migration routes.""" + + @pytest.mark.parametrize( + "source_dtype,target_dtype", + [ + ("float32", "float16"), + ("float32", "bfloat16"), + ("float16", "float32"), + ], + ) + def test_flat_datatype_change( + self, redis_url, worker_id, source_dtype, target_dtype + ): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat", "datatype": source_dtype} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + @pytest.mark.parametrize("target_dtype", ["int8", "uint8"]) + def test_flat_quantized_datatype(self, redis_url, worker_id, target_dtype): + """Test INT8/UINT8 datatypes (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8/UINT8 requires Redis 8.0+") + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + @pytest.mark.parametrize( + "source_dtype,target_dtype", + [ + ("float32", "float16"), + ("float32", "bfloat16"), + ], + ) + def test_hnsw_datatype_change( + self, redis_url, worker_id, source_dtype, target_dtype + ): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw", "datatype": source_dtype} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + @pytest.mark.parametrize("target_dtype", ["int8", "uint8"]) + def test_hnsw_quantized_datatype(self, redis_url, worker_id, target_dtype): + """Test INT8/UINT8 datatypes with HNSW (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8/UINT8 requires Redis 8.0+") + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + +class TestDistanceMetricChanges: + """Test distance metric migration routes.""" + + @pytest.mark.parametrize( + "source_metric,target_metric", + [ + ("cosine", "l2"), + ("cosine", "ip"), + ("l2", "cosine"), + ("ip", "l2"), + ], + ) + def test_distance_metric_change( + self, redis_url, worker_id, source_metric, target_metric + ): + index, index_name = create_source_index( + redis_url, + worker_id, + {"algorithm": "flat", "distance_metric": source_metric}, + ) + try: + report, _ = run_migration( + redis_url, index_name, {"distance_metric": target_metric} + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + +class TestHNSWTuningParameters: + """Test HNSW parameter tuning routes.""" + + def test_hnsw_m_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"m": 64}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + assert live.schema.fields["embedding"].attrs.m == 64 + finally: + index.delete(drop=True) + + def test_hnsw_ef_construction_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"ef_construction": 500}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + assert live.schema.fields["embedding"].attrs.ef_construction == 500 + finally: + index.delete(drop=True) + + def test_hnsw_ef_runtime_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"ef_runtime": 50}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + def test_hnsw_epsilon_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"epsilon": 0.1}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + def test_hnsw_all_params_combined(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"m": 48, "ef_construction": 300, "ef_runtime": 75, "epsilon": 0.05}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + attrs = live.schema.fields["embedding"].attrs + assert attrs.m == 48 + assert attrs.ef_construction == 300 + finally: + index.delete(drop=True) + + +class TestCombinedChanges: + """Test combined migration routes (multiple changes at once).""" + + def test_flat_to_hnsw_with_datatype_and_metric(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"algorithm": "hnsw", "datatype": "float16", "distance_metric": "l2"}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + attrs = live.schema.fields["embedding"].attrs + assert str(attrs.algorithm).endswith("HNSW") + finally: + index.delete(drop=True) + + def test_flat_to_hnsw_with_int8(self, redis_url, worker_id): + """Combined algorithm + quantized datatype (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8 requires Redis 8.0+") + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"algorithm": "hnsw", "datatype": "int8"}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) diff --git a/tests/integration/test_migration_v1.py b/tests/integration/test_migration_v1.py new file mode 100644 index 000000000..88720cb94 --- /dev/null +++ b/tests/integration/test_migration_v1.py @@ -0,0 +1,129 @@ +import uuid + +import yaml + +from redisvl.index import SearchIndex +from redisvl.migration import MigrationExecutor, MigrationPlanner, MigrationValidator +from redisvl.migration.utils import load_migration_plan, schemas_equal +from redisvl.redis.utils import array_to_buffer + + +def test_drop_recreate_plan_apply_validate_flow(redis_url, worker_id, tmp_path): + unique_id = str(uuid.uuid4())[:8] + index_name = f"migration_v1_{worker_id}_{unique_id}" + prefix = f"migration_v1:{worker_id}:{unique_id}" + + source_index = SearchIndex.from_dict( + { + "index": { + "name": index_name, + "prefix": prefix, + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "title", "type": "text"}, + {"name": "price", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + redis_url=redis_url, + ) + + docs = [ + { + "doc_id": "1", + "title": "alpha", + "price": 1, + "category": "news", + "embedding": array_to_buffer([0.1, 0.2, 0.3], "float32"), + }, + { + "doc_id": "2", + "title": "beta", + "price": 2, + "category": "sports", + "embedding": array_to_buffer([0.2, 0.1, 0.4], "float32"), + }, + ] + + source_index.create(overwrite=True) + source_index.load(docs, id_field="doc_id") + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "category", + "type": "tag", + "attrs": {"separator": ","}, + } + ], + "remove_fields": ["price"], + "update_fields": [{"name": "title", "attrs": {"sortable": True}}], + }, + }, + sort_keys=False, + ) + ) + + plan_path = tmp_path / "migration_plan.yaml" + planner = MigrationPlanner() + plan = planner.create_plan( + index_name, + redis_url=redis_url, + schema_patch_path=str(patch_path), + ) + assert plan.diff_classification.supported is True + planner.write_plan(plan, str(plan_path)) + + query_check_path = tmp_path / "query_checks.yaml" + query_check_path.write_text( + yaml.safe_dump({"fetch_ids": ["1", "2"]}, sort_keys=False) + ) + + executor = MigrationExecutor() + report = executor.apply( + load_migration_plan(str(plan_path)), + redis_url=redis_url, + query_check_file=str(query_check_path), + ) + + try: + assert report.result == "succeeded" + assert report.validation.schema_match is True + assert report.validation.doc_count_match is True + assert report.validation.key_sample_exists is True + assert report.validation.indexing_failures_delta == 0 + assert not report.validation.errors + assert report.benchmark_summary.documents_indexed_per_second is not None + + live_index = SearchIndex.from_existing(index_name, redis_url=redis_url) + assert schemas_equal(live_index.schema.to_dict(), plan.merged_target_schema) + + validator = MigrationValidator() + validation, _target_info, _duration = validator.validate( + load_migration_plan(str(plan_path)), + redis_url=redis_url, + query_check_file=str(query_check_path), + ) + assert validation.schema_match is True + assert validation.doc_count_match is True + assert validation.key_sample_exists is True + assert not validation.errors + finally: + live_index = SearchIndex.from_existing(index_name, redis_url=redis_url) + live_index.delete(drop=True) diff --git a/tests/unit/test_async_migration_executor.py b/tests/unit/test_async_migration_executor.py new file mode 100644 index 000000000..2360a4eff --- /dev/null +++ b/tests/unit/test_async_migration_executor.py @@ -0,0 +1,480 @@ +"""Unit tests for migration executors and disk space estimator. + +These tests mirror the sync MigrationExecutor patterns but use async/await. +Also includes pure-calculation tests for estimate_disk_space(). +""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from redisvl.migration import AsyncMigrationExecutor, MigrationExecutor +from redisvl.migration.models import ( + DiffClassification, + KeyspaceSnapshot, + MigrationPlan, + SourceSnapshot, + ValidationPolicy, + _format_bytes, +) +from redisvl.migration.utils import ( + build_scan_match_patterns, + estimate_disk_space, + normalize_keys, +) + + +def _make_basic_plan(): + """Create a basic migration plan for testing.""" + return MigrationPlan( + mode="drop_recreate", + source=SourceSnapshot( + index_name="test_index", + keyspace=KeyspaceSnapshot( + storage_type="hash", + prefixes=["test"], + key_separator=":", + key_sample=["test:1", "test:2"], + ), + schema_snapshot={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + stats_snapshot={"num_docs": 2}, + ), + requested_changes={}, + merged_target_schema={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", # Changed from flat + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + diff_classification=DiffClassification( + supported=True, + blocked_reasons=[], + ), + validation=ValidationPolicy( + require_doc_count_match=True, + ), + warnings=["Index downtime is required"], + ) + + +def test_async_executor_instantiation(): + """Test AsyncMigrationExecutor can be instantiated.""" + executor = AsyncMigrationExecutor() + assert executor is not None + assert executor.validator is not None + + +def test_async_executor_with_validator(): + """Test AsyncMigrationExecutor with custom validator.""" + from redisvl.migration import AsyncMigrationValidator + + custom_validator = AsyncMigrationValidator() + executor = AsyncMigrationExecutor(validator=custom_validator) + assert executor.validator is custom_validator + + +@pytest.mark.asyncio +async def test_async_executor_handles_unsupported_plan(): + """Test executor returns error report for unsupported plan.""" + plan = _make_basic_plan() + plan.diff_classification.supported = False + plan.diff_classification.blocked_reasons = ["Test blocked reason"] + + executor = AsyncMigrationExecutor() + + # The executor doesn't raise an error - it returns a report with errors + report = await executor.apply(plan, redis_url="redis://localhost:6379") + assert report.result == "failed" + assert "Test blocked reason" in report.validation.errors + + +@pytest.mark.asyncio +async def test_async_executor_validates_redis_url(): + """Test executor requires redis_url or redis_client.""" + plan = _make_basic_plan() + executor = AsyncMigrationExecutor() + + # The executor should raise an error internally when trying to connect + # but let's verify it doesn't crash before it tries to apply + # For a proper test, we'd need to mock AsyncSearchIndex.from_existing + # For now, we just verify the executor is created + assert executor is not None + + +# ============================================================================= +# Disk Space Estimator Tests +# ============================================================================= + + +def _make_quantize_plan( + source_dtype="float32", + target_dtype="float16", + dims=3072, + doc_count=100_000, + storage_type="hash", +): + """Helper to create a migration plan with a vector datatype change.""" + return MigrationPlan( + mode="drop_recreate", + source=SourceSnapshot( + index_name="test_index", + keyspace=KeyspaceSnapshot( + storage_type=storage_type, + prefixes=["test"], + key_separator=":", + ), + schema_snapshot={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": storage_type, + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": dims, + "distance_metric": "cosine", + "datatype": source_dtype, + }, + }, + ], + }, + stats_snapshot={"num_docs": doc_count}, + ), + requested_changes={}, + merged_target_schema={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": storage_type, + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": dims, + "distance_metric": "cosine", + "datatype": target_dtype, + }, + }, + ], + }, + diff_classification=DiffClassification(supported=True, blocked_reasons=[]), + validation=ValidationPolicy(require_doc_count_match=True), + ) + + +def test_estimate_fp32_to_fp16(): + """FP32->FP16 with 3072 dims, 100K docs should produce expected byte counts.""" + plan = _make_quantize_plan("float32", "float16", dims=3072, doc_count=100_000) + est = estimate_disk_space(plan) + + assert est.has_quantization is True + assert len(est.vector_fields) == 1 + vf = est.vector_fields[0] + assert vf.source_bytes_per_doc == 3072 * 4 # 12288 + assert vf.target_bytes_per_doc == 3072 * 2 # 6144 + + assert est.total_source_vector_bytes == 100_000 * 12288 + assert est.total_target_vector_bytes == 100_000 * 6144 + assert est.memory_savings_after_bytes == 100_000 * (12288 - 6144) + + # RDB = source * 0.95 + assert est.rdb_snapshot_disk_bytes == int(100_000 * 12288 * 0.95) + # COW = full source + assert est.rdb_cow_memory_if_concurrent_bytes == 100_000 * 12288 + # AOF disabled by default + assert est.aof_enabled is False + assert est.aof_growth_bytes == 0 + assert est.total_new_disk_bytes == est.rdb_snapshot_disk_bytes + + +def test_estimate_with_aof_enabled(): + """AOF growth should include RESP overhead per HSET.""" + plan = _make_quantize_plan("float32", "float16", dims=3072, doc_count=100_000) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.aof_enabled is True + target_vec_size = 3072 * 2 + expected_aof = 100_000 * (target_vec_size + 114) # 114 = HSET overhead + assert est.aof_growth_bytes == expected_aof + assert est.total_new_disk_bytes == est.rdb_snapshot_disk_bytes + expected_aof + + +def test_estimate_json_storage_aof(): + """JSON storage quantization should not report in-place rewrite costs.""" + plan = _make_quantize_plan( + "float32", "float16", dims=128, doc_count=1000, storage_type="json" + ) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.has_quantization is False + assert est.aof_growth_bytes == 0 + assert est.total_new_disk_bytes == 0 + + +def test_estimate_no_quantization(): + """Same dtype source and target should produce empty estimate.""" + plan = _make_quantize_plan("float32", "float32", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + + assert est.has_quantization is False + assert len(est.vector_fields) == 0 + assert est.total_new_disk_bytes == 0 + assert est.memory_savings_after_bytes == 0 + + +def test_estimate_fp32_to_int8(): + """FP32->INT8 should use 1 byte per element.""" + plan = _make_quantize_plan("float32", "int8", dims=768, doc_count=50_000) + est = estimate_disk_space(plan) + + assert est.vector_fields[0].source_bytes_per_doc == 768 * 4 + assert est.vector_fields[0].target_bytes_per_doc == 768 * 1 + assert est.memory_savings_after_bytes == 50_000 * 768 * 3 + + +def test_estimate_summary_with_quantization(): + """Summary string should contain key information.""" + plan = _make_quantize_plan("float32", "float16", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + summary = est.summary() + + assert "Pre-migration disk space estimate" in summary + assert "test_index" in summary + assert "1,000 documents" in summary + assert "float32 -> float16" in summary + assert "RDB snapshot" in summary + assert "reduction" in summary or "memory savings" in summary + + +def test_estimate_summary_no_quantization(): + """Summary for non-quantization migration should say no disk needed.""" + plan = _make_quantize_plan("float32", "float32", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + summary = est.summary() + + assert "No vector quantization" in summary + + +def test_format_bytes_gb(): + assert _format_bytes(1_073_741_824) == "1.00 GB" + assert _format_bytes(2_147_483_648) == "2.00 GB" + + +def test_format_bytes_mb(): + assert _format_bytes(1_048_576) == "1.0 MB" + assert _format_bytes(10_485_760) == "10.0 MB" + + +def test_format_bytes_kb(): + assert _format_bytes(1024) == "1.0 KB" + assert _format_bytes(2048) == "2.0 KB" + + +def test_format_bytes_bytes(): + assert _format_bytes(500) == "500 bytes" + assert _format_bytes(0) == "0 bytes" + + +def test_savings_pct(): + """Verify savings percentage calculation.""" + plan = _make_quantize_plan("float32", "float16", dims=128, doc_count=100) + est = estimate_disk_space(plan) + # FP32->FP16 = 50% savings + assert est._savings_pct() == 50 + + +# ============================================================================= +# TDD RED Phase: Idempotent Dtype Detection Tests +# ============================================================================= +# These test detect_vector_dtype() and is_already_quantized() which inspect +# raw vector bytes to determine whether a key needs conversion or can be skipped. + + +def test_detect_dtype_float32_by_size(): + """A 3072-dim vector stored as FP32 should be 12288 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.random.randn(3072).astype(np.float32).tobytes() + detected = detect_vector_dtype(vec, expected_dims=3072) + assert detected == "float32" + + +def test_detect_dtype_float16_by_size(): + """A 3072-dim vector stored as FP16 should be 6144 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.random.randn(3072).astype(np.float16).tobytes() + detected = detect_vector_dtype(vec, expected_dims=3072) + assert detected == "float16" + + +def test_detect_dtype_int8_by_size(): + """A 768-dim vector stored as INT8 should be 768 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.zeros(768, dtype=np.int8).tobytes() + detected = detect_vector_dtype(vec, expected_dims=768) + assert detected == "int8" + + +def test_detect_dtype_bfloat16_by_size(): + """A 768-dim bfloat16 vector should be 1536 bytes (same as float16).""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + # bfloat16 and float16 are both 2 bytes per element + vec = np.random.randn(768).astype(np.float16).tobytes() + detected = detect_vector_dtype(vec, expected_dims=768) + # Cannot distinguish float16 from bfloat16 by size alone; returns "float16" + assert detected in ("float16", "bfloat16") + + +def test_detect_dtype_empty_returns_none(): + """Empty bytes should return None.""" + from redisvl.migration.reliability import detect_vector_dtype + + assert detect_vector_dtype(b"", expected_dims=128) is None + + +def test_detect_dtype_unknown_size(): + """Bytes that don't match any known dtype should return None.""" + from redisvl.migration.reliability import detect_vector_dtype + + # 7 bytes doesn't match any dtype for 3 dims + assert detect_vector_dtype(b"\x00" * 7, expected_dims=3) is None + + +def test_is_already_quantized_skip(): + """If source is float32 and vector is already float16, should return True.""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="float16" + ) + assert result is True + + +def test_is_already_quantized_needs_conversion(): + """If source is float32 and vector IS float32, should return False.""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float32).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="float16" + ) + assert result is False + + +def test_is_already_quantized_bfloat16_target(): + """If target is bfloat16 and vector is 2-bytes-per-element, should return True. + + bfloat16 and float16 share the same byte width (2 bytes per element) + and are treated as the same dtype family for idempotent detection. + """ + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="bfloat16" + ) + assert result is True + + +def test_is_already_quantized_uint8_target(): + """If target is uint8 and vector is 1-byte-per-element, should return True. + + uint8 and int8 share the same byte width (1 byte per element) + and are treated as the same dtype family for idempotent detection. + """ + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.int8).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="uint8" + ) + assert result is True + + +def test_is_already_quantized_same_width_float16_to_bfloat16(): + """float16 -> bfloat16 should NOT be skipped (same byte width, different encoding).""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float16", target_dtype="bfloat16" + ) + assert result is False + + +def test_is_already_quantized_same_width_int8_to_uint8(): + """int8 -> uint8 should NOT be skipped (same byte width, different encoding).""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.int8).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="int8", target_dtype="uint8" + ) + assert result is False diff --git a/tests/unit/test_async_migration_planner.py b/tests/unit/test_async_migration_planner.py new file mode 100644 index 000000000..93ce3d49d --- /dev/null +++ b/tests/unit/test_async_migration_planner.py @@ -0,0 +1,319 @@ +"""Unit tests for AsyncMigrationPlanner. + +These tests mirror the sync MigrationPlanner tests but use async/await patterns. +""" + +from fnmatch import fnmatch + +import pytest +import yaml + +from redisvl.migration import AsyncMigrationPlanner, MigrationPlanner +from redisvl.schema.schema import IndexSchema + + +class AsyncDummyClient: + """Async mock Redis client for testing.""" + + def __init__(self, keys): + self.keys = keys + + async def scan(self, cursor=0, match=None, count=None): + matched = [] + for key in self.keys: + decoded_key = key.decode() if isinstance(key, bytes) else str(key) + if match is None or fnmatch(decoded_key, match): + matched.append(key) + return 0, matched + + +class AsyncDummyIndex: + """Async mock SearchIndex for testing.""" + + def __init__(self, schema, stats, keys): + self.schema = schema + self._stats = stats + self._client = AsyncDummyClient(keys) + + @property + def client(self): + return self._client + + async def info(self): + return self._stats + + +def _make_source_schema(): + return IndexSchema.from_dict( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +@pytest.mark.asyncio +async def test_async_create_plan_from_schema_patch(monkeypatch, tmp_path): + """Test async planner creates valid plan from schema patch.""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex( + source_schema, + {"num_docs": 2, "indexing": False}, + [b"docs:1", b"docs:2", b"docs:3"], + ) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "category", + "type": "tag", + "path": "$.category", + "attrs": {"separator": ","}, + } + ], + "remove_fields": ["price"], + "update_fields": [ + { + "name": "title", + "options": {"sortable": True}, + } + ], + }, + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner(key_sample_limit=2) + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + assert plan.diff_classification.supported is True + assert plan.source.index_name == "docs" + assert plan.source.keyspace.storage_type == "json" + assert plan.source.keyspace.prefixes == ["docs"] + assert plan.source.keyspace.key_separator == ":" + assert plan.source.keyspace.key_sample == ["docs:1", "docs:2"] + assert plan.warnings == ["Index downtime is required"] + + merged_fields = { + field["name"]: field for field in plan.merged_target_schema["fields"] + } + assert plan.merged_target_schema["index"]["prefix"] == "docs" + assert merged_fields["title"]["attrs"]["sortable"] is True + assert "price" not in merged_fields + assert merged_fields["category"]["type"] == "tag" + + # Test write_plan works (delegates to sync) + plan_path = tmp_path / "migration_plan.yaml" + planner.write_plan(plan, str(plan_path)) + written_plan = yaml.safe_load(plan_path.read_text()) + assert written_plan["mode"] == "drop_recreate" + assert written_plan["diff_classification"]["supported"] is True + + +@pytest.mark.asyncio +async def test_async_planner_datatype_change_allowed(monkeypatch, tmp_path): + """Changing vector datatype (quantization) is allowed - executor will re-encode.""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + {"name": "price", "type": "numeric", "path": "$.price"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float16", # Changed from float32 + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + # Verify datatype changes are detected + datatype_changes = MigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, plan.merged_target_schema + ) + assert "embedding" in datatype_changes + assert datatype_changes["embedding"]["source"] == "float32" + assert datatype_changes["embedding"]["target"] == "float16" + + +@pytest.mark.asyncio +async def test_async_planner_algorithm_change_allowed(monkeypatch, tmp_path): + """Changing vector algorithm is allowed (index-only change).""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + {"name": "price", "type": "numeric", "path": "$.price"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "hnsw", # Changed from flat + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + +@pytest.mark.asyncio +async def test_async_planner_prefix_change_is_supported(monkeypatch, tmp_path): + """Prefix change is supported: executor will rename keys.""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs_v2", # Changed prefix + "key_separator": ":", + "storage_type": "json", + }, + "fields": source_schema.to_dict()["fields"], + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + # Prefix change is now supported + assert plan.diff_classification.supported is True + assert plan.rename_operations.change_prefix == "docs_v2" + # Should have a warning about key renaming + assert any("prefix" in w.lower() for w in plan.warnings) diff --git a/tests/unit/test_batch_migration.py b/tests/unit/test_batch_migration.py new file mode 100644 index 000000000..38cb91763 --- /dev/null +++ b/tests/unit/test_batch_migration.py @@ -0,0 +1,1366 @@ +""" +Unit tests for BatchMigrationPlanner and BatchMigrationExecutor. + +Tests use mocked Redis clients to verify: +- Pattern matching and index selection +- Applicability checking +- Checkpoint persistence and resume +- Failure policies +- Progress callbacks +""" + +from fnmatch import fnmatch +from typing import Any, Dict, List +from unittest.mock import Mock + +import pytest +import yaml + +from redisvl.migration import ( + BatchMigrationExecutor, + BatchMigrationPlanner, + BatchPlan, + BatchState, + SchemaPatch, +) +from redisvl.migration.models import BatchIndexEntry, BatchIndexState +from redisvl.schema.schema import IndexSchema + +# ============================================================================= +# Test Fixtures and Mock Helpers +# ============================================================================= + + +class MockRedisClient: + """Mock Redis client for batch migration tests.""" + + def __init__(self, indexes: List[str] = None, keys: Dict[str, List[str]] = None): + self.indexes = indexes or [] + self.keys = keys or {} + self._data: Dict[str, Dict[str, bytes]] = {} + + def execute_command(self, *args, **kwargs): + if args[0] == "FT._LIST": + return [idx.encode() for idx in self.indexes] + raise NotImplementedError(f"Command not mocked: {args}") + + def scan(self, cursor=0, match=None, count=None): + matched = [] + all_keys = [] + for prefix_keys in self.keys.values(): + all_keys.extend(prefix_keys) + + for key in all_keys: + decoded_key = key.decode() if isinstance(key, bytes) else str(key) + if match is None or fnmatch(decoded_key, match): + matched.append(key if isinstance(key, bytes) else key.encode()) + return 0, matched + + def hget(self, key, field): + return self._data.get(key, {}).get(field) + + def hset(self, key, field, value): + if key not in self._data: + self._data[key] = {} + self._data[key][field] = value + + def pipeline(self): + return MockPipeline(self) + + +class MockPipeline: + """Mock Redis pipeline.""" + + def __init__(self, client: MockRedisClient): + self._client = client + self._commands: List[tuple] = [] + + def hset(self, key, field, value): + self._commands.append(("hset", key, field, value)) + return self + + def execute(self): + results = [] + for cmd in self._commands: + if cmd[0] == "hset": + self._client.hset(cmd[1], cmd[2], cmd[3]) + results.append(1) + self._commands = [] + return results + + +def make_dummy_index(name: str, schema_dict: Dict[str, Any], stats: Dict[str, Any]): + """Create a mock SearchIndex for testing.""" + mock_index = Mock() + mock_index.name = name + mock_index.schema = IndexSchema.from_dict(schema_dict) + mock_index._redis_client = MockRedisClient() + mock_index.client = mock_index._redis_client + mock_index.info = Mock(return_value=stats) + mock_index.delete = Mock() + mock_index.create = Mock() + mock_index.exists = Mock(return_value=True) + return mock_index + + +def make_test_schema(name: str, prefix: str = None, dims: int = 3) -> Dict[str, Any]: + """Create a test schema dictionary.""" + return { + "index": { + "name": name, + "prefix": prefix or name, + "key_separator": ":", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": dims, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + +def make_shared_patch( + update_fields: List[Dict] = None, + add_fields: List[Dict] = None, + remove_fields: List[str] = None, +) -> Dict[str, Any]: + """Create a test schema patch dictionary.""" + return { + "version": 1, + "changes": { + "update_fields": update_fields or [], + "add_fields": add_fields or [], + "remove_fields": remove_fields or [], + "index": {}, + }, + } + + +def make_batch_plan( + batch_id: str, + indexes: List[BatchIndexEntry], + failure_policy: str = "fail_fast", + requires_quantization: bool = False, +) -> BatchPlan: + """Create a BatchPlan with default values for testing.""" + return BatchPlan( + batch_id=batch_id, + shared_patch=SchemaPatch( + version=1, + changes={"update_fields": [], "add_fields": [], "remove_fields": []}, + ), + indexes=indexes, + requires_quantization=requires_quantization, + failure_policy=failure_policy, + created_at="2026-03-20T10:00:00Z", + ) + + +# ============================================================================= +# BatchMigrationPlanner Tests +# ============================================================================= + + +class TestBatchMigrationPlannerPatternMatching: + """Test pattern matching for index discovery.""" + + def test_pattern_matches_multiple_indexes(self, monkeypatch, tmp_path): + """Pattern should match multiple indexes.""" + mock_client = MockRedisClient( + indexes=["products_idx", "users_idx", "orders_idx", "logs_idx"] + ) + + def mock_list_indexes(**kwargs): + return ["products_idx", "users_idx", "orders_idx", "logs_idx"] + + monkeypatch.setattr( + "redisvl.migration.batch_planner.list_indexes", mock_list_indexes + ) + + # Mock from_existing for each index + def mock_from_existing(name, **kwargs): + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 10, "indexing": False} + ) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + make_shared_patch( + update_fields=[ + {"name": "embedding", "attrs": {"algorithm": "hnsw"}} + ] + ) + ) + ) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + pattern="*_idx", + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + assert len(batch_plan.indexes) == 4 + assert all(idx.name.endswith("_idx") for idx in batch_plan.indexes) + + def test_pattern_no_matches_raises_error(self, monkeypatch, tmp_path): + """Empty pattern results should raise ValueError.""" + mock_client = MockRedisClient(indexes=["products", "users"]) + + def mock_list_indexes(**kwargs): + return ["products", "users"] + + monkeypatch.setattr( + "redisvl.migration.batch_planner.list_indexes", mock_list_indexes + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + with pytest.raises(ValueError, match="No indexes found"): + planner.create_batch_plan( + pattern="*_idx", # Won't match anything + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + def test_pattern_with_special_characters(self, monkeypatch, tmp_path): + """Pattern matching with special characters in index names.""" + mock_client = MockRedisClient( + indexes=["app:prod:idx", "app:dev:idx", "app:staging:idx"] + ) + + def mock_list_indexes(**kwargs): + return ["app:prod:idx", "app:dev:idx", "app:staging:idx"] + + monkeypatch.setattr( + "redisvl.migration.batch_planner.list_indexes", mock_list_indexes + ) + + def mock_from_existing(name, **kwargs): + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 5, "indexing": False} + ) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + pattern="app:*:idx", + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + assert len(batch_plan.indexes) == 3 + + +class TestBatchMigrationPlannerIndexSelection: + """Test explicit index list selection.""" + + def test_explicit_index_list(self, monkeypatch, tmp_path): + """Explicit index list should be used directly.""" + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3", "idx4", "idx5"]) + + def mock_from_existing(name, **kwargs): + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 10, "indexing": False} + ) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=["idx1", "idx3", "idx5"], + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + assert len(batch_plan.indexes) == 3 + assert [idx.name for idx in batch_plan.indexes] == ["idx1", "idx3", "idx5"] + + def test_duplicate_index_names(self, monkeypatch, tmp_path): + """Duplicate index names in list should be preserved (user intent).""" + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + def mock_from_existing(name, **kwargs): + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 10, "indexing": False} + ) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + # Duplicates are deduplicated to avoid migrating the same index twice + batch_plan = planner.create_batch_plan( + indexes=["idx1", "idx1", "idx2"], + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + assert len(batch_plan.indexes) == 2 + assert [e.name for e in batch_plan.indexes] == ["idx1", "idx2"] + + def test_non_existent_index(self, monkeypatch, tmp_path): + """Non-existent index should be marked as not applicable.""" + mock_client = MockRedisClient(indexes=["idx1"]) + + def mock_from_existing(name, **kwargs): + if name == "idx1": + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 10, "indexing": False} + ) + raise Exception(f"Index '{name}' not found") + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=["idx1", "nonexistent"], + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + assert len(batch_plan.indexes) == 2 + assert batch_plan.indexes[0].applicable is True + assert batch_plan.indexes[1].applicable is False + assert "not found" in batch_plan.indexes[1].skip_reason.lower() + + def test_indexes_from_file(self, monkeypatch, tmp_path): + """Load index names from file.""" + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + def mock_from_existing(name, **kwargs): + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 10, "indexing": False} + ) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + # Create indexes file + indexes_file = tmp_path / "indexes.txt" + indexes_file.write_text("idx1\n# comment\nidx2\n\nidx3\n") + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes_file=str(indexes_file), + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + assert len(batch_plan.indexes) == 3 + assert [idx.name for idx in batch_plan.indexes] == ["idx1", "idx2", "idx3"] + + +class TestBatchMigrationPlannerApplicability: + """Test applicability checking for shared patches.""" + + def test_missing_field_marks_not_applicable(self, monkeypatch, tmp_path): + """Index missing field in update_fields should be marked not applicable.""" + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + def mock_from_existing(name, **kwargs): + if name == "idx1": + # Has embedding field + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 10, "indexing": False} + ) + # idx2 - no embedding field + schema = { + "index": {"name": name, "prefix": name, "storage_type": "hash"}, + "fields": [{"name": "title", "type": "text"}], + } + return make_dummy_index(name, schema, {"num_docs": 5, "indexing": False}) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + make_shared_patch( + update_fields=[ + {"name": "embedding", "attrs": {"algorithm": "hnsw"}} + ] + ) + ) + ) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=["idx1", "idx2"], + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + idx1_entry = next(e for e in batch_plan.indexes if e.name == "idx1") + idx2_entry = next(e for e in batch_plan.indexes if e.name == "idx2") + + assert idx1_entry.applicable is True + assert idx2_entry.applicable is False + assert "embedding" in idx2_entry.skip_reason.lower() + + def test_field_already_exists_marks_not_applicable(self, monkeypatch, tmp_path): + """Adding field that already exists should mark not applicable.""" + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + def mock_from_existing(name, **kwargs): + schema = make_test_schema(name) + # Add 'category' field to idx2 + if name == "idx2": + schema["fields"].append({"name": "category", "type": "tag"}) + return make_dummy_index(name, schema, {"num_docs": 10, "indexing": False}) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + make_shared_patch(add_fields=[{"name": "category", "type": "tag"}]) + ) + ) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=["idx1", "idx2"], + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + idx1_entry = next(e for e in batch_plan.indexes if e.name == "idx1") + idx2_entry = next(e for e in batch_plan.indexes if e.name == "idx2") + + assert idx1_entry.applicable is True + assert idx2_entry.applicable is False + assert "category" in idx2_entry.skip_reason.lower() + + def test_blocked_change_marks_not_applicable(self, monkeypatch, tmp_path): + """Blocked changes (e.g., dims change) should mark not applicable.""" + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + def mock_from_existing(name, **kwargs): + dims = 3 if name == "idx1" else 768 + return make_dummy_index( + name, + make_test_schema(name, dims=dims), + {"num_docs": 10, "indexing": False}, + ) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + make_shared_patch( + update_fields=[ + {"name": "embedding", "attrs": {"dims": 1536}} # Change dims + ] + ) + ) + ) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=["idx1", "idx2"], + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + # Both should be not applicable because dims change is blocked + for entry in batch_plan.indexes: + assert entry.applicable is False + assert "dims" in entry.skip_reason.lower() + + +class TestBatchMigrationPlannerQuantization: + """Test quantization detection in batch plans.""" + + def test_detects_quantization_required(self, monkeypatch, tmp_path): + """Batch plan should detect when quantization is required.""" + mock_client = MockRedisClient(indexes=["idx1"]) + + def mock_from_existing(name, **kwargs): + return make_dummy_index( + name, make_test_schema(name), {"num_docs": 10, "indexing": False} + ) + + monkeypatch.setattr( + "redisvl.migration.batch_planner.SearchIndex.from_existing", + mock_from_existing, + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", mock_from_existing + ) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text( + yaml.safe_dump( + make_shared_patch( + update_fields=[ + {"name": "embedding", "attrs": {"datatype": "float16"}} + ] + ) + ) + ) + + planner = BatchMigrationPlanner() + batch_plan = planner.create_batch_plan( + indexes=["idx1"], + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + assert batch_plan.requires_quantization is True + + +class TestBatchMigrationPlannerEdgeCases: + """Test edge cases and error handling.""" + + def test_multiple_source_specification_error(self, tmp_path): + """Should error when multiple source types are specified.""" + mock_client = MockRedisClient(indexes=["idx1"]) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + with pytest.raises(ValueError, match="only one of"): + planner.create_batch_plan( + indexes=["idx1"], + pattern="*", # Can't specify both + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + def test_no_source_specification_error(self, tmp_path): + """Should error when no source is specified.""" + mock_client = MockRedisClient(indexes=["idx1"]) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + with pytest.raises(ValueError, match="Must provide one of"): + planner.create_batch_plan( + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + def test_missing_patch_file_error(self): + """Should error when patch file doesn't exist.""" + mock_client = MockRedisClient(indexes=["idx1"]) + + planner = BatchMigrationPlanner() + with pytest.raises(FileNotFoundError): + planner.create_batch_plan( + indexes=["idx1"], + schema_patch_path="/nonexistent/patch.yaml", + redis_client=mock_client, + ) + + def test_missing_indexes_file_error(self, tmp_path): + """Should error when indexes file doesn't exist.""" + mock_client = MockRedisClient(indexes=["idx1"]) + + patch_path = tmp_path / "patch.yaml" + patch_path.write_text(yaml.safe_dump(make_shared_patch())) + + planner = BatchMigrationPlanner() + with pytest.raises(FileNotFoundError): + planner.create_batch_plan( + indexes_file="/nonexistent/indexes.txt", + schema_patch_path=str(patch_path), + redis_client=mock_client, + ) + + +# ============================================================================= +# BatchMigrationExecutor Tests +# ============================================================================= + + +class MockMigrationPlan: + """Mock migration plan for testing.""" + + def __init__(self, index_name: str): + self.source = Mock() + self.source.schema_snapshot = make_test_schema(index_name) + self.merged_target_schema = make_test_schema(index_name) + + +class MockMigrationReport: + """Mock migration report for testing.""" + + def __init__(self, result: str = "succeeded", errors: List[str] = None): + self.result = result + self.validation = Mock(errors=errors or []) + + def model_dump(self, **kwargs): + return {"result": self.result} + + +def create_mock_executor( + succeed_on: List[str] = None, + fail_on: List[str] = None, + track_calls: List[str] = None, +): + """Create a properly configured BatchMigrationExecutor with mocks. + + Args: + succeed_on: Index names that should succeed. + fail_on: Index names that should fail. + track_calls: List to append index names as they're migrated. + + Returns: + A BatchMigrationExecutor with mocked planner and executor. + """ + succeed_on = succeed_on or [] + fail_on = fail_on or [] + if track_calls is None: + track_calls = [] + + # Create mock planner + mock_planner = Mock() + + def create_plan_from_patch(index_name, **kwargs): + track_calls.append(index_name) + return MockMigrationPlan(index_name) + + mock_planner.create_plan_from_patch = create_plan_from_patch + + # Create mock executor + mock_single_executor = Mock() + + def apply(plan, **kwargs): + # Determine if this should succeed or fail based on tracked calls + if track_calls: + last_index = track_calls[-1] + if last_index in fail_on: + return MockMigrationReport( + result="failed", errors=["Simulated failure"] + ) + return MockMigrationReport(result="succeeded") + + mock_single_executor.apply = apply + + # Create the batch executor with injected mocks + batch_executor = BatchMigrationExecutor(executor=mock_single_executor) + batch_executor._planner = mock_planner + + return batch_executor, track_calls + + +class TestBatchMigrationExecutorCheckpointing: + """Test checkpoint persistence and state management.""" + + def test_checkpoint_created_at_start(self, tmp_path): + """Checkpoint state file should be created when migration starts.""" + batch_plan = make_batch_plan( + batch_id="test-batch-001", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + ], + failure_policy="fail_fast", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor, _ = create_mock_executor(succeed_on=["idx1", "idx2"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # Verify checkpoint file was created + assert state_path.exists() + state_data = yaml.safe_load(state_path.read_text()) + assert state_data["batch_id"] == "test-batch-001" + + def test_checkpoint_updated_after_each_index(self, monkeypatch, tmp_path): + """Checkpoint should be updated after each index is processed.""" + batch_plan = make_batch_plan( + batch_id="test-batch-002", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + BatchIndexEntry(name="idx3", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + checkpoint_snapshots = [] + + # Capture checkpoints as they're written + original_write = BatchMigrationExecutor._write_state + + def capture_checkpoint(self, state, path): + checkpoint_snapshots.append( + {"remaining": list(state.remaining), "completed": len(state.completed)} + ) + return original_write(self, state, path) + + monkeypatch.setattr(BatchMigrationExecutor, "_write_state", capture_checkpoint) + + executor, _ = create_mock_executor(succeed_on=["idx1", "idx2", "idx3"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # Verify checkpoints were written progressively + # Each index should trigger 2 writes: start and end + assert len(checkpoint_snapshots) >= 6 # At least 2 per index + + def test_resume_from_checkpoint(self, tmp_path): + """Resume should continue from where migration left off.""" + # Create a checkpoint state simulating interrupted migration + batch_plan = make_batch_plan( + batch_id="test-batch-003", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + BatchIndexEntry(name="idx3", applicable=True), + ], + failure_policy="continue_on_error", + ) + + # Write the batch plan + plan_path = tmp_path / "batch_plan.yaml" + with open(plan_path, "w") as f: + yaml.safe_dump(batch_plan.model_dump(exclude_none=True), f, sort_keys=False) + + # Write a checkpoint state (idx1 completed, idx2 and idx3 remaining) + state_path = tmp_path / "batch_state.yaml" + checkpoint_state = BatchState( + batch_id="test-batch-003", + plan_path=str(plan_path), + started_at="2026-03-20T10:00:00Z", + updated_at="2026-03-20T10:05:00Z", + remaining=["idx2", "idx3"], + completed=[ + BatchIndexState( + name="idx1", + status="success", + completed_at="2026-03-20T10:05:00Z", + ) + ], + current_index=None, + ) + with open(state_path, "w") as f: + yaml.safe_dump( + checkpoint_state.model_dump(exclude_none=True), f, sort_keys=False + ) + + report_dir = tmp_path / "reports" + migrated_indexes: List[str] = [] + + executor, migrated_indexes = create_mock_executor( + succeed_on=["idx2", "idx3"], + ) + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + # Resume from checkpoint + report = executor.resume( + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # idx1 should NOT be migrated again (already completed) + assert "idx1" not in migrated_indexes + # Only idx2 and idx3 should be migrated + assert migrated_indexes == ["idx2", "idx3"] + # Report should show all 3 as succeeded + assert report.summary.successful == 3 + + +class TestBatchMigrationExecutorFailurePolicies: + """Test failure policy behavior (fail_fast vs continue_on_error).""" + + def test_fail_fast_stops_on_first_error(self, tmp_path): + """fail_fast policy should stop processing after first failure.""" + batch_plan = make_batch_plan( + batch_id="test-batch-fail-fast", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), # This will fail + BatchIndexEntry(name="idx3", applicable=True), + ], + failure_policy="fail_fast", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor, migrated_indexes = create_mock_executor( + succeed_on=["idx1", "idx3"], + fail_on=["idx2"], + ) + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # idx3 should NOT have been attempted due to fail_fast + assert "idx3" not in migrated_indexes + assert migrated_indexes == ["idx1", "idx2"] + + # Report should show partial results + assert report.summary.successful == 1 + assert report.summary.failed == 1 + assert report.summary.skipped == 1 # idx3 was skipped + + def test_continue_on_error_processes_all(self, tmp_path): + """continue_on_error policy should process all indexes.""" + batch_plan = make_batch_plan( + batch_id="test-batch-continue", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), # This will fail + BatchIndexEntry(name="idx3", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor, migrated_indexes = create_mock_executor( + succeed_on=["idx1", "idx3"], + fail_on=["idx2"], + ) + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # ALL indexes should have been attempted + assert migrated_indexes == ["idx1", "idx2", "idx3"] + + # Report should show mixed results + assert report.summary.successful == 2 # idx1 and idx3 + assert report.summary.failed == 1 # idx2 + assert report.summary.skipped == 0 + assert report.status == "partial_failure" + + def test_retry_failed_on_resume(self, tmp_path): + """retry_failed=True should retry previously failed indexes.""" + batch_plan = make_batch_plan( + batch_id="test-batch-retry", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + ], + failure_policy="continue_on_error", + ) + + plan_path = tmp_path / "batch_plan.yaml" + with open(plan_path, "w") as f: + yaml.safe_dump(batch_plan.model_dump(exclude_none=True), f, sort_keys=False) + + # Create checkpoint with idx1 failed + state_path = tmp_path / "batch_state.yaml" + checkpoint_state = BatchState( + batch_id="test-batch-retry", + plan_path=str(plan_path), + started_at="2026-03-20T10:00:00Z", + updated_at="2026-03-20T10:05:00Z", + remaining=[], # All "done" but idx1 failed + completed=[ + BatchIndexState( + name="idx1", status="failed", completed_at="2026-03-20T10:03:00Z" + ), + BatchIndexState( + name="idx2", status="success", completed_at="2026-03-20T10:05:00Z" + ), + ], + current_index=None, + ) + with open(state_path, "w") as f: + yaml.safe_dump( + checkpoint_state.model_dump(exclude_none=True), f, sort_keys=False + ) + + report_dir = tmp_path / "reports" + + executor, migrated_indexes = create_mock_executor(succeed_on=["idx1", "idx2"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + report = executor.resume( + state_path=str(state_path), + retry_failed=True, + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # idx1 should be retried, idx2 should not (already succeeded) + assert "idx1" in migrated_indexes + assert "idx2" not in migrated_indexes + assert report.summary.successful == 2 + + +class TestBatchMigrationExecutorProgressCallback: + """Test progress callback functionality.""" + + def test_progress_callback_called_for_each_index(self, tmp_path): + """Progress callback should be invoked for each index.""" + batch_plan = make_batch_plan( + batch_id="test-batch-progress", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + BatchIndexEntry(name="idx3", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + progress_events = [] + + def progress_callback(index_name, position, total, status): + progress_events.append( + {"index": index_name, "pos": position, "total": total, "status": status} + ) + + executor, _ = create_mock_executor(succeed_on=["idx1", "idx2", "idx3"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + progress_callback=progress_callback, + ) + + # Should have 2 events per index (starting + final status) + assert len(progress_events) == 6 + # Check first index events + assert progress_events[0] == { + "index": "idx1", + "pos": 1, + "total": 3, + "status": "starting", + } + assert progress_events[1] == { + "index": "idx1", + "pos": 1, + "total": 3, + "status": "success", + } + + +class TestBatchMigrationExecutorEdgeCases: + """Test edge cases and error scenarios.""" + + def test_exception_during_migration_captured(self, tmp_path): + """Exception during migration should be captured in state.""" + batch_plan = make_batch_plan( + batch_id="test-batch-exception", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + # Track calls and raise exception for idx1 + call_count = [0] + + # Create mock planner that raises on idx1 + mock_planner = Mock() + + def create_plan_from_patch(index_name, **kwargs): + call_count[0] += 1 + if index_name == "idx1": + raise RuntimeError("Connection lost to Redis") + return MockMigrationPlan(index_name) + + mock_planner.create_plan_from_patch = create_plan_from_patch + + # Create mock executor + mock_single_executor = Mock() + mock_single_executor.apply = Mock( + return_value=MockMigrationReport(result="succeeded") + ) + + # Create batch executor with mocks + executor = BatchMigrationExecutor(executor=mock_single_executor) + executor._planner = mock_planner + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # Both should have been attempted + assert call_count[0] == 2 + # idx1 failed with exception, idx2 succeeded + assert report.summary.failed == 1 + assert report.summary.successful == 1 + + # Check error message is captured + idx1_report = next(r for r in report.indexes if r.name == "idx1") + assert "Connection lost" in idx1_report.error + + def test_non_applicable_indexes_skipped(self, tmp_path): + """Non-applicable indexes should be skipped and reported.""" + batch_plan = make_batch_plan( + batch_id="test-batch-skip", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry( + name="idx2", + applicable=False, + skip_reason="Missing field: embedding", + ), + BatchIndexEntry(name="idx3", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor, migrated_indexes = create_mock_executor(succeed_on=["idx1", "idx3"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # idx2 should NOT be migrated + assert "idx2" not in migrated_indexes + assert migrated_indexes == ["idx1", "idx3"] + + # Report should show idx2 as skipped + assert report.summary.successful == 2 + assert report.summary.skipped == 1 + + idx2_report = next(r for r in report.indexes if r.name == "idx2") + assert idx2_report.status == "skipped" + assert "Missing field" in idx2_report.error + + def test_empty_batch_plan(self, monkeypatch, tmp_path): + """Empty batch plan should complete immediately.""" + batch_plan = make_batch_plan( + batch_id="test-batch-empty", + indexes=[], # No indexes + failure_policy="fail_fast", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor = BatchMigrationExecutor() + mock_client = MockRedisClient(indexes=[]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + assert report.status == "completed" + assert report.summary.total_indexes == 0 + assert report.summary.successful == 0 + + def test_missing_redis_connection_error(self, tmp_path): + """Should error when no Redis connection is provided.""" + batch_plan = make_batch_plan( + batch_id="test-batch-no-redis", + indexes=[BatchIndexEntry(name="idx1", applicable=True)], + failure_policy="fail_fast", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor = BatchMigrationExecutor() + + with pytest.raises(ValueError, match="redis"): + executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + # No redis_url or redis_client provided + ) + + def test_resume_missing_state_file_error(self, tmp_path): + """Resume should error when state file doesn't exist.""" + executor = BatchMigrationExecutor() + mock_client = MockRedisClient(indexes=[]) + + with pytest.raises(FileNotFoundError, match="State file"): + executor.resume( + state_path=str(tmp_path / "nonexistent_state.yaml"), + report_dir=str(tmp_path / "reports"), + redis_client=mock_client, + ) + + def test_resume_missing_plan_file_error(self, tmp_path): + """Resume should error when plan file doesn't exist.""" + # Create state file pointing to nonexistent plan + state_path = tmp_path / "batch_state.yaml" + state = BatchState( + batch_id="test-batch", + plan_path="/nonexistent/plan.yaml", + started_at="2026-03-20T10:00:00Z", + updated_at="2026-03-20T10:05:00Z", + remaining=["idx1"], + completed=[], + current_index=None, + ) + with open(state_path, "w") as f: + yaml.safe_dump(state.model_dump(exclude_none=True), f) + + executor = BatchMigrationExecutor() + mock_client = MockRedisClient(indexes=["idx1"]) + + with pytest.raises(FileNotFoundError, match="Batch plan"): + executor.resume( + state_path=str(state_path), + report_dir=str(tmp_path / "reports"), + redis_client=mock_client, + ) + + +class TestBatchMigrationExecutorReportGeneration: + """Test batch report generation.""" + + def test_report_contains_all_indexes(self, tmp_path): + """Final report should contain entries for all indexes.""" + batch_plan = make_batch_plan( + batch_id="test-batch-report", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry( + name="idx2", applicable=False, skip_reason="Missing field" + ), + BatchIndexEntry(name="idx3", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor, _ = create_mock_executor(succeed_on=["idx1", "idx3"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2", "idx3"]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # All indexes should be in report + index_names = {r.name for r in report.indexes} + assert index_names == {"idx1", "idx2", "idx3"} + + # Verify totals + assert report.summary.total_indexes == 3 + assert report.summary.successful == 2 + assert report.summary.skipped == 1 + + def test_per_index_reports_written(self, tmp_path): + """Individual reports should be written for each migrated index.""" + batch_plan = make_batch_plan( + batch_id="test-batch-files", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor, _ = create_mock_executor(succeed_on=["idx1", "idx2"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + # Report files should exist + assert (report_dir / "idx1_report.yaml").exists() + assert (report_dir / "idx2_report.yaml").exists() + + def test_completed_status_when_all_succeed(self, tmp_path): + """Status should be 'completed' when all indexes succeed.""" + batch_plan = make_batch_plan( + batch_id="test-batch-complete", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + executor, _ = create_mock_executor(succeed_on=["idx1", "idx2"]) + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + assert report.status == "completed" + + def test_failed_status_when_all_fail(self, tmp_path): + """Status should be 'failed' when all indexes fail.""" + batch_plan = make_batch_plan( + batch_id="test-batch-all-fail", + indexes=[ + BatchIndexEntry(name="idx1", applicable=True), + BatchIndexEntry(name="idx2", applicable=True), + ], + failure_policy="continue_on_error", + ) + + state_path = tmp_path / "batch_state.yaml" + report_dir = tmp_path / "reports" + + # Create a mock that raises exceptions for all indexes + mock_planner = Mock() + mock_planner.create_plan_from_patch = Mock( + side_effect=RuntimeError("All migrations fail") + ) + + mock_single_executor = Mock() + executor = BatchMigrationExecutor(executor=mock_single_executor) + executor._planner = mock_planner + mock_client = MockRedisClient(indexes=["idx1", "idx2"]) + + report = executor.apply( + batch_plan, + state_path=str(state_path), + report_dir=str(report_dir), + redis_client=mock_client, + ) + + assert report.status == "failed" + assert report.summary.failed == 2 + assert report.summary.successful == 0 diff --git a/tests/unit/test_executor_backup_quantize.py b/tests/unit/test_executor_backup_quantize.py new file mode 100644 index 000000000..ce51aa8af --- /dev/null +++ b/tests/unit/test_executor_backup_quantize.py @@ -0,0 +1,196 @@ +"""Tests for the new two-phase quantize flow in MigrationExecutor. + +Verifies: + - dump_vectors: pipeline-reads originals, writes to backup file + - quantize_from_backup: reads backup file, converts, pipeline-writes + - Resume: reloads backup file, skips completed batches + - BGSAVE is NOT called +""" + +import struct +from typing import Any, Dict, List +from unittest.mock import MagicMock, call, patch + +import numpy as np +import pytest + + +def _make_float32_vector(dims: int = 4, seed: float = 0.0) -> bytes: + return struct.pack(f"<{dims}f", *[seed + i for i in range(dims)]) + + +class TestDumpVectors: + """Test Phase 1: dumping original vectors to backup file.""" + + def test_dump_creates_backup_and_reads_via_pipeline(self, tmp_path): + from redisvl.migration.executor import MigrationExecutor + + executor = MigrationExecutor() + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + dims = 4 + keys = [f"doc:{i}" for i in range(6)] + vec = _make_float32_vector(dims) + # 6 keys × 1 field = 6 results per execute + mock_pipe.execute.return_value = [vec] * 6 + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": dims} + } + + backup_path = str(tmp_path / "test_backup") + backup = executor._dump_vectors( + client=mock_client, + index_name="myindex", + keys=keys, + datatype_changes=datatype_changes, + backup_path=backup_path, + batch_size=3, + ) + + # Should use pipeline reads, not individual hget + mock_client.hget.assert_not_called() + # 2 batches of 3 keys = 2 pipeline.execute() calls + assert mock_pipe.execute.call_count == 2 + # Backup file created and dump complete + assert backup.header.phase == "ready" + assert backup.header.dump_completed_batches == 2 + # All data readable + batches = list(backup.iter_batches()) + assert len(batches) == 2 + assert len(batches[0][0]) == 3 # first batch has 3 keys + assert len(batches[1][0]) == 3 # second batch has 3 keys + + +class TestQuantizeFromBackup: + """Test Phase 2: reading from backup, converting, writing to Redis.""" + + def _create_dumped_backup(self, tmp_path, num_keys=4, dims=4, batch_size=2): + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + backup = VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={ + "embedding": {"source": "float32", "target": "float16", "dims": dims} + }, + batch_size=batch_size, + ) + for batch_idx in range(num_keys // batch_size): + start = batch_idx * batch_size + keys = [f"doc:{j}" for j in range(start, start + batch_size)] + vec = _make_float32_vector(dims) + originals = {k: {"embedding": vec} for k in keys} + backup.write_batch(batch_idx, keys, originals) + backup.mark_dump_complete() + return backup + + def test_quantize_writes_converted_via_pipeline(self, tmp_path): + from redisvl.migration.executor import MigrationExecutor + + executor = MigrationExecutor() + backup = self._create_dumped_backup(tmp_path) + + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": 4} + } + + docs = executor._quantize_from_backup( + client=mock_client, + backup=backup, + datatype_changes=datatype_changes, + ) + + # Should write via pipeline, not individual hset + mock_client.hset.assert_not_called() + # 2 batches = 2 pipeline.execute() calls + assert mock_pipe.execute.call_count == 2 + # Each batch has 2 keys × 1 field = 2 hset calls per batch + assert mock_pipe.hset.call_count == 4 + # 4 docs quantized + assert docs == 4 + # Backup should be marked complete + assert backup.header.phase == "completed" + + def test_quantize_writes_correct_float16_data(self, tmp_path): + from redisvl.migration.executor import MigrationExecutor + + executor = MigrationExecutor() + backup = self._create_dumped_backup(tmp_path, num_keys=2, batch_size=2) + + written_data = {} + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + def capture_hset(key, field, value): + written_data[key] = {field: value} + + mock_pipe.hset.side_effect = capture_hset + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": 4} + } + + executor._quantize_from_backup( + client=mock_client, + backup=backup, + datatype_changes=datatype_changes, + ) + + # Verify written data is float16 (2 bytes per dim = 8 bytes total) + for key, fields in written_data.items(): + assert len(fields["embedding"]) == 4 * 2 # dims * sizeof(float16) + + +class TestQuantizeResume: + """Test resume after crash during quantize phase.""" + + def test_resume_skips_completed_batches(self, tmp_path): + from redisvl.migration.backup import VectorBackup + from redisvl.migration.executor import MigrationExecutor + + # Create backup with 4 batches, mark 2 as quantized + backup_path = str(tmp_path / "test_backup") + backup = VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={"embedding": {"source": "float32", "target": "float16", "dims": 4}}, + batch_size=2, + ) + vec = _make_float32_vector(4) + for batch_idx in range(4): + keys = [f"doc:{batch_idx*2}", f"doc:{batch_idx*2+1}"] + backup.write_batch(batch_idx, keys, {k: {"embedding": vec} for k in keys}) + backup.mark_dump_complete() + backup.start_quantize() + backup.mark_batch_quantized(0) + backup.mark_batch_quantized(1) + # Simulate crash — save and reload + del backup + backup = VectorBackup.load(backup_path) + + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + executor = MigrationExecutor() + docs = executor._quantize_from_backup( + client=mock_client, + backup=backup, + datatype_changes={ + "embedding": {"source": "float32", "target": "float16", "dims": 4} + }, + ) + + # Only 2 remaining batches × 2 keys = 4 docs, but should only process 2 batches + assert mock_pipe.execute.call_count == 2 + assert mock_pipe.hset.call_count == 4 # 2 batches × 2 keys + assert docs == 4 diff --git a/tests/unit/test_migration_planner.py b/tests/unit/test_migration_planner.py new file mode 100644 index 000000000..672a5c4cc --- /dev/null +++ b/tests/unit/test_migration_planner.py @@ -0,0 +1,1008 @@ +from fnmatch import fnmatch + +import yaml + +from redisvl.migration import MigrationPlanner +from redisvl.schema.schema import IndexSchema + + +class DummyClient: + def __init__(self, keys): + self.keys = keys + + def scan(self, cursor=0, match=None, count=None): + matched = [] + for key in self.keys: + decoded_key = key.decode() if isinstance(key, bytes) else str(key) + if match is None or fnmatch(decoded_key, match): + matched.append(key) + return 0, matched + + +class DummyIndex: + def __init__(self, schema, stats, keys): + self.schema = schema + self._stats = stats + self._client = DummyClient(keys) + + @property + def client(self): + return self._client + + def info(self): + return self._stats + + +def _make_source_schema(): + return IndexSchema.from_dict( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +def test_create_plan_from_schema_patch_preserves_unspecified_config( + monkeypatch, tmp_path +): + source_schema = _make_source_schema() + dummy_index = DummyIndex( + source_schema, + {"num_docs": 2, "indexing": False}, + [b"docs:1", b"docs:2", b"docs:3"], + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "category", + "type": "tag", + "path": "$.category", + "attrs": {"separator": ","}, + } + ], + "remove_fields": ["price"], + "update_fields": [ + { + "name": "title", + "options": {"sortable": True}, + } + ], + }, + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner(key_sample_limit=2) + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + assert plan.diff_classification.supported is True + assert plan.source.index_name == "docs" + assert plan.source.keyspace.storage_type == "json" + assert plan.source.keyspace.prefixes == ["docs"] + assert plan.source.keyspace.key_separator == ":" + assert plan.source.keyspace.key_sample == ["docs:1", "docs:2"] + assert plan.warnings == ["Index downtime is required"] + + merged_fields = { + field["name"]: field for field in plan.merged_target_schema["fields"] + } + assert plan.merged_target_schema["index"]["prefix"] == "docs" + assert merged_fields["title"]["attrs"]["sortable"] is True + assert "price" not in merged_fields + assert merged_fields["category"]["type"] == "tag" + + plan_path = tmp_path / "migration_plan.yaml" + planner.write_plan(plan, str(plan_path)) + written_plan = yaml.safe_load(plan_path.read_text()) + assert written_plan["mode"] == "drop_recreate" + assert written_plan["validation"]["require_doc_count_match"] is True + assert written_plan["diff_classification"]["supported"] is True + + +def test_target_schema_vector_datatype_change_is_allowed(monkeypatch, tmp_path): + """Changing vector datatype (quantization) is allowed - executor will re-encode.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", # Same algorithm + "dims": 3, + "distance_metric": "cosine", + "datatype": "float16", # Changed from float32 + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + # Datatype change (quantization) should now be ALLOWED + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + # Verify datatype changes are detected for the executor + datatype_changes = MigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, plan.merged_target_schema + ) + assert "embedding" in datatype_changes + assert datatype_changes["embedding"]["source"] == "float32" + assert datatype_changes["embedding"]["target"] == "float16" + + +def test_target_schema_vector_algorithm_change_is_allowed(monkeypatch, tmp_path): + """Changing vector algorithm is allowed (index-only change).""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "hnsw", # Changed from flat + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", # Same datatype + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + # Algorithm change should be ALLOWED + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + +# ============================================================================= +# BLOCKED CHANGES (Document-Dependent) - require iterative_shadow +# ============================================================================= + + +def test_target_schema_prefix_change_is_supported(monkeypatch, tmp_path): + """Prefix change is now supported via key rename operations.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs_v2", + "key_separator": ":", + "storage_type": "json", + }, + "fields": source_schema.to_dict()["fields"], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + # Prefix change is now supported + assert plan.diff_classification.supported is True + # Verify rename operation is populated + assert plan.rename_operations.change_prefix == "docs_v2" + # Verify warning is present + assert any("Prefix change" in w for w in plan.warnings) + + +def test_key_separator_change_is_blocked(monkeypatch, tmp_path): + """Key separator change is blocked: document keys don't match new pattern.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": "/", # Changed from ":" + "storage_type": "json", + }, + "fields": source_schema.to_dict()["fields"], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is False + assert any( + "key_separator" in reason.lower() or "separator" in reason.lower() + for reason in plan.diff_classification.blocked_reasons + ) + + +def test_storage_type_change_is_blocked(monkeypatch, tmp_path): + """Storage type change is blocked: documents are in wrong format.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "hash", # Changed from "json" + }, + "fields": [ + {"name": "title", "type": "text", "attrs": {"sortable": False}}, + {"name": "price", "type": "numeric", "attrs": {"sortable": True}}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is False + assert any( + "storage" in reason.lower() + for reason in plan.diff_classification.blocked_reasons + ) + + +def test_vector_dimension_change_is_blocked(monkeypatch, tmp_path): + """Vector dimension change is blocked: stored vectors have wrong size.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 768, # Changed from 3 + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is False + assert any( + "dims" in reason and "document migration" in reason + for reason in plan.diff_classification.blocked_reasons + ) + + +def test_field_path_change_is_blocked(monkeypatch, tmp_path): + """JSON path change is blocked: stored data is at wrong path.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.metadata.title", # Changed from $.title + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is False + assert any( + "path" in reason.lower() for reason in plan.diff_classification.blocked_reasons + ) + + +def test_field_type_change_is_blocked(monkeypatch, tmp_path): + """Field type change is blocked: index expects different data format.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "tag", # Changed from text + "path": "$.title", + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is False + assert any( + "type" in reason.lower() for reason in plan.diff_classification.blocked_reasons + ) + + +def test_field_rename_is_detected_and_blocked(monkeypatch, tmp_path): + """Field rename is blocked: stored data uses old field name.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "document_title", # Renamed from "title" + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is False + assert any( + "rename" in reason.lower() + for reason in plan.diff_classification.blocked_reasons + ) + + +# ============================================================================= +# ALLOWED CHANGES (Index-Only) +# ============================================================================= + + +def test_add_non_vector_field_is_allowed(monkeypatch, tmp_path): + """Adding a non-vector field is allowed.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + {"name": "category", "type": "tag", "path": "$.category"} + ] + }, + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + assert plan.diff_classification.supported is True + + +def test_remove_field_is_allowed(monkeypatch, tmp_path): + """Removing a field from the index is allowed.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + {"version": 1, "changes": {"remove_fields": ["price"]}}, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + assert plan.diff_classification.supported is True + + +def test_change_field_sortable_is_allowed(monkeypatch, tmp_path): + """Changing field sortable option is allowed.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "update_fields": [{"name": "title", "options": {"sortable": True}}] + }, + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + assert plan.diff_classification.supported is True + + +def test_change_vector_distance_metric_is_allowed(monkeypatch, tmp_path): + """Changing vector distance metric is allowed (index-only).""" + source_schema = _make_source_schema() + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "L2", # Changed from cosine + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + +def test_change_hnsw_tuning_params_is_allowed(monkeypatch, tmp_path): + """Changing HNSW tuning parameters is allowed (index-only).""" + source_schema = IndexSchema.from_dict( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "hnsw", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + "m": 16, + "ef_construction": 200, + }, + }, + ], + } + ) + dummy_index = DummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "hnsw", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + "m": 32, # Changed from 16 + "ef_construction": 400, # Changed from 200 + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + +def test_plan_warns_when_source_has_hash_indexing_failures(monkeypatch, tmp_path): + """Plan should include a warning when the source index has hash_indexing_failures > 0.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex( + source_schema, + {"num_docs": 5, "hash_indexing_failures": 3}, + [b"docs:1"], + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + {"name": "status", "type": "tag", "path": "$.status"} + ], + }, + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + failure_warnings = [w for w in plan.warnings if "hash indexing failure" in w] + assert len(failure_warnings) == 1 + assert "3" in failure_warnings[0] + + +def test_plan_no_warning_when_source_has_zero_indexing_failures(monkeypatch, tmp_path): + """Plan should NOT include an indexing failure warning when failures == 0.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex( + source_schema, + {"num_docs": 5, "hash_indexing_failures": 0}, + [b"docs:1"], + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + {"name": "status", "type": "tag", "path": "$.status"} + ], + }, + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + failure_warnings = [w for w in plan.warnings if "hash indexing failure" in w] + assert len(failure_warnings) == 0 + + +def test_plan_no_warning_when_stats_missing_failures_key(monkeypatch, tmp_path): + """Plan should handle missing hash_indexing_failures key gracefully.""" + source_schema = _make_source_schema() + dummy_index = DummyIndex( + source_schema, + {"num_docs": 5}, # No hash_indexing_failures key + [b"docs:1"], + ) + monkeypatch.setattr( + "redisvl.migration.planner.SearchIndex.from_existing", + lambda *args, **kwargs: dummy_index, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + {"name": "status", "type": "tag", "path": "$.status"} + ], + }, + }, + sort_keys=False, + ) + ) + + planner = MigrationPlanner() + plan = planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + failure_warnings = [w for w in plan.warnings if "hash indexing failure" in w] + assert len(failure_warnings) == 0 diff --git a/tests/unit/test_migration_wizard.py b/tests/unit/test_migration_wizard.py new file mode 100644 index 000000000..bd53d6415 --- /dev/null +++ b/tests/unit/test_migration_wizard.py @@ -0,0 +1,1190 @@ +from redisvl.migration.wizard import MigrationWizard + + +def _make_vector_source_schema(algorithm="hnsw", datatype="float32"): + """Helper to create a source schema with a vector field.""" + return { + "index": { + "name": "test_index", + "prefix": "test:", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": algorithm, + "dims": 384, + "distance_metric": "cosine", + "datatype": datatype, + "m": 16, + "ef_construction": 200, + }, + }, + ], + } + + +def test_wizard_builds_patch_from_interactive_inputs(monkeypatch): + source_schema = { + "index": { + "name": "docs", + "prefix": "docs", + "storage_type": "json", + }, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + {"name": "category", "type": "tag", "path": "$.category"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + answers = iter( + [ + # Add field + "1", + "status", # field name + "tag", # field type + "$.status", # JSON path + "y", # sortable + "n", # index_missing + "n", # index_empty + "|", # separator (tag-specific) + "n", # case_sensitive (tag-specific) + "n", # no_index (prompted since sortable=y) + # Update field + "2", + "title", # select field + "y", # sortable + "n", # index_missing + "n", # index_empty + "n", # no_stem (text-specific) + "", # weight (blank to skip, text-specific) + "", # phonetic_matcher (blank to skip) + "n", # unf (prompted since sortable=y) + "n", # no_index (prompted since sortable=y) + # Remove field + "3", + "category", + # Finish + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) # noqa: SLF001 + + assert patch.changes.add_fields == [ + { + "name": "status", + "type": "tag", + "path": "$.status", + "attrs": { + "sortable": True, + "index_missing": False, + "index_empty": False, + "separator": "|", + "case_sensitive": False, + "no_index": False, + }, + } + ] + assert patch.changes.remove_fields == ["category"] + assert len(patch.changes.update_fields) == 1 + assert patch.changes.update_fields[0].name == "title" + assert patch.changes.update_fields[0].attrs["sortable"] is True + assert patch.changes.update_fields[0].attrs["no_stem"] is False + + +# ============================================================================= +# Vector Algorithm Tests +# ============================================================================= + + +class TestVectorAlgorithmChanges: + """Test wizard handling of vector algorithm changes.""" + + def test_hnsw_to_flat(self, monkeypatch): + """Test changing from HNSW to FLAT algorithm.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "FLAT", # Change to FLAT + "", # datatype (keep current) + "", # distance_metric (keep current) + # No HNSW params prompted for FLAT + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 1 + update = patch.changes.update_fields[0] + assert update.name == "embedding" + assert update.attrs["algorithm"] == "FLAT" + + def test_flat_to_hnsw_with_params(self, monkeypatch): + """Test changing from FLAT to HNSW with custom M and EF_CONSTRUCTION.""" + source_schema = _make_vector_source_schema(algorithm="flat") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "HNSW", # Change to HNSW + "", # datatype (keep current) + "", # distance_metric (keep current) + "32", # M + "400", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "HNSW" + assert update.attrs["m"] == 32 + assert update.attrs["ef_construction"] == 400 + + def test_hnsw_to_svs_vamana_with_underscore(self, monkeypatch): + """Test changing to SVS_VAMANA (underscore format) is normalized.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "SVS_VAMANA", # Underscore format (should be normalized) + "float16", # SVS only supports float16/float32 + "", # distance_metric (keep current) + "64", # GRAPH_MAX_DEGREE + "LVQ8", # COMPRESSION + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" # Normalized to hyphen + assert update.attrs["datatype"] == "float16" + assert update.attrs["graph_max_degree"] == 64 + assert update.attrs["compression"] == "LVQ8" + + def test_hnsw_to_svs_vamana_with_hyphen(self, monkeypatch): + """Test changing to SVS-VAMANA (hyphen format) works directly.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "SVS-VAMANA", # Hyphen format + "", # datatype (keep current) + "", # distance_metric (keep current) + "", # GRAPH_MAX_DEGREE (keep default) + "", # COMPRESSION (none) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + + def test_svs_vamana_with_leanvec_compression(self, monkeypatch): + """Test SVS-VAMANA with LeanVec compression type.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "SVS-VAMANA", + "float16", + "", # distance_metric + "48", # GRAPH_MAX_DEGREE + "LEANVEC8X8", # COMPRESSION + "192", # REDUCE (dims/2) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + assert update.attrs["compression"] == "LeanVec8x8" + assert update.attrs["reduce"] == 192 + + +# ============================================================================= +# Vector Datatype (Quantization) Tests +# ============================================================================= + + +class TestVectorDatatypeChanges: + """Test wizard handling of vector datatype/quantization changes.""" + + def test_float32_to_float16(self, monkeypatch): + """Test quantization from float32 to float16.""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm (keep current) + "float16", # datatype + "", # distance_metric + "", # M (keep current) + "", # EF_CONSTRUCTION (keep current) + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "float16" + + def test_float16_to_float32(self, monkeypatch): + """Test changing from float16 back to float32.""" + source_schema = _make_vector_source_schema(datatype="float16") + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm + "float32", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "float32" + + def test_int8_accepted_for_hnsw(self, monkeypatch): + """Test that int8 is accepted for HNSW/FLAT (but not SVS-VAMANA).""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm (keep HNSW) + "int8", # Valid for HNSW/FLAT + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # int8 is now valid for HNSW/FLAT + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "int8" + + +# ============================================================================= +# Distance Metric Tests +# ============================================================================= + + +class TestDistanceMetricChanges: + """Test wizard handling of distance metric changes.""" + + def test_cosine_to_l2(self, monkeypatch): + """Test changing distance metric from cosine to L2.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm + "", # datatype + "l2", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["distance_metric"] == "l2" + + def test_cosine_to_ip(self, monkeypatch): + """Test changing distance metric from cosine to inner product.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm + "", # datatype + "ip", # distance_metric (inner product) + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["distance_metric"] == "ip" + + +# ============================================================================= +# Combined Changes Tests +# ============================================================================= + + +class TestCombinedVectorChanges: + """Test wizard handling of multiple vector attribute changes.""" + + def test_algorithm_datatype_and_metric_change(self, monkeypatch): + """Test changing algorithm, datatype, and distance metric together.""" + source_schema = _make_vector_source_schema(algorithm="flat", datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "HNSW", # algorithm + "float16", # datatype + "l2", # distance_metric + "24", # M + "300", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "HNSW" + assert update.attrs["datatype"] == "float16" + assert update.attrs["distance_metric"] == "l2" + assert update.attrs["m"] == 24 + assert update.attrs["ef_construction"] == 300 + + def test_svs_vamana_full_config(self, monkeypatch): + """Test SVS-VAMANA with all parameters configured.""" + source_schema = _make_vector_source_schema(algorithm="hnsw", datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "SVS-VAMANA", # algorithm + "float16", # datatype (required for SVS) + "ip", # distance_metric + "50", # GRAPH_MAX_DEGREE + "LVQ4X8", # COMPRESSION + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + assert update.attrs["datatype"] == "float16" + assert update.attrs["distance_metric"] == "ip" + assert update.attrs["graph_max_degree"] == 50 + assert update.attrs["compression"] == "LVQ4x8" + + def test_no_changes_when_all_blank(self, monkeypatch): + """Test that blank inputs result in no changes.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm (keep current) + "", # datatype (keep current) + "", # distance_metric (keep current) + "", # M (keep current) + "", # EF_CONSTRUCTION (keep current) + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # No changes collected means no update_fields + assert len(patch.changes.update_fields) == 0 + + +# ============================================================================= +# Adversarial / Edge Case Tests +# ============================================================================= + + +class TestWizardAdversarialInputs: + """Test wizard robustness against malformed, malicious, or edge case inputs.""" + + # ------------------------------------------------------------------------- + # Invalid Algorithm Inputs + # ------------------------------------------------------------------------- + + def test_typo_in_algorithm_ignored(self, monkeypatch): + """Test that typos in algorithm name are ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW_TYPO", # Invalid algorithm + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Invalid algorithm should be ignored, no changes + assert len(patch.changes.update_fields) == 0 + + def test_partial_algorithm_name_ignored(self, monkeypatch): + """Test that partial algorithm names are ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNS", # Partial name + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_algorithm_with_special_chars_ignored(self, monkeypatch): + """Test that algorithm with special characters is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW; DROP TABLE users;--", # SQL injection attempt + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_algorithm_lowercase_works(self, monkeypatch): + """Test that lowercase algorithm names work (case insensitive).""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "flat", # lowercase + "", + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "FLAT" + + def test_algorithm_mixed_case_works(self, monkeypatch): + """Test that mixed case algorithm names work.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SvS_VaMaNa", # Mixed case with underscore + "", + "", + "", + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + + # ------------------------------------------------------------------------- + # Invalid Numeric Inputs + # ------------------------------------------------------------------------- + + def test_negative_m_ignored(self, monkeypatch): + """Test that negative M value is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "-16", # Negative M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Negative M is ignored, and since algorithm/datatype/metric are unchanged, + # no update should be generated at all + assert len(patch.changes.update_fields) == 0 + + def test_float_m_ignored(self, monkeypatch): + """Test that float M value is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "16.5", # Float M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Float M is ignored, and since algorithm/datatype/metric are unchanged, + # no update should be generated at all + assert len(patch.changes.update_fields) == 0 + + def test_string_m_ignored(self, monkeypatch): + """Test that string M value is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "sixteen", # String M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # String M is ignored, and since algorithm/datatype/metric are unchanged, + # no update should be generated at all + assert len(patch.changes.update_fields) == 0 + + def test_zero_m_accepted(self, monkeypatch): + """Test that zero M is accepted (validation happens at schema level).""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "0", # Zero M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Zero is a valid digit, wizard accepts it (validation at apply time) + update = patch.changes.update_fields[0] + assert update.attrs.get("m") == 0 + + def test_very_large_ef_construction_accepted(self, monkeypatch): + """Test that very large EF_CONSTRUCTION is accepted by wizard.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", + "", + "", + "999999999", # Very large EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["ef_construction"] == 999999999 + + # ------------------------------------------------------------------------- + # Invalid Datatype Inputs + # ------------------------------------------------------------------------- + + def test_bfloat16_accepted_for_hnsw(self, monkeypatch): + """Test that bfloat16 is accepted for HNSW/FLAT.""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "bfloat16", # Valid for HNSW/FLAT + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "bfloat16" + + def test_uint8_accepted_for_hnsw(self, monkeypatch): + """Test that uint8 is accepted for HNSW/FLAT.""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "uint8", # Valid for HNSW/FLAT + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "uint8" + + def test_int8_rejected_for_svs_vamana(self, monkeypatch): + """Test that int8 is rejected for SVS-VAMANA (only float16/float32 allowed).""" + source_schema = _make_vector_source_schema(datatype="float32", algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", # Switch to SVS-VAMANA + "int8", # Invalid for SVS-VAMANA + "", + "", + "", # graph_max_degree + "", # compression + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Should have algorithm change but NOT datatype + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + assert "datatype" not in update.attrs # int8 rejected + + # ------------------------------------------------------------------------- + # Invalid Distance Metric Inputs + # ------------------------------------------------------------------------- + + def test_invalid_distance_metric_ignored(self, monkeypatch): + """Test that invalid distance metric is ignored.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "", # datatype + "euclidean", # Invalid (should be 'l2') + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_distance_metric_uppercase_works(self, monkeypatch): + """Test that uppercase distance metric works.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "", # datatype + "L2", # Uppercase + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["distance_metric"] == "l2" + + # ------------------------------------------------------------------------- + # Invalid Compression Inputs + # ------------------------------------------------------------------------- + + def test_invalid_compression_ignored(self, monkeypatch): + """Test that invalid compression type is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "", + "INVALID_COMPRESSION", # Invalid + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert "compression" not in update.attrs + + def test_compression_lowercase_works(self, monkeypatch): + """Test that lowercase compression works.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "", + "lvq8", # lowercase + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["compression"] == "LVQ8" + + # ------------------------------------------------------------------------- + # Whitespace and Special Character Inputs + # ------------------------------------------------------------------------- + + def test_whitespace_only_treated_as_blank(self, monkeypatch): + """Test that whitespace-only input is treated as blank.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + " ", # Whitespace only (algorithm) + " ", # datatype + " ", # distance_metric + " ", # M + " ", # EF_CONSTRUCTION + " ", # EF_RUNTIME + " ", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_algorithm_with_leading_trailing_whitespace(self, monkeypatch): + """Test that algorithm with whitespace is trimmed and works.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + " FLAT ", # Whitespace around (FLAT has no extra params) + "", # datatype + "", # distance_metric + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "FLAT" + + def test_unicode_input_ignored(self, monkeypatch): + """Test that unicode/emoji inputs are ignored.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "HNSW\U0001f680", # Unicode emoji + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_very_long_input_ignored(self, monkeypatch): + """Test that very long inputs are ignored.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "A" * 10000, # Very long string + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + # ------------------------------------------------------------------------- + # Field Selection Edge Cases + # ------------------------------------------------------------------------- + + def test_nonexistent_field_selection(self, monkeypatch): + """Test selecting a nonexistent field.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "nonexistent_field", # Doesn't exist + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Should print "Invalid field selection" and continue + assert len(patch.changes.update_fields) == 0 + + def test_field_selection_by_number_out_of_range(self, monkeypatch): + """Test selecting a field by out-of-range number.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "99", # Out of range + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_field_selection_negative_number(self, monkeypatch): + """Test selecting a field with negative number.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "-1", # Negative + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + # ------------------------------------------------------------------------- + # Menu Action Edge Cases + # ------------------------------------------------------------------------- + + def test_invalid_menu_action(self, monkeypatch): + """Test invalid menu action selection.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "99", # Invalid action + "abc", # Invalid action + "", # Empty + "8", # Finally finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Should handle invalid actions gracefully and eventually finish + assert patch is not None + + # ------------------------------------------------------------------------- + # SVS-VAMANA Specific Edge Cases + # ------------------------------------------------------------------------- + + def test_svs_vamana_negative_graph_max_degree_ignored(self, monkeypatch): + """Test that negative GRAPH_MAX_DEGREE is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "-40", # Negative + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert "graph_max_degree" not in update.attrs + + def test_svs_vamana_string_graph_max_degree_ignored(self, monkeypatch): + """Test that string GRAPH_MAX_DEGREE is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "forty", # String + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert "graph_max_degree" not in update.attrs diff --git a/tests/unit/test_multi_worker_quantize.py b/tests/unit/test_multi_worker_quantize.py new file mode 100644 index 000000000..b69db83bd --- /dev/null +++ b/tests/unit/test_multi_worker_quantize.py @@ -0,0 +1,342 @@ +"""Tests for multi-worker quantization. + +TDD: tests written BEFORE implementation. + +Tests: + - Key splitting across N workers + - Per-worker backup file shards + - Multi-worker sync execution via ThreadPoolExecutor + - Progress aggregation +""" + +import struct +from typing import Any, Dict, List +from unittest.mock import MagicMock, patch + +import pytest + + +def _make_float32_vector(dims: int = 4, seed: float = 0.0) -> bytes: + return struct.pack(f"<{dims}f", *[seed + i for i in range(dims)]) + + +class TestSplitKeys: + """Test splitting keys into N contiguous slices.""" + + def test_split_evenly(self): + from redisvl.migration.quantize import split_keys + + keys = [f"doc:{i}" for i in range(8)] + slices = split_keys(keys, num_workers=4) + assert len(slices) == 4 + assert slices[0] == ["doc:0", "doc:1"] + assert slices[1] == ["doc:2", "doc:3"] + assert slices[2] == ["doc:4", "doc:5"] + assert slices[3] == ["doc:6", "doc:7"] + + def test_split_uneven(self): + from redisvl.migration.quantize import split_keys + + keys = [f"doc:{i}" for i in range(10)] + slices = split_keys(keys, num_workers=3) + assert len(slices) == 3 + # 10 / 3 = 4, 4, 2 + assert len(slices[0]) == 4 + assert len(slices[1]) == 4 + assert len(slices[2]) == 2 + # All keys present + all_keys = [k for s in slices for k in s] + assert all_keys == keys + + def test_split_fewer_keys_than_workers(self): + from redisvl.migration.quantize import split_keys + + keys = ["doc:0", "doc:1"] + slices = split_keys(keys, num_workers=5) + # Should produce only 2 non-empty slices (not 5) + non_empty = [s for s in slices if s] + assert len(non_empty) == 2 + + def test_split_single_worker(self): + from redisvl.migration.quantize import split_keys + + keys = [f"doc:{i}" for i in range(10)] + slices = split_keys(keys, num_workers=1) + assert len(slices) == 1 + assert slices[0] == keys + + def test_split_empty_keys(self): + from redisvl.migration.quantize import split_keys + + slices = split_keys([], num_workers=4) + assert slices == [] + + def test_split_zero_workers_raises(self): + from redisvl.migration.quantize import split_keys + + with pytest.raises(ValueError, match="num_workers must be >= 1"): + split_keys(["doc:0"], num_workers=0) + + def test_split_negative_workers_raises(self): + from redisvl.migration.quantize import split_keys + + with pytest.raises(ValueError, match="num_workers must be >= 1"): + split_keys(["doc:0", "doc:1"], num_workers=-1) + + def test_split_zero_workers_empty_keys_raises(self): + """Even with empty keys, invalid num_workers should still raise.""" + from redisvl.migration.quantize import split_keys + + with pytest.raises(ValueError, match="num_workers must be >= 1"): + split_keys([], num_workers=0) + + +class TestMultiWorkerSync: + """Test multi-worker quantization with ThreadPoolExecutor.""" + + def test_multi_worker_dump_and_quantize(self, tmp_path): + """4 workers process 8 keys (2 each). Each gets own backup shard.""" + from redisvl.migration.quantize import multi_worker_quantize + + dims = 4 + vec = _make_float32_vector(dims) + all_keys = [f"doc:{i}" for i in range(8)] + + # Mock Redis: each client.pipeline().execute() returns vectors + def make_mock_client(): + mock = MagicMock() + mock_pipe = MagicMock() + mock.pipeline.return_value = mock_pipe + mock_pipe.execute.return_value = [vec] * 2 # 2 keys per worker + return mock + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": dims} + } + + with patch( + "redisvl.redis.connection.RedisConnectionFactory.get_redis_connection" + ) as mock_get_conn: + mock_get_conn.side_effect = lambda **kwargs: make_mock_client() + + result = multi_worker_quantize( + redis_url="redis://localhost:6379", + keys=all_keys, + datatype_changes=datatype_changes, + backup_dir=str(tmp_path), + index_name="myindex", + num_workers=4, + batch_size=2, + ) + + assert result.total_docs_quantized == 8 + assert result.num_workers == 4 + # Each worker should have created a backup shard + assert len(list(tmp_path.glob("*.header"))) == 4 + + def test_single_worker_fallback(self, tmp_path): + """With num_workers=1, should still work (no ThreadPoolExecutor needed).""" + from redisvl.migration.quantize import multi_worker_quantize + + dims = 4 + vec = _make_float32_vector(dims) + keys = [f"doc:{i}" for i in range(4)] + + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + mock_pipe.execute.return_value = [vec] * 4 + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": dims} + } + + with patch( + "redisvl.redis.connection.RedisConnectionFactory.get_redis_connection" + ) as mock_get_conn: + mock_get_conn.return_value = mock_client + + result = multi_worker_quantize( + redis_url="redis://localhost:6379", + keys=keys, + datatype_changes=datatype_changes, + backup_dir=str(tmp_path), + index_name="myindex", + num_workers=1, + batch_size=4, + ) + + assert result.total_docs_quantized == 4 + assert result.num_workers == 1 + + +class TestMultiWorkerResult: + """Test the result object from multi-worker quantization.""" + + def test_result_attributes(self): + from redisvl.migration.quantize import MultiWorkerResult + + result = MultiWorkerResult( + total_docs_quantized=1000, + num_workers=4, + worker_results=[ + {"worker_id": 0, "docs": 250}, + {"worker_id": 1, "docs": 250}, + {"worker_id": 2, "docs": 250}, + {"worker_id": 3, "docs": 250}, + ], + ) + assert result.total_docs_quantized == 1000 + assert result.num_workers == 4 + assert len(result.worker_results) == 4 + + +class TestWorkerResume: + """Test sync and async worker resume from partial backups.""" + + def _make_partial_backup(self, tmp_path, phase="dump", dump_batches=1): + """Create a partial backup to simulate crash-resume.""" + from redisvl.migration.backup import VectorBackup + + bp = str(tmp_path / "migration_backup_testidx_shard_0") + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": 4} + } + backup = VectorBackup.create( + path=bp, + index_name="testidx", + fields=datatype_changes, + batch_size=2, + ) + # Write some batches + for i in range(dump_batches): + keys = [f"doc:{i * 2}", f"doc:{i * 2 + 1}"] + originals = { + k: {"embedding": _make_float32_vector(4, seed=float(j))} + for j, k in enumerate(keys) + } + backup.write_batch(i, keys, originals) + + if phase == "ready": + backup.mark_dump_complete() + elif phase == "active": + backup.mark_dump_complete() + backup.start_quantize() + return bp, datatype_changes + + def test_sync_worker_resumes_from_ready_phase(self, tmp_path): + """Sync worker should skip dump and proceed to quantize on resume.""" + from redisvl.migration.backup import VectorBackup + + bp, dt_changes = self._make_partial_backup( + tmp_path, phase="ready", dump_batches=2 + ) + + # Verify backup is in ready phase + backup = VectorBackup.load(bp) + assert backup is not None + assert backup.header.phase == "ready" + assert backup.header.dump_completed_batches == 2 + + def test_sync_worker_resumes_from_dump_phase(self, tmp_path): + """Sync worker should resume dumping from the last completed batch.""" + from redisvl.migration.backup import VectorBackup + + bp, dt_changes = self._make_partial_backup( + tmp_path, phase="dump", dump_batches=1 + ) + + backup = VectorBackup.load(bp) + assert backup is not None + assert backup.header.phase == "dump" + assert backup.header.dump_completed_batches == 1 + # Worker should start from batch 1, not 0 + + def test_sync_worker_skips_completed_backup(self, tmp_path): + """Completed backup should be detected and skipped.""" + from redisvl.migration.backup import VectorBackup + + bp, dt_changes = self._make_partial_backup( + tmp_path, phase="active", dump_batches=2 + ) + backup = VectorBackup.load(bp) + # Mark all batches quantized + for i in range(2): + backup.mark_batch_quantized(i) + backup.mark_complete() + + # Reload and verify + backup = VectorBackup.load(bp) + assert backup.header.phase == "completed" + + @pytest.mark.asyncio + async def test_async_worker_loads_existing_backup(self, tmp_path): + """Async worker should load existing backup instead of creating new.""" + from redisvl.migration.backup import VectorBackup + + bp, dt_changes = self._make_partial_backup( + tmp_path, phase="ready", dump_batches=2 + ) + + # Verify load succeeds and returns existing backup + backup = VectorBackup.load(bp) + assert backup is not None + assert backup.header.phase == "ready" + assert backup.header.dump_completed_batches == 2 + + # Verify create would fail (backup already exists) + with pytest.raises(FileExistsError): + VectorBackup.create( + path=bp, + index_name="testidx", + fields=dt_changes, + batch_size=2, + ) + + +class TestResumeDeprecation: + """Test --resume deprecated alias behavior and checkpoint_path shim.""" + + def test_checkpoint_path_kwarg_triggers_deprecation(self): + """Passing checkpoint_path= to executor.apply should emit DeprecationWarning.""" + import warnings + + from redisvl.migration.executor import MigrationExecutor + + executor = MigrationExecutor() + # We can't actually call apply() without a plan, but we can verify the + # parameter is accepted by checking the signature + import inspect + + sig = inspect.signature(executor.apply) + assert "checkpoint_path" in sig.parameters + assert "backup_dir" in sig.parameters + + def test_async_checkpoint_path_kwarg_accepted(self): + """Async executor should also accept checkpoint_path.""" + import inspect + + from redisvl.migration.async_executor import AsyncMigrationExecutor + + executor = AsyncMigrationExecutor() + sig = inspect.signature(executor.apply) + assert "checkpoint_path" in sig.parameters + assert "backup_dir" in sig.parameters + + def test_resume_flag_rejects_yaml_file(self, tmp_path): + """--resume with a .yaml path should be rejected by the CLI.""" + import os + + # Create a fake yaml file + yaml_path = tmp_path / "checkpoint.yaml" + yaml_path.write_text("test: true") + # The CLI checks: os.path.isfile(value) or value.endswith(('.yaml', '.yml')) + assert os.path.isfile(str(yaml_path)) + assert str(yaml_path).endswith(".yaml") + + def test_resume_flag_accepts_directory(self, tmp_path): + """--resume with a directory path should be accepted as backup_dir.""" + import os + + assert os.path.isdir(str(tmp_path)) + assert not str(tmp_path).endswith((".yaml", ".yml")) diff --git a/tests/unit/test_pipeline_quantize.py b/tests/unit/test_pipeline_quantize.py new file mode 100644 index 000000000..6f9944f14 --- /dev/null +++ b/tests/unit/test_pipeline_quantize.py @@ -0,0 +1,164 @@ +"""Tests for pipelined read/write quantization. + +TDD: tests written BEFORE refactoring _quantize_vectors. + +Tests the new quantize flow: + 1. Pipeline-read original vectors (dump phase) + 2. Convert dtype in memory + 3. Pipeline-write converted vectors (quantize phase) +""" + +import struct +from typing import Any, Dict, List +from unittest.mock import MagicMock, call, patch + +import pytest + + +def _make_float32_vector(dims: int = 4, seed: float = 0.0) -> bytes: + """Create a fake float32 vector.""" + return struct.pack(f"<{dims}f", *[seed + i for i in range(dims)]) + + +class TestPipelineReadBatch: + """Test that vector reads are pipelined, not individual HGET calls.""" + + def test_pipeline_read_batches_hgets(self): + """A batch of N keys with F fields should produce N*F pipelined HGET + calls and exactly 1 pipe.execute() — not N*F individual client.hget().""" + from redisvl.migration.backup import VectorBackup + + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + dims = 4 + keys = [f"doc:{i}" for i in range(5)] + vec = _make_float32_vector(dims) + # Pipeline execute returns one result per hget call + mock_pipe.execute.return_value = [vec] * 5 + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": dims} + } + + from redisvl.migration.quantize import pipeline_read_vectors + + result = pipeline_read_vectors(mock_client, keys, datatype_changes) + + # Should call pipeline(), not client.hget() + mock_client.pipeline.assert_called_once_with(transaction=False) + assert mock_pipe.hget.call_count == 5 + # Exactly 1 execute call (not 5) + mock_pipe.execute.assert_called_once() + # Should NOT call client.hget directly + mock_client.hget.assert_not_called() + # Returns dict of {key: {field: bytes}} + assert len(result) == 5 + assert result["doc:0"]["embedding"] == vec + + def test_pipeline_read_multiple_fields(self): + """Keys with multiple vector fields produce N*F pipelined HGETs.""" + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + dims = 4 + keys = ["doc:0", "doc:1"] + vec = _make_float32_vector(dims) + # 2 keys × 2 fields = 4 results + mock_pipe.execute.return_value = [vec, vec, vec, vec] + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": dims}, + "title_vec": {"source": "float32", "target": "float16", "dims": dims}, + } + + from redisvl.migration.quantize import pipeline_read_vectors + + result = pipeline_read_vectors(mock_client, keys, datatype_changes) + + assert mock_pipe.hget.call_count == 4 + mock_pipe.execute.assert_called_once() + assert "embedding" in result["doc:0"] + assert "title_vec" in result["doc:0"] + + def test_pipeline_read_handles_missing_keys(self): + """Missing keys (hget returns None) should be excluded from results.""" + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + keys = ["doc:0", "doc:1"] + vec = _make_float32_vector() + # doc:0 has data, doc:1 is missing + mock_pipe.execute.return_value = [vec, None] + + datatype_changes = { + "embedding": {"source": "float32", "target": "float16", "dims": 4} + } + + from redisvl.migration.quantize import pipeline_read_vectors + + result = pipeline_read_vectors(mock_client, keys, datatype_changes) + + assert "embedding" in result["doc:0"] + # doc:1 should have empty field dict or be excluded + assert result.get("doc:1", {}).get("embedding") is None + + +class TestPipelineWriteBatch: + """Test that converted vectors are written via pipeline.""" + + def test_pipeline_write_batches_hsets(self): + """Writing N keys should produce N pipelined HSET calls and 1 execute.""" + mock_client = MagicMock() + mock_pipe = MagicMock() + mock_client.pipeline.return_value = mock_pipe + + converted = { + "doc:0": {"embedding": b"\x00\x01\x02\x03"}, + "doc:1": {"embedding": b"\x04\x05\x06\x07"}, + } + + from redisvl.migration.quantize import pipeline_write_vectors + + pipeline_write_vectors(mock_client, converted) + + mock_client.pipeline.assert_called_once_with(transaction=False) + assert mock_pipe.hset.call_count == 2 + mock_pipe.execute.assert_called_once() + + def test_pipeline_write_skips_empty(self): + """If no keys to write, don't create a pipeline at all.""" + mock_client = MagicMock() + + from redisvl.migration.quantize import pipeline_write_vectors + + pipeline_write_vectors(mock_client, {}) + + mock_client.pipeline.assert_not_called() + + +class TestConvertVectors: + """Test dtype conversion logic.""" + + def test_convert_float32_to_float16(self): + import numpy as np + + from redisvl.migration.quantize import convert_vectors + + dims = 4 + vec = _make_float32_vector(dims, seed=1.0) + originals = {"doc:0": {"embedding": vec}} + changes = { + "embedding": {"source": "float32", "target": "float16", "dims": dims} + } + + converted = convert_vectors(originals, changes) + + # Result should be float16 bytes (2 bytes per dim) + assert len(converted["doc:0"]["embedding"]) == dims * 2 + # Verify values round-trip through float16 + arr = np.frombuffer(converted["doc:0"]["embedding"], dtype=np.float16) + np.testing.assert_allclose(arr, [1.0, 2.0, 3.0, 4.0], rtol=1e-3) diff --git a/tests/unit/test_vector_backup.py b/tests/unit/test_vector_backup.py new file mode 100644 index 000000000..1aff0bd38 --- /dev/null +++ b/tests/unit/test_vector_backup.py @@ -0,0 +1,549 @@ +"""Tests for VectorBackup — the backup file for crash-safe quantization. + +TDD: these tests are written BEFORE the implementation. +""" + +import os +import struct +import tempfile + +import pytest + + +class TestVectorBackupCreate: + """Test creating a new backup file.""" + + def test_create_new_backup(self, tmp_path): + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + backup = VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={ + "embedding": {"source": "float32", "target": "float16", "dims": 768} + }, + batch_size=500, + ) + assert backup.header.index_name == "myindex" + assert backup.header.phase == "dump" + assert backup.header.dump_completed_batches == 0 + assert backup.header.quantize_completed_batches == 0 + assert backup.header.batch_size == 500 + assert backup.header.fields == { + "embedding": {"source": "float32", "target": "float16", "dims": 768} + } + + def test_create_writes_header_to_disk(self, tmp_path): + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={ + "embedding": {"source": "float32", "target": "float16", "dims": 768} + }, + batch_size=500, + ) + # Header file should exist + assert os.path.exists(backup_path + ".header") + + def test_create_raises_if_already_exists(self, tmp_path): + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={ + "embedding": {"source": "float32", "target": "float16", "dims": 768} + }, + ) + with pytest.raises(FileExistsError): + VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={ + "embedding": {"source": "float32", "target": "float16", "dims": 768} + }, + ) + + +class TestVectorBackupDump: + """Test writing batches during the dump phase.""" + + def _make_backup(self, tmp_path, batch_size=500): + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + return VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={"embedding": {"source": "float32", "target": "float16", "dims": 4}}, + batch_size=batch_size, + ) + + def _fake_vector(self, dims=4): + """Create a fake float32 vector.""" + return struct.pack(f"<{dims}f", *[float(i) for i in range(dims)]) + + def test_write_batch(self, tmp_path): + backup = self._make_backup(tmp_path, batch_size=2) + keys = ["doc:0", "doc:1"] + originals = { + "doc:0": {"embedding": self._fake_vector()}, + "doc:1": {"embedding": self._fake_vector()}, + } + backup.write_batch(0, keys, originals) + assert backup.header.dump_completed_batches == 1 + + def test_write_multiple_batches(self, tmp_path): + backup = self._make_backup(tmp_path, batch_size=2) + vec = self._fake_vector() + for batch_idx in range(4): + keys = [f"doc:{batch_idx * 2}", f"doc:{batch_idx * 2 + 1}"] + originals = {k: {"embedding": vec} for k in keys} + backup.write_batch(batch_idx, keys, originals) + assert backup.header.dump_completed_batches == 4 + + def test_mark_dump_complete_transitions_to_ready(self, tmp_path): + backup = self._make_backup(tmp_path, batch_size=2) + vec = self._fake_vector() + backup.write_batch( + 0, ["doc:0", "doc:1"], {k: {"embedding": vec} for k in ["doc:0", "doc:1"]} + ) + backup.mark_dump_complete() + assert backup.header.phase == "ready" + + def test_iter_batches_returns_all_dumped_data(self, tmp_path): + backup = self._make_backup(tmp_path, batch_size=2) + vec = self._fake_vector() + + # Write 2 batches + for batch_idx in range(2): + keys = [f"doc:{batch_idx * 2}", f"doc:{batch_idx * 2 + 1}"] + originals = {k: {"embedding": vec} for k in keys} + backup.write_batch(batch_idx, keys, originals) + backup.mark_dump_complete() + + # Read them back + batches = list(backup.iter_batches()) + assert len(batches) == 2 + batch_keys, batch_data = batches[0] + assert batch_keys == ["doc:0", "doc:1"] + assert batch_data["doc:0"]["embedding"] == vec + assert batch_data["doc:1"]["embedding"] == vec + + def test_write_batch_wrong_phase_raises(self, tmp_path): + backup = self._make_backup(tmp_path, batch_size=2) + vec = self._fake_vector() + backup.write_batch( + 0, ["doc:0", "doc:1"], {k: {"embedding": vec} for k in ["doc:0", "doc:1"]} + ) + backup.mark_dump_complete() + # Now in "ready" phase — writing another batch should fail + with pytest.raises(ValueError, match="Cannot write batch.*phase"): + backup.write_batch(1, ["doc:2"], {"doc:2": {"embedding": vec}}) + + +class TestVectorBackupQuantize: + """Test quantize phase progress tracking.""" + + def _make_dumped_backup(self, tmp_path, num_keys=8, batch_size=2, dims=4): + """Create a backup that has completed the dump phase.""" + import struct + + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + backup = VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={ + "embedding": {"source": "float32", "target": "float16", "dims": dims} + }, + batch_size=batch_size, + ) + vec = struct.pack(f"<{dims}f", *[float(i) for i in range(dims)]) + num_batches = (num_keys + batch_size - 1) // batch_size + for batch_idx in range(num_batches): + start = batch_idx * batch_size + end = min(start + batch_size, num_keys) + keys = [f"doc:{j}" for j in range(start, end)] + originals = {k: {"embedding": vec} for k in keys} + backup.write_batch(batch_idx, keys, originals) + backup.mark_dump_complete() + return backup + + def test_mark_batch_quantized(self, tmp_path): + backup = self._make_dumped_backup(tmp_path) + backup.start_quantize() # ready → active + assert backup.header.phase == "active" + backup.mark_batch_quantized(0) + assert backup.header.quantize_completed_batches == 1 + backup.mark_batch_quantized(1) + assert backup.header.quantize_completed_batches == 2 + + def test_mark_complete(self, tmp_path): + backup = self._make_dumped_backup(tmp_path, num_keys=4) + backup.start_quantize() + backup.mark_batch_quantized(0) + backup.mark_batch_quantized(1) + backup.mark_complete() + assert backup.header.phase == "completed" + + def test_iter_batches_skips_completed(self, tmp_path): + """After marking batches 0 and 1 as quantized, iter_remaining_batches + should only yield batches 2 and 3.""" + backup = self._make_dumped_backup(tmp_path) # 8 keys, batch_size=2 → 4 batches + backup.start_quantize() + backup.mark_batch_quantized(0) + backup.mark_batch_quantized(1) + + remaining = list(backup.iter_remaining_batches()) + assert len(remaining) == 2 + # Batch 2 starts at doc:4 + batch_keys, _ = remaining[0] + assert batch_keys[0] == "doc:4" + + +class TestVectorBackupResume: + """Test loading a backup file and resuming from where it left off.""" + + def _make_dumped_backup(self, tmp_path, num_keys=8, batch_size=2, dims=4): + import struct + + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + backup = VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={ + "embedding": {"source": "float32", "target": "float16", "dims": dims} + }, + batch_size=batch_size, + ) + vec = struct.pack(f"<{dims}f", *[float(i) for i in range(dims)]) + num_batches = (num_keys + batch_size - 1) // batch_size + for batch_idx in range(num_batches): + start = batch_idx * batch_size + end = min(start + batch_size, num_keys) + keys = [f"doc:{j}" for j in range(start, end)] + originals = {k: {"embedding": vec} for k in keys} + backup.write_batch(batch_idx, keys, originals) + backup.mark_dump_complete() + return backup, backup_path + + def test_load_returns_none_if_no_file(self, tmp_path): + from redisvl.migration.backup import VectorBackup + + result = VectorBackup.load(str(tmp_path / "nonexistent")) + assert result is None + + def test_load_restores_header(self, tmp_path): + from redisvl.migration.backup import VectorBackup + + backup, path = self._make_dumped_backup(tmp_path) + loaded = VectorBackup.load(path) + assert loaded is not None + assert loaded.header.index_name == "myindex" + assert loaded.header.phase == "ready" + assert loaded.header.dump_completed_batches == 4 + + def test_load_and_resume_quantize(self, tmp_path): + """Simulate crash: dump complete, 2 batches quantized, then crash. + On reload, iter_remaining_batches should skip the 2 completed.""" + from redisvl.migration.backup import VectorBackup + + backup, path = self._make_dumped_backup(tmp_path) + backup.start_quantize() + backup.mark_batch_quantized(0) + backup.mark_batch_quantized(1) + # "crash" — drop the object, reload from disk + del backup + + loaded = VectorBackup.load(path) + assert loaded is not None + assert loaded.header.phase == "active" + assert loaded.header.quantize_completed_batches == 2 + + remaining = list(loaded.iter_remaining_batches()) + assert len(remaining) == 2 + batch_keys, _ = remaining[0] + assert batch_keys[0] == "doc:4" + + def test_load_and_resume_dump(self, tmp_path): + """Simulate crash during dump: 2 of 4 batches dumped. + On reload, should see phase=dump, dump_completed_batches=2.""" + import struct + + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + backup = VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={"embedding": {"source": "float32", "target": "float16", "dims": 4}}, + batch_size=2, + ) + vec = struct.pack("<4f", 0.0, 1.0, 2.0, 3.0) + # Write only 2 of 4 expected batches + for batch_idx in range(2): + keys = [f"doc:{batch_idx * 2}", f"doc:{batch_idx * 2 + 1}"] + originals = {k: {"embedding": vec} for k in keys} + backup.write_batch(batch_idx, keys, originals) + # "crash" — don't call mark_dump_complete + del backup + + loaded = VectorBackup.load(backup_path) + assert loaded is not None + assert loaded.header.phase == "dump" + assert loaded.header.dump_completed_batches == 2 + # Can read back the 2 completed batches + batches = list(loaded.iter_batches()) + assert len(batches) == 2 + + +class TestVectorBackupRollback: + """Test reading originals for rollback.""" + + def test_rollback_reads_all_originals(self, tmp_path): + import struct + + from redisvl.migration.backup import VectorBackup + + backup_path = str(tmp_path / "test_backup") + backup = VectorBackup.create( + path=backup_path, + index_name="myindex", + fields={"embedding": {"source": "float32", "target": "float16", "dims": 4}}, + batch_size=2, + ) + vecs = {} + for i in range(4): + vec = struct.pack("<4f", *[float(i * 10 + j) for j in range(4)]) + vecs[f"doc:{i}"] = vec + + # Write 2 batches with distinct vectors + backup.write_batch( + 0, + ["doc:0", "doc:1"], + { + "doc:0": {"embedding": vecs["doc:0"]}, + "doc:1": {"embedding": vecs["doc:1"]}, + }, + ) + backup.write_batch( + 1, + ["doc:2", "doc:3"], + { + "doc:2": {"embedding": vecs["doc:2"]}, + "doc:3": {"embedding": vecs["doc:3"]}, + }, + ) + backup.mark_dump_complete() + + # Read all batches and verify originals are preserved + all_originals = {} + for batch_keys, batch_data in backup.iter_batches(): + all_originals.update(batch_data) + + assert len(all_originals) == 4 + for key in ["doc:0", "doc:1", "doc:2", "doc:3"]: + assert all_originals[key]["embedding"] == vecs[key] + + +class TestRollbackCLI: + """Tests for the rvl migrate rollback CLI command path derivation and restore logic.""" + + def _create_backup_with_data(self, tmp_path, name="test_idx"): + """Helper: create a backup with 2 batches of data.""" + from redisvl.migration.backup import VectorBackup + + bp = str(tmp_path / f"migration_backup_{name}") + vecs = { + "doc:0": struct.pack("<4f", 1.0, 2.0, 3.0, 4.0), + "doc:1": struct.pack("<4f", 5.0, 6.0, 7.0, 8.0), + } + backup = VectorBackup.create( + path=bp, + index_name=name, + fields={"embedding": {"source": "float32", "target": "float16", "dims": 4}}, + batch_size=1, + ) + backup.write_batch(0, ["doc:0"], {"doc:0": {"embedding": vecs["doc:0"]}}) + backup.write_batch(1, ["doc:1"], {"doc:1": {"embedding": vecs["doc:1"]}}) + backup.mark_dump_complete() + return bp, vecs + + def test_header_path_derivation_no_removesuffix(self, tmp_path): + """Verify path derivation works without str.removesuffix (Python 3.8 compat).""" + from pathlib import Path + + bp, _ = self._create_backup_with_data(tmp_path) + header_files = sorted(Path(tmp_path).glob("*.header")) + assert len(header_files) == 1 + # This is how the CLI derives backup paths — must not use removesuffix + derived = str(header_files[0].with_suffix("")) + assert derived == bp + + def test_rollback_restores_via_iter_batches(self, tmp_path): + """Verify rollback reads all batches and gets correct original vectors.""" + from redisvl.migration.backup import VectorBackup + + bp, vecs = self._create_backup_with_data(tmp_path) + backup = VectorBackup.load(bp) + assert backup is not None + + restored = {} + for batch_keys, originals in backup.iter_batches(): + for key in batch_keys: + if key in originals: + restored[key] = originals[key] + + assert len(restored) == 2 + assert restored["doc:0"]["embedding"] == vecs["doc:0"] + assert restored["doc:1"]["embedding"] == vecs["doc:1"] + + def test_rollback_nonexistent_dir(self): + """Verify error handling for missing backup directory.""" + import os + + assert not os.path.isdir("/nonexistent/backup/dir/xyz123") + + def test_rollback_empty_dir(self, tmp_path): + """Verify no header files found in empty directory.""" + from pathlib import Path + + header_files = sorted(Path(tmp_path).glob("*.header")) + assert len(header_files) == 0 + + def test_rollback_unloadable_backup_returns_none(self, tmp_path): + """VectorBackup.load returns None for corrupt/missing data.""" + from redisvl.migration.backup import VectorBackup + + # Create header but no data file + bp = str(tmp_path / "bad_backup") + result = VectorBackup.load(bp) + assert result is None + + def test_rollback_skips_incomplete_backup_phase(self, tmp_path): + """Backups in 'dump' phase should be skipped without --force.""" + from redisvl.migration.backup import VectorBackup + + bp = str(tmp_path / "migration_backup_partial") + backup = VectorBackup.create( + path=bp, + index_name="partial_idx", + fields={"embedding": {"source": "float32", "target": "float16", "dims": 4}}, + batch_size=1, + ) + # Write one batch but don't mark dump complete — phase stays "dump" + backup.write_batch(0, ["doc:0"], {"doc:0": {"embedding": b"\x00" * 16}}) + # Phase is "dump" — not in safe rollback phases + assert backup.header.phase == "dump" + safe_phases = frozenset({"ready", "active", "completed"}) + assert backup.header.phase not in safe_phases + + def test_rollback_index_filter(self, tmp_path): + """--index filter should match only backups for the specified index.""" + self._create_backup_with_data(tmp_path, name="idx_a") + self._create_backup_with_data(tmp_path, name="idx_b") + + from pathlib import Path + + from redisvl.migration.backup import VectorBackup + + header_files = sorted(Path(tmp_path).glob("*.header")) + assert len(header_files) == 2 + + # Filter for idx_a only + backup_paths = [str(h.with_suffix("")) for h in header_files] + filtered = [] + for bp in backup_paths: + backup = VectorBackup.load(bp) + if backup and backup.header.index_name == "idx_a": + filtered.append(bp) + assert len(filtered) == 1 + assert "idx_a" in filtered[0] + + def test_rollback_multi_index_requires_flag(self, tmp_path): + """Multiple distinct indexes should require --index or --yes.""" + self._create_backup_with_data(tmp_path, name="idx_a") + self._create_backup_with_data(tmp_path, name="idx_b") + + from pathlib import Path + + from redisvl.migration.backup import VectorBackup + + header_files = sorted(Path(tmp_path).glob("*.header")) + backup_paths = [str(h.with_suffix("")) for h in header_files] + backups = [] + for bp in backup_paths: + backup = VectorBackup.load(bp) + if backup: + backups.append(backup) + distinct = {b.header.index_name for b in backups} + assert len(distinct) > 1 # Multi-index — should require --index or --yes + + +class TestBackupCleanup: + """Tests for tightened backup file cleanup.""" + + def test_cleanup_only_removes_known_extensions(self, tmp_path): + """Cleanup should only remove .header and .data files.""" + # Create files with various extensions + (tmp_path / "migration_backup_test.header").touch() + (tmp_path / "migration_backup_test.data").touch() + (tmp_path / "migration_backup_test.log").touch() # should NOT be deleted + (tmp_path / "migration_backup_test_shard_0.header").touch() + (tmp_path / "migration_backup_test_shard_0.data").touch() + (tmp_path / "unrelated_file.txt").touch() # should NOT be deleted + + # Simulate the cleanup logic + base_prefix = "migration_backup_test" + known_suffixes = (".header", ".data") + deleted = [] + for entry in tmp_path.iterdir(): + if not entry.is_file(): + continue + name = entry.name + if not name.startswith(base_prefix): + continue + if not any(name.endswith(s) for s in known_suffixes): + continue + remainder = name[len(base_prefix) :] + if remainder and remainder[0] not in (".", "_"): + continue + deleted.append(name) + + assert "migration_backup_test.header" in deleted + assert "migration_backup_test.data" in deleted + assert "migration_backup_test_shard_0.header" in deleted + assert "migration_backup_test_shard_0.data" in deleted + assert "migration_backup_test.log" not in deleted + assert "unrelated_file.txt" not in deleted + + def test_cleanup_does_not_match_similar_prefix(self, tmp_path): + """migration_backup_foo should not match migration_backup_foobar.""" + (tmp_path / "migration_backup_foo.header").touch() + (tmp_path / "migration_backup_foobar.header").touch() + + base_prefix = "migration_backup_foo" + known_suffixes = (".header", ".data") + deleted = [] + for entry in tmp_path.iterdir(): + name = entry.name + if not name.startswith(base_prefix): + continue + if not any(name.endswith(s) for s in known_suffixes): + continue + remainder = name[len(base_prefix) :] + if remainder and remainder[0] not in (".", "_"): + continue + deleted.append(name) + + assert "migration_backup_foo.header" in deleted + assert "migration_backup_foobar.header" not in deleted