Skip to content

Commit

Permalink
enhance: support recalls for milvus_client
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <chao.gao@zilliz.com>
  • Loading branch information
chasingegg committed Jan 9, 2025
1 parent b37111b commit 595998b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
14 changes: 11 additions & 3 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,15 +995,23 @@ class ExtraList(list):
ExtraList([1, 2, 3], extra={"total": 3})
"""

def __init__(self, *args, extra: Optional[Dict] = None, **kwargs) -> None:
def __init__(
self, *args, extra: Optional[Dict] = None, recalls: Optional[List[float]] = None, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.extra = OmitZeroDict(extra or {})
self.recalls = recalls

def __str__(self) -> str:
"""Only print at most 10 query results"""
recall_msg = (
f", recalls: {list(map(str, self.recalls))}"
if self.recalls is not None and len(self.recalls) > 0
else ""
)
if self.extra and self.extra.omit_zero_len() != 0:
return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}, extra_info: {self.extra}"
return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}"
return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}{recall_msg}, extra_info: {self.extra}"
return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}{recall_msg}"

__repr__ = __str__

Expand Down
2 changes: 1 addition & 1 deletion pymilvus/milvus_client/async_milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ async def search(
query_result.append(hit.to_dict())
ret.append(query_result)

return ExtraList(ret, extra=construct_cost_extra(res.cost))
return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls)

async def query(
self,
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def search(
query_result.append(hit.to_dict())
ret.append(query_result)

return ExtraList(ret, extra=construct_cost_extra(res.cost))
return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls)

def query(
self,
Expand Down

0 comments on commit 595998b

Please sign in to comment.