-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_collection.py
61 lines (52 loc) · 1.75 KB
/
generate_collection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import polars as pl
from tqdm import tqdm
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer
import logging
COLLECTION_NAME = "recipes"
logging.basicConfig(level=logging.INFO)
logging.info("loading dataset")
df = pl.read_parquet("tmp/recipes.parquet")
logging.info("setup clients")
encoder = SentenceTransformer("all-MiniLM-L6-v2")
qdrant = QdrantClient("http://localhost:6333")
if qdrant.collection_exists(COLLECTION_NAME):
logging.info("deleting existing collection")
qdrant.delete_collection(COLLECTION_NAME)
logging.info("creating collection")
qdrant.create_collection(
collection_name=COLLECTION_NAME,
vectors_config={
"title": models.VectorParams(
size=encoder.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
"ingredients": models.VectorParams(
size=encoder.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
},
)
logging.info("creating embeddings")
count = df.height
batch_titles = df.get_column("title").to_list()
batch_ingredients = df.get_column("ingredients").to_list()
title_embeddings = encoder.encode(batch_titles, batch_size=32, convert_to_tensor=True)
ingredient_embeddings = encoder.encode(
batch_ingredients, batch_size=32, convert_to_tensor=True
)
logging.info("uploading points")
qdrant.upload_points(
collection_name=COLLECTION_NAME,
points=[
models.PointStruct(
id=idx,
vector={
"title": title_embeddings[idx].tolist(),
"ingredients": ingredient_embeddings[idx].tolist(),
},
payload=df.row(idx, named=True),
)
for idx in tqdm(range(count))
],
)