Gopikanth123 commited on
Commit
34cc5b3
·
verified ·
1 Parent(s): 3e21064

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +77 -211
main.py CHANGED
@@ -1,211 +1,77 @@
1
- import os
2
- import shutil
3
- from flask import Flask, render_template, request, jsonify
4
- from whoosh.index import create_in, open_dir
5
- from whoosh.fields import Schema, TEXT
6
- from whoosh.qparser import QueryParser
7
- from transformers import AutoTokenizer, AutoModel
8
- from deep_translator import GoogleTranslator
9
-
10
- # Ensure the necessary directories exist
11
- PERSIST_DIR = "db"
12
- PDF_DIRECTORY = 'data'
13
- os.makedirs(PDF_DIRECTORY, exist_ok=True)
14
- os.makedirs(PERSIST_DIR, exist_ok=True)
15
-
16
- # Load the XLM-R tokenizer and model
17
- tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
18
- modelHere’s the complete corrected code for your Flask application that utilizes the XLM-R model and integrates Whoosh for indexing, ensuring that it handles the creation of indices properly. This should resolve the `EmptyIndexError` you encountered.
19
-
20
- ### Complete Code for RAG Chatbot Using XLM-R
21
-
22
- ```python
23
- import os
24
- import shutil
25
- import torch
26
- from flask import Flask, render_template, request, jsonify
27
- from whoosh.index import create_in, open_dir
28
- from whoosh.fields import Schema, TEXT
29
- from whoosh.qparser import QueryParser
30
- from transformers import AutoTokenizer, AutoModel
31
- from deep_translator import GoogleTranslator
32
-
33
- # Set up directories
34
- PERSIST_DIR = "db"
35
- PDF_DIRECTORY = 'data'
36
- os.makedirs(PDF_DIRECTORY, exist_ok=True)
37
- os.makedirs(PERSIST_DIR, exist_ok=True)
38
-
39
- # Load the XLM-R tokenizer and model
40
- tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
41
- model = AutoModel.from_pretrained("xlm-roberta-base")
42
-
43
- # Setup Whoosh schema for indexing
44
- schema = Schema(title=TEXT(stored=True), content=TEXT(stored=True))
45
-
46
- # Create or open the Whoosh index
47
- def create_index():
48
- if not os.path.exists(PERSIST_DIR):
49
- os.makedirs(PERSIST_DIR)
50
- return create_in(PERSIST_DIR, schema)
51
- else:
52
- return open_dir(PERSIST_DIR)
53
-
54
- index = create_index()
55
-
56
- # Function to load documents from a directory
57
- def load_documents():
58
- documents = []
59
- for filename in os.listdir(PDF_DIRECTORY):
60
- if filename.endswith(".txt"): # Assuming documents are in .txt format
61
- with open(os.path.join(PDF_DIRECTORY, filename), 'r', encoding='utf-8') as file:
62
- content = file.read()
63
- documents.append({'title': filename, 'content': content})
64
- print(f"Loaded document: {filename}") # Debugging line
65
- return documents
66
-
67
- # Function to index documents
68
- def index_documents(documents):
69
- writer = index.writer()
70
- for doc in documents:
71
- writer.add_document(title=doc['title'], content=doc['content'])
72
- writer.commit()
73
-
74
- # Data ingestion from the directory
75
- def data_ingestion_from_directory():
76
- # Clear previous data by removing the persist directory
77
- if os.path.exists(PERSIST_DIR):
78
- shutil.rmtree(PERSIST_DIR)
79
-
80
- os.makedirs(PERSIST_DIR, exist_ok=True)
81
-
82
- # Load new documents from the directory
83
- new_documents = load_documents()
84
- if not new_documents:
85
- print("No documents found to index.")
86
- return
87
-
88
- # Re-create index and index documents
89
- global index
90
- index = create_index()
91
- index_documents(new_documents)
92
-
93
- # Function to retrieve documents based on a query
94
- def retrieve_documents(query):
95
- with index.searcher() as searcher:
96
- query_parser = QueryParser("content", index.schema)
97
- query_object = query_parser.parse(query)
98
- results = searcher.search(query_object)
99
- return [(result['title'], result['content']) for result in results]
100
-
101
- # Function to generate embeddings (not used in this example, but can be utilized if needed)
102
- def get_embeddings(text):
103
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
104
- with torch.no_grad():
105
- outputs = model(**inputs)
106
- embeddings = outputs.last_hidden_state.mean(dim=1) # Average pooling
107
- return embeddings.squeeze().numpy()
108
-
109
- # Function to handle queries and generate responses
110
- def handle_query(query):
111
- retrieved_docs = retrieve_documents(query)
112
-
113
- if not retrieved_docs:
114
- return "Sorry, I couldn't find an answer."
115
-
116
- # Construct a response using the retrieved documents
117
- response = "Here are some insights based on your query:\n" + "\n".join(
118
- [f"Title: {title}\nContent: {content[:100]}..." for title, content in retrieved_docs]
119
- )
120
- return response
121
-
122
- # Initialize Flask app
123
- app = Flask(__name__)
124
-
125
- # Data ingestion
126
- data_ingestion_from_directory()
127
-
128
- # Generate Response
129
- def generate_response(query, language):
130
- try:
131
- # Call the handle_query function to get the response
132
- bot_response = handle_query(query)
133
-
134
- # Map of supported languages
135
- supported_languages = {
136
- "hindi": "hi",
137
- "bengali": "bn",
138
- "telugu": "te",
139
- "marathi": "mr",
140
- "tamil": "ta",
141
- "gujarati": "gu",
142
- "kannada": "kn",
143
- "malayalam": "ml",
144
- "punjabi": "pa",
145
- "odia": "or",
146
- "urdu": "ur",
147
- "assamese": "as",
148
- "sanskrit": "sa",
149
- "arabic": "ar",
150
- "australian": "en-AU",
151
- "bangla-india": "bn-IN",
152
- "chinese": "zh-CN",
153
- "dutch": "nl",
154
- "french": "fr",
155
- "filipino": "tl",
156
- "greek": "el",
157
- "indonesian": "id",
158
- "italian": "it",
159
- "japanese": "ja",
160
- "korean": "ko",
161
- "latin": "la",
162
- "nepali": "ne",
163
- "portuguese": "pt",
164
- "romanian": "ro",
165
- "russian": "ru",
166
- "spanish": "es",
167
- "swedish": "sv",
168
- "thai": "th",
169
- "ukrainian": "uk",
170
- "turkish": "tr"
171
- }
172
-
173
- # Initialize the translated text
174
- translated_text = bot_response
175
-
176
- # Translate only if the language is supported and not English
177
- try:
178
- if language in supported_languages:
179
- target_lang = supported_languages[language]
180
- translated_text = GoogleTranslator(source='auto', target=target_lang).translate(bot_response)
181
- else:
182
- print(f"Unsupported language: {language}")
183
- except Exception as e:
184
- print(f"Translation error: {e}")
185
- translated_text = "Sorry, I couldn't translate the response."
186
-
187
- return translated_text
188
- except Exception as e:
189
- return f"Error fetching the response: {str(e)}"
190
-
191
- # Route for the homepage
192
- @app.route('/')
193
- def index():
194
- return render_template('index.html')
195
-
196
- # Route to handle chatbot messages
197
- @app.route('/chat', methods=['POST'])
198
- def chat():
199
- try:
200
- user_message = request.json.get("message")
201
- language = request.json.get("language")
202
- if not user_message:
203
- return jsonify({"response": "Please say something!"})
204
-
205
- bot_response = generate_response(user_message, language)
206
- return jsonify({"response": bot_response})
207
- except Exception as e:
208
- return jsonify({"response": f"An error occurred: {str(e)}"})
209
-
210
- if __name__ == '__main__':
211
- app.run(debug=True)
 
1
+ import os
2
+ from flask import Flask, request, jsonify
3
+ from llama_index import SimpleDirectoryReader, StorageContext, VectorStoreIndex, load_index_from_storage, ChatPromptTemplate
4
+ from huggingface_hub import InferenceClient
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from deep_translator import GoogleTranslator
7
+
8
+ # Ensure HF_TOKEN is set
9
+ HF_TOKEN = os.getenv("HF_TOKEN")
10
+ if not HF_TOKEN:
11
+ raise ValueError("HF_TOKEN environment variable not set.")
12
+
13
+ # Hugging Face model configuration
14
+ REPO_ID = "facebook/xlm-roberta-xl" # Use xlm-roberta-xl model
15
+ tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
16
+ model = AutoModel.from_pretrained(REPO_ID)
17
+
18
+ # Flask app
19
+ app = Flask(__name__)
20
+
21
+ # Directories for storing data
22
+ PERSIST_DIR = "db"
23
+ PDF_DIRECTORY = "data"
24
+ os.makedirs(PDF_DIRECTORY, exist_ok=True)
25
+ os.makedirs(PERSIST_DIR, exist_ok=True)
26
+
27
+ # Initialize variables
28
+ chat_history = []
29
+
30
+ # Function to ingest documents
31
+ def data_ingestion_from_directory():
32
+ if os.path.exists(PERSIST_DIR):
33
+ os.system(f"rm -rf {PERSIST_DIR}") # Clear previous data
34
+ os.makedirs(PERSIST_DIR, exist_ok=True)
35
+
36
+ documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
37
+ index = VectorStoreIndex.from_documents(documents)
38
+ index.storage_context.persist(persist_dir=PERSIST_DIR)
39
+
40
+ # Function to handle queries
41
+ def handle_query(query):
42
+ storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
43
+ index = load_index_from_storage(storage_context)
44
+ query_engine = index.as_query_engine()
45
+
46
+ chat_prompt = ChatPromptTemplate.from_messages([
47
+ ("user", "User asked: {query_str}"),
48
+ ("assistant", "Answer: {response}"),
49
+ ])
50
+
51
+ result = query_engine.query(query, prompt_template=chat_prompt)
52
+ return result.response if hasattr(result, 'response') else "No relevant answer found."
53
+
54
+ # Route for homepage
55
+ @app.route("/")
56
+ def index():
57
+ return "Welcome to the RAG Application using xlm-roberta-xl!"
58
+
59
+ # Route to handle chatbot messages
60
+ @app.route("/chat", methods=["POST"])
61
+ def chat():
62
+ try:
63
+ user_message = request.json.get("message")
64
+ if not user_message:
65
+ return jsonify({"response": "Please provide a message!"})
66
+
67
+ # Generate response
68
+ response = handle_query(user_message)
69
+ chat_history.append({"user": user_message, "bot": response})
70
+ return jsonify({"response": response})
71
+ except Exception as e:
72
+ return jsonify({"response": f"An error occurred: {str(e)}"})
73
+
74
+ if __name__ == "__main__":
75
+ # Ingest data before starting the app
76
+ data_ingestion_from_directory()
77
+ app.run(debug=True)