Sriramsr3 commited on
Commit
cd859d1
·
verified ·
1 Parent(s): e02fea8

Update rag_chain.py

Browse files
Files changed (1) hide show
  1. rag_chain.py +93 -91
rag_chain.py CHANGED
@@ -1,91 +1,93 @@
1
- # rag_chain.py
2
-
3
- import os
4
- import requests
5
- from dotenv import load_dotenv
6
- from langchain_community.document_loaders import PyPDFLoader
7
- from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain_community.embeddings import SentenceTransformerEmbeddings
9
- from langchain_community.vectorstores import Chroma
10
- from langchain.prompts import ChatPromptTemplate
11
- from langchain.schema.runnable import RunnablePassthrough
12
- from langchain.schema.output_parser import StrOutputParser
13
- from tempfile import NamedTemporaryFile
14
-
15
- # Load environment variables (HF_TOKEN)
16
- load_dotenv()
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
-
19
- # Hugging Face LLaMA 3 API call
20
- def generate_response(prompt: str) -> str:
21
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
22
- payload = {
23
- "model": "meta-llama/Llama-3.1-8B-Instruct:novita",
24
- "messages": [
25
- {"role": "system", "content": "You are a helpful health insurance assistant."},
26
- {"role": "user", "content": prompt}
27
- ],
28
- "max_tokens": 300
29
- }
30
-
31
- response = requests.post(
32
- "https://router.huggingface.co/v1/chat/completions", # Use router or correct endpoint
33
- headers=headers,
34
- json=payload
35
- )
36
- response.raise_for_status()
37
- return response.json()["choices"][0]["message"]["content"]
38
-
39
- # Prompt template for RAG
40
- template = """[INST]
41
- You are a professional Health Insurance Assistant.
42
- Provide a short and policy-specific answer in one sentence using only verified content from this policy.
43
- Do not include any explanations or formatting.
44
-
45
- Policy Text:
46
- {context}
47
-
48
- User Question:
49
- {query}
50
- [/INST]"""
51
- prompt = ChatPromptTemplate.from_template(template)
52
-
53
- def load_remote_pdf(url: str) -> str:
54
- # Optional: basic sanity check (skip .endswith('.pdf'))
55
- headers = {
56
- "User-Agent": "Mozilla/5.0"
57
- }
58
- response = requests.get(url, stream=True, headers=headers)
59
- response.raise_for_status()
60
-
61
- content_type = response.headers.get("Content-Type", "")
62
- if "application/pdf" not in content_type:
63
- raise ValueError("URL did not return a PDF file.")
64
-
65
- with NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
66
- for chunk in response.iter_content(chunk_size=8192):
67
- tmp.write(chunk)
68
- return tmp.name# Return local path to temp PDF
69
-
70
- # RAG chain build function
71
- def build_rag_chain(pdf_path: str):
72
- # Load and split PDF
73
- loader = PyPDFLoader(pdf_path)
74
- docs = loader.load()
75
-
76
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
77
- chunks = splitter.split_documents(docs)
78
-
79
- # Embeddings & Vector Store
80
- embeddings = SentenceTransformerEmbeddings(model_name="intfloat/e5-small-v2")
81
- vectorstore = Chroma.from_documents(chunks, embedding=embeddings)
82
- retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
83
-
84
- # RAG pipeline
85
- return (
86
- {"context": retriever, "query": RunnablePassthrough()}
87
- | prompt
88
- | (lambda chat_prompt: generate_response(chat_prompt.to_string())) # FIXED
89
- | StrOutputParser()
90
- )
91
-
 
 
 
1
+ # rag_chain.py
2
+ import os
3
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
4
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
5
+
6
+ import requests
7
+ from dotenv import load_dotenv
8
+ from langchain_community.document_loaders import PyPDFLoader
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
11
+ from langchain_community.vectorstores import Chroma
12
+ from langchain.prompts import ChatPromptTemplate
13
+ from langchain.schema.runnable import RunnablePassthrough
14
+ from langchain.schema.output_parser import StrOutputParser
15
+ from tempfile import NamedTemporaryFile
16
+
17
+ # Load environment variables (HF_TOKEN)
18
+ load_dotenv()
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+
21
+ # Hugging Face LLaMA 3 API call
22
+ def generate_response(prompt: str) -> str:
23
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
24
+ payload = {
25
+ "model": "meta-llama/Llama-3.1-8B-Instruct:novita",
26
+ "messages": [
27
+ {"role": "system", "content": "You are a helpful health insurance assistant."},
28
+ {"role": "user", "content": prompt}
29
+ ],
30
+ "max_tokens": 300
31
+ }
32
+
33
+ response = requests.post(
34
+ "https://router.huggingface.co/v1/chat/completions", # Use router or correct endpoint
35
+ headers=headers,
36
+ json=payload
37
+ )
38
+ response.raise_for_status()
39
+ return response.json()["choices"][0]["message"]["content"]
40
+
41
+ # Prompt template for RAG
42
+ template = """[INST]
43
+ You are a professional Health Insurance Assistant.
44
+ Provide a short and policy-specific answer in one sentence using only verified content from this policy.
45
+ Do not include any explanations or formatting.
46
+
47
+ Policy Text:
48
+ {context}
49
+
50
+ User Question:
51
+ {query}
52
+ [/INST]"""
53
+ prompt = ChatPromptTemplate.from_template(template)
54
+
55
+ def load_remote_pdf(url: str) -> str:
56
+ # Optional: basic sanity check (skip .endswith('.pdf'))
57
+ headers = {
58
+ "User-Agent": "Mozilla/5.0"
59
+ }
60
+ response = requests.get(url, stream=True, headers=headers)
61
+ response.raise_for_status()
62
+
63
+ content_type = response.headers.get("Content-Type", "")
64
+ if "application/pdf" not in content_type:
65
+ raise ValueError("URL did not return a PDF file.")
66
+
67
+ with NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
68
+ for chunk in response.iter_content(chunk_size=8192):
69
+ tmp.write(chunk)
70
+ return tmp.name# Return local path to temp PDF
71
+
72
+ # RAG chain build function
73
+ def build_rag_chain(pdf_path: str):
74
+ # Load and split PDF
75
+ loader = PyPDFLoader(pdf_path)
76
+ docs = loader.load()
77
+
78
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
79
+ chunks = splitter.split_documents(docs)
80
+
81
+ # Embeddings & Vector Store
82
+ embeddings = SentenceTransformerEmbeddings(model_name="intfloat/e5-small-v2")
83
+ vectorstore = Chroma.from_documents(chunks, embedding=embeddings)
84
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
85
+
86
+ # RAG pipeline
87
+ return (
88
+ {"context": retriever, "query": RunnablePassthrough()}
89
+ | prompt
90
+ | (lambda chat_prompt: generate_response(chat_prompt.to_string())) # FIXED
91
+ | StrOutputParser()
92
+ )
93
+