Skip to content

Commit

Permalink
Merge pull request #35 from monarch-initiative/byob_embeddings
Browse files Browse the repository at this point in the history
BYOE bring your own embeddings
  • Loading branch information
cmungall authored May 31, 2024
2 parents 2bc7879 + 31cf525 commit 4c8f3b7
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions src/curate_gpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import json
import logging
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Union

import click
import pandas as pd
import requests
import yaml
from click_default_group import DefaultGroup
from linkml_runtime.dumpers import json_dumper
Expand Down Expand Up @@ -1683,6 +1685,76 @@ def _text_lookup(obj: Dict):
db.update_collection_metadata(collection, object_type="OntologyClass")


@main.group()
def embeddings():
"""Command group for handling embeddings."""
pass


def download_file(url):
"""
Helper function to download a file from a URL to a temporary file.
"""
local_filename = tempfile.mktemp()
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return local_filename


def load_embeddings(file_path, embedding_format=None):
"""
Helper function to load embeddings from a file. Supports Parquet and CSV formats.
"""
if file_path.endswith('.parquet') or file_path.endswith('.parquet.gz') or embedding_format == 'parquet':
df = pd.read_parquet(file_path)
elif file_path.endswith('.csv') or file_path.endswith('.csv.gz') or embedding_format == 'csv':
df = pd.read_csv(file_path)
else:
raise ValueError(
"Unsupported file type. Only Parquet and CSV files are supported.")
return df.to_dict(orient='records')


@embeddings.command(name="load")
@path_option
@collection_option
@model_option
@append_option
@click.option("--embedding-format", "-f",
type=click.Choice(['parquet', 'csv']), help="Format of the input file")
@click.argument("file_or_url")
def load_embeddings(path, collection, append, embedding_format, model, file_or_url):
"""
Index embeddings from a local file or URL into a ChromaDB collection.
"""
# Check if file_or_url is a URL
if file_or_url.startswith('http://') or file_or_url.startswith('https://'):
print(f"Downloading file from URL: {file_or_url}")
file_path = download_file(file_or_url)
else:
file_path = file_or_url

print(f"Loading embeddings from file: {file_path}")
embeddings = load_embeddings(file_path, embedding_format)

# Initialize the database adapter
db = ChromaDBAdapter(path)
if append:
if collection in db.list_collection_names():
print(
f"Collection '{collection}' already exists. Adding to the existing collection.")
else:
db.remove_collection(collection, exists_ok=True)

# Insert embeddings into the collection
db.insert(embeddings, model=model, collection=collection)
print(f"Successfully indexed embeddings into collection '{collection}'.")



@main.group()
def view():
"Virtual store/wrapper"
Expand Down

0 comments on commit 4c8f3b7

Please sign in to comment.