import gradio as gr from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from src.backend.database.qdrant import QdrantDatabase from src.frontend.responses import QdrantQueryResponses, QdrantArticleResponse import torch import os device = 'cuda' if torch.cuda.is_available() else 'cpu' qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333") qdrant_api_key = os.environ.get("QDRANT_API_KEY", "") client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key) model = SentenceTransformer("sentence-transformers/multi-qa-mpnet-base-dot-v1", device=device) qdrant_database = QdrantDatabase(client, model) def query_database(query: str, k=5): response = qdrant_database.query(query, paragraphs_per_document=1, docs_per_query=k) response = response[query] articles = QdrantQueryResponses(query, response) return articles.query_responses load_dataset_button = gr.Button("Load Example Dataset") css = """ .highlight-paragraph {background-color: rgba(167, 246, 243, 0.28);} gradio-app {height: 100vh;} #article-add-article {flex: 0 1 0; flex-grow: 0} #article-add-menu {align-items: center} #app-tab-section {flex: 1 0 0; flex-grow: 1} #article-search {display: flex;} #article-search-query {} #article-search-btns {} #article-search-output {flex: 1 0 0; overflow-y: auto;} .article-full-text { height: 30vh; overflow-y: scroll; scrollbar-width: none; -ms-overflow-style: none; } .article-progress-container { width: 100%; height: 8px; background: #ccc; } .article-progress-bar { height: 8px; background: linear-gradient(120deg, var(--secondary-600) 0%, var(--primary-500) 60%, var(--primary-600) 100%);; width: 0%; } """ js = """ """ progress_bar = """
""" with gr.Blocks() as semantic_search: topk_default = 5 query_state = gr.State("") topk_state = gr.State(topk_default) with gr.Row(elem_id="article-search-query"): query = gr.Textbox(placeholder="Your query", label="Query") topk = gr.Number(topk_default, label="Number of returned documents", interactive=True) with gr.Row(elem_id="article-search-btns"): clear_btn = gr.ClearButton() load_example_btn = gr.Button("Load Example") submit_btn = gr.Button("Submit", variant="primary") with gr.Row(elem_id="article-search-output"): @gr.render(inputs=[query_state, topk_state]) def render_output(query, k): with gr.Group(elem_id="article-container"): if query: articles = query_database(query, k=k) for article in articles: with gr.Group(elem_classes="article-container-item"): article_link = gr.HTML(article.article_link) article_rel_paragraph = gr.HTML(article.html_most_relevant_paragraph) with gr.Accordion(label="View full article", open=False): article_out = gr.HTML(article.html_article, elem_classes=["article-full-text"]) gr.HTML(progress_bar) clear_btn.add([query, query_state]) load_example_btn.click(lambda: "venuous thrombosis", inputs=[], outputs=[query], show_progress="hidden") submit_btn.click(lambda x, y: (x, y), inputs=[query, topk], outputs=[query_state, topk_state]) scientific_papers_demo = gr.Blocks(css=css, head=js, theme=gr.themes.Ocean(), fill_height=True) with scientific_papers_demo: gr.HTML("

Find relevant articles using text queries

") semantic_search.render() # this does not work for some reason """with gr.Tabs(): with gr.TabItem("Semantic Search", elem_id="article-search"): semantic_search.render() with gr.TabItem("RAG", elem_id="article-rag"): pass""" if __name__ == "__main__": # load_dataset_sample() scientific_papers_demo.launch()