Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import asyncio | |
| import os | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.schema import SystemMessage, HumanMessage | |
| from PyPDF2 import PdfReader | |
| import aiohttp | |
| from io import BytesIO | |
| # Set up API key | |
| os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
| # Set up prompts | |
| system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer." | |
| system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) | |
| human_template = "Context:\n{context}\n\nQuestion:\n{question}" | |
| human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) | |
| chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) | |
| # Define RetrievalAugmentedQAPipeline class | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None: | |
| self.llm = llm | |
| self.vector_db = vector_db | |
| async def arun_pipeline(self, user_query: str): | |
| context_docs = self.vector_db.similarity_search(user_query, k=2) | |
| context_list = [doc.page_content for doc in context_docs] | |
| context_prompt = "\n".join(context_list) | |
| max_context_length = 12000 | |
| if len(context_prompt) > max_context_length: | |
| context_prompt = context_prompt[:max_context_length] | |
| messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages() | |
| response = await self.llm.agenerate([messages]) | |
| return {"response": response.generations[0][0].text} | |
| # PDF processing functions | |
| async def fetch_pdf(session, url): | |
| async with session.get(url) as response: | |
| if response.status == 200: | |
| return await response.read() | |
| else: | |
| return None | |
| async def process_pdf(pdf_content): | |
| pdf_reader = PdfReader(BytesIO(pdf_content)) | |
| text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40) | |
| return text_splitter.split_text(text) | |
| def initialize_pipeline(): | |
| return asyncio.run(main()) | |
| # Main execution | |
| async def main(): | |
| pdf_urls = [ | |
| "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf", | |
| "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf", | |
| ] | |
| all_chunks = [] | |
| async with aiohttp.ClientSession() as session: | |
| pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls]) | |
| for pdf_content in pdf_contents: | |
| if pdf_content: | |
| chunks = await process_pdf(pdf_content) | |
| all_chunks.extend(chunks) | |
| embeddings = OpenAIEmbeddings() | |
| vector_db = Chroma.from_texts(all_chunks, embeddings) | |
| chat_openai = ChatOpenAI() | |
| return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai) | |
| # Streamlit UI | |
| st.title("Ask About AI!") | |
| pipeline = initialize_pipeline() | |
| user_query = st.text_input("Enter your question about AI:") | |
| if user_query: | |
| with st.spinner("Generating response..."): | |
| result = asyncio.run(pipeline.arun_pipeline(user_query)) | |
| st.write("Response:") | |
| st.write(result["response"]) |