|
import logging |
|
import os |
|
from typing import List, Dict, Any, Tuple |
|
from langchain_groq import ChatGroq |
|
from langchain.chains import RetrievalQA |
|
from langchain_core.documents import Document |
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain.chains.summarize import load_summarize_chain |
|
from langchain.prompts import PromptTemplate |
|
|
|
class LLMManager: |
|
DEFAULT_MODEL = "gemma2-9b-it" |
|
|
|
def __init__(self): |
|
self.generation_llm = None |
|
logging.info("LLMManager initialized") |
|
|
|
|
|
try: |
|
self.initialize_generation_llm(self.DEFAULT_MODEL) |
|
logging.info(f"Initialized default LLM model: {self.DEFAULT_MODEL}") |
|
except ValueError as e: |
|
logging.error(f"Failed to initialize default LLM model: {str(e)}") |
|
|
|
def initialize_generation_llm(self, model_name: str) -> None: |
|
""" |
|
Initialize the generation LLM using the Groq API. |
|
|
|
Args: |
|
model_name (str): The name of the model to use for generation. |
|
|
|
Raises: |
|
ValueError: If GROQ_API_KEY is not set. |
|
""" |
|
api_key = os.getenv("GROQ_API_KEY") |
|
if not api_key: |
|
raise ValueError("GROQ_API_KEY is not set. Please add it in your environment variables.") |
|
|
|
os.environ["GROQ_API_KEY"] = api_key |
|
self.generation_llm = ChatGroq(model=model_name, temperature=0.7) |
|
self.generation_llm.name = model_name |
|
logging.info(f"Generation LLM {model_name} initialized") |
|
|
|
def reinitialize_llm(self, model_name: str) -> str: |
|
""" |
|
Reinitialize the LLM with a new model name. |
|
|
|
Args: |
|
model_name (str): The name of the new model to initialize. |
|
|
|
Returns: |
|
str: Status message indicating success or failure. |
|
""" |
|
try: |
|
self.initialize_generation_llm(model_name) |
|
return f"LLM model changed to {model_name}" |
|
except ValueError as e: |
|
logging.error(f"Failed to reinitialize LLM with model {model_name}: {str(e)}") |
|
return f"Error: Failed to change LLM model: {str(e)}" |
|
|
|
def generate_response(self, question: str, relevant_docs: List[Dict[str, Any]]) -> Tuple[str, List[Document]]: |
|
""" |
|
Generate a response using the generation LLM based on the question and relevant documents. |
|
|
|
Args: |
|
question (str): The user's query. |
|
relevant_docs (List[Dict[str, Any]]): List of relevant document chunks with text, metadata, and scores. |
|
|
|
Returns: |
|
Tuple[str, List[Document]]: The LLM's response and the source documents used. |
|
|
|
Raises: |
|
ValueError: If the generation LLM is not initialized. |
|
Exception: If there's an error during the QA chain invocation. |
|
""" |
|
if not self.generation_llm: |
|
raise ValueError("Generation LLM is not initialized. Call initialize_generation_llm first.") |
|
|
|
|
|
documents = [ |
|
Document(page_content=doc['text'], metadata=doc['metadata']) |
|
for doc in relevant_docs |
|
] |
|
|
|
|
|
class SimpleRetriever(BaseRetriever): |
|
def __init__(self, docs: List[Document], **kwargs): |
|
super().__init__(**kwargs) |
|
self._docs = docs |
|
logging.debug(f"SimpleRetriever initialized with {len(docs)} documents") |
|
|
|
def _get_relevant_documents(self, query: str) -> List[Document]: |
|
logging.debug(f"SimpleRetriever._get_relevant_documents called with query: {query}") |
|
return self._docs |
|
|
|
async def _aget_relevant_documents(self, query: str) -> List[Document]: |
|
logging.debug(f"SimpleRetriever._aget_relevant_documents called with query: {query}") |
|
return self._docs |
|
|
|
|
|
retriever = SimpleRetriever(docs=documents) |
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=self.generation_llm, |
|
retriever=retriever, |
|
return_source_documents=True |
|
) |
|
|
|
try: |
|
result = qa_chain.invoke({"query": question}) |
|
response = result['result'] |
|
source_docs = result['source_documents'] |
|
|
|
return response, source_docs |
|
except Exception as e: |
|
logging.error(f"Error during QA chain invocation: {str(e)}") |
|
raise e |
|
|
|
def generate_summary_v0(self, chunks: any): |
|
logging.info("Generating summary ...") |
|
|
|
|
|
limited_chunks = chunks[:30] |
|
|
|
|
|
full_text = "\n".join(chunk['text'] for chunk in limited_chunks) |
|
text_length = len(full_text) |
|
logging.info(f"Total text length (characters): {text_length}") |
|
|
|
|
|
|
|
MAX_CHAR_LIMIT = 3200 |
|
if text_length > MAX_CHAR_LIMIT: |
|
logging.warning(f"Input text too long ({text_length} chars), truncating to {MAX_CHAR_LIMIT} chars.") |
|
full_text = full_text[:MAX_CHAR_LIMIT] |
|
|
|
|
|
custom_prompt_template = """ |
|
You are an expert summarizer. Summarize the following text into a concise summary using bullet points. |
|
Ensure that the final summary is no longer than 20-30 bullet points and fits within 15-20 lines. |
|
Focus only on the most critical points. |
|
|
|
Text to summarize: |
|
{text} |
|
|
|
Summary: |
|
""" |
|
prompt = PromptTemplate(input_variables=["text"], template=custom_prompt_template) |
|
|
|
|
|
chain = load_summarize_chain(self.generation_llm, chain_type="stuff", prompt=prompt) |
|
|
|
|
|
docs = [Document(page_content=full_text)] |
|
|
|
|
|
summary = chain.invoke(docs) |
|
return summary['output_text'] |
|
|
|
def generate_questions(self, chunks: any): |
|
logging.info("Generating sample questions ...") |
|
|
|
|
|
limited_chunks = chunks[:30] |
|
|
|
|
|
full_text = "\n".join(chunk['text'] for chunk in limited_chunks) |
|
text_length = len(full_text) |
|
logging.info(f"Total text length for questions: {text_length}") |
|
|
|
MAX_CHAR_LIMIT = 3200 |
|
if text_length > MAX_CHAR_LIMIT: |
|
logging.warning(f"Input text too long ({text_length} chars), truncating to {MAX_CHAR_LIMIT} chars.") |
|
full_text = full_text[:MAX_CHAR_LIMIT] |
|
|
|
|
|
question_prompt_template = """ |
|
You are an AI expert at creating questions from documents. |
|
|
|
Based on the text below, generate not less than 20 insightful and highly relevant sample questions that a user might ask to better understand the content. |
|
|
|
**Instructions:** |
|
- Questions must be specific to the document's content and context. |
|
- Avoid generic questions like 'What is this document about?' |
|
- Do not include numbers, prefixes (e.g., '1.', '2.'), or explanations (e.g., '(Clarifies...)'). |
|
- Each question should be a single, clear sentence ending with a question mark. |
|
- Focus on key concepts, processes, components, or use cases mentioned in the text. |
|
|
|
Text: |
|
{text} |
|
|
|
Output format: |
|
What is the purpose of the Communication Server in Collateral Management? |
|
How does the system handle data encryption for secure communication? |
|
... |
|
""" |
|
prompt = PromptTemplate(input_variables=["text"], template=question_prompt_template) |
|
|
|
chain = load_summarize_chain(self.generation_llm, chain_type="stuff", prompt=prompt) |
|
docs = [Document(page_content=full_text)] |
|
|
|
try: |
|
result = chain.invoke(docs) |
|
question_output = result.get("output_text", "").strip() |
|
|
|
|
|
questions = [] |
|
for line in question_output.split("\n"): |
|
|
|
cleaned_line = line.strip().strip("-*1234567890. ").rstrip(".") |
|
|
|
cleaned_line = cleaned_line.split("(")[0].strip() |
|
|
|
if cleaned_line and cleaned_line.endswith("?"): |
|
questions.append(cleaned_line) |
|
|
|
|
|
questions = questions[:10] |
|
logging.info(f"Generated questions: {questions}") |
|
return questions |
|
except Exception as e: |
|
logging.error(f"Error generating questions: {e}") |
|
return [] |
|
|
|
def generate_summary(self, chunks: Any, toc_text: Any, summary_type: str = "medium") -> str: |
|
""" |
|
Generate a summary of the document using LangChain's summarization chains. |
|
|
|
Args: |
|
vector_store_manager: Instance of VectorStoreManager with a FAISS vector store. |
|
summary_type (str): Type of summary ("small", "medium", "detailed"). |
|
k (int): Number of chunks to retrieve from the vector store. |
|
include_toc (bool): Whether to include the table of contents (if available). |
|
|
|
Returns: |
|
str: Generated summary. |
|
|
|
Raises: |
|
ValueError: If summary_type is invalid or vector store is not initialized. |
|
""" |
|
|
|
|
|
if summary_type == "small": |
|
k = min(k, 3) |
|
chain_type = "stuff" |
|
word_count = "50-100" |
|
elif summary_type == "medium": |
|
k = min(k, 10) |
|
chain_type = "map_reduce" |
|
word_count = "200-400" |
|
else: |
|
k = min(k, 20) |
|
chain_type = "map_reduce" |
|
word_count = "500-1000" |
|
|
|
|
|
if chain_type == "stuff": |
|
prompt = PromptTemplate( |
|
input_variables=["text"], |
|
template=( |
|
"Generate a {summary_type} summary ({word_count} words) of the following document excerpts. " |
|
"Focus on key points and ensure clarity. Stick strictly to the provided text:\n\n" |
|
"{toc_prompt}{text}" |
|
).format( |
|
summary_type=summary_type, |
|
word_count=word_count, |
|
toc_prompt="Table of Contents:\n{toc_text}\n\n" if toc_text else "" |
|
) |
|
) |
|
chain = load_summarize_chain( |
|
llm=self.generation_llm, |
|
chain_type="stuff", |
|
prompt=prompt |
|
) |
|
else: |
|
map_prompt = PromptTemplate( |
|
input_variables=["text"], |
|
template=( |
|
"Summarize the following document excerpt in 1-2 sentences, focusing on key points. " |
|
"Consider the document's structure from this table of contents:\n\n" |
|
"Table of Contents:\n{toc_text}\n\nExcerpt:\n{text}" |
|
).format(toc_text=toc_text if toc_text else "Not provided") |
|
) |
|
combine_prompt = PromptTemplate( |
|
input_variables=["text"], |
|
template=( |
|
"Combine the following summaries into a cohesive {summary_type} summary " |
|
"({word_count} words) of the document. Ensure clarity, avoid redundancy, and " |
|
"organize by key themes or sections if applicable:\n\n{text}" |
|
).format(summary_type=summary_type, word_count=word_count) |
|
) |
|
chain = load_summarize_chain( |
|
llm=self.generation_llm, |
|
chain_type="map_reduce", |
|
map_prompt=map_prompt, |
|
combine_prompt=combine_prompt, |
|
return_intermediate_steps=False |
|
) |
|
|
|
|
|
try: |
|
logging.info(f"Generating {summary_type} summary with {len(chunks)} chunks") |
|
summary = chain.run(chunks) |
|
logging.info(f"{summary_type.capitalize()} summary generated successfully") |
|
return summary |
|
except Exception as e: |
|
logging.error(f"Error generating summary: {str(e)}") |
|
return f"Error generating summary: {str(e)}" |