Spaces:
Running
Running
from typing import List, Dict, Optional, Tuple | |
import uuid | |
from .excel_aware_rag import ExcelAwareRAGAgent | |
from .enhanced_context_manager import EnhancedContextManager | |
from ..llms.base_llm import BaseLLM | |
from src.embeddings.base_embedding import BaseEmbedding | |
from src.vectorstores.base_vectorstore import BaseVectorStore | |
from src.utils.conversation_manager import ConversationManager | |
from src.db.mongodb_store import MongoDBStore | |
from src.models.rag import RAGResponse | |
from src.utils.logger import logger | |
class RAGAgent(ExcelAwareRAGAgent): | |
def __init__( | |
self, | |
llm: BaseLLM, | |
embedding: BaseEmbedding, | |
vector_store: BaseVectorStore, | |
mongodb: MongoDBStore, | |
max_history_tokens: int = 4000, | |
max_history_messages: int = 10 | |
): | |
"""Initialize RAG Agent with enhanced context management""" | |
super().__init__() # Initialize ExcelAwareRAGAgent | |
self.llm = llm | |
self.embedding = embedding | |
self.vector_store = vector_store | |
self.mongodb = mongodb | |
self.conversation_manager = ConversationManager( | |
max_tokens=max_history_tokens, | |
max_messages=max_history_messages | |
) | |
# Add enhanced context management while preserving existing functionality | |
self.context_manager = EnhancedContextManager() | |
logger.info("RAGAgent initialized with enhanced context management") | |
async def generate_response( | |
self, | |
query: str, | |
conversation_id: Optional[str], | |
temperature: float, | |
max_tokens: Optional[int] = None, | |
context_docs: Optional[List[str]] = None, | |
stream: bool = False, | |
custom_roles: Optional[List[Dict[str, str]]] = None | |
) -> RAGResponse: | |
""" | |
Generate a response with comprehensive context and role management | |
Args: | |
query (str): User query | |
conversation_id (Optional[str]): Conversation identifier | |
temperature (float): LLM temperature for response generation | |
max_tokens (Optional[int]): Maximum tokens for response | |
context_docs (Optional[List[str]]): Pre-retrieved context documents | |
stream (bool): Whether to stream the response | |
custom_roles (Optional[List[Dict[str, str]]]): Custom role instructions | |
Returns: | |
RAGResponse: Generated response with context and metadata | |
""" | |
try: | |
logger.info(f"Generating response for query: {query}") | |
# Apply custom roles if provided | |
if custom_roles: | |
for role in custom_roles: | |
# Modify query or context based on role | |
if role.get('name') == 'introduction_specialist': | |
query += " Provide a concise, welcoming response." | |
elif role.get('name') == 'knowledge_based_specialist': | |
query += " Ensure response is precise and directly from available knowledge." | |
# Introduction Handling | |
is_introduction = ( | |
"wants support" in query and | |
"This is Introduction" in query and | |
("A new user with name:" in query or "An old user with name:" in query) | |
) | |
if is_introduction: | |
logger.info("Processing introduction message") | |
welcome_message = self._handle_contact_query(query) | |
return RAGResponse( | |
response=welcome_message, | |
context_docs=[], | |
sources=[], | |
scores=None | |
) | |
# Conversation History Processing | |
history = [] | |
last_context = None | |
if conversation_id: | |
logger.info(f"Retrieving conversation history for ID: {conversation_id}") | |
history = await self.mongodb.get_recent_messages( | |
conversation_id, | |
limit=self.conversation_manager.max_messages | |
) | |
# Process history for conversation manager | |
history = self.conversation_manager.get_relevant_history( | |
messages=history, | |
current_query=query | |
) | |
# Process in enhanced context manager | |
for msg in history: | |
self.context_manager.process_turn( | |
msg.get('query', ''), | |
msg.get('response', '') | |
) | |
# Get last context if available | |
if history and history[-1].get('response'): | |
last_context = history[-1]['response'] | |
# Query Enhancement | |
enhanced_query = self.context_manager.enhance_query(query) | |
# Manual Pronoun Handling Fallback | |
if enhanced_query == query: | |
pronoun_map = { | |
'his': 'he', | |
'her': 'she', | |
'their': 'they' | |
} | |
words = query.lower().split() | |
for pronoun, replacement in pronoun_map.items(): | |
if pronoun in words: | |
# Try to use last context | |
if last_context: | |
self.context_manager.record_last_context(last_context) | |
enhanced_query = self.context_manager.enhance_query(query) | |
break | |
logger.info(f"Enhanced query: {enhanced_query}") | |
# Context Retrieval | |
if not context_docs: | |
logger.info("Retrieving context for enhanced query") | |
context_docs, sources, scores = await self.retrieve_context( | |
enhanced_query, | |
conversation_history=history | |
) | |
else: | |
sources = [] | |
scores = None | |
# Context Fallback Mechanism | |
if not context_docs: | |
# If no context and last context exists, use it | |
if last_context: | |
context_docs = [last_context] | |
sources = [{"source": "previous_context"}] | |
scores = [1.0] | |
else: | |
logger.info("No relevant context found") | |
return RAGResponse( | |
response="Information about this is not available, do you want to inquire about something else?", | |
context_docs=[], | |
sources=[], | |
scores=None | |
) | |
# Excel-specific Content Handling | |
has_excel_content = any('Sheet:' in doc for doc in context_docs) | |
if has_excel_content: | |
logger.info("Processing Excel-specific content") | |
try: | |
context_docs = self._process_excel_context(context_docs, enhanced_query) | |
except Exception as e: | |
logger.warning(f"Error processing Excel context: {str(e)}") | |
# Prompt Generation with Conversation History | |
prompt = self.conversation_manager.generate_prompt_with_history( | |
current_query=enhanced_query, | |
history=history, | |
context_docs=context_docs | |
) | |
# Streaming Response Generation | |
if stream: | |
# TODO: Implement actual streaming logic | |
# This is a placeholder and needs proper implementation | |
logger.warning("Streaming not fully implemented") | |
# Standard Response Generation | |
response = self.llm.generate( | |
prompt=prompt, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
# Response Cleaning | |
cleaned_response = self._clean_response(response) | |
# Excel Response Enhancement | |
if has_excel_content: | |
try: | |
enhanced_response = await self.enhance_excel_response( | |
query=enhanced_query, | |
response=cleaned_response, | |
context_docs=context_docs | |
) | |
if enhanced_response: | |
cleaned_response = enhanced_response | |
except Exception as e: | |
logger.warning(f"Error enhancing Excel response: {str(e)}") | |
# Context Tracking | |
self.context_manager.process_turn(query, cleaned_response) | |
# Metadata Generation | |
metadata = { | |
'llm_provider': getattr(self.llm, 'model_name', 'unknown'), | |
'temperature': temperature, | |
'conversation_id': conversation_id, | |
'context_sources': sources, | |
'has_excel_content': has_excel_content | |
} | |
logger.info("Successfully generated response") | |
return RAGResponse( | |
response=cleaned_response, | |
context_docs=context_docs, | |
sources=sources, | |
scores=scores, | |
metadata=metadata # Added metadata | |
) | |
except Exception as e: | |
logger.error(f"Error in generate_response: {str(e)}") | |
raise | |
async def retrieve_context( | |
self, | |
query: str, | |
conversation_history: Optional[List[Dict]] = None, | |
top_k: int = 3 | |
) -> Tuple[List[str], List[Dict], Optional[List[float]]]: | |
"""Retrieve context with both original and enhanced handling""" | |
try: | |
logger.info(f"Retrieving context for query: {query}") | |
# Enhance query using both managers | |
if conversation_history: | |
# Get the last two messages for immediate context | |
recent_messages = conversation_history[-2:] | |
# Extract queries and responses for context | |
context_parts = [] | |
for msg in recent_messages: | |
if msg.get('query'): | |
context_parts.append(msg['query']) | |
if msg.get('response'): | |
response = msg['response'] | |
if "Information about this is not available" not in response: | |
context_parts.append(response) | |
# Combine with current query | |
enhanced_query = f"{' '.join(context_parts)} {query}".strip() | |
logger.info(f"Enhanced query with history: {enhanced_query}") | |
else: | |
enhanced_query = query | |
# Debug log the enhanced query | |
logger.info(f"Final enhanced query: {enhanced_query}") | |
# Embed the enhanced query | |
query_embedding = self.embedding.embed_query(enhanced_query) | |
# Debug log embedding shape | |
logger.info(f"Query embedding shape: {len(query_embedding)}") | |
# Retrieve similar documents | |
results = self.vector_store.similarity_search( | |
query_embedding, | |
top_k=top_k | |
) | |
# Debug log search results | |
logger.info(f"Number of search results: {len(results)}") | |
for i, result in enumerate(results): | |
logger.info(f"Result {i} score: {result.get('score', 'N/A')}") | |
logger.info(f"Result {i} text preview: {result.get('text', '')[:100]}...") | |
if not results: | |
logger.info("No results found in similarity search") | |
return [], [], None | |
# Process results | |
documents = [doc['text'] for doc in results] | |
sources = [self._convert_metadata_to_strings(doc['metadata']) | |
for doc in results] | |
scores = [doc['score'] for doc in results | |
if doc.get('score') is not None] | |
# Return scores only if available for all documents | |
if len(scores) != len(documents): | |
scores = None | |
logger.info(f"Retrieved {len(documents)} relevant documents") | |
return documents, sources, scores | |
except Exception as e: | |
logger.error(f"Error in retrieve_context: {str(e)}") | |
raise | |
def _clean_response(self, response: str) -> str: | |
"""Clean response text while preserving key information""" | |
if not response: | |
return response | |
# Keep only the most common phrases to remove | |
phrases_to_remove = [ | |
"Based on the context,", | |
"According to the documents,", | |
"From the information available,", | |
"Based on the provided information,", | |
"I apologize," | |
] | |
cleaned_response = response | |
for phrase in phrases_to_remove: | |
cleaned_response = cleaned_response.replace(phrase, "").strip() | |
cleaned_response = " ".join(cleaned_response.split()) | |
if not cleaned_response: | |
return response | |
if cleaned_response[0].islower(): | |
cleaned_response = cleaned_response[0].upper() + cleaned_response[1:] | |
return cleaned_response | |
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict: | |
"""Convert metadata values to strings""" | |
try: | |
return { | |
key: str(value) if isinstance(value, (int, float)) else value | |
for key, value in metadata.items() | |
} | |
except Exception as e: | |
logger.error(f"Error converting metadata: {str(e)}") | |
return metadata | |
def _handle_contact_query(self, query: str) -> str: | |
"""Handle contact/introduction queries""" | |
try: | |
name_start = query.find('name: "') + 7 | |
name_end = query.find('"', name_start) | |
name = query[name_start:name_end] if name_start > 6 and name_end != -1 else "there" | |
is_returning = ( | |
"An old user with name:" in query and | |
"wants support again" in query | |
) | |
return f"Welcome back {name}, How can I help you?" if is_returning else f"Welcome {name}, How can I help you?" | |
except Exception as e: | |
logger.error(f"Error handling contact query: {str(e)}") | |
return "Welcome, How can I help you?" |