From 9edfaa2e6b5db4f8f010ce66449211dd9ed544aa Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 16 Dec 2024 12:42:04 +0100 Subject: [PATCH] improved text_search folder testing --- .../data/text_search/text_search.py | 2 +- .../semantic_kernel/data/text_search/utils.py | 28 +++--- python/tests/unit/data/test_text_search.py | 97 +++++++++++++++++-- .../data/test_vector_store_text_search.py | 82 ++++++++++++++++ 4 files changed, 183 insertions(+), 26 deletions(-) diff --git a/python/semantic_kernel/data/text_search/text_search.py b/python/semantic_kernel/data/text_search/text_search.py index ff7b6c416435..d40f7169786c 100644 --- a/python/semantic_kernel/data/text_search/text_search.py +++ b/python/semantic_kernel/data/text_search/text_search.py @@ -294,7 +294,7 @@ async def _map_results( return [self._default_map_to_string(result) async for result in results.results] @staticmethod - def _default_map_to_string(result: Any) -> str: + def _default_map_to_string(result: BaseModel | object) -> str: """Default mapping function for text search results.""" if isinstance(result, BaseModel): return result.model_dump_json() diff --git a/python/semantic_kernel/data/text_search/utils.py b/python/semantic_kernel/data/text_search/utils.py index eb60f87b3d82..44e432f1ec36 100644 --- a/python/semantic_kernel/data/text_search/utils.py +++ b/python/semantic_kernel/data/text_search/utils.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +import logging +from contextlib import suppress from typing import TYPE_CHECKING, Any, Protocol from pydantic import ValidationError @@ -8,6 +10,8 @@ from semantic_kernel.data.search_options import SearchOptions from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata +logger = logging.getLogger(__name__) + class OptionsUpdateFunctionType(Protocol): """Type definition for the options update function in Text Search.""" @@ -20,7 +24,7 @@ def __call__( **kwargs: Any, ) -> tuple[str, "SearchOptions"]: """Signature of the function.""" - ... + ... # pragma: no cover def create_options( @@ -44,30 +48,24 @@ def create_options( SearchOptions: The options. """ + new_options = options_class() if options: if not isinstance(options, options_class): + inputs = None try: - # Validate the options in one go - new_options = options_class.model_validate( - options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True), - ) - except ValidationError: - # if that fails, go one by one - new_options = options_class() - for key, value in options.model_dump( - exclude_none=True, exclude_defaults=True, exclude_unset=True - ).items(): - setattr(new_options, key, value) + inputs = options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True) + except Exception: + logger.warning("Options are not valid. Creating new options.") + if inputs: + new_options = options_class.model_validate(inputs) else: new_options = options for key, value in kwargs.items(): if key in new_options.model_fields: setattr(new_options, key, value) else: - try: + with suppress(ValidationError): new_options = options_class(**kwargs) - except ValidationError: - new_options = options_class() return new_options diff --git a/python/tests/unit/data/test_text_search.py b/python/tests/unit/data/test_text_search.py index 5b03b67e52e9..74a10909317b 100644 --- a/python/tests/unit/data/test_text_search.py +++ b/python/tests/unit/data/test_text_search.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import AsyncGenerator from typing import Any from unittest.mock import patch @@ -21,6 +22,7 @@ ) from semantic_kernel.exceptions import TextSearchException from semantic_kernel.functions import KernelArguments, KernelParameterMetadata +from semantic_kernel.utils.list_handler import desync_list def test_text_search(): @@ -33,7 +35,7 @@ class TestSearch(TextSearch): async def search(self, **kwargs) -> KernelSearchResults[Any]: """Test search function.""" - async def generator() -> str: + async def generator() -> AsyncGenerator[str, None]: yield "test" return KernelSearchResults(results=generator(), metadata=kwargs) @@ -43,7 +45,7 @@ async def get_text_search_results( ) -> KernelSearchResults[TextSearchResult]: """Test get text search result function.""" - async def generator() -> TextSearchResult: + async def generator() -> AsyncGenerator[TextSearchResult, None]: yield TextSearchResult(value="test") return KernelSearchResults(results=generator(), metadata=kwargs) @@ -53,7 +55,7 @@ async def get_search_results( ) -> KernelSearchResults[Any]: """Test get search result function.""" - async def generator() -> str: + async def generator() -> AsyncGenerator[str, None]: yield "test" return KernelSearchResults(results=generator(), metadata=kwargs) @@ -190,12 +192,18 @@ async def test_create_kernel_function_inner_update_options(kernel: Kernel): called = False args = {} - def update_options(**kwargs: Any) -> tuple[str, SearchOptions]: - kwargs["options"].filter.equal_to("address/city", kwargs.get("city")) + def update_options( + query: str, + options: "SearchOptions", + parameters: list["KernelParameterMetadata"] | None = None, + **kwargs: Any, + ) -> tuple[str, SearchOptions]: + options.filter.equal_to("address/city", kwargs.get("city", "")) nonlocal called, args called = True - args = kwargs - return kwargs["query"], kwargs["options"] + args = {"query": query, "options": options, "parameters": parameters} + args.update(kwargs) + return query, options kernel_function = test_search._create_kernel_function( search_function="search", @@ -225,14 +233,29 @@ def update_options(**kwargs: Any) -> tuple[str, SearchOptions]: assert "parameters" in args -def test_default_map_to_string(): +async def test_default_map_to_string(): test_search = TestSearch() - assert test_search._default_map_to_string("test") == "test" + assert (await test_search._map_results(results=KernelSearchResults(results=desync_list(["test"])))) == ["test"] class TestClass(BaseModel): test: str - assert test_search._default_map_to_string(TestClass(test="test")) == '{"test":"test"}' + assert ( + await test_search._map_results(results=KernelSearchResults(results=desync_list([TestClass(test="test")]))) + ) == ['{"test":"test"}'] + + +async def test_custom_map_to_string(): + test_search = TestSearch() + + class TestClass(BaseModel): + test: str + + assert ( + await test_search._map_results( + results=KernelSearchResults(results=desync_list([TestClass(test="test")])), string_mapper=lambda x: x.test + ) + ) == ["test"] def test_create_options(): @@ -253,6 +276,27 @@ def test_create_options_none(): assert new_options.top == 1 +def test_create_options_vector_to_text(): + options = VectorSearchOptions(top=2, skip=1, include_vectors=True) + options_class = TextSearchOptions + new_options = create_options(options_class, options, top=1) + assert new_options is not None + assert isinstance(new_options, options_class) + assert new_options.top == 1 + assert getattr(new_options, "include_vectors", None) is None + + +def test_create_options_from_dict(): + options = {"skip": 1} + options_class = TextSearchOptions + new_options = create_options(options_class, options, top=1) # type: ignore + assert new_options is not None + assert isinstance(new_options, options_class) + assert new_options.top == 1 + # if a non SearchOptions object is passed in, it should be ignored + assert new_options.skip == 0 + + def test_default_options_update_function(): options = SearchOptions() params = [ @@ -267,3 +311,36 @@ def test_default_options_update_function(): assert options.filter.filters[0].value == "test" assert options.filter.filters[1].field_name == "test2" assert options.filter.filters[1].value == "test2" + + +def test_public_create_functions_search(): + test_search = TestSearch() + function = test_search.create_search() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 + + +def test_public_create_functions_get_text_search_results(): + test_search = TestSearch() + function = test_search.create_get_text_search_results() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 + + +def test_public_create_functions_get_search_results(): + test_search = TestSearch() + function = test_search.create_get_search_results() + assert function is not None + assert function.name == "search" + assert ( + function.description == "Perform a search for content related to the specified query and return string results" + ) + assert len(function.parameters) == 3 diff --git a/python/tests/unit/data/test_vector_store_text_search.py b/python/tests/unit/data/test_vector_store_text_search.py index 0f485349d098..e03a104492eb 100644 --- a/python/tests/unit/data/test_vector_store_text_search.py +++ b/python/tests/unit/data/test_vector_store_text_search.py @@ -2,11 +2,16 @@ from unittest.mock import patch +from pydantic import BaseModel from pytest import fixture, raises from semantic_kernel.connectors.ai.open_ai import AzureTextEmbedding from semantic_kernel.data import VectorStoreTextSearch +from semantic_kernel.data.text_search.text_search_result import TextSearchResult +from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions +from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult from semantic_kernel.exceptions import VectorStoreTextSearchValidationError +from semantic_kernel.utils.list_handler import desync_list @fixture @@ -28,6 +33,7 @@ async def test_from_vectorizable_text_search(vector_collection): assert search is not None assert text_search_result is not None assert search_result is not None + assert vsts.options_class is VectorSearchOptions async def test_from_vector_text_search(vector_collection): @@ -67,3 +73,79 @@ def test_validation_no_embedder_for_vectorized_search(vector_collection): def test_validation_no_collections(): with raises(VectorStoreTextSearchValidationError): VectorStoreTextSearch() + + +async def test_get_results_as_string(vector_collection): + test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection) + results = [ + res + async for res in test_search._get_results_as_strings(results=desync_list([VectorSearchResult(record="test")])) + ] + assert results == ["test"] + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_strings( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == ['{"test":"test"}'] + + test_search = VectorStoreTextSearch.from_vector_text_search( + vector_text_search=vector_collection, string_mapper=lambda x: x.test + ) + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_strings( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == ["test"] + + +async def test_get_results_as_test_search_result(vector_collection): + test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection) + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record="test")]) + ) + ] + assert results == [TextSearchResult(value="test")] + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == [TextSearchResult(value='{"test":"test"}')] + + test_search = VectorStoreTextSearch.from_vector_text_search( + vector_text_search=vector_collection, text_search_results_mapper=lambda x: TextSearchResult(value=x.test) + ) + + class TestClass(BaseModel): + test: str + + results = [ + res + async for res in test_search._get_results_as_text_search_result( + results=desync_list([VectorSearchResult(record=TestClass(test="test"))]) + ) + ] + + assert results == [TextSearchResult(value="test")]