DHEIVER commited on
Commit
2fd6061
·
verified ·
1 Parent(s): 2913a25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -1
app.py CHANGED
@@ -1,3 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def demo():
2
  with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
3
  vector_db = gr.State()
@@ -70,4 +199,7 @@ def demo():
70
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
71
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
72
 
73
- demo.queue().launch(debug=True)
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ api_token = os.getenv("API_TOKEN")
4
+
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain_community.llms import HuggingFaceEndpoint
12
+ import torch
13
+
14
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
15
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
16
+
17
+ # Load and split PDF document
18
+ def load_doc(list_file_path):
19
+ loaders = [PyPDFLoader(x) for x in list_file_path]
20
+ pages = []
21
+ for loader in loaders:
22
+ pages.extend(loader.load())
23
+ text_splitter = RecursiveCharacterTextSplitter(
24
+ chunk_size=1024,
25
+ chunk_overlap=64
26
+ )
27
+ doc_splits = text_splitter.split_documents(pages)
28
+ return doc_splits
29
+
30
+ # Create vector database
31
+ def create_db(splits):
32
+ embeddings = HuggingFaceEmbeddings()
33
+ vectordb = FAISS.from_documents(splits, embeddings)
34
+ return vectordb
35
+
36
+ # Initialize langchain LLM chain
37
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
38
+ if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
39
+ llm = HuggingFaceEndpoint(
40
+ repo_id=llm_model,
41
+ huggingfacehub_api_token=api_token,
42
+ temperature=temperature,
43
+ max_new_tokens=max_tokens,
44
+ top_k=top_k,
45
+ )
46
+ else:
47
+ llm = HuggingFaceEndpoint(
48
+ huggingfacehub_api_token=api_token,
49
+ repo_id=llm_model,
50
+ temperature=temperature,
51
+ max_new_tokens=max_tokens,
52
+ top_k=top_k,
53
+ )
54
+
55
+ memory = ConversationBufferMemory(
56
+ memory_key="chat_history",
57
+ output_key='answer',
58
+ return_messages=True
59
+ )
60
+
61
+ retriever = vector_db.as_retriever()
62
+ qa_chain = ConversationalRetrievalChain.from_llm(
63
+ llm,
64
+ retriever=retriever,
65
+ chain_type="stuff",
66
+ memory=memory,
67
+ return_source_documents=True,
68
+ verbose=False,
69
+ )
70
+ return qa_chain
71
+
72
+ # Initialize database
73
+ def initialize_database(list_file_obj, progress=gr.Progress()):
74
+ list_file_path = [x.name for x in list_file_obj if x is not None]
75
+ doc_splits = load_doc(list_file_path)
76
+ vector_db = create_db(doc_splits)
77
+ return vector_db, "Database created!"
78
+
79
+ # Initialize LLM
80
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
81
+ llm_name = list_llm[llm_option]
82
+ print("llm_name: ", llm_name)
83
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
84
+ return qa_chain, "QA chain initialized. Chatbot is ready!"
85
+
86
+ def format_chat_history(message, chat_history):
87
+ formatted_chat_history = []
88
+ for user_message, bot_message in chat_history:
89
+ formatted_chat_history.append(f"User: {user_message}")
90
+ formatted_chat_history.append(f"Assistant: {bot_message}")
91
+ return formatted_chat_history
92
+
93
+ # Updated conversation function with language selection
94
+ def conversation(qa_chain, message, history, language):
95
+ formatted_chat_history = format_chat_history(message, history)
96
+
97
+ # Prepend language instruction to the message
98
+ if language == "Portuguese":
99
+ language_instruction = "Responda em português: "
100
+ else: # Default to English
101
+ language_instruction = "Answer in English: "
102
+
103
+ modified_message = language_instruction + message
104
+
105
+ # Generate response using QA chain
106
+ response = qa_chain.invoke({"question": modified_message, "chat_history": formatted_chat_history})
107
+ response_answer = response["answer"]
108
+ if response_answer.find("Helpful Answer:") != -1:
109
+ response_answer = response_answer.split("Helpful Answer:")[-1]
110
+ response_sources = response["source_documents"]
111
+ response_source1 = response_sources[0].page_content.strip()
112
+ response_source2 = response_sources[1].page_content.strip()
113
+ response_source3 = response_sources[2].page_content.strip()
114
+ # Langchain sources are zero-based
115
+ response_source1_page = response_sources[0].metadata["page"] + 1
116
+ response_source2_page = response_sources[1].metadata["page"] + 1
117
+ response_source3_page = response_sources[2].metadata["page"] + 1
118
+
119
+ # Append user message and response to chat history
120
+ new_history = history + [(message, response_answer)]
121
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
122
+
123
+ def upload_file(file_obj):
124
+ list_file_path = []
125
+ for idx, file in enumerate(file_obj):
126
+ file_path = file_obj.name
127
+ list_file_path.append(file_path)
128
+ return list_file_path
129
+
130
  def demo():
131
  with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
132
  vector_db = gr.State()
 
199
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, language_selector], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
200
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
201
 
202
+ demo.queue().launch(debug=True)
203
+
204
+ if __name__ == "__main__":
205
+ demo()