Spaces:
Running
Running
# src/agents/system_instructions_rag.py | |
from typing import List, Dict, Optional, Tuple | |
import spacy | |
from src.agents.rag_agent import RAGAgent | |
from src.llms.base_llm import BaseLLM | |
from src.embeddings.base_embedding import BaseEmbedding | |
from src.vectorstores.base_vectorstore import BaseVectorStore | |
from src.db.mongodb_store import MongoDBStore | |
from src.models.rag import RAGResponse | |
from src.utils.logger import logger | |
class SystemInstructionsRAGAgent(RAGAgent): | |
def __init__( | |
self, | |
llm: BaseLLM, | |
embedding: BaseEmbedding, | |
vector_store: BaseVectorStore, | |
mongodb: MongoDBStore, | |
max_history_tokens: int = 4000, | |
max_history_messages: int = 10 | |
): | |
"""Initialize SystemInstructionsRAGAgent with enhanced context management""" | |
super().__init__( | |
llm=llm, | |
embedding=embedding, | |
vector_store=vector_store, | |
mongodb=mongodb, | |
max_history_tokens=max_history_tokens, | |
max_history_messages=max_history_messages | |
) | |
self.nlp = spacy.load("en_core_web_sm") | |
async def generate_response( | |
self, | |
query: str, | |
conversation_id: Optional[str] = None, | |
temperature: float = 0.7, | |
max_tokens: Optional[int] = None, | |
context_docs: Optional[List[str]] = None, | |
stream: bool = False | |
) -> RAGResponse: | |
"""Generate response with guaranteed context handling""" | |
try: | |
logger.info(f"Processing query: {query}") | |
# Store original context if provided | |
original_context = context_docs | |
# Handle introduction queries | |
if self._is_introduction_query(query): | |
welcome_message = self._handle_contact_query(query) | |
return RAGResponse( | |
response=welcome_message, | |
context_docs=[], | |
sources=[], | |
scores=None | |
) | |
# Get and process conversation history | |
history = [] | |
if conversation_id: | |
history = await self.mongodb.get_recent_messages( | |
conversation_id, | |
limit=self.conversation_manager.max_messages | |
) | |
# Process history in context manager | |
for msg in history: | |
if msg.get('query') and msg.get('response'): | |
self.context_manager.process_turn(msg['query'], msg['response']) | |
# Initialize context tracking | |
current_context = None | |
sources = [] | |
scores = None | |
# Multi-stage context retrieval | |
if original_context: | |
current_context = original_context | |
else: | |
# Try with original query first | |
current_context, sources, scores = await self.retrieve_context( | |
query, | |
conversation_history=history | |
) | |
# If no context, try with enhanced query | |
if not current_context: | |
enhanced_query = self.context_manager.enhance_query(query) | |
if enhanced_query != query: | |
current_context, sources, scores = await self.retrieve_context( | |
enhanced_query, | |
conversation_history=history | |
) | |
# If still no context, try history fallback | |
if not current_context: | |
current_context, sources = self._get_context_from_history(history) | |
logger.info(f"Retrieved {len(current_context) if current_context else 0} context documents") | |
# Check context relevance | |
has_relevant_context = self._check_context_relevance(query, current_context or []) | |
logger.info(f"Context relevance check result: {has_relevant_context}") | |
# Handle no context case | |
if not has_relevant_context: | |
return self._create_no_info_response() | |
# Generate response | |
prompt = self._create_response_prompt(query, current_context) | |
response_text = self.llm.generate( | |
prompt=prompt, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
# Process and validate response | |
cleaned_response = self._clean_response(response_text) | |
if self._is_no_info_response(cleaned_response): | |
return self._create_no_info_response() | |
# Update context tracking | |
self.context_manager.process_turn(query, cleaned_response) | |
# For Excel content, enhance the response | |
if any('Sheet:' in doc for doc in (current_context or [])): | |
try: | |
cleaned_response = await self.enhance_excel_response( | |
query=query, | |
response=cleaned_response, | |
context_docs=current_context | |
) | |
except Exception as e: | |
logger.warning(f"Error enhancing Excel response: {str(e)}") | |
return RAGResponse( | |
response=cleaned_response, | |
context_docs=current_context, | |
sources=sources, | |
scores=scores | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_response: {str(e)}") | |
raise | |
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict: | |
"""Convert all metadata values to strings""" | |
return { | |
key: str(value) if value is not None else None | |
for key, value in metadata.items() | |
} | |
async def retrieve_context( | |
self, | |
query: str, | |
conversation_history: Optional[List[Dict]] = None | |
) -> Tuple[List[str], List[Dict], Optional[List[float]]]: | |
"""Enhanced context retrieval with proper metadata type handling""" | |
try: | |
logger.info(f"Processing query for context retrieval: {query}") | |
collection_data = self.vector_store.collection.get() | |
if not collection_data or 'documents' not in collection_data: | |
logger.warning("No documents found in ChromaDB") | |
return [], [], None | |
documents = collection_data['documents'] | |
metadatas = collection_data.get('metadatas', []) | |
# Clean and enhance query with date variations | |
clean_query = query.lower().strip() | |
# Extract and enhance date information | |
import re | |
from datetime import datetime | |
date_pattern = r'(?:jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)[a-z]* \d{1,2},? \d{4}' | |
dates = re.findall(date_pattern, clean_query.lower()) | |
enhanced_query = clean_query | |
target_date = None | |
if dates: | |
try: | |
date_obj = datetime.strptime(dates[0], '%b %d, %Y') | |
target_date = date_obj.strftime('%b %d, %Y') | |
date_variations = [ | |
date_obj.strftime('%B %d, %Y'), | |
date_obj.strftime('%d/%m/%Y'), | |
date_obj.strftime('%Y-%m-%d'), | |
target_date | |
] | |
enhanced_query = f"{clean_query} {' '.join(date_variations)}" | |
except ValueError as e: | |
logger.warning(f"Error parsing date: {str(e)}") | |
# First try exact date matching | |
exact_matches = [] | |
exact_metadata = [] | |
if target_date: | |
for i, doc in enumerate(documents): | |
if target_date in doc: | |
logger.info(f"Found exact date match in document {i}") | |
exact_matches.append(doc) | |
if metadatas: | |
# Convert metadata values to strings | |
exact_metadata.append(self._convert_metadata_to_strings(metadatas[i])) | |
if exact_matches: | |
logger.info(f"Found {len(exact_matches)} exact date matches") | |
document_id = exact_metadata[0].get('document_id') if exact_metadata else None | |
if document_id: | |
all_related_chunks = [] | |
all_related_metadata = [] | |
all_related_scores = [] | |
for i, doc in enumerate(documents): | |
if metadatas[i].get('document_id') == document_id: | |
all_related_chunks.append(doc) | |
# Convert metadata values to strings | |
all_related_metadata.append(self._convert_metadata_to_strings(metadatas[i])) | |
all_related_scores.append(1.0) | |
# Sort chunks by their index | |
sorted_results = sorted( | |
zip(all_related_chunks, all_related_metadata, all_related_scores), | |
key=lambda x: int(x[1].get('chunk_index', '0')) # Convert to int for sorting | |
) | |
sorted_chunks, sorted_metadata, sorted_scores = zip(*sorted_results) | |
logger.info(f"Returning {len(sorted_chunks)} chunks from document {document_id}") | |
return list(sorted_chunks), list(sorted_metadata), list(sorted_scores) | |
# If no exact matches, use enhanced query for embedding search | |
logger.info("No exact matches found, using enhanced query for embedding search") | |
query_embedding = self.embedding.embed_query(enhanced_query) | |
results = self.vector_store.similarity_search( | |
query_embedding, | |
top_k=5 | |
) | |
if not results: | |
logger.warning("No results found in similarity search") | |
return [], [], None | |
context_docs = [] | |
sources = [] | |
scores = [] | |
sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True) | |
for result in sorted_results: | |
score = result.get('score', 0) | |
if score > 0.3: | |
context_docs.append(result.get('text', '')) | |
# Convert metadata values to strings | |
sources.append(self._convert_metadata_to_strings(result.get('metadata', {}))) | |
scores.append(score) | |
if context_docs: | |
logger.info(f"Returning {len(context_docs)} documents from similarity search") | |
return context_docs, sources, scores | |
logger.warning("No relevant documents found") | |
return [], [], None | |
except Exception as e: | |
logger.error(f"Error in retrieve_context: {str(e)}") | |
logger.exception("Full traceback:") | |
return [], [], None | |
def _is_introduction_query(self, query: str) -> bool: | |
"""Check if query is an introduction message""" | |
return ( | |
"wants support" in query and | |
"This is Introduction" in query and | |
("A new user with name:" in query or "An old user with name:" in query) | |
) | |
def _get_context_from_history( | |
self, | |
history: List[Dict] | |
) -> Tuple[Optional[List[str]], Optional[List[Dict]]]: | |
"""Extract context from conversation history""" | |
for msg in reversed(history): | |
if msg.get('context') and not self._is_no_info_response(msg.get('response', '')): | |
return msg['context'], msg.get('sources', []) | |
return None, None | |
def _create_response_prompt(self, query: str, context_docs: List[str]) -> str: | |
"""Create prompt for response generation""" | |
formatted_context = '\n\n'.join( | |
f"Context {i+1}:\n{doc.strip()}" | |
for i, doc in enumerate(context_docs) | |
if doc and doc.strip() | |
) | |
return f""" | |
Use ONLY the following context to provide information about: {query} | |
{formatted_context} | |
Instructions: | |
1. Use ONLY information present in the context above | |
2. If the information is found in the context, provide a direct and concise response | |
3. Do not make assumptions or add information not present in the context | |
4. Ensure the response is clear and complete based on available information | |
5. If you cannot find relevant information about the specific query in the context, | |
respond exactly with: "Information about this is not available, do you want to inquire about something else?" | |
Query: {query} | |
Response:""" | |
def _create_no_info_response(self) -> RAGResponse: | |
"""Create standard response for no information case""" | |
return RAGResponse( | |
response="Information about this is not available, do you want to inquire about something else?", | |
context_docs=[], | |
sources=[], | |
scores=None | |
) | |
def _clean_response(self, response: str) -> str: | |
"""Clean response text""" | |
if not response: | |
return response | |
phrases_to_remove = [ | |
"Based on the context provided,", | |
"According to the documents,", | |
"From the information available,", | |
"I can tell you that", | |
"Let me help you with that", | |
"I understand you're asking about", | |
"To answer your question,", | |
"The documents indicate that", | |
"Based on the available information,", | |
"As per the provided context,", | |
"I would be happy to help you with that", | |
"Let me provide you with information about", | |
"Here's what I found:", | |
"Here's the information you requested:", | |
"According to the provided information,", | |
"The information suggests that", | |
"From what I can see,", | |
"Let me explain", | |
"To clarify,", | |
"It appears that", | |
"I can see that", | |
"Sure,", | |
"Well,", | |
"I apologize," | |
] | |
cleaned_response = response | |
for phrase in phrases_to_remove: | |
cleaned_response = cleaned_response.replace(phrase, "").strip() | |
cleaned_response = " ".join(cleaned_response.split()) | |
if not cleaned_response: | |
return response | |
if cleaned_response[0].islower(): | |
cleaned_response = cleaned_response[0].upper() + cleaned_response[1:] | |
return cleaned_response | |
def _is_no_info_response(self, response: str) -> bool: | |
"""Check if response indicates no information available""" | |
no_info_indicators = [ | |
"i do not have", | |
"i don't have", | |
"no information", | |
"not available", | |
"could not find", | |
"couldn't find", | |
"cannot find", | |
"don't know", | |
"do not know", | |
"unable to find", | |
"no data", | |
"no relevant" | |
] | |
response_lower = response.lower() | |
return any(indicator in response_lower for indicator in no_info_indicators) | |
def _check_context_relevance(self, query: str, context_docs: List[str]) -> bool: | |
"""Enhanced context relevance checking""" | |
if not context_docs: | |
return False | |
# Clean and prepare query | |
clean_query = query.lower().strip() | |
query_terms = set(word for word in clean_query.split() | |
if word not in {'tell', 'me', 'about', 'what', 'is', 'the'}) | |
for doc in context_docs: | |
if not doc: | |
continue | |
doc_lower = doc.lower() | |
# For CSV-like content, check each line | |
lines = doc_lower.split('\n') | |
for line in lines: | |
# Check if any query term appears in the line | |
if any(term in line for term in query_terms): | |
return True | |
# Also check the whole document for good measure | |
if any(term in doc_lower for term in query_terms): | |
return True | |
return False | |
def _handle_contact_query(self, query: str) -> str: | |
"""Handle contact/introduction queries""" | |
try: | |
name_start = query.find('name: "') + 7 | |
name_end = query.find('"', name_start) | |
name = query[name_start:name_end] if name_start > 6 and name_end != -1 else "there" | |
is_returning = ( | |
"An old user with name:" in query and | |
"wants support again" in query | |
) | |
return f"Welcome back {name}, How can I help you?" if is_returning else f"Welcome {name}, How can I help you?" | |
except Exception as e: | |
logger.error(f"Error handling contact query: {str(e)}") | |
return "Welcome, How can I help you?" |