import streamlit as st import asyncio import os from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain_community.vectorstores import Chroma from langchain_community.embeddings import OpenAIEmbeddings from langchain.chat_models import ChatOpenAI from langchain.schema import SystemMessage, HumanMessage, AIMessage from PyPDF2 import PdfReader import aiohttp from io import BytesIO # Set up API key os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] # Set up prompts system_template = "You are an AI assistant answering questions about AI. Use the following context to answer the user's question. If you cannot find the answer in the context, say you don't know the answer but you can try to help with related information." system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) human_template = "Context:\n{context}\n\nQuestion:\n{question}" human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) # Define RetrievalAugmentedQAPipeline class class RetrievalAugmentedQAPipeline: def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None: self.llm = llm self.vector_db = vector_db async def arun_pipeline(self, user_query: str, chat_history: list): context_docs = self.vector_db.similarity_search(user_query, k=2) context_list = [doc.page_content for doc in context_docs] context_prompt = "\n".join(context_list) max_context_length = 12000 if len(context_prompt) > max_context_length: context_prompt = context_prompt[:max_context_length] messages = [SystemMessage(content=system_template)] messages.extend(chat_history) messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query))) response = await self.llm.agenerate([messages]) return {"response": response.generations[0][0].text} # PDF processing functions async def fetch_pdf(session, url): async with session.get(url) as response: if response.status == 200: return await response.read() else: return None async def process_pdf(pdf_content): pdf_reader = PdfReader(BytesIO(pdf_content)) text = "\n".join([page.extract_text() for page in pdf_reader.pages]) text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40) return text_splitter.split_text(text) @st.cache_resource def initialize_pipeline(): return asyncio.run(main()) # Main execution async def main(): pdf_urls = [ "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf", "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf", ] all_chunks = [] async with aiohttp.ClientSession() as session: pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls]) for pdf_content in pdf_contents: if pdf_content: chunks = await process_pdf(pdf_content) all_chunks.extend(chunks) embeddings = OpenAIEmbeddings() vector_db = Chroma.from_texts(all_chunks, embeddings) chat_openai = ChatOpenAI() return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai) # Streamlit UI st.title("Ask About AI!") # Initialize session state for chat history if "chat_history" not in st.session_state: st.session_state.chat_history = [] pipeline = initialize_pipeline() # Display chat history for message in st.session_state.chat_history: if isinstance(message, HumanMessage): st.write("You:", message.content) elif isinstance(message, AIMessage): st.write("AI:", message.content) user_query = st.text_input("Enter your question about AI:") if user_query: # Add user message to chat history st.session_state.chat_history.append(HumanMessage(content=user_query)) with st.spinner("Generating response..."): result = asyncio.run(pipeline.arun_pipeline(user_query, st.session_state.chat_history)) # Add AI response to chat history ai_message = AIMessage(content=result["response"]) st.session_state.chat_history.append(ai_message) # Display the latest response st.write("AI:", result["response"]) # Add a button to clear chat history if st.button("Clear Chat History"): st.session_state.chat_history = [] st.experimental_rerun()