|
import logging
|
|
import os
|
|
from typing import List, Dict, Any, Tuple
|
|
from langchain_groq import ChatGroq
|
|
from langchain.chains import RetrievalQA
|
|
from langchain_core.documents import Document
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
class LLMManager:
|
|
DEFAULT_MODEL = "gemma2-9b-it"
|
|
|
|
def __init__(self):
|
|
self.generation_llm = None
|
|
logging.info("LLMManager initialized")
|
|
|
|
|
|
try:
|
|
self.initialize_generation_llm(self.DEFAULT_MODEL)
|
|
logging.info(f"Initialized default LLM model: {self.DEFAULT_MODEL}")
|
|
except ValueError as e:
|
|
logging.error(f"Failed to initialize default LLM model: {str(e)}")
|
|
|
|
def initialize_generation_llm(self, model_name: str) -> None:
|
|
"""
|
|
Initialize the generation LLM using the Groq API.
|
|
|
|
Args:
|
|
model_name (str): The name of the model to use for generation.
|
|
|
|
Raises:
|
|
ValueError: If GROQ_API_KEY is not set.
|
|
"""
|
|
api_key = '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'
|
|
if not api_key:
|
|
raise ValueError("GROQ_API_KEY is not set. Please add it in your environment variables.")
|
|
|
|
os.environ["GROQ_API_KEY"] = api_key
|
|
self.generation_llm = ChatGroq(model=model_name, temperature=0.7)
|
|
self.generation_llm.name = model_name
|
|
logging.info(f"Generation LLM {model_name} initialized")
|
|
|
|
def reinitialize_llm(self, model_name: str) -> str:
|
|
"""
|
|
Reinitialize the LLM with a new model name.
|
|
|
|
Args:
|
|
model_name (str): The name of the new model to initialize.
|
|
|
|
Returns:
|
|
str: Status message indicating success or failure.
|
|
"""
|
|
try:
|
|
self.initialize_generation_llm(model_name)
|
|
return f"LLM model changed to {model_name}"
|
|
except ValueError as e:
|
|
logging.error(f"Failed to reinitialize LLM with model {model_name}: {str(e)}")
|
|
return f"Error: Failed to change LLM model: {str(e)}"
|
|
|
|
def generate_response(self, question: str, relevant_docs: List[Dict[str, Any]]) -> Tuple[str, List[Document]]:
|
|
"""
|
|
Generate a response using the generation LLM based on the question and relevant documents.
|
|
|
|
Args:
|
|
question (str): The user's query.
|
|
relevant_docs (List[Dict[str, Any]]): List of relevant document chunks with text, metadata, and scores.
|
|
|
|
Returns:
|
|
Tuple[str, List[Document]]: The LLM's response and the source documents used.
|
|
|
|
Raises:
|
|
ValueError: If the generation LLM is not initialized.
|
|
Exception: If there's an error during the QA chain invocation.
|
|
"""
|
|
if not self.generation_llm:
|
|
raise ValueError("Generation LLM is not initialized. Call initialize_generation_llm first.")
|
|
|
|
|
|
documents = [
|
|
Document(page_content=doc['text'], metadata=doc['metadata'])
|
|
for doc in relevant_docs
|
|
]
|
|
|
|
|
|
class SimpleRetriever(BaseRetriever):
|
|
def __init__(self, docs: List[Document], **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._docs = docs
|
|
logging.debug(f"SimpleRetriever initialized with {len(docs)} documents")
|
|
|
|
def _get_relevant_documents(self, query: str) -> List[Document]:
|
|
logging.debug(f"SimpleRetriever._get_relevant_documents called with query: {query}")
|
|
return self._docs
|
|
|
|
async def _aget_relevant_documents(self, query: str) -> List[Document]:
|
|
logging.debug(f"SimpleRetriever._aget_relevant_documents called with query: {query}")
|
|
return self._docs
|
|
|
|
|
|
retriever = SimpleRetriever(docs=documents)
|
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type(
|
|
llm=self.generation_llm,
|
|
retriever=retriever,
|
|
return_source_documents=True
|
|
)
|
|
|
|
try:
|
|
result = qa_chain.invoke({"query": question})
|
|
response = result['result']
|
|
source_docs = result['source_documents']
|
|
logging.info(f"Generated response for question: {question} : {response}")
|
|
return response, source_docs
|
|
except Exception as e:
|
|
logging.error(f"Error during QA chain invocation: {str(e)}")
|
|
raise e |