Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Oct 12, 2024
1 parent 5f6b4e9 commit 78e0d9b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 51 deletions.
5 changes: 3 additions & 2 deletions examples/chat_with_X/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ async def upsert_documents(documents: list[Document]):

@flow(flow_run_name="{repo}")
async def ingest_repo(repo: str):
"""repo should be in the format 'owner/repo'"""
documents = await gather_documents(repo)
await upsert_documents(documents)

Expand Down Expand Up @@ -69,4 +68,6 @@ async def chat_with_repo(initial_message: str | None = None, clean_up: bool = Tr

if __name__ == "__main__":
warnings.filterwarnings("ignore", category=UserWarning)
run_coro_as_sync(chat_with_repo("lets chat about zzstoatzz/prefect-bot"))
run_coro_as_sync(
chat_with_repo("lets chat about zzstoatzz/prefect-bot - please ingest it")
)
23 changes: 15 additions & 8 deletions examples/refresh_chroma/refresh_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal

from bs4 import BeautifulSoup
from chromadb.api.models.Collection import Document as ChromaDocument
from prefect import flow, task
from prefect.tasks import task_input_hash

Expand Down Expand Up @@ -50,6 +51,18 @@ async def run_loader(loader: Loader) -> list[Document]:
return await loader.load()


@task
async def add_documents(
chroma: Chroma, documents: list[Document], mode: Literal["upsert", "reset"]
) -> list[ChromaDocument]:
if mode == "reset":
await chroma.reset_collection()
docs = await chroma.add(documents)
elif mode == "upsert":
docs = await chroma.upsert(documents)
return docs


@flow(name="Update Knowledge", log_prints=True)
async def refresh_chroma(
collection_name: str = "default",
Expand All @@ -68,13 +81,7 @@ async def refresh_chroma(
async with Chroma(
collection_name=collection_name, client_type=chroma_client_type
) as chroma:
if mode == "reset":
await chroma.reset_collection()
docs = await chroma.add(documents)
elif mode == "upsert":
docs = await task(chroma.upsert)(documents)
else:
raise ValueError(f"Unknown mode: {mode!r} (expected 'upsert' or 'reset')")
docs = await add_documents(chroma, documents, mode)

print(f"Added {len(docs)} documents to the {collection_name} collection.") # type: ignore

Expand All @@ -83,5 +90,5 @@ async def refresh_chroma(
import asyncio

asyncio.run(
refresh_chroma(collection_name="docs", chroma_client_type="cloud", mode="reset")
refresh_chroma(collection_name="test", chroma_client_type="cloud", mode="reset") # type: ignore
)
2 changes: 1 addition & 1 deletion examples/refresh_chroma/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
chromadb
git+https://github.com/prefecthq/prefect.git@main
prefect
trafilatura
12 changes: 3 additions & 9 deletions src/raggy/documents.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
import inspect
from functools import partial
from typing import Annotated, Iterable
from typing import Annotated

from jinja2 import Environment, Template
from pydantic import BaseModel, ConfigDict, Field, model_validator

from raggy.utilities.collections import distinct
from raggy.utilities.ids import generate_prefixed_uuid
from raggy.utilities.text import count_tokens, extract_keywords, hash_text, split_text

Expand Down Expand Up @@ -39,7 +38,7 @@ class Document(BaseModel):
keywords: list[str] = Field(default_factory=list)

@model_validator(mode="after")
def validate_tokens(self):
def ensure_tokens(self):
if self.tokens is None:
self.tokens = count_tokens(self.text)
return self
Expand All @@ -51,7 +50,7 @@ def __hash__(self) -> int:

EXCERPT_TEMPLATE = jinja_env.from_string(
inspect.cleandoc(
"""The following is an excerpt from a document
"""This is an excerpt from a document
{% if document.metadata %}\n\n# Document metadata
{{ document.metadata }}
{% endif %}
Expand Down Expand Up @@ -126,8 +125,3 @@ async def _create_excerpt(
metadata=document.metadata if document.metadata else {},
tokens=count_tokens(excerpt_text),
)


def get_distinct_documents(documents: Iterable["Document"]) -> Iterable["Document"]:
"""Return a list of distinct documents."""
return distinct(documents, key=lambda doc: hash(doc.text))
20 changes: 20 additions & 0 deletions src/raggy/utilities/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
from typing import overload

from openai import APIConnectionError, AsyncOpenAI
from openai.types import CreateEmbeddingResponse
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed

import raggy


@overload
async def create_openai_embeddings(
input_: str,
timeout: int = 60,
model: str = raggy.settings.openai_embeddings_model,
) -> list[float]:
...


@overload
async def create_openai_embeddings(
input_: list[str],
timeout: int = 60,
model: str = raggy.settings.openai_embeddings_model,
) -> list[list[float]]:
...


@retry(
retry=retry_if_exception_type(APIConnectionError),
stop=stop_after_attempt(3),
Expand Down
79 changes: 48 additions & 31 deletions src/raggy/vectorstores/chroma.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import asyncio
import re
from typing import Iterable, Literal
from typing import Literal

from raggy.utilities.collections import distinct

try:
from chromadb import Client, CloudClient, HttpClient
from chromadb.api import ClientAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Include, QueryResult
from chromadb.api.models.Collection import Document as ChromaDocument
from chromadb.api.types import QueryResult
from chromadb.utils.batch_utils import create_batches
except ImportError:
raise ImportError(
"You must have `chromadb` installed to use the Chroma vector store. "
"Install it with `pip install 'raggy[chroma]'`."
)

from raggy.documents import Document, get_distinct_documents
from raggy.documents import Document as RaggyDocument
from raggy.settings import settings
from raggy.utilities.asyncutils import run_sync_in_worker_thread
from raggy.utilities.embeddings import create_openai_embeddings
Expand Down Expand Up @@ -69,7 +74,7 @@ async def delete(
self,
ids: list[str] | None = None,
where: dict | None = None,
where_document: Document | None = None,
where_document: ChromaDocument | None = None,
):
await run_sync_in_worker_thread(
self.collection.delete,
Expand All @@ -78,27 +83,36 @@ async def delete(
where_document=where_document,
)

async def add(self, documents: list[Document]) -> Iterable[Document]:
documents = list(get_distinct_documents(documents))
kwargs = dict(
ids=[document.id for document in documents],
documents=[document.text for document in documents],
metadatas=[
document.metadata.model_dump(exclude_none=True) or None
for document in documents
],
embeddings=await create_openai_embeddings(
[document.text for document in documents]
),
)
async def add(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
unique_documents = list(distinct(documents, key=lambda doc: doc.text))

await run_sync_in_worker_thread(self.collection.add, **kwargs)
ids = [doc.id for doc in unique_documents]
texts = [doc.text for doc in unique_documents]
metadatas = [
doc.metadata.model_dump(exclude_none=True) for doc in unique_documents
]

get_result = await run_sync_in_worker_thread(
self.collection.get, ids=kwargs["ids"]
embeddings = await create_openai_embeddings(texts)

data = {
"ids": ids,
"documents": texts,
"metadatas": metadatas,
"embeddings": embeddings,
}

batched_data: list[tuple] = create_batches(
get_client(self.client_type),
**data,
)

await asyncio.gather(
*(asyncio.to_thread(self.collection.add, *batch) for batch in batched_data)
)

return get_result.get("documents")
get_result = await asyncio.to_thread(self.collection.get, ids=ids)

return get_result.get("documents") or []

async def query(
self,
Expand All @@ -107,7 +121,7 @@ async def query(
n_results: int = 10,
where: dict | None = None,
where_document: dict | None = None,
include: "Include" = ["metadatas"],
include: list[str] = ["metadatas"],
**kwargs,
) -> "QueryResult":
return await run_sync_in_worker_thread(
Expand All @@ -124,8 +138,8 @@ async def query(
async def count(self) -> int:
return await run_sync_in_worker_thread(self.collection.count)

async def upsert(self, documents: list[Document]):
documents = list(get_distinct_documents(documents))
async def upsert(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
documents = list(distinct(documents, key=lambda doc: hash(doc.text)))
kwargs = dict(
ids=[document.id for document in documents],
documents=[document.text for document in documents],
Expand All @@ -143,7 +157,7 @@ async def upsert(self, documents: list[Document]):
self.collection.get, ids=kwargs["ids"]
)

return get_result.get("documents")
return get_result.get("documents") or []

async def reset_collection(self):
client = get_client(self.client_type)
Expand All @@ -160,7 +174,7 @@ async def reset_collection(self):

def ok(self) -> bool:
try:
version = self.client.get_version()
version = get_client(self.client_type).get_version()
except Exception as e:
self.logger.error_kv("Connection error", f"Cannot connect to Chroma: {e}")
if re.match(r"^\d+\.\d+\.\d+$", version):
Expand All @@ -177,6 +191,7 @@ async def query_collection(
where: dict | None = None,
where_document: dict | None = None,
max_tokens: int = 500,
client_type: ChromaClientType = "base",
) -> str:
"""Query a Chroma collection.
Expand All @@ -194,7 +209,9 @@ async def query_collection(
print(await query_collection("How to create a flow in Prefect?"))
```
"""
async with Chroma(collection_name=collection_name) as chroma:
async with Chroma(
collection_name=collection_name, client_type=client_type
) as chroma:
query_embedding = query_embedding or await create_openai_embeddings(query_text)

query_result = await chroma.query(
Expand All @@ -205,8 +222,8 @@ async def query_collection(
include=["documents"],
)

concatenated_result = "\n".join(
doc for doc in query_result.get("documents", [])
)
assert (
result := query_result.get("documents")
) is not None, "No documents found"

return slice_tokens(concatenated_result, max_tokens)
return slice_tokens("\n".join(result[0]), max_tokens)

0 comments on commit 78e0d9b

Please sign in to comment.