Skip to content

Commit

Permalink
VoyageAIRerank constructor fix (#17343)
Browse files Browse the repository at this point in the history
  • Loading branch information
Adversarian authored Dec 29, 2024
1 parent 0f02701 commit ada192f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

class VoyageAIRerank(BaseNodePostprocessor):
model: str = Field(description="Name of the model to use.")
top_n: int = Field(
description="The number of most relevant documents to return. If not specified, the reranking results of all documents will be returned."
top_n: Optional[int] = Field(
description="The number of most relevant documents to return. If not specified, the reranking results of all documents will be returned.",
default=None,
)
truncation: bool = Field(
description="Whether to truncate the input to satisfy the 'context length limit' on the query and the documents."
description="Whether to truncate the input to satisfy the 'context length limit' on the query and the documents.",
default=True,
)

_client: Any = PrivateAttr()
Expand All @@ -29,7 +31,7 @@ def __init__(
model: str,
api_key: Optional[str] = None,
top_n: Optional[int] = None,
truncation: Optional[bool] = None,
truncation: bool = True,
# deprecated
top_k: Optional[int] = None,
):
Expand All @@ -55,7 +57,10 @@ def _postprocess_nodes(
) -> List[NodeWithScore]:
dispatcher.event(
ReRankStartEvent(
query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model
query=query_bundle,
nodes=nodes,
top_n=self.top_n or len(nodes),
model_name=self.model,
)
)

Expand All @@ -70,7 +75,7 @@ def _postprocess_nodes(
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
EventPayload.TOP_K: self.top_n or len(nodes),
},
) as event:
texts = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-postprocessor-voyageai-rerank"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.3.1"
version = "0.3.2"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
import os

import pytest
from pytest_mock import MockerFixture
from voyageai.api_resources import VoyageResponse
from voyageai.object.reranking import RerankingObject

from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from llama_index.postprocessor.voyageai_rerank import VoyageAIRerank
from voyageai.object.reranking import RerankingObject
from pytest_mock import MockerFixture

rerank_sample_response = {
"object": "list",
Expand All @@ -22,7 +25,11 @@ def test_class():
assert BaseNodePostprocessor.__name__ in names_of_base_classes


def test_rerank(mocker: MockerFixture) -> None:
@pytest.mark.parametrize(
"constructor_kwargs",
[{"top_n": 2, "truncation": True}, {"top_n": None}],
)
def test_rerank(mocker: MockerFixture, constructor_kwargs: dict) -> None:
# Mocked client with the desired behavior for embed_documents
result_object = RerankingObject(
documents=["0", "1"],
Expand All @@ -39,7 +46,7 @@ def test_rerank(mocker: MockerFixture) -> None:
)

voyageai_rerank = VoyageAIRerank(
api_key="api_key", top_n=2, model="rerank-lite-1", truncation=True
api_key="api_key", model="rerank-lite-1", **constructor_kwargs
)
result = voyageai_rerank.postprocess_nodes(
nodes=[
Expand All @@ -51,3 +58,20 @@ def test_rerank(mocker: MockerFixture) -> None:
assert len(result) == 2
assert result[0].text == "text2"
assert result[1].text == "text1"


def test_rerank_construction_with_no_optional_kwargs():
os.environ["VOYAGE_API_KEY"] = "mock_api_key"
reranker = VoyageAIRerank(model="rerank-2")
assert reranker.truncation
assert reranker.top_n is None
assert reranker.model == "rerank-2"


def test_rerank_construction_with_optional_kwargs():
reranker = VoyageAIRerank(
model="rerank-2", api_key="mock_api_key", top_n=10, truncation=False
)
assert not reranker.truncation
assert reranker.top_n == 10
assert reranker.model == "rerank-2"

0 comments on commit ada192f

Please sign in to comment.