Skip to content

Commit

Permalink
refactor: add batch_size to FAISS __init__ (#6401)
Browse files Browse the repository at this point in the history
* refactor: add batch_size to FAISS __init__

* refactor: add batch_size to FAISS __init__

* add release note to refactor: add batch_size to FAISS __init__

* fix release note

* add batch_size to docstrings

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
  • Loading branch information
pandasar13 and anakin87 authored Nov 23, 2023
1 parent 4ec6a60 commit edb40b6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
21 changes: 16 additions & 5 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
ef_search: int = 20,
ef_construction: int = 80,
validate_index_sync: bool = True,
batch_size: int = 10_000,
):
"""
:param sql_url: SQL connection URL for the database. The default value is "sqlite:///faiss_document_store.db"`. It defaults to a local, file-based SQLite DB. For large scale deployment, we recommend Postgres.
Expand Down Expand Up @@ -103,6 +104,8 @@ def __init__(
:param ef_search: Used only if `index_factory == "HNSW"`.
:param ef_construction: Used only if `index_factory == "HNSW"`.
:param validate_index_sync: Checks if the document count equals the embedding count at initialization time.
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
memory issues, decrease the batch_size.
"""
faiss_import.check()
# special case if we want to load an existing index from disk
Expand Down Expand Up @@ -152,6 +155,7 @@ def __init__(

self.return_embedding = return_embedding
self.embedding_field = embedding_field
self.batch_size = batch_size

self.progress_bar = progress_bar

Expand Down Expand Up @@ -216,7 +220,7 @@ def write_documents(
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> None:
Expand All @@ -240,6 +244,8 @@ def write_documents(
raise NotImplementedError("FAISSDocumentStore does not support headers.")

index = index or self.index
batch_size = batch_size or self.batch_size

duplicate_documents = duplicate_documents or self.duplicate_documents
assert (
duplicate_documents in self.duplicate_documents_options
Expand Down Expand Up @@ -324,7 +330,7 @@ def update_embeddings(
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[FilterType] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
Expand All @@ -342,6 +348,7 @@ def update_embeddings(
:return: None
"""
index = index or self.index
batch_size = batch_size or self.batch_size

if update_existing_embeddings is True:
if filters is None:
Expand Down Expand Up @@ -404,9 +411,10 @@ def get_all_documents(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
batch_size = batch_size or self.batch_size
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")

Expand All @@ -421,7 +429,7 @@ def get_all_documents_generator(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> Generator[Document, None, None]:
"""
Expand All @@ -440,6 +448,7 @@ def get_all_documents_generator(
raise NotImplementedError("FAISSDocumentStore does not support headers.")

index = index or self.index
batch_size = batch_size or self.batch_size
documents = super(FAISSDocumentStore, self).get_all_documents_generator(
index=index, filters=filters, batch_size=batch_size, return_embedding=False
)
Expand All @@ -455,13 +464,15 @@ def get_documents_by_id(
self,
ids: List[str],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
if headers:
raise NotImplementedError("FAISSDocumentStore does not support headers.")

index = index or self.index
batch_size = batch_size or self.batch_size

documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index, batch_size=batch_size)
if self.return_embedding:
for doc in documents:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add batch_size to the __init__ method of FAISS Document Store. This works as the default value for all methods of
FAISS Document Store that support batch_size.

0 comments on commit edb40b6

Please sign in to comment.