Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved --path handling in CLI #64 fix ef and dim selection #65 #66

Merged
merged 7 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions src/curate_gpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,6 @@ def index(
This will index the DataElementQueryResults from each file.

"""
if os.path.isdir(path) & (database_type == "duckdb"):
default_duckdb_path(path)
db = get_store(database_type, path)
db.text_lookup = text_field
if glob:
Expand Down Expand Up @@ -2301,10 +2299,6 @@ def index_ontology_command(
"""

s = time.time()

if os.path.isdir(path) & (database_type == "duckdb"):
default_duckdb_path(path)

oak_adapter = get_adapter(ont)
view = OntologyWrapper(oak_adapter=oak_adapter)
if branches:
Expand Down Expand Up @@ -2510,9 +2504,6 @@ def view_index(
curategpt -v index -p stagedb/hpoa.duckdb --batch-size 10 -V hpoa -c hpoa -m openai: -D duckdb

"""
if os.path.isdir(path) & (database_type == "duckdb"):
default_duckdb_path(path)

if init_with:
for k, v in yaml.safe_load(init_with).items():
kwargs[k] = v
Expand Down Expand Up @@ -2616,11 +2607,5 @@ def pubmed_ask(query, path, model, show_references, database_type, **kwargs):
print(ref_text)


def default_duckdb_path(path):
path = os.path.join(path, "duck.duckdb")
click.echo("You have to provide a path to a file : Defaulting to " + path)
return path


if __name__ == "__main__":
main()
2 changes: 0 additions & 2 deletions src/curate_gpt/store/chromadb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def _embedding_function(self, model: str = None) -> EmbeddingFunction:
:param model:
:return:
"""
logger.info(f"Getting embedding function for {model}")
if model is None:
raise ValueError("Model must be specified")
if model.startswith("openai:"):
Expand Down Expand Up @@ -272,7 +271,6 @@ def collection_metadata(
Parameters
----------
"""
logger.info(f"Getting metadata for {collection_name}")
collection_name = self._get_collection(collection_name)
try:
logger.info(f"Getting collection object {collection_name}")
Expand Down
32 changes: 17 additions & 15 deletions src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
EMBEDDINGS,
IDS,
METADATAS,
MODEL_DIMENSIONS,
MODEL_MAP,
DEFAULT_OPENAI_MODEL,
DEFAULT_MODEL,
MODELS,
OBJECT,
OPENAI_MODEL_DIMENSIONS,
PROJECTION,
QUERY,
SEARCH_RESULT,
Expand Down Expand Up @@ -65,8 +63,12 @@ class DuckDBAdapter(DBAdapter):

def __post_init__(self):
if not self.path:
self.path = "./duck.db"
os.makedirs(os.path.dirname(self.path), exist_ok=True)
self.path = "./db/db_file.duckdb"
if os.path.isdir(self.path):
self.path = os.path.join("./db", self.path, "db_file.duckdb")
os.makedirs(os.path.dirname(self.path), exist_ok=True)
logger.info(f"Path {self.path} is a directory. Using {self.path} as the database path\n\
as duckdb needs a file path")
self.ef_construction = self._validate_ef_construction(self.ef_construction)
self.ef_search = self._validate_ef_search(self.ef_search)
self.M = self._validate_m(self.M)
Expand Down Expand Up @@ -197,9 +199,10 @@ def _embedding_function(self, texts: Union[str, List[str], List[List[str]]], mod
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(
f"The model {openai_model} is not "
f"one of {[MODEL_MAP.keys()]}. Defaulting to {DEFAULT_MODEL}"
f"one of {[MODEL_MAP.keys()]}. Defaulting to {DEFAULT_OPENAI_MODEL}"
)
openai_model = DEFAULT_MODEL
openai_model = DEFAULT_OPENAI_MODEL


responses = [
self.openai_client.embeddings.create(input=text, model=openai_model)
Expand Down Expand Up @@ -347,8 +350,8 @@ def _process_objects(
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(f"The model {openai_model} is not "
f"one of {MODEL_MAP.keys()}. Defaulting to {DEFAULT_MODEL}")
openai_model = DEFAULT_MODEL #ada 002
f"one of {MODEL_MAP.keys()}. Defaulting to {DEFAULT_OPENAI_MODEL}")
openai_model = DEFAULT_OPENAI_MODEL #ada 002
else:
logger.error(f"Something went wonky ## model: {model}")
from transformers import GPT2Tokenizer
Expand Down Expand Up @@ -1018,16 +1021,15 @@ def _parse_where_clause(where: Dict[str, Any]) -> str:

def _get_embedding_dimension(self, model_name: str) -> int:
if model_name is None:
return MODEL_DIMENSIONS[self.default_model]
return DEFAULT_MODEL[self.default_model]
if isinstance(model_name, str):
logger.info("somehow here")
if model_name.startswith("openai:"):
model_key = model_name.split("openai:", 1)[1]
return OPENAI_MODEL_DIMENSIONS.get(
model_key, OPENAI_MODEL_DIMENSIONS["text-embedding-3-small"]
)
model_info = MODEL_MAP.get(model_key, DEFAULT_OPENAI_MODEL)
return MODEL_MAP[model_info][1]
else:
if model_name in MODEL_DIMENSIONS:
return MODEL_DIMENSIONS[model_name]
return MODEL_MAP[DEFAULT_OPENAI_MODEL][1]

@staticmethod
def _validate_ef_construction(value: int) -> int:
Expand Down
13 changes: 3 additions & 10 deletions src/curate_gpt/store/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@
DOCUMENTS = "documents"
DISTANCES = "distances"

MODEL_DIMENSIONS = {"all-MiniLM-L6-v2": 384}
OPENAI_MODEL_DIMENSIONS = {
"text-embedding-ada-002": 1536,
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
}
MODELS = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]

MODEL_MAP = {
"text-embedding-ada-002": ("ada-002", 1536),
"text-embedding-3-small": ("3-small", 1536),
"text-embedding-3-large": ("3-large", 3072),
"text-embedding-3-small-512": ("3-small-512", 512),
"text-embedding-3-large-256": ("3-large-256", 256),
"text-embedding-3-large-1024": ("3-large-1024", 1024)
}
}

DEFAULT_MODEL = "text-embedding-ada-002"
DEFAULT_OPENAI_MODEL = "text-embedding-ada-002"
DEFAULT_MODEL= {"all-MiniLM-L6-v2":384}
Loading