-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR: Search with image and added more input sources
Search with image and added more input sources
- Loading branch information
Showing
11 changed files
with
236 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def get_adapter_model(in_shape, out_shape): | ||
model = nn.Sequential( | ||
nn.Linear(in_shape, 1024), | ||
nn.ReLU(), | ||
nn.Linear(1024, 1024), | ||
nn.ReLU(), | ||
nn.Linear(1024, out_shape) | ||
) | ||
return model | ||
|
||
|
||
def load_adapter_model(): | ||
model = get_adapter_model(512, 384) | ||
model.load_state_dict(torch.load("./weights/adapter_model.pt", map_location=torch.device('cpu'))) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,20 @@ | ||
import os | ||
import streamlit as st | ||
import sys | ||
from vectordb import add_image_to_index, add_pdf_to_index | ||
|
||
from data_upload.input_sources_utils import image_util, pdf_util, website_util | ||
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|
||
|
||
def data_upload(clip_model, preprocess, text_embedding_model): | ||
st.title("Data Upload") | ||
upload_choice = st.selectbox(options=["Upload Image", "Upload PDF"], label="Select Upload Type") | ||
upload_choice = st.selectbox(options=["Upload Image", "Add Image from URL / Link", "Upload PDF", "Website Link"], label="Select Upload Type") | ||
if upload_choice == "Upload Image": | ||
st.subheader("Add Image to Database") | ||
images = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | ||
if images: | ||
cols = st.columns(5, vertical_alignment="center") | ||
for count, image in enumerate(images[:4]): | ||
with cols[count]: | ||
st.image(image) | ||
with cols[4]: | ||
if len(images) > 5: | ||
st.info(f"and more {len(images) - 5} images...") | ||
st.info(f"Total {len(images)} files selected.") | ||
if st.button("Add Images"): | ||
progress_bar = st.progress(0) | ||
for image in images: | ||
add_image_to_index(image, clip_model, preprocess) | ||
progress_bar.progress((images.index(image) + 1) / len(images), f"{images.index(image) + 1}/{len(images)}") | ||
st.success("Images Added to Database") | ||
else: | ||
st.subheader("Add PDF to Database") | ||
st.warning("Please note that the images in the PDF will also be extracted and added to the database.") | ||
pdfs = st.file_uploader("Upload PDF", type=["pdf"], accept_multiple_files=True) | ||
if pdfs: | ||
st.info(f"Total {len(pdfs)} files selected.") | ||
if st.button("Add PDF"): | ||
for pdf in pdfs: | ||
add_pdf_to_index( | ||
pdf=pdf, | ||
clip_model=clip_model, | ||
preprocess=preprocess, | ||
text_embedding_model=text_embedding_model, | ||
) | ||
st.success("PDF Added to Database") | ||
image_util.upload_image(clip_model, preprocess) | ||
elif upload_choice == "Add Image from URL / Link": | ||
image_util.image_from_url(clip_model, preprocess) | ||
elif upload_choice == "Upload PDF": | ||
pdf_util.upload_pdf(clip_model, preprocess, text_embedding_model) | ||
elif upload_choice == "Website Link": | ||
website_util.data_from_website(clip_model, preprocess, text_embedding_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import os | ||
import requests | ||
import streamlit as st | ||
import sys | ||
from vectordb import add_image_to_index, add_pdf_to_index | ||
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|
||
def image_from_url(clip_model, preprocess): | ||
st.title("Image from URL") | ||
url = st.text_input("Enter Image URL") | ||
correct_url = False | ||
if url: | ||
try: | ||
st.image(url) | ||
correct_url = True | ||
except: | ||
st.error("Invalid URL") | ||
correct_url = False | ||
if correct_url: | ||
if st.button("Add Image"): | ||
response = requests.get(url) | ||
if response.status_code == 200: | ||
add_image_to_index(response.content, clip_model, preprocess) | ||
st.success("Image Added to Database") | ||
else: | ||
st.error("Invalid URL") | ||
|
||
def upload_image(clip_model, preprocess): | ||
st.subheader("Add Image to Database") | ||
images = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | ||
if images: | ||
cols = st.columns(5, vertical_alignment="center") | ||
for count, image in enumerate(images[:4]): | ||
with cols[count]: | ||
st.image(image) | ||
with cols[4]: | ||
if len(images) > 5: | ||
st.info(f"and more {len(images) - 5} images...") | ||
st.info(f"Total {len(images)} files selected.") | ||
if st.button("Add Images"): | ||
progress_bar = st.progress(0) | ||
for image in images: | ||
add_image_to_index(image, clip_model, preprocess) | ||
progress_bar.progress((images.index(image) + 1) / len(images), f"{images.index(image) + 1}/{len(images)}") | ||
st.success("Images Added to Database") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
import streamlit as st | ||
import sys | ||
from vectordb import add_image_to_index, add_pdf_to_index | ||
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|
||
def upload_pdf(clip_model, preprocess, text_embedding_model): | ||
st.subheader("Add PDF to Database") | ||
st.warning("Please note that the images in the PDF will also be extracted and added to the database.") | ||
pdfs = st.file_uploader("Upload PDF", type=["pdf"], accept_multiple_files=True) | ||
if pdfs: | ||
st.info(f"Total {len(pdfs)} files selected.") | ||
if st.button("Add PDF"): | ||
for pdf in pdfs: | ||
add_pdf_to_index( | ||
pdf=pdf, | ||
clip_model=clip_model, | ||
preprocess=preprocess, | ||
text_embedding_model=text_embedding_model, | ||
) | ||
st.success("PDF Added to Database") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import bs4 | ||
import os | ||
from langchain_text_splitters import CharacterTextSplitter | ||
import requests | ||
import streamlit as st | ||
import sys | ||
from vectordb import add_image_to_index, add_pdf_to_index, update_vectordb | ||
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|
||
|
||
def process_text(text: str, text_embedding_model): | ||
text_splitter = CharacterTextSplitter( | ||
separator="\n", | ||
chunk_size=1200, | ||
chunk_overlap=200, | ||
length_function=len, | ||
is_separator_regex=False, | ||
) | ||
chunks = text_splitter.split_text(text) | ||
text_embeddings = text_embedding_model.encode(chunks) | ||
for chunk, embedding in zip(chunks, text_embeddings): | ||
index = update_vectordb(index_path="text_index.index", embedding=embedding, text_content=chunk) | ||
return index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import bs4 | ||
import os | ||
import requests | ||
import streamlit as st | ||
import sys | ||
from vectordb import add_image_to_index, add_pdf_to_index | ||
from data_upload.input_sources_utils import text_util | ||
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|
||
|
||
def data_from_website(clip_model, preprocess, text_embedding_model): | ||
st.title("Data from Website") | ||
website_url = st.text_input("Enter Website URL") | ||
if website_url: | ||
st.write(f"URL: {website_url}") | ||
if st.button("Extract and Add Data"): | ||
response = requests.get(website_url) | ||
if response.status_code == 200: | ||
st.success("Data Extracted Successfully") | ||
else: | ||
st.error("Invalid URL") | ||
|
||
soup = bs4.BeautifulSoup(response.content, features="lxml") | ||
images = soup.find_all("img") | ||
image_dict = [] | ||
if not images: | ||
st.info("No Images Found!") | ||
else: | ||
st.info(f"Found {len(images)} Images") | ||
progress_bar = st.progress(0, f"Extracting Images... | 0/{len(images)}") | ||
cols = st.columns(5) | ||
for count, image in enumerate(images): | ||
try: | ||
image_url = image["src"].replace("//", "https://") | ||
response = requests.get(image_url) | ||
if response.status_code == 200: | ||
image_dict.append({"src": image_url, "content": response.content}) | ||
add_image_to_index(response.content, clip_model, preprocess) | ||
len_image_dict = len(image_dict) | ||
if len_image_dict <= 4: | ||
with cols[len_image_dict - 1]: | ||
st.image(image_url, caption=image_url, use_container_width=True) | ||
elif len_image_dict == 5: | ||
with cols[4]: | ||
st.info(f"and more {len(images) - 4} images...") | ||
except: | ||
pass | ||
progress_bar.progress((count + 1) / len(images), f"Extracting Images... | {count + 1}/{len(images)}") | ||
progress_bar.empty() | ||
|
||
main_content = soup.find('main') | ||
sample_text = main_content.text.strip().replace(r'\n', '') | ||
with st.spinner("Processing Text..."): | ||
text_util.process_text(main_content.text, text_embedding_model) | ||
st.success("Data Added to Database") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.