diff --git a/README.md b/README.md index 6bf96a9..2120159 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,10 @@ Experience the project in action: | ![Screenshot 2025-01-01 184852](https://github.com/user-attachments/assets/ad79d0f0-d200-4a82-8c2f-0890a9fe8189) | ![Screenshot 2025-01-01 222334](https://github.com/user-attachments/assets/7307857d-a41f-4f60-8808-00d6db6e8e3e) | | ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | -| Data Upload Page | Data Search / Retrieval | -| | | -| ![Screenshot 2025-01-01 222412](https://github.com/user-attachments/assets/e38273f4-426b-444d-80f0-501fa9563779) | ![Screenshot 2025-01-01 223948](https://github.com/user-attachments/assets/21724a92-ef79-44ae-83e6-25f8de29c45a) -| Data Annotation Page | CLIP Fine-Tuning | - +| Data Upload Page | Data Search / Retrieval | +| | | +| ![Screenshot 2025-01-01 222412](https://github.com/user-attachments/assets/e38273f4-426b-444d-80f0-501fa9563779) | ![Screenshot 2025-01-01 223948](https://github.com/user-attachments/assets/21724a92-ef79-44ae-83e6-25f8de29c45a) | +| Data Annotation Page | CLIP Fine-Tuning | --- @@ -37,9 +36,9 @@ Experience the project in action: - 📤 **Upload Options**: Allows users to upload images and PDFs for AI-powered processing and retrieval - 🧠 **Embedding-Based Search**: Uses OpenAI's CLIP model to align text and image embeddings in a shared latent space - 🔍 **Augmented Text Generation**: Enhances text results using LLMs for contextually rich outputs -- 🏷️ Image Annotation: Enables users to annotate uploaded images through an intuitive interface -- 🎯 CLIP Fine-Tuning: Supports custom model training with configurable parameters including test dataset split size, learning rate, optimizer, and weight decay -- 🔨 Fine-Tuned Model Integration: Seamlessly load and utilize fine-tuned CLIP models for enhanced search and retrieval +- 🏷️ **Image Annotation**: Enables users to annotate uploaded images through an intuitive interface +- 🎯 **CLIP Fine-Tuning**: Supports custom model training with configurable parameters including test dataset split size, learning rate, optimizer, and weight decay +- 🔨 **Fine-Tuned Model Integration**: Seamlessly load and utilize fine-tuned CLIP models for enhanced search and retrieval --- @@ -57,11 +56,13 @@ Experience the project in action: - The system performs a nearest neighbor search in the vector database to retrieve relevant text and images 3. **Response Generation**: + - For text results: Optionally refined or augmented using a language model - For image results: Directly returned or enhanced with image captions - For PDFs: Extracts text content and provides relevant sections 4. **Image Annotation**: + - Dedicated annotation page for managing uploaded images - Support for creating and managing multiple datasets simultaneously - Flexible annotation workflow for efficient data labeling diff --git a/data_search/adapter_utils.py b/data_search/adapter_utils.py new file mode 100644 index 0000000..10f5988 --- /dev/null +++ b/data_search/adapter_utils.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn + + +def get_adapter_model(in_shape, out_shape): + model = nn.Sequential( + nn.Linear(in_shape, 1024), + nn.ReLU(), + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Linear(1024, out_shape) + ) + return model + + +def load_adapter_model(): + model = get_adapter_model(512, 384) + model.load_state_dict(torch.load("./weights/adapter_model.pt", map_location=torch.device('cpu'))) + return model diff --git a/data_search/data_search_page.py b/data_search/data_search_page.py index a051545..badf0e2 100644 --- a/data_search/data_search_page.py +++ b/data_search/data_search_page.py @@ -5,8 +5,9 @@ import streamlit as st import sys import torch -from vectordb import search_image_index, search_text_index +from vectordb import search_image_index, search_text_index, search_image_index_with_image, search_text_index_with_image from utils import load_image_index, load_text_index, get_local_files +from data_search import adapter_utils sys.path.append(os.path.dirname(os.path.abspath(__file__))) @@ -18,6 +19,11 @@ def load_finetuned_model(file_name): model, preprocess = clip.load("ViT-B/32", device=device) model.load_state_dict(torch.load(f"annotations/{file_name}/finetuned_model.pt", weights_only=True)) return model, preprocess + + @st.cache_resource + def load_adapter(): + adapter = adapter_utils.load_adapter_model() + return adapter st.title("Data Search") @@ -51,8 +57,13 @@ def load_finetuned_model(file_name): else: st.info("Using Default Model") + adapter = load_adapter() + adapter.to(device) + text_input = st.text_input("Search Database") - if st.button("Search", disabled=text_input.strip() == ""): + image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) + + if st.button("Search", disabled=text_input.strip() == "" and image_input is None): if os.path.exists("./vectorstore/image_index.index"): image_index, image_data = load_image_index() if os.path.exists("./vectorstore/text_index.index"): @@ -64,10 +75,21 @@ def load_finetuned_model(file_name): if not os.path.exists("./vectorstore/text_data.csv"): st.warning("No Text Index Found. So not searching for text.") text_index = None - if image_index is not None: - image_indices = search_image_index(text_input, image_index, clip_model, k=3) - if text_index is not None: - text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3) + if image_input: + image = Image.open(image_input) + image = preprocess(image).unsqueeze(0).to(device) + with torch.no_grad(): + image_features = clip_model.encode_image(image) + adapted_text_embeddings = adapter(image_features) + if image_index is not None: + image_indices = search_image_index_with_image(image_features, image_index, clip_model, k=3) + if text_index is not None: + text_indices = search_text_index_with_image(adapted_text_embeddings, text_index, text_embedding_model, k=3) + else: + if image_index is not None: + image_indices = search_image_index(text_input, image_index, clip_model, k=3) + if text_index is not None: + text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3) if not image_index and not text_index: st.error("No Data Found! Please add data to the database.") st.subheader("Top 3 Results") diff --git a/data_upload/data_upload_page.py b/data_upload/data_upload_page.py index ff7f430..df1e473 100644 --- a/data_upload/data_upload_page.py +++ b/data_upload/data_upload_page.py @@ -1,44 +1,20 @@ import os import streamlit as st import sys -from vectordb import add_image_to_index, add_pdf_to_index + +from data_upload.input_sources_utils import image_util, pdf_util, website_util sys.path.append(os.path.dirname(os.path.abspath(__file__))) def data_upload(clip_model, preprocess, text_embedding_model): st.title("Data Upload") - upload_choice = st.selectbox(options=["Upload Image", "Upload PDF"], label="Select Upload Type") + upload_choice = st.selectbox(options=["Upload Image", "Add Image from URL / Link", "Upload PDF", "Website Link"], label="Select Upload Type") if upload_choice == "Upload Image": - st.subheader("Add Image to Database") - images = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], accept_multiple_files=True) - if images: - cols = st.columns(5, vertical_alignment="center") - for count, image in enumerate(images[:4]): - with cols[count]: - st.image(image) - with cols[4]: - if len(images) > 5: - st.info(f"and more {len(images) - 5} images...") - st.info(f"Total {len(images)} files selected.") - if st.button("Add Images"): - progress_bar = st.progress(0) - for image in images: - add_image_to_index(image, clip_model, preprocess) - progress_bar.progress((images.index(image) + 1) / len(images), f"{images.index(image) + 1}/{len(images)}") - st.success("Images Added to Database") - else: - st.subheader("Add PDF to Database") - st.warning("Please note that the images in the PDF will also be extracted and added to the database.") - pdfs = st.file_uploader("Upload PDF", type=["pdf"], accept_multiple_files=True) - if pdfs: - st.info(f"Total {len(pdfs)} files selected.") - if st.button("Add PDF"): - for pdf in pdfs: - add_pdf_to_index( - pdf=pdf, - clip_model=clip_model, - preprocess=preprocess, - text_embedding_model=text_embedding_model, - ) - st.success("PDF Added to Database") \ No newline at end of file + image_util.upload_image(clip_model, preprocess) + elif upload_choice == "Add Image from URL / Link": + image_util.image_from_url(clip_model, preprocess) + elif upload_choice == "Upload PDF": + pdf_util.upload_pdf(clip_model, preprocess, text_embedding_model) + elif upload_choice == "Website Link": + website_util.data_from_website(clip_model, preprocess, text_embedding_model) diff --git a/data_upload/input_sources_utils/image_util.py b/data_upload/input_sources_utils/image_util.py new file mode 100644 index 0000000..4da6d90 --- /dev/null +++ b/data_upload/input_sources_utils/image_util.py @@ -0,0 +1,46 @@ +import os +import requests +import streamlit as st +import sys +from vectordb import add_image_to_index, add_pdf_to_index + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +def image_from_url(clip_model, preprocess): + st.title("Image from URL") + url = st.text_input("Enter Image URL") + correct_url = False + if url: + try: + st.image(url) + correct_url = True + except: + st.error("Invalid URL") + correct_url = False + if correct_url: + if st.button("Add Image"): + response = requests.get(url) + if response.status_code == 200: + add_image_to_index(response.content, clip_model, preprocess) + st.success("Image Added to Database") + else: + st.error("Invalid URL") + +def upload_image(clip_model, preprocess): + st.subheader("Add Image to Database") + images = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], accept_multiple_files=True) + if images: + cols = st.columns(5, vertical_alignment="center") + for count, image in enumerate(images[:4]): + with cols[count]: + st.image(image) + with cols[4]: + if len(images) > 5: + st.info(f"and more {len(images) - 5} images...") + st.info(f"Total {len(images)} files selected.") + if st.button("Add Images"): + progress_bar = st.progress(0) + for image in images: + add_image_to_index(image, clip_model, preprocess) + progress_bar.progress((images.index(image) + 1) / len(images), f"{images.index(image) + 1}/{len(images)}") + st.success("Images Added to Database") \ No newline at end of file diff --git a/data_upload/input_sources_utils/pdf_util.py b/data_upload/input_sources_utils/pdf_util.py new file mode 100644 index 0000000..df77ff0 --- /dev/null +++ b/data_upload/input_sources_utils/pdf_util.py @@ -0,0 +1,22 @@ +import os +import streamlit as st +import sys +from vectordb import add_image_to_index, add_pdf_to_index + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +def upload_pdf(clip_model, preprocess, text_embedding_model): + st.subheader("Add PDF to Database") + st.warning("Please note that the images in the PDF will also be extracted and added to the database.") + pdfs = st.file_uploader("Upload PDF", type=["pdf"], accept_multiple_files=True) + if pdfs: + st.info(f"Total {len(pdfs)} files selected.") + if st.button("Add PDF"): + for pdf in pdfs: + add_pdf_to_index( + pdf=pdf, + clip_model=clip_model, + preprocess=preprocess, + text_embedding_model=text_embedding_model, + ) + st.success("PDF Added to Database") \ No newline at end of file diff --git a/data_upload/input_sources_utils/text_util.py b/data_upload/input_sources_utils/text_util.py new file mode 100644 index 0000000..b2e18fb --- /dev/null +++ b/data_upload/input_sources_utils/text_util.py @@ -0,0 +1,24 @@ +import bs4 +import os +from langchain_text_splitters import CharacterTextSplitter +import requests +import streamlit as st +import sys +from vectordb import add_image_to_index, add_pdf_to_index, update_vectordb + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + +def process_text(text: str, text_embedding_model): + text_splitter = CharacterTextSplitter( + separator="\n", + chunk_size=1200, + chunk_overlap=200, + length_function=len, + is_separator_regex=False, + ) + chunks = text_splitter.split_text(text) + text_embeddings = text_embedding_model.encode(chunks) + for chunk, embedding in zip(chunks, text_embeddings): + index = update_vectordb(index_path="text_index.index", embedding=embedding, text_content=chunk) + return index diff --git a/data_upload/input_sources_utils/website_util.py b/data_upload/input_sources_utils/website_util.py new file mode 100644 index 0000000..002c2c7 --- /dev/null +++ b/data_upload/input_sources_utils/website_util.py @@ -0,0 +1,56 @@ +import bs4 +import os +import requests +import streamlit as st +import sys +from vectordb import add_image_to_index, add_pdf_to_index +from data_upload.input_sources_utils import text_util + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + +def data_from_website(clip_model, preprocess, text_embedding_model): + st.title("Data from Website") + website_url = st.text_input("Enter Website URL") + if website_url: + st.write(f"URL: {website_url}") + if st.button("Extract and Add Data"): + response = requests.get(website_url) + if response.status_code == 200: + st.success("Data Extracted Successfully") + else: + st.error("Invalid URL") + + soup = bs4.BeautifulSoup(response.content, features="lxml") + images = soup.find_all("img") + image_dict = [] + if not images: + st.info("No Images Found!") + else: + st.info(f"Found {len(images)} Images") + progress_bar = st.progress(0, f"Extracting Images... | 0/{len(images)}") + cols = st.columns(5) + for count, image in enumerate(images): + try: + image_url = image["src"].replace("//", "https://") + response = requests.get(image_url) + if response.status_code == 200: + image_dict.append({"src": image_url, "content": response.content}) + add_image_to_index(response.content, clip_model, preprocess) + len_image_dict = len(image_dict) + if len_image_dict <= 4: + with cols[len_image_dict - 1]: + st.image(image_url, caption=image_url, use_container_width=True) + elif len_image_dict == 5: + with cols[4]: + st.info(f"and more {len(images) - 4} images...") + except: + pass + progress_bar.progress((count + 1) / len(images), f"Extracting Images... | {count + 1}/{len(images)}") + progress_bar.empty() + + main_content = soup.find('main') + sample_text = main_content.text.strip().replace(r'\n', '') + with st.spinner("Processing Text..."): + text_util.process_text(main_content.text, text_embedding_model) + st.success("Data Added to Database") diff --git a/requirements.txt b/requirements.txt index 5e7172e..4d76931 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ annotated-types==0.7.0 anyio==4.7.0 async-timeout==4.0.3 attrs==24.3.0 +beautifulsoup4==4.12.3 blinker==1.9.0 cachetools==5.5.0 certifi==2024.12.14 @@ -47,6 +48,7 @@ langchain-core==0.3.28 langchain-experimental==0.3.4 langchain-text-splitters==0.3.4 langsmith==0.1.147 +lxml==5.1.0 markdown-it-py==3.0.0 MarkupSafe==3.0.2 marshmallow==3.23.2 @@ -91,6 +93,7 @@ sentence-transformers==3.3.1 six==1.17.0 smmap==5.0.1 sniffio==1.3.1 +soupsieve==2.6 SQLAlchemy==2.0.36 streamlit==1.41.1 streamlit-option-menu==0.4.0 diff --git a/vectordb.py b/vectordb.py index 19fbd98..4147668 100644 --- a/vectordb.py +++ b/vectordb.py @@ -57,7 +57,10 @@ def update_vectordb(index_path: str, embedding: torch.Tensor, image_path: str = def add_image_to_index(image, model: clip.model.CLIP, preprocess): - image_name = image.name + if hasattr(image, "name"): + image_name = image.name + else: + image_name = f"{time.time()}.png" image_name = image_name.replace(" ", "_") os.makedirs("./images", exist_ok=True) os.makedirs("./vectorstore", exist_ok=True) @@ -65,7 +68,10 @@ def add_image_to_index(image, model: clip.model.CLIP, preprocess): try: f.write(image.read()) except: - image = io.BytesIO(image.data) + if hasattr(image, "data"): + image = io.BytesIO(image.data) + else: + image = io.BytesIO(image) f.write(image.read()) image = Image.open(f"./images/{image_name}") with torch.no_grad(): @@ -106,7 +112,7 @@ def add_pdf_to_index(pdf, clip_model: clip.model.CLIP, preprocess, text_embeddin pdf_texts.append(page_text) if page_text != "" or page_text.strip() != "": chunks = text_splitter.split_text(page_text) - text_embeddings: torch.Tensor = text_embedding_model.encode(chunks) + text_embeddings = text_embedding_model.encode(chunks) for i, chunk in enumerate(chunks): update_vectordb(index_path="text_index.index", embedding=text_embeddings[i], text_content=chunk) pdf_pages_data.append({f"page_number": page_num, "content": chunk, "type": "text"}) @@ -114,6 +120,16 @@ def add_pdf_to_index(pdf, clip_model: clip.model.CLIP, preprocess, text_embeddin progress_bar.progress(percent_complete, f"Processing Page {page_num + 1}/{len(pdf_reader.pages)}") return pdf_pages_data +def search_image_index_with_image(image_features, index: faiss.IndexFlatL2, clip_model: clip.model.CLIP, k: int = 3): + with torch.no_grad(): + distances, indices = index.search(image_features.cpu().numpy(), k) + return indices + + +def search_text_index_with_image(text_embeddings, index: faiss.IndexFlatL2, text_embedding_model: SentenceTransformer, k: int = 3): + distances, indices = index.search(text_embeddings, k) + return indices + def search_image_index(text_input: str, index: faiss.IndexFlatL2, clip_model: clip.model.CLIP, k: int = 3): with torch.no_grad(): diff --git a/weights/adapter_model.pt b/weights/adapter_model.pt new file mode 100644 index 0000000..9aed3b8 Binary files /dev/null and b/weights/adapter_model.pt differ