Skip to content

Commit

Permalink
tests for MongoDB change
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbeardey committed Dec 18, 2024
1 parent 26d2711 commit 3ca0a7a
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions test/agentchat/contrib/vectordb/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from autogen.agentchat.contrib.vectordb.base import Document
from autogen.agentchat.contrib.vectordb.mongodb import MongoDocument

try:
import pymongo
Expand Down Expand Up @@ -96,14 +97,22 @@ def db():
_empty_collections_and_delete_indexes(database)


def generate_embeddings(n=384):
return [random.random() for _ in range(n)]


@pytest.fixture
def example_documents() -> List[Document]:
"""Note mix of integers and strings as ids"""
"""Note mix of integers and strings as ids, MongoDocuments added for testing"""
return [
Document(id=1, content="Dogs are tough.", metadata={"a": 1}),
Document(id=2, content="Cats have fluff.", metadata={"b": 1}),
Document(id="1", content="What is a sandwich?", metadata={"c": 1}),
Document(id="2", content="A sandwich makes a great lunch.", metadata={"d": 1, "e": 2}),
MongoDocument(content="Stars are big.", metadata={"a": 1}),
MongoDocument(content="Atoms are small.", metadata={"b": 1}, embedding=generate_embeddings()),
MongoDocument(id="123", content="I hate grass", metadata={"c": 1}),
MongoDocument(id="321", content="I love sand", metadata={"d": 1, "e": 2}, embedding=generate_embeddings()),
]


Expand Down Expand Up @@ -207,25 +216,32 @@ def test_insert_docs(db, collection_name, example_documents):
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
# Check ids
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
assert {doc["_id"] for doc in found} == {1, "1", 2, "2", found[4]["_id"], found[5]["_id"], "123", "321"}
# Check embedding lengths
assert len(found[0]["embedding"]) == 384

db.delete_collection(collection_name)
collection = db.create_collection(collection_name)
example_documents[0].embedding = [random.random() for _ in range(10)]
# Ensuring different size embeddings are not inserted
with pytest.raises(AssertionError, match=r"Embedding Vectors are not all equal in length. Sizes:"):
db.insert_docs(example_documents, collection_name=collection_name, upsert=False)


def test_update_docs(db_with_indexed_clxn, example_documents):
db, collection = db_with_indexed_clxn
# Use update_docs to insert new documents
db.update_docs(example_documents, collection.name, upsert=True)
# Test that no changes were made to example_documents
assert set(example_documents[0].keys()) == {"id", "content", "metadata"}
assert collection.count_documents({}) == len(example_documents)
assert collection.count_documents({}) == len([doc for doc in example_documents if doc.get("id") is not None])
found = list(collection.find({}))
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
assert all([isinstance(doc["embedding"][0], float) for doc in found])
assert all([len(doc["embedding"]) == db.dimensions for doc in found])
# Check ids
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
assert {doc["_id"] for doc in found} == {1, "1", 2, "2", "123", "321"}

# Update an *existing* Document
updated_doc = Document(id=1, content="Cats are tough.", metadata={"a": 10})
Expand Down Expand Up @@ -254,7 +270,8 @@ def test_delete_docs(db_with_indexed_clxn, example_documents):
# Delete the 1s
db.delete_docs(ids=[1, "1"], collection_name=clxn.name)
# Confirm just the 2s remain
assert {2, "2"} == {doc["_id"] for doc in clxn.find({})}
result_set = {doc["_id"] for doc in clxn.find({})}
assert 2 in result_set and "2" in result_set


def test_get_docs_by_ids(db_with_indexed_clxn, example_documents):
Expand Down Expand Up @@ -359,8 +376,8 @@ def results_ready():

assert len(results) == len(queries)
assert all([len(res) == n_results for res in results])
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
assert {doc[0]["id"] for doc in results[1]} == {"1", "2"}
assert {1, 2} <= {doc[0]["id"] for doc in results[0]}
assert {"1", "2"} <= {doc[0]["id"] for doc in results[1]}


def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents):
Expand Down Expand Up @@ -397,6 +414,6 @@ def test_wait_until_document_ready(collection_name, example_documents):
wait_until_document_ready=TIMEOUT,
)
vectorstore.insert_docs(example_documents)
assert vectorstore.retrieve_docs(queries=["Cats"], n_results=4)
assert vectorstore.retrieve_docs(queries=["Cats"], n_results=8)
finally:
_empty_collections_and_delete_indexes(database, [collection_name])

0 comments on commit 3ca0a7a

Please sign in to comment.