Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DuckDBAdapter: Batch OpenAI calls #62

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 103 additions & 32 deletions src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, Optional, Union

import llm
import duckdb
import numpy as np
import openai
Expand All @@ -33,6 +33,8 @@
IDS,
METADATAS,
MODEL_DIMENSIONS,
MODEL_MAP,
DEFAULT_MODEL,
MODELS,
OBJECT,
OPENAI_MODEL_DIMENSIONS,
Expand Down Expand Up @@ -174,7 +176,7 @@ def create_index(self, collection: str):
"""
self.conn.execute(create_index_sql)

def _embedding_function(self, texts: Union[str, List[str]], model: str = None) -> list:
def _embedding_function(self, texts: Union[str, List[str], List[List[str]]], model: str = None) -> list:
"""
Get the embeddings for the given texts using the specified model
:param texts: A single text or a list of texts to embed
Expand All @@ -192,12 +194,12 @@ def _embedding_function(self, texts: Union[str, List[str]], model: str = None) -
if model.startswith("openai:"):
self._initialize_openai_client()
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODELS:
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(
f"The model {openai_model} is not "
f"one of {MODELS}. Defaulting to {MODELS[1]}"
f"one of {[MODEL_MAP.keys()]}. Defaulting to {DEFAULT_MODEL}"
)
openai_model = MODELS[1]
openai_model = DEFAULT_MODEL

responses = [
self.openai_client.embeddings.create(input=text, model=openai_model)
Expand Down Expand Up @@ -320,33 +322,102 @@ def _process_objects(
cumulative_len = 0
sql_command = self._generate_sql_command(collection, method)
sql_command = sql_command.format(collection=collection)
for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
logger.info("Processing batch of objects in DuckDB process_objects ...")
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
cumulative_len += docs_len
if self._is_openai(collection) and cumulative_len > 3000000:
logger.warning(f"Cumulative length = {cumulative_len}, pausing ...")
time.sleep(60)
cumulative_len = 0
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]
embeddings = self._embedding_function(docs, cm.model)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(
f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}"
)
raise
finally:
self.create_index(collection)
if not self._is_openai(collection):
for next_objs in chunk(objs, batch_size):
next_objs = list(next_objs)
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]
embeddings = self._embedding_function(docs, cm.model)
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}")
raise
finally:
self.create_index(collection)
else:
if model.startswith("openai:"):
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(f"The model {openai_model} is not "
f"one of {MODEL_MAP.keys()}. Defaulting to {DEFAULT_MODEL}")
openai_model = DEFAULT_MODEL #ada 002
else:
logger.error(f"Something went wonky ## model: {model}")
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
for next_objs in chunk(objs, batch_size): # Existing chunking
next_objs = list(next_objs)
docs = [self._text(o, text_field) for o in next_objs]
docs_len = sum([len(d) for d in docs])
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]

tokenized_docs = [tokenizer.encode(doc) for doc in docs]
current_batch = []
current_token_count = 0
batch_embeddings = []

i = 0
while i < len(tokenized_docs):
doc_tokens = tokenized_docs[i]
# peek
if current_token_count + len(doc_tokens) <= 8192:
current_batch.append(doc_tokens)
current_token_count += len(doc_tokens)
i += 1
else:
if current_batch:
logger.info(f"Tokens: {current_token_count}")
texts = [tokenizer.decode(tokens) for tokens in current_batch]
short_name, _ = MODEL_MAP[openai_model]
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
logger.info(f"Number of Documents in batch: {len(embeddings)}")
batch_embeddings.extend(embeddings)

if len(doc_tokens) > 8192:
logger.warning(
f"Document with ID {ids[i]} exceeds the token limit alone and will be skipped.")
# try:
# embeddings = OpenAIEmbeddings(model=model, tiktoken_model_name=model).embed_query(texts,
# embeddings.average model)
# batch_embeddings.extend(embeddings)
# skipping
i += 1
continue
else:
current_batch = []
current_token_count = 0

if current_batch:
logger.info(f"Last batch, token count: {current_token_count}")
texts = [tokenizer.decode(tokens) for tokens in current_batch]
short_name, _ = MODEL_MAP[openai_model]
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
batch_embeddings.extend(embeddings)
logger.info(f"Trying to insert: {len(ids)} IDS, {len(metadatas)} METADATAS, {len(batch_embeddings)} EMBEDDINGS")
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(
sql_command, list(zip(ids, metadatas, batch_embeddings, docs, strict=False))
)
self.conn.execute("COMMIT;")
except Exception as e:
self.conn.execute("ROLLBACK;")
logger.error(
f"Transaction failed: {e}, default model: {self.default_model}, model used: {model}, len(embeddings): {len(embeddings[0])}")
raise
finally:
self.create_index(collection)

def remove_collection(self, collection: str = None, exists_ok=False, **kwargs):
"""
Expand Down
11 changes: 11 additions & 0 deletions src/curate_gpt/store/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@
"text-embedding-3-large": 3072,
}
MODELS = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]

MODEL_MAP = {
"text-embedding-ada-002": ("ada-002", 1536),
"text-embedding-3-small": ("3-small", 1536),
"text-embedding-3-large": ("3-large", 3072),
"text-embedding-3-small-512": ("3-small-512", 512),
"text-embedding-3-large-256": ("3-large-256", 256),
"text-embedding-3-large-1024": ("3-large-1024", 1024)
}

DEFAULT_MODEL = "text-embedding-ada-002"
Loading