{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "adcfdba2", "metadata": {}, "outputs": [], "source": [ "# import some packages\n", "import os\n", "\n", "from dotenv import load_dotenv\n", "from langchain.document_loaders import PyPDFLoader\n", "#from langchain.chat_models import ChatCohere\n", "from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter\n", "from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings\n", "from langchain.vectorstores import FAISS, Chroma\n", "from langchain.chains import ConversationalRetrievalChain\n", "from langchain.llms import HuggingFaceTextGenInference\n", "from langchain.chains.conversation.memory import (\n", " ConversationBufferMemory,\n", " ConversationBufferWindowMemory,\n", ")" ] }, { "cell_type": "code", "execution_count": 2, "id": "2d85c6d9", "metadata": {}, "outputs": [], "source": [ "# Set api keys\n", "load_dotenv(\"API.env\") # put all the API tokens here, such as openai, huggingface...\n", "HUGGINGFACEHUB_API_TOKEN = os.getenv(\"HUGGINGFACEHUB_API_TOKEN\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "ffd3db32", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/mnt/data2/yinghanz/codes/machine_learning_projects/llm/venv/hftest/lib/python3.10/site-packages/pydantic/_internal/_fields.py:151: UserWarning: Field \"model_id\" has conflict with protected namespace \"model_\".\n", "\n", "You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.\n", " warnings.warn(\n", "/mnt/data2/yinghanz/codes/machine_learning_projects/llm/venv/hftest/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "# Set inference link, use this online one for easier reproduce\n", "inference_api_url = 'https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta'\n", "# Recommend using better LLMs, such as Mixtral 7x8B\n", "\n", "llm = HuggingFaceTextGenInference(\n", " verbose=True, # Provides detailed logs of operation\n", " max_new_tokens=1024, # Maximum number of token that can be generated.\n", " top_p=0.95, # Threshold for controlling randomness in text generation process. \n", " typical_p=0.95, #\n", " temperature=0.1, # For choosing probable words.\n", " inference_server_url=inference_api_url, # URL des Inferenzservers\n", " timeout=120, # Timeout for connection with the url\n", " )\n", "\n", "# Alternative, you can load model locally, e.g.:\n", "# model_path = \"where/you/store/local/models/zephyr-7b-beta\" # change this to your model path\n", "# model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\")\n", "# tokenizer = AutoTokenizer.from_pretrained(model_path)\n", "# pipe = pipeline(\n", "# \"text-generation\", model=model, tokenizer=tokenizer, max_new_tokens=1024, model_kwargs={\"temperature\":0.1}\n", "# )\n", "# llm = HuggingFacePipeline(pipeline=pipe)" ] }, { "cell_type": "code", "execution_count": 4, "id": "2d5bacd5", "metadata": {}, "outputs": [], "source": [ "# Function for reading and chunking text\n", "def load_pdf_as_docs(pdf_path, loader_module=None):\n", " if pdf_path.endswith('.pdf'): # single file\n", " pdf_docs = [pdf_path]\n", " else: # a directory\n", " pdf_docs = [os.path.join(pdf_path, f) for f in os.listdir(pdf_path) if f.endswith('.pdf')]\n", " \n", " docs = []\n", " \n", " if loader_module is None: # Set PDFLoader\n", " loader_module = PyPDFLoader\n", " for pdf in pdf_docs:\n", " loader = loader_module(pdf)\n", " doc = loader.load()\n", " docs.extend(doc)\n", " \n", " return docs\n", "\n", "def get_doc_chunks(docs, splitter=None):\n", " \"\"\"Split docs into chunks.\"\"\"\n", " \n", " if splitter is None:\n", " splitter = RecursiveCharacterTextSplitter(\n", " separators=[\"\\n\\n\", \"\\n\"], chunk_size=256, chunk_overlap=128\n", " )\n", " chunks = splitter.split_documents(docs)\n", " \n", " return chunks" ] }, { "cell_type": "code", "execution_count": 5, "id": "8cd31248", "metadata": {}, "outputs": [], "source": [ "# Specify the directory containing your PDFs\n", "# directory = \"C:\\\\Orga\\\\FestBatt\\\\FB2\\\\LISA\\\\Literature\"\n", "directory = \"FestbattLiterature\" # change to your pdf dictory\n", "\n", "# Find and parse all PDFs in the directory\n", "pdf_docs = load_pdf_as_docs(directory, PyPDFLoader)\n", "\n", "document_chunks = get_doc_chunks(pdf_docs)" ] }, { "cell_type": "code", "execution_count": 6, "id": "7bf62c76", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/mnt/data2/yinghanz/codes/machine_learning_projects/llm/venv/hftest/lib/python3.10/site-packages/torch/cuda/__init__.py:141: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n", " return torch._C._cuda_getDeviceCount() > 0\n" ] } ], "source": [ "# Set embedding\n", "embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-base-en-v1.5') # choose the one you like\n", "\n", "# Set vectorstore, e.g. FAISS\n", "texts = [\"LISA - Lithium Ion Solid-state Assistant\"]\n", "vectorstore = FAISS.from_texts(texts, embeddings) # this is a workaround as FAISS cannot be initilized by FAISS(embedding_function=embeddings), waiting for Langchain fix\n", "# You may also use Chroma\n", "# vectorstore = Chroma(embedding_function=embeddings)" ] }, { "cell_type": "code", "execution_count": 7, "id": "73d560de", "metadata": {}, "outputs": [], "source": [ "# Create retrievers" ] }, { "cell_type": "code", "execution_count": 12, "id": "e5796990", "metadata": {}, "outputs": [], "source": [ "# Some advanced RAG, with parent document retriever, hybrid-search and rerank\n", "\n", "# 1. ParentDocumentRetriever. Note: this will take a long time (~several minutes)\n", "\n", "from langchain.storage import InMemoryStore\n", "from langchain.retrievers import ParentDocumentRetriever\n", "# For local storage, ref: https://stackoverflow.com/questions/77385587/persist-parentdocumentretriever-of-langchain\n", "store = InMemoryStore()\n", "\n", "parent_splitter = RecursiveCharacterTextSplitter(separators=[\"\\n\\n\", \"\\n\"], chunk_size=512, chunk_overlap=256)\n", "child_splitter = RecursiveCharacterTextSplitter(separators=[\"\\n\\n\", \"\\n\"], chunk_size=256, chunk_overlap=128)\n", "\n", "parent_doc_retriver = ParentDocumentRetriever(\n", " vectorstore=vectorstore,\n", " docstore=store,\n", " child_splitter=child_splitter,\n", " parent_splitter=parent_splitter,\n", ")\n", "parent_doc_retriver.add_documents(pdf_docs)" ] }, { "cell_type": "code", "execution_count": 11, "id": "bc299740", "metadata": {}, "outputs": [], "source": [ "# 2. Hybrid search\n", "from langchain.retrievers import BM25Retriever\n", "\n", "bm25_retriever = BM25Retriever.from_documents(document_chunks, k=5) # 1/2 of dense retriever, experimental value" ] }, { "cell_type": "code", "execution_count": 13, "id": "2eb8bc8f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "config.json: 100%|██████████| 801/801 [00:00<00:00, 2.96MB/s]\n", "model.safetensors: 100%|██████████| 2.24G/2.24G [00:06<00:00, 359MB/s]\n", "tokenizer_config.json: 100%|██████████| 443/443 [00:00<00:00, 2.68MB/s]\n", "sentencepiece.bpe.model: 100%|██████████| 5.07M/5.07M [00:00<00:00, 405MB/s]\n", "tokenizer.json: 100%|██████████| 17.1M/17.1M [00:00<00:00, 354MB/s]\n", "special_tokens_map.json: 100%|██████████| 279/279 [00:00<00:00, 1.31MB/s]\n" ] } ], "source": [ "# 3. Rerank\n", "\"\"\"\n", "Ref:\n", "https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c\n", "https://github.com/langchain-ai/langchain/issues/13076\n", "good to read:\n", "https://teemukanstren.com/2023/12/25/llmrag-based-question-answering/\n", "\"\"\"\n", "from __future__ import annotations\n", "from typing import Dict, Optional, Sequence\n", "from langchain.schema import Document\n", "from langchain.pydantic_v1 import Extra, root_validator\n", "\n", "from langchain.callbacks.manager import Callbacks\n", "from langchain.retrievers.document_compressors.base import BaseDocumentCompressor\n", "\n", "from sentence_transformers import CrossEncoder\n", "\n", "model_name = \"BAAI/bge-reranker-large\" #\n", "\n", "class BgeRerank(BaseDocumentCompressor):\n", " model_name:str = model_name\n", " \"\"\"Model name to use for reranking.\"\"\" \n", " top_n: int = 10 \n", " \"\"\"Number of documents to return.\"\"\"\n", " model:CrossEncoder = CrossEncoder(model_name)\n", " \"\"\"CrossEncoder instance to use for reranking.\"\"\"\n", "\n", " def bge_rerank(self,query,docs):\n", " model_inputs = [[query, doc] for doc in docs]\n", " scores = self.model.predict(model_inputs)\n", " results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)\n", " return results[:self.top_n]\n", "\n", "\n", " class Config:\n", " \"\"\"Configuration for this pydantic object.\"\"\"\n", "\n", " extra = Extra.forbid\n", " arbitrary_types_allowed = True\n", "\n", " def compress_documents(\n", " self,\n", " documents: Sequence[Document],\n", " query: str,\n", " callbacks: Optional[Callbacks] = None,\n", " ) -> Sequence[Document]:\n", " \"\"\"\n", " Compress documents using BAAI/bge-reranker models.\n", "\n", " Args:\n", " documents: A sequence of documents to compress.\n", " query: The query to use for compressing the documents.\n", " callbacks: Callbacks to run during the compression process.\n", "\n", " Returns:\n", " A sequence of compressed documents.\n", " \"\"\"\n", " \n", " if len(documents) == 0: # to avoid empty api call\n", " return []\n", " doc_list = list(documents)\n", " _docs = [d.page_content for d in doc_list]\n", " results = self.bge_rerank(query, _docs)\n", " final_results = []\n", " for r in results:\n", " doc = doc_list[r[0]]\n", " doc.metadata[\"relevance_score\"] = r[1]\n", " final_results.append(doc)\n", " return final_results\n", " \n", " \n", "from langchain.retrievers import ContextualCompressionRetriever" ] }, { "cell_type": "code", "execution_count": 14, "id": "af780912", "metadata": {}, "outputs": [], "source": [ "# Stack all the retrievers together\n", "from langchain.retrievers import EnsembleRetriever\n", "# Ensemble all above\n", "ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, parent_doc_retriver], weights=[0.5, 0.5])\n", "\n", "# Re-rank\n", "compressor = BgeRerank()\n", "rerank_retriever = ContextualCompressionRetriever(\n", " base_compressor=compressor, base_retriever=ensemble_retriever\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "beb9ab21", "metadata": {}, "outputs": [], "source": [ "## Now begin to build Q&A system\n", "class RAGChain:\n", " def __init__(\n", " self, memory_key=\"chat_history\", output_key=\"answer\", return_messages=True\n", " ):\n", " self.memory_key = memory_key\n", " self.output_key = output_key\n", " self.return_messages = return_messages\n", "\n", " def create(self, retriver, llm):\n", " memory = ConversationBufferWindowMemory( # ConversationBufferMemory(\n", " memory_key=self.memory_key,\n", " return_messages=self.return_messages,\n", " output_key=self.output_key,\n", " )\n", "\n", " # https://github.com/langchain-ai/langchain/issues/4608\n", " conversation_chain = ConversationalRetrievalChain.from_llm(\n", " llm=llm,\n", " retriever=retriver,\n", " memory=memory,\n", " return_source_documents=True,\n", " rephrase_question=False, # disable rephrase, for test purpose\n", " get_chat_history=lambda x: x,\n", " )\n", " \n", " return conversation_chain\n", " \n", " \n", "rag_chain = RAGChain()\n", "lisa_qa_conversation = rag_chain.create(rerank_retriever, llm)" ] }, { "cell_type": "code", "execution_count": 16, "id": "59159951", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/mnt/data2/yinghanz/codes/machine_learning_projects/llm/venv/hftest/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.\n", " warn_deprecated(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Two common solid electrolytes are LLZO (lithium lanthanum zirconate titanate) and sulfide-based solid electrolytes, as mentioned in the context provided.\n" ] } ], "source": [ "# Now begin to ask question\n", "question = \"Please name two common solid electrolytes.\"\n", "result = lisa_qa_conversation({\"question\":question, \"chat_history\": []})\n", "print(result[\"answer\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "f5e3c7b5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 19, "id": "d736960b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "Running on public URL: https://3a0ee58b7378104912.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "