Skip to content

Commit

Permalink
improved text_search folder testing
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Dec 16, 2024
1 parent 7c25ac4 commit 9edfaa2
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 26 deletions.
2 changes: 1 addition & 1 deletion python/semantic_kernel/data/text_search/text_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ async def _map_results(
return [self._default_map_to_string(result) async for result in results.results]

@staticmethod
def _default_map_to_string(result: Any) -> str:
def _default_map_to_string(result: BaseModel | object) -> str:
"""Default mapping function for text search results."""
if isinstance(result, BaseModel):
return result.model_dump_json()
Expand Down
28 changes: 13 additions & 15 deletions python/semantic_kernel/data/text_search/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Protocol

from pydantic import ValidationError
Expand All @@ -8,6 +10,8 @@
from semantic_kernel.data.search_options import SearchOptions
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata

logger = logging.getLogger(__name__)


class OptionsUpdateFunctionType(Protocol):
"""Type definition for the options update function in Text Search."""
Expand All @@ -20,7 +24,7 @@ def __call__(
**kwargs: Any,
) -> tuple[str, "SearchOptions"]:
"""Signature of the function."""
...
... # pragma: no cover


def create_options(
Expand All @@ -44,30 +48,24 @@ def create_options(
SearchOptions: The options.
"""
new_options = options_class()
if options:
if not isinstance(options, options_class):
inputs = None
try:
# Validate the options in one go
new_options = options_class.model_validate(
options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
)
except ValidationError:
# if that fails, go one by one
new_options = options_class()
for key, value in options.model_dump(
exclude_none=True, exclude_defaults=True, exclude_unset=True
).items():
setattr(new_options, key, value)
inputs = options.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True)
except Exception:
logger.warning("Options are not valid. Creating new options.")
if inputs:
new_options = options_class.model_validate(inputs)
else:
new_options = options
for key, value in kwargs.items():
if key in new_options.model_fields:
setattr(new_options, key, value)
else:
try:
with suppress(ValidationError):
new_options = options_class(**kwargs)
except ValidationError:
new_options = options_class()
return new_options


Expand Down
97 changes: 87 additions & 10 deletions python/tests/unit/data/test_text_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import patch

Expand All @@ -21,6 +22,7 @@
)
from semantic_kernel.exceptions import TextSearchException
from semantic_kernel.functions import KernelArguments, KernelParameterMetadata
from semantic_kernel.utils.list_handler import desync_list


def test_text_search():
Expand All @@ -33,7 +35,7 @@ class TestSearch(TextSearch):
async def search(self, **kwargs) -> KernelSearchResults[Any]:
"""Test search function."""

async def generator() -> str:
async def generator() -> AsyncGenerator[str, None]:
yield "test"

return KernelSearchResults(results=generator(), metadata=kwargs)
Expand All @@ -43,7 +45,7 @@ async def get_text_search_results(
) -> KernelSearchResults[TextSearchResult]:
"""Test get text search result function."""

async def generator() -> TextSearchResult:
async def generator() -> AsyncGenerator[TextSearchResult, None]:
yield TextSearchResult(value="test")

return KernelSearchResults(results=generator(), metadata=kwargs)
Expand All @@ -53,7 +55,7 @@ async def get_search_results(
) -> KernelSearchResults[Any]:
"""Test get search result function."""

async def generator() -> str:
async def generator() -> AsyncGenerator[str, None]:
yield "test"

return KernelSearchResults(results=generator(), metadata=kwargs)
Expand Down Expand Up @@ -190,12 +192,18 @@ async def test_create_kernel_function_inner_update_options(kernel: Kernel):
called = False
args = {}

def update_options(**kwargs: Any) -> tuple[str, SearchOptions]:
kwargs["options"].filter.equal_to("address/city", kwargs.get("city"))
def update_options(
query: str,
options: "SearchOptions",
parameters: list["KernelParameterMetadata"] | None = None,
**kwargs: Any,
) -> tuple[str, SearchOptions]:
options.filter.equal_to("address/city", kwargs.get("city", ""))
nonlocal called, args
called = True
args = kwargs
return kwargs["query"], kwargs["options"]
args = {"query": query, "options": options, "parameters": parameters}
args.update(kwargs)
return query, options

kernel_function = test_search._create_kernel_function(
search_function="search",
Expand Down Expand Up @@ -225,14 +233,29 @@ def update_options(**kwargs: Any) -> tuple[str, SearchOptions]:
assert "parameters" in args


def test_default_map_to_string():
async def test_default_map_to_string():
test_search = TestSearch()
assert test_search._default_map_to_string("test") == "test"
assert (await test_search._map_results(results=KernelSearchResults(results=desync_list(["test"])))) == ["test"]

class TestClass(BaseModel):
test: str

assert test_search._default_map_to_string(TestClass(test="test")) == '{"test":"test"}'
assert (
await test_search._map_results(results=KernelSearchResults(results=desync_list([TestClass(test="test")])))
) == ['{"test":"test"}']


async def test_custom_map_to_string():
test_search = TestSearch()

class TestClass(BaseModel):
test: str

assert (
await test_search._map_results(
results=KernelSearchResults(results=desync_list([TestClass(test="test")])), string_mapper=lambda x: x.test
)
) == ["test"]


def test_create_options():
Expand All @@ -253,6 +276,27 @@ def test_create_options_none():
assert new_options.top == 1


def test_create_options_vector_to_text():
options = VectorSearchOptions(top=2, skip=1, include_vectors=True)
options_class = TextSearchOptions
new_options = create_options(options_class, options, top=1)
assert new_options is not None
assert isinstance(new_options, options_class)
assert new_options.top == 1
assert getattr(new_options, "include_vectors", None) is None


def test_create_options_from_dict():
options = {"skip": 1}
options_class = TextSearchOptions
new_options = create_options(options_class, options, top=1) # type: ignore
assert new_options is not None
assert isinstance(new_options, options_class)
assert new_options.top == 1
# if a non SearchOptions object is passed in, it should be ignored
assert new_options.skip == 0


def test_default_options_update_function():
options = SearchOptions()
params = [
Expand All @@ -267,3 +311,36 @@ def test_default_options_update_function():
assert options.filter.filters[0].value == "test"
assert options.filter.filters[1].field_name == "test2"
assert options.filter.filters[1].value == "test2"


def test_public_create_functions_search():
test_search = TestSearch()
function = test_search.create_search()
assert function is not None
assert function.name == "search"
assert (
function.description == "Perform a search for content related to the specified query and return string results"
)
assert len(function.parameters) == 3


def test_public_create_functions_get_text_search_results():
test_search = TestSearch()
function = test_search.create_get_text_search_results()
assert function is not None
assert function.name == "search"
assert (
function.description == "Perform a search for content related to the specified query and return string results"
)
assert len(function.parameters) == 3


def test_public_create_functions_get_search_results():
test_search = TestSearch()
function = test_search.create_get_search_results()
assert function is not None
assert function.name == "search"
assert (
function.description == "Perform a search for content related to the specified query and return string results"
)
assert len(function.parameters) == 3
82 changes: 82 additions & 0 deletions python/tests/unit/data/test_vector_store_text_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from unittest.mock import patch

from pydantic import BaseModel
from pytest import fixture, raises

from semantic_kernel.connectors.ai.open_ai import AzureTextEmbedding
from semantic_kernel.data import VectorStoreTextSearch
from semantic_kernel.data.text_search.text_search_result import TextSearchResult
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
from semantic_kernel.exceptions import VectorStoreTextSearchValidationError
from semantic_kernel.utils.list_handler import desync_list


@fixture
Expand All @@ -28,6 +33,7 @@ async def test_from_vectorizable_text_search(vector_collection):
assert search is not None
assert text_search_result is not None
assert search_result is not None
assert vsts.options_class is VectorSearchOptions


async def test_from_vector_text_search(vector_collection):
Expand Down Expand Up @@ -67,3 +73,79 @@ def test_validation_no_embedder_for_vectorized_search(vector_collection):
def test_validation_no_collections():
with raises(VectorStoreTextSearchValidationError):
VectorStoreTextSearch()


async def test_get_results_as_string(vector_collection):
test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection)
results = [
res
async for res in test_search._get_results_as_strings(results=desync_list([VectorSearchResult(record="test")]))
]
assert results == ["test"]

class TestClass(BaseModel):
test: str

results = [
res
async for res in test_search._get_results_as_strings(
results=desync_list([VectorSearchResult(record=TestClass(test="test"))])
)
]

assert results == ['{"test":"test"}']

test_search = VectorStoreTextSearch.from_vector_text_search(
vector_text_search=vector_collection, string_mapper=lambda x: x.test
)

class TestClass(BaseModel):
test: str

results = [
res
async for res in test_search._get_results_as_strings(
results=desync_list([VectorSearchResult(record=TestClass(test="test"))])
)
]

assert results == ["test"]


async def test_get_results_as_test_search_result(vector_collection):
test_search = VectorStoreTextSearch.from_vector_text_search(vector_text_search=vector_collection)
results = [
res
async for res in test_search._get_results_as_text_search_result(
results=desync_list([VectorSearchResult(record="test")])
)
]
assert results == [TextSearchResult(value="test")]

class TestClass(BaseModel):
test: str

results = [
res
async for res in test_search._get_results_as_text_search_result(
results=desync_list([VectorSearchResult(record=TestClass(test="test"))])
)
]

assert results == [TextSearchResult(value='{"test":"test"}')]

test_search = VectorStoreTextSearch.from_vector_text_search(
vector_text_search=vector_collection, text_search_results_mapper=lambda x: TextSearchResult(value=x.test)
)

class TestClass(BaseModel):
test: str

results = [
res
async for res in test_search._get_results_as_text_search_result(
results=desync_list([VectorSearchResult(record=TestClass(test="test"))])
)
]

assert results == [TextSearchResult(value="test")]

0 comments on commit 9edfaa2

Please sign in to comment.