From 90d99f787e1ffd91ec4fe6d03062068e474d4044 Mon Sep 17 00:00:00 2001 From: chasingegg Date: Thu, 9 Jan 2025 17:25:33 +0800 Subject: [PATCH] enhance: support recalls for milvus_client Signed-off-by: chasingegg --- pymilvus/client/abstract.py | 8 +++++++- pymilvus/client/types.py | 18 +++++++++++++++--- pymilvus/milvus_client/async_milvus_client.py | 2 +- pymilvus/milvus_client/milvus_client.py | 2 +- pyproject.toml | 2 +- 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index b1eafdc1d..1349e82d0 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -541,7 +541,13 @@ def __iter__(self) -> SequenceIterator: def __str__(self) -> str: """Only print at most 10 query results""" reminder = f" ... and {len(self) - 10} results remaining" if len(self) > 10 else "" - recall_msg = f", recalls: {list(map(str, self.recalls))}" if len(self.recalls) > 0 else "" + recall_msg = ( + f", recalls: {list(map(str, self.recalls[:10]))}" if len(self.recalls) > 0 else "" + ) + ( + f" ... and {len(self.recalls) - 10} recall results remaining" + if len(self.recalls) > 10 + else "" + ) cost_msg = f", cost: {self.cost}" if self.cost else "" return f"data: {list(map(str, self[:10]))}{reminder}{recall_msg}{cost_msg}" diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 3654a6e8b..e8b696c32 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -995,15 +995,27 @@ class ExtraList(list): ExtraList([1, 2, 3], extra={"total": 3}) """ - def __init__(self, *args, extra: Optional[Dict] = None, **kwargs) -> None: + def __init__( + self, *args, extra: Optional[Dict] = None, recalls: Optional[List[float]] = None, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.extra = OmitZeroDict(extra or {}) + self.recalls = recalls def __str__(self) -> str: """Only print at most 10 query results""" + recall_msg = ( + f", recalls: {list(map(str, self.recalls[:10]))}" + if self.recalls is not None and len(self.recalls) > 0 + else "" + ) + ( + f" ... and {len(self.recalls) - 10} recall results remaining" + if self.recalls is not None and len(self.recalls) > 10 + else "" + ) if self.extra and self.extra.omit_zero_len() != 0: - return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}, extra_info: {self.extra}" - return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}" + return f"data: {list(map(str, self[:10]))}{' ...' if len(self) > 10 else ''}{recall_msg}, extra_info: {self.extra}" + return f"data: {list(map(str, self[:10]))}{' ...' if len(self) > 10 else ''}{recall_msg}" __repr__ = __str__ diff --git a/pymilvus/milvus_client/async_milvus_client.py b/pymilvus/milvus_client/async_milvus_client.py index 2544caeb0..f95ade792 100644 --- a/pymilvus/milvus_client/async_milvus_client.py +++ b/pymilvus/milvus_client/async_milvus_client.py @@ -400,7 +400,7 @@ async def search( query_result.append(hit.to_dict()) ret.append(query_result) - return ExtraList(ret, extra=construct_cost_extra(res.cost)) + return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls) async def query( self, diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index 4a6b3f2f8..833f99670 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -419,7 +419,7 @@ def search( query_result.append(hit.to_dict()) ret.append(query_result) - return ExtraList(ret, extra=construct_cost_extra(res.cost)) + return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls) def query( self, diff --git a/pyproject.toml b/pyproject.toml index 3164ae034..1d3ea9877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,4 +199,4 @@ builtins-ignorelist = [ "dict", # TODO "filter", ] - +builtins-allowed-modules = ["types"]