Skip to content

Commit

Permalink
PR: search with image and added more input sources
Browse files Browse the repository at this point in the history
feat: search with image and added more input sources
  • Loading branch information
NotShrirang authored Jan 3, 2025
2 parents ce7ccee + 740610e commit 5c527a0
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 51 deletions.
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ Experience the project in action:

| ![Screenshot 2025-01-01 184852](https://github.com/user-attachments/assets/ad79d0f0-d200-4a82-8c2f-0890a9fe8189) | ![Screenshot 2025-01-01 222334](https://github.com/user-attachments/assets/7307857d-a41f-4f60-8808-00d6db6e8e3e) |
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
| Data Upload Page | Data Search / Retrieval |
| | |
| ![Screenshot 2025-01-01 222412](https://github.com/user-attachments/assets/e38273f4-426b-444d-80f0-501fa9563779) | ![Screenshot 2025-01-01 223948](https://github.com/user-attachments/assets/21724a92-ef79-44ae-83e6-25f8de29c45a)
| Data Annotation Page | CLIP Fine-Tuning |

| Data Upload Page | Data Search / Retrieval |
| | |
| ![Screenshot 2025-01-01 222412](https://github.com/user-attachments/assets/e38273f4-426b-444d-80f0-501fa9563779) | ![Screenshot 2025-01-01 223948](https://github.com/user-attachments/assets/21724a92-ef79-44ae-83e6-25f8de29c45a) |
| Data Annotation Page | CLIP Fine-Tuning |

---

Expand All @@ -37,9 +36,9 @@ Experience the project in action:
- 📤 **Upload Options**: Allows users to upload images and PDFs for AI-powered processing and retrieval
- 🧠 **Embedding-Based Search**: Uses OpenAI's CLIP model to align text and image embeddings in a shared latent space
- 🔍 **Augmented Text Generation**: Enhances text results using LLMs for contextually rich outputs
- 🏷️ Image Annotation: Enables users to annotate uploaded images through an intuitive interface
- 🎯 CLIP Fine-Tuning: Supports custom model training with configurable parameters including test dataset split size, learning rate, optimizer, and weight decay
- 🔨 Fine-Tuned Model Integration: Seamlessly load and utilize fine-tuned CLIP models for enhanced search and retrieval
- 🏷️ **Image Annotation**: Enables users to annotate uploaded images through an intuitive interface
- 🎯 **CLIP Fine-Tuning**: Supports custom model training with configurable parameters including test dataset split size, learning rate, optimizer, and weight decay
- 🔨 **Fine-Tuned Model Integration**: Seamlessly load and utilize fine-tuned CLIP models for enhanced search and retrieval

---

Expand All @@ -57,11 +56,13 @@ Experience the project in action:
- The system performs a nearest neighbor search in the vector database to retrieve relevant text and images

3. **Response Generation**:

- For text results: Optionally refined or augmented using a language model
- For image results: Directly returned or enhanced with image captions
- For PDFs: Extracts text content and provides relevant sections

4. **Image Annotation**:

- Dedicated annotation page for managing uploaded images
- Support for creating and managing multiple datasets simultaneously
- Flexible annotation workflow for efficient data labeling
Expand Down
19 changes: 19 additions & 0 deletions data_search/adapter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import torch.nn as nn


def get_adapter_model(in_shape, out_shape):
model = nn.Sequential(
nn.Linear(in_shape, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, out_shape)
)
return model


def load_adapter_model():
model = get_adapter_model(512, 384)
model.load_state_dict(torch.load("./weights/adapter_model.pt", map_location=torch.device('cpu')))
return model
34 changes: 28 additions & 6 deletions data_search/data_search_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import streamlit as st
import sys
import torch
from vectordb import search_image_index, search_text_index
from vectordb import search_image_index, search_text_index, search_image_index_with_image, search_text_index_with_image
from utils import load_image_index, load_text_index, get_local_files
from data_search import adapter_utils

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

Expand All @@ -18,6 +19,11 @@ def load_finetuned_model(file_name):
model, preprocess = clip.load("ViT-B/32", device=device)
model.load_state_dict(torch.load(f"annotations/{file_name}/finetuned_model.pt", weights_only=True))
return model, preprocess

@st.cache_resource
def load_adapter():
adapter = adapter_utils.load_adapter_model()
return adapter

st.title("Data Search")

Expand Down Expand Up @@ -51,8 +57,13 @@ def load_finetuned_model(file_name):
else:
st.info("Using Default Model")

adapter = load_adapter()
adapter.to(device)

text_input = st.text_input("Search Database")
if st.button("Search", disabled=text_input.strip() == ""):
image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])

if st.button("Search", disabled=text_input.strip() == "" and image_input is None):
if os.path.exists("./vectorstore/image_index.index"):
image_index, image_data = load_image_index()
if os.path.exists("./vectorstore/text_index.index"):
Expand All @@ -64,10 +75,21 @@ def load_finetuned_model(file_name):
if not os.path.exists("./vectorstore/text_data.csv"):
st.warning("No Text Index Found. So not searching for text.")
text_index = None
if image_index is not None:
image_indices = search_image_index(text_input, image_index, clip_model, k=3)
if text_index is not None:
text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3)
if image_input:
image = Image.open(image_input)
image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = clip_model.encode_image(image)
adapted_text_embeddings = adapter(image_features)
if image_index is not None:
image_indices = search_image_index_with_image(image_features, image_index, clip_model, k=3)
if text_index is not None:
text_indices = search_text_index_with_image(adapted_text_embeddings, text_index, text_embedding_model, k=3)
else:
if image_index is not None:
image_indices = search_image_index(text_input, image_index, clip_model, k=3)
if text_index is not None:
text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3)
if not image_index and not text_index:
st.error("No Data Found! Please add data to the database.")
st.subheader("Top 3 Results")
Expand Down
44 changes: 10 additions & 34 deletions data_upload/data_upload_page.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,20 @@
import os
import streamlit as st
import sys
from vectordb import add_image_to_index, add_pdf_to_index

