import gradio as gr
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from src.backend.database.qdrant import QdrantDatabase
from src.frontend.responses import QdrantQueryResponses, QdrantArticleResponse
import torch
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'
qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333")
qdrant_api_key = os.environ.get("QDRANT_API_KEY", "")
client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
model = SentenceTransformer("sentence-transformers/multi-qa-mpnet-base-dot-v1", device=device)
qdrant_database = QdrantDatabase(client, model)
def query_database(query: str, k=5):
response = qdrant_database.query(query, paragraphs_per_document=1, docs_per_query=k)
response = response[query]
articles = QdrantQueryResponses(query, response)
return articles.query_responses
load_dataset_button = gr.Button("Load Example Dataset")
css = """
.highlight-paragraph {background-color: rgba(167, 246, 243, 0.28);}
gradio-app {height: 100vh;}
#article-add-article {flex: 0 1 0; flex-grow: 0}
#article-add-menu {align-items: center}
#app-tab-section {flex: 1 0 0; flex-grow: 1}
#article-search {display: flex;}
#article-search-query {}
#article-search-btns {}
#article-search-output {flex: 1 0 0; overflow-y: auto;}
.article-full-text {
height: 30vh;
overflow-y: scroll;
scrollbar-width: none;
-ms-overflow-style: none;
}
.article-progress-container {
width: 100%;
height: 8px;
background: #ccc;
}
.article-progress-bar {
height: 8px;
background: linear-gradient(120deg, var(--secondary-600) 0%, var(--primary-500) 60%, var(--primary-600) 100%);;
width: 0%;
}
"""
js = """
"""
progress_bar = """
"""
with gr.Blocks() as semantic_search:
topk_default = 5
query_state = gr.State("")
topk_state = gr.State(topk_default)
with gr.Row(elem_id="article-search-query"):
query = gr.Textbox(placeholder="Your query", label="Query")
topk = gr.Number(topk_default, label="Number of returned documents", interactive=True)
with gr.Row(elem_id="article-search-btns"):
clear_btn = gr.ClearButton()
load_example_btn = gr.Button("Load Example")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Row(elem_id="article-search-output"):
@gr.render(inputs=[query_state, topk_state])
def render_output(query, k):
with gr.Group(elem_id="article-container"):
if query:
articles = query_database(query, k=k)
for article in articles:
with gr.Group(elem_classes="article-container-item"):
article_link = gr.HTML(article.article_link)
article_rel_paragraph = gr.HTML(article.html_most_relevant_paragraph)
with gr.Accordion(label="View full article", open=False):
article_out = gr.HTML(article.html_article, elem_classes=["article-full-text"])
gr.HTML(progress_bar)
clear_btn.add([query, query_state])
load_example_btn.click(lambda: "venuous thrombosis",
inputs=[],
outputs=[query],
show_progress="hidden")
submit_btn.click(lambda x, y: (x, y),
inputs=[query, topk],
outputs=[query_state, topk_state])
scientific_papers_demo = gr.Blocks(css=css, head=js, theme=gr.themes.Ocean(), fill_height=True)
with scientific_papers_demo:
gr.HTML("Find relevant articles using text queries
")
semantic_search.render()
# this does not work for some reason
"""with gr.Tabs():
with gr.TabItem("Semantic Search", elem_id="article-search"):
semantic_search.render()
with gr.TabItem("RAG", elem_id="article-rag"):
pass"""
if __name__ == "__main__":
# load_dataset_sample()
scientific_papers_demo.launch()