Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ qdrant = [
redis = [
"redis[hiredis] >= 6,< 8",
"types-redis ~= 4.6.0.20240425",
"redisvl ~= 0.4"
"redisvl >= 0.5"
]
realtime = [
"websockets >= 13, < 16",
Expand Down
12 changes: 8 additions & 4 deletions python/semantic_kernel/connectors/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from redisvl.query.filter import FilterExpression, Num, Tag, Text
from redisvl.query.query import BaseQuery, VectorQuery
from redisvl.redis.utils import array_to_buffer, buffer_to_array, convert_bytes
from redisvl.schema import StorageType
from redisvl.schema import IndexSchema as _RedisVLIndexSchema, StorageType

from semantic_kernel.connectors.ai.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.data.vector import (
Expand Down Expand Up @@ -321,7 +321,10 @@ async def _inner_search(
results = await self.redis_database.ft(self.collection_name).search( # type: ignore
query=query.query, query_params=query.params
)
processed = process_results(results, query, STORAGE_TYPE_MAP[self.collection_type])
schema = _RedisVLIndexSchema.from_dict(
{"index": {"name": self.collection_name, "storage_type": STORAGE_TYPE_MAP[self.collection_type].value}}
)
processed = process_results(results, query, schema)
return KernelSearchResults(
results=self._get_vector_search_results_from_results(desync_list(processed)),
total_count=results.total,
Expand Down Expand Up @@ -616,8 +619,9 @@ def _deserialize_store_models_to_dicts(
case FieldTypes.KEY:
rec[field.name] = self._unget_redis_key(rec[field.name])
case "vector":
dtype = DATATYPE_MAP_VECTOR[field.type_ or "default"]
rec[field.name] = buffer_to_array(rec[field.name], dtype)
if field.name in rec:
dtype = DATATYPE_MAP_VECTOR[field.type_ or "default"]
rec[field.name] = buffer_to_array(rec[field.name], dtype)
results.append(rec)
return results

Expand Down
43 changes: 43 additions & 0 deletions python/tests/unit/connectors/memory/test_redis_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,46 @@ async def test_create_index_manual(collection_hash, mock_ensure_collection_exist
async def test_create_index_fail(collection_hash, mock_ensure_collection_exists):
with raises(VectorStoreOperationException, match="Invalid index type supplied."):
await collection_hash.ensure_collection_exists(index_definition="index_definition", fields="fields")


def test_deserialize_hashset_skips_missing_vector_field(collection_hash):
# Simulate search results with include_vectors=False: vector key is absent.
records = [{"id": "id1", "content": "hello"}]
result = collection_hash._deserialize_store_models_to_dicts(records)
assert len(result) == 1
assert result[0]["id"] == "id1"
assert result[0]["content"] == "hello"
assert "vector" not in result[0]


@mark.parametrize("type_", ["hashset", "json"])
async def test_inner_search_passes_index_schema_to_process_results(
collection_hash, collection_json, type_
):
from unittest.mock import MagicMock

from redisvl.schema import IndexSchema, StorageType

from semantic_kernel.data.vector import SearchType, VectorSearchOptions

collection = collection_hash if type_ == "hashset" else collection_json
expected_storage = StorageType.HASH if type_ == "hashset" else StorageType.JSON

mock_results = MagicMock()
mock_results.docs = []
mock_results.total = 0

with patch("redis.commands.search.AsyncSearch.search", new=AsyncMock(return_value=mock_results)):
with patch(
"semantic_kernel.connectors.redis.process_results", return_value=[]
) as mock_process:
await collection._inner_search(
search_type=SearchType.VECTOR,
options=VectorSearchOptions(vector_property_name="vector", top=3),
vector=[1.0, 2.0, 3.0, 4.0, 5.0],
)

mock_process.assert_called_once()
_results_arg, _query_arg, schema_arg = mock_process.call_args.args
assert isinstance(schema_arg, IndexSchema), "process_results must receive an IndexSchema, not a StorageType"
assert schema_arg.index.storage_type == expected_storage
Loading