from data_upload.input_sources_utils import image_util, pdf_util, website_util

sys.path.append(os.path.dirname(os.path.abspath(__file__)))


def data_upload(clip_model, preprocess, text_embedding_model):
st.title("Data Upload")
upload_choice = st.selectbox(options=["Upload Image", "Upload PDF"], label="Select Upload Type")
upload_choice = st.selectbox(options=["Upload Image", "Add Image from URL / Link", "Upload PDF", "Website Link"], label="Select Upload Type")
if upload_choice == "Upload Image":
st.subheader("Add Image to Database")
images = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
if images:
cols = st.columns(5, vertical_alignment="center")
for count, image in enumerate(images[:4]):
with cols[count]:
st.image(image)
with cols[4]:
if len(images) > 5:
st.info(f"and more {len(images) - 5} images...")
st.info(f"Total {len(images)} files selected.")
if st.button("Add Images"):
progress_bar = st.progress(0)
for image in images:
add_image_to_index(image, clip_model, preprocess)
progress_bar.progress((images.index(image) + 1) / len(images), f"{images.index(image) + 1}/{len(images)}")
st.success("Images Added to Database")
else:
st.subheader("Add PDF to Database")
st.warning("Please note that the images in the PDF will also be extracted and added to the database.")
pdfs = st.file_uploader("Upload PDF", type=["pdf"], accept_multiple_files=True)
if pdfs:
st.info(f"Total {len(pdfs)} files selected.")
if st.button("Add PDF"):
for pdf in pdfs:
add_pdf_to_index(
pdf=pdf,
clip_model=clip_model,
preprocess=preprocess,
text_embedding_model=text_embedding_model,
)
st.success("PDF Added to Database")
image_util.upload_image(clip_model, preprocess)
elif upload_choice == "Add Image from URL / Link":
image_util.image_from_url(clip_model, preprocess)
elif upload_choice == "Upload PDF":
pdf_util.upload_pdf(clip_model, preprocess, text_embedding_model)
elif upload_choice == "Website Link":
website_util.data_from_website(clip_model, preprocess, text_embedding_model)
46 changes: 46 additions & 0 deletions data_upload/input_sources_utils/image_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import requests
import streamlit as st
import sys
from vectordb import add_image_to_index, add_pdf_to_index

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

def image_from_url(clip_model, preprocess):
st.title("Image from URL")
url = st.text_input("Enter Image URL")
correct_url = False
if url:
try:
st.image(url)
correct_url = True
except:
st.error("Invalid URL")
correct_url = False
if correct_url:
if st.button("Add Image"):
response = requests.get(url)
if response.status_code == 200:
add_image_to_index(response.content, clip_model, preprocess)
st.success("Image Added to Database")
else:
st.error("Invalid URL")

def upload_image(clip_model, preprocess):
st.subheader("Add Image to Database")
images = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
if images:
cols = st.columns(5, vertical_alignment="center")
for count, image in enumerate(images[:4]):
with cols[count]:
st.image(image)
with cols[4]:
if len(images) > 5:
st.info(f"and more {len(images) - 5} images...")
st.info(f"Total {len(images)} files selected.")
if st.button("Add Images"):
progress_bar = st.progress(0)
for image in images:
add_image_to_index(image, clip_model, preprocess)
progress_bar.progress((images.index(image) + 1) / len(images), f"{images.index(image) + 1}/{len(images)}")
st.success("Images Added to Database")
22 changes: 22 additions & 0 deletions data_upload/input_sources_utils/pdf_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import streamlit as st
import sys
from vectordb import add_image_to_index, add_pdf_to_index

sys.path.append(os.path.dirname(os.path.abspath(__file__)))

def upload_pdf(clip_model, preprocess, text_embedding_model):
st.subheader("Add PDF to Database")
st.warning("Please note that the images in the PDF will also be extracted and added to the database.")
pdfs = st.file_uploader("Upload PDF", type=["pdf"], accept_multiple_files=True)
if pdfs:
st.info(f"Total {len(pdfs)} files selected.")
if st.button("Add PDF"):
for pdf in pdfs:
add_pdf_to_index(
pdf=pdf,
clip_model=clip_model,
preprocess=preprocess,
text_embedding_model=text_embedding_model,
)
st.success("PDF Added to Database")
24 changes: 24 additions & 0 deletions data_upload/input_sources_utils/text_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import bs4
import os
from langchain_text_splitters import CharacterTextSplitter
import requests
import streamlit as st
import sys
from vectordb import add_image_to_index, add_pdf_to_index, update_vectordb

sys.path.append(os.path.dirname(os.path.abspath(__file__)))


def process_text(text: str, text_embedding_model):
text_splitter = CharacterTextSplitter(
separator="\n",
chunk_size=1200,
chunk_overlap=200,
length_function=len,
is_separator_regex=False,
)
chunks = text_splitter.split_text(text)
text_embeddings = text_embedding_model.encode(chunks)
for chunk, embedding in zip(chunks, text_embeddings):
index = update_vectordb(index_path="text_index.index", embedding=embedding, text_content=chunk)
return index
56 changes: 56 additions & 0 deletions data_upload/input_sources_utils/website_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import bs4
import os
import requests
import streamlit as st
import sys
from vectordb import add_image_to_index, add_pdf_to_index
from data_upload.input_sources_utils import text_util

sys.path.append(os.path.dirname(os.path.abspath(__file__)))


def data_from_website(clip_model, preprocess, text_embedding_model):
st.title("Data from Website")
website_url = st.text_input("Enter Website URL")
if website_url:
st.write(f"URL: {website_url}")
if st.button("Extract and Add Data"):
response = requests.get(website_url)
if response.status_code == 200:
st.success("Data Extracted Successfully")
else:
st.error("Invalid URL")

soup = bs4.BeautifulSoup(response.content, features="lxml")
images = soup.find_all("img")
image_dict = []
if not images:
st.info("No Images Found!")
else:
st.info(f"Found {len(images)} Images")
progress_bar = st.progress(0, f"Extracting Images... | 0/{len(images)}")
cols = st.columns(5)
for count, image in enumerate(images):
try:
image_url = image["src"].replace("//", "https://")
response = requests.get(image_url)
if response.status_code == 200:
image_dict.append({"src": image_url, "content": response.content})
add_image_to_index(response.content, clip_model, preprocess)
len_image_dict = len(image_dict)
if len_image_dict <= 4:
with cols[len_image_dict - 1]:
st.image(image_url, caption=image_url, use_container_width=True)
elif len_image_dict == 5:
with cols[4]:
st.info(f"and more {len(images) - 4} images...")
except:
pass
progress_bar.progress((count + 1) / len(images), f"Extracting Images... | {count + 1}/{len(images)}")
progress_bar.empty()

main_content = soup.find('main')
sample_text = main_content.text.strip().replace(r'\n', '')
with st.spinner("Processing Text..."):
text_util.process_text(main_content.text, text_embedding_model)
st.success("Data Added to Database")
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ annotated-types==0.7.0
anyio==4.7.0
async-timeout==4.0.3
attrs==24.3.0
beautifulsoup4==4.12.3
blinker==1.9.0
cachetools==5.5.0
certifi==2024.12.14
Expand Down Expand Up @@ -47,6 +48,7 @@ langchain-core==0.3.28
langchain-experimental==0.3.4
langchain-text-splitters==0.3.4
langsmith==0.1.147
lxml==5.1.0
markdown-it-py==3.0.0
MarkupSafe==3.0.2
marshmallow==3.23.2
Expand Down Expand Up @@ -91,6 +93,7 @@ sentence-transformers==3.3.1
six==1.17.0
smmap==5.0.1
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.36
streamlit==1.41.1
streamlit-option-menu==0.4.0
Expand Down
Loading

0 comments on commit 5c527a0

Please sign in to comment.