Skip to content

Commit

Permalink
Python: Qdrant - fix in filter and 100% test coverage (#9982)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
There was a small error in the filter creation logic, and improved test
coverage for Qdrant.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
eavanvalkenburg authored Dec 16, 2024
1 parent 5874188 commit 62a50f3
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async def _inner_search(
else:
query_vector = vector
if query_vector is None:
raise VectorSearchExecutionException("Search requires either a vector.")
raise VectorSearchExecutionException("Search requires a vector.")
results = await self.qdrant_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
Expand All @@ -214,7 +214,7 @@ def _get_score_from_result(self, result: ScoredPoint) -> float:
def _create_filter(self, options: VectorSearchOptions) -> Filter:
return Filter(
must=[
FieldCondition(key=filter.field_name, match=MatchAny(any=filter.value))
FieldCondition(key=filter.field_name, match=MatchAny(any=[filter.value]))
for filter in options.filter.filters
]
)
Expand Down
84 changes: 71 additions & 13 deletions python/tests/unit/connectors/memory/qdrant/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

from pytest import fixture, mark, raises
from qdrant_client.async_qdrant_client import AsyncQdrantClient
from qdrant_client.models import Datatype, Distance, VectorParams
from qdrant_client.models import Datatype, Distance, FieldCondition, Filter, MatchAny, VectorParams

from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection
from semantic_kernel.connectors.memory.qdrant.qdrant_store import QdrantStore
from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField
from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.exceptions.memory_connector_exceptions import (
MemoryConnectorException,
MemoryConnectorInitializationError,
VectorStoreModelValidationError,
)
from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException

BASE_PATH = "qdrant_client.async_qdrant_client.AsyncQdrantClient"

Expand Down Expand Up @@ -119,9 +121,10 @@ def mock_search():
yield mock_search


def test_vector_store_defaults(vector_store):
assert vector_store.qdrant_client is not None
assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333"
async def test_vector_store_defaults(vector_store):
async with vector_store:
assert vector_store.qdrant_client is not None
assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333"


def test_vector_store_with_client():
Expand Down Expand Up @@ -162,18 +165,18 @@ def test_get_collection(vector_store, data_model_definition, qdrant_unit_test_en
assert vector_store.vector_record_collections["test"] == collection


def test_collection_init(data_model_definition, qdrant_unit_test_env):
collection = QdrantCollection(
async def test_collection_init(data_model_definition, qdrant_unit_test_env):
async with QdrantCollection(
data_model_type=dict,
collection_name="test",
data_model_definition=data_model_definition,
env_file_path="test.env",
)
assert collection.collection_name == "test"
assert collection.qdrant_client is not None
assert collection.data_model_type is dict
assert collection.data_model_definition == data_model_definition
assert collection.named_vectors
) as collection:
assert collection.collection_name == "test"
assert collection.qdrant_client is not None
assert collection.data_model_type is dict
assert collection.data_model_definition == data_model_definition
assert collection.named_vectors


def test_collection_init_fail(data_model_definition):
Expand Down Expand Up @@ -275,8 +278,63 @@ async def test_create_index_fail(collection_to_use, request):
await collection.create_collection()


async def test_search(collection):
async def test_search(collection, mock_search):
results = await collection._inner_search(vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(include_vectors=False))
async for result in results.results:
assert result.record["id"] == "id1"
break

assert mock_search.call_count == 1
mock_search.assert_called_with(
collection_name="test",
query_vector=[1.0, 2.0, 3.0],
query_filter=Filter(must=[]),
with_vectors=False,
limit=3,
offset=0,
)


async def test_search_named_vectors(collection, mock_search):
collection.named_vectors = True
results = await collection._inner_search(
vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(vector_field_name="vector", include_vectors=False)
)
async for result in results.results:
assert result.record["id"] == "id1"
break

assert mock_search.call_count == 1
mock_search.assert_called_with(
collection_name="test",
query_vector=("vector", [1.0, 2.0, 3.0]),
query_filter=Filter(must=[]),
with_vectors=False,
limit=3,
offset=0,
)


async def test_search_filter(collection, mock_search):
results = await collection._inner_search(
vector=[1.0, 2.0, 3.0],
options=VectorSearchOptions(include_vectors=False, filter=VectorSearchFilter.equal_to("id", "id1")),
)
async for result in results.results:
assert result.record["id"] == "id1"
break

assert mock_search.call_count == 1
mock_search.assert_called_with(
collection_name="test",
query_vector=[1.0, 2.0, 3.0],
query_filter=Filter(must=[FieldCondition(key="id", match=MatchAny(any=["id1"]))]),
with_vectors=False,
limit=3,
offset=0,
)


async def test_search_fail(collection):
with raises(VectorSearchExecutionException, match="Search requires a vector."):
await collection._inner_search(options=VectorSearchOptions(include_vectors=False))

0 comments on commit 62a50f3

Please sign in to comment.