File size: 5,206 Bytes
a6c26b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ab5b15
 
a6c26b1
5ab5b15
a6c26b1
 
5ab5b15
 
a6c26b1
 
 
 
 
5ab5b15
a6c26b1
 
5ab5b15
 
 
a6c26b1
aa94ed8
a6c26b1
aa94ed8
a6c26b1
 
5ab5b15
a6c26b1
5ab5b15
 
 
 
aa94ed8
a6c26b1
aa94ed8
a6c26b1
aa94ed8
5ab5b15
 
a6c26b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa94ed8
5ab5b15
aa94ed8
5ab5b15
a84e3d2
a6c26b1
 
a84e3d2
a6c26b1
 
 
 
a84e3d2
 
a6c26b1
 
a84e3d2
 
 
a6c26b1
 
a84e3d2
5ab5b15
aa94ed8
a6c26b1
 
 
 
a84e3d2
a6c26b1
a84e3d2
5ab5b15
a6c26b1
 
 
5ab5b15
a6c26b1
aa94ed8
 
 
a6c26b1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import sys
import yaml
import gradio as gr

current_dir = os.path.dirname(os.path.abspath(__file__))
print(current_dir)

from src.document_retrieval import DocumentRetrieval
from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials
from utils.parsing.sambaparse import parse_doc_universal # added Petro
from utils.vectordb.vector_db import VectorDb

CONFIG_PATH = os.path.join(current_dir,'config.yaml')
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir

chat_history = gr.State()
chat_history = []

def handle_userinput(user_question, conversation):
    if user_question:
        try:
            response = conversation.invoke({"question": user_question})
            chat_history.append((user_question, response["answer"]))

            #sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
            #sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
            #state.sources_history.append(sources_text)

            return chat_history, "" #, state.sources_history
        except Exception as e:
            return f"An error occurred: {str(e)}", "" #, state.sources_history
    else:
        return "An error occurred", ""
    #return chat_history, "" #, state.sources_history

def process_documents(files, document_retrieval, vectorstore, conversation, save_location=None):
    try:
        document_retrieval = DocumentRetrieval()
        _, _, text_chunks = parse_doc_universal(doc=files)
        print(text_chunks)
        embeddings = document_retrieval.load_embedding_model()
        collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
        vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
        document_retrieval.init_retriever(vectorstore)
        conversation = document_retrieval.get_qa_retrieval_chain()
        #input_disabled = False
        return conversation, vectorstore, document_retrieval, "Complete! You can now ask questions."
    except Exception as e:
        return conversation, vectorstore, document_retrieval, f"An error occurred while processing: {str(e)}"

def reset_conversation(chat_history):
    chat_history = []
    return chat_history, ""

# Read config file
with open(CONFIG_PATH, 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)
    
prod_mode = config.get('prod_mode', False)
default_collection = 'ekr_default_collection'

# Load env variables
initialize_env_variables(prod_mode)

caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes.
"""

with gr.Blocks() as demo:
    vectorstore = gr.State()
    conversation = gr.State()
    document_retrieval = gr.State()

    gr.Markdown("# Enterprise Knowledge Retriever", 
            elem_id="title")
    
    gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).")

    api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")

    # Step 1: Add PDF file   
    gr.Markdown("## 1️⃣ Upload PDF")
    docs = gr.File(label="Add PDF file (single)", file_types=["pdf"], file_count="single")
        
    # Step 2: Process PDF file
    gr.Markdown(("## 2️⃣ Process document and create vector store"))
    db_btn = gr.Radio(["ChromaDB"], label="Vector store type", value = "ChromaDB", type="index", info="Choose your vector store")
    setup_output = gr.Textbox(label="Processing status", visible=True, value="None") 
    process_btn = gr.Button("🔄 Process")
    gr.Markdown(caution_text)
      
    # Preprocessing events
    process_btn.click(process_documents, inputs=[docs, document_retrieval, vectorstore, conversation], outputs=[conversation, vectorstore, document_retrieval, setup_output], concurrency_limit=10)
        #process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
        #load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
    
    # Step 3: Chat with your data
    gr.Markdown("## 3️⃣ Chat with your document")
    chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
    msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...")
    clear_btn = gr.Button("Clear chat")
    #show_sources = gr.Checkbox(label="Show sources", value=True)
    sources_output = gr.Textbox(label="Sources", visible=False)

    # Chatbot events
    #msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
    msg.submit(handle_userinput, inputs=[msg, conversation], outputs=[chatbot, msg], queue=False)
    clear_btn.click(lambda: [None, ""], inputs=None, outputs=[chatbot, msg], queue=False)
    #clear_btn.click(reset_conversation, inputs=[], outputs=[chatbot,msg])
    #show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)

if __name__ == "__main__":
    demo.launch()