# src/agents/excel_aware_rag.py
from typing import List, Dict, Optional, Set
from src.utils.logger import logger

class ExcelAwareRAGAgent:
    """Extension of RAGAgent with enhanced Excel handling"""
    
    def _process_excel_context(self, context_docs: List[str], query: str) -> List[str]:
        """
        Process and enhance context for Excel-related queries
        
        Args:
            context_docs (List[str]): Original context documents
            query (str): User query
            
        Returns:
            List[str]: Enhanced context documents
        """
        excel_context = []
        
        for doc in context_docs:
            if 'Sheet:' in doc:  # Identify Excel content
                # Extract relevant sheet context based on query
                relevant_sheets = self._identify_relevant_sheets(doc, query)
                for sheet in relevant_sheets:
                    sheet_context = self._extract_sheet_context(doc, sheet)
                    if sheet_context:
                        excel_context.append(sheet_context)
                
                # Add relationship context if query suggests multi-sheet analysis
                if self._needs_relationship_context(query):
                    relationship_context = self._extract_relationship_context(doc)
                    if relationship_context:
                        excel_context.append(relationship_context)
            else:
                excel_context.append(doc)
        
        return excel_context

    def _identify_relevant_sheets(self, doc: str, query: str) -> List[str]:
        """Identify sheets relevant to the query"""
        sheets = []
        current_sheet = None
        
        # Extract sheet names from the document
        for line in doc.split('\n'):
            if line.startswith('Sheet: '):
                current_sheet = line.replace('Sheet: ', '').strip()
                # Check if sheet name or its contents are relevant to query
                if self._is_relevant_to_query(current_sheet, query):
                    sheets.append(current_sheet)
                    
        return sheets

    def _is_relevant_to_query(self, sheet_name: str, query: str) -> bool:
        """Check if a sheet is relevant to the query"""
        # Convert to lower case for comparison
        query_lower = query.lower()
        sheet_lower = sheet_name.lower()
        
        # Direct mention of sheet name
        if sheet_lower in query_lower:
            return True
            
        # Check for related terms
        sheet_terms = set(sheet_lower.split())
        query_terms = set(query_lower.split())
        
        # If there's significant term overlap
        common_terms = sheet_terms.intersection(query_terms)
        if len(common_terms) > 0:
            return True
            
        return False

    def _extract_sheet_context(self, doc: str, sheet_name: str) -> Optional[str]:
        """Extract context for a specific sheet"""
        lines = doc.split('\n')
        sheet_context = []
        in_target_sheet = False
        
        for line in lines:
            if line.startswith(f'Sheet: {sheet_name}'):
                in_target_sheet = True
                sheet_context.append(line)
            elif line.startswith('Sheet: '):
                in_target_sheet = False
            elif in_target_sheet:
                sheet_context.append(line)
        
        return '\n'.join(sheet_context) if sheet_context else None

    def _needs_relationship_context(self, query: str) -> bool:
        """Determine if query needs relationship context"""
        relationship_indicators = [
            'compare', 'relationship', 'between', 'across', 'correlation',
            'related', 'connection', 'link', 'join', 'combine', 'multiple sheets',
            'all sheets', 'different sheets'
        ]
        
        query_lower = query.lower()
        return any(indicator in query_lower for indicator in relationship_indicators)

    def _extract_relationship_context(self, doc: str) -> Optional[str]:
        """Extract relationship context from document"""
        lines = doc.split('\n')
        relationship_context = []
        in_relationships = False
        
        for line in lines:
            if 'Sheet Relationships:' in line:
                in_relationships = True
                relationship_context.append(line)
            elif in_relationships and line.strip() and not line.startswith('Sheet: '):
                relationship_context.append(line)
            elif in_relationships and line.startswith('Sheet: '):
                break
                
        return '\n'.join(relationship_context) if relationship_context else None

    async def enhance_excel_response(
        self,
        query: str,
        response: str,
        context_docs: List[str]
    ) -> str:
        """
        Enhance response for Excel-related queries
        
        Args:
            query (str): Original query
            response (str): Generated response
            context_docs (List[str]): Context documents
            
        Returns:
            str: Enhanced response
        """
        if not any('Sheet:' in doc for doc in context_docs):
            return response
            
        try:
            # Enhance response with specific Excel insights
            enhanced_parts = [response]
            
            # Add sheet-specific insights if relevant
            if self._needs_sheet_specific_insights(query):
                insights = self._generate_sheet_insights(query, context_docs)
                if insights:
                    enhanced_parts.append("\nAdditional Sheet Insights:")
                    enhanced_parts.extend(insights)
            
            # Add relationship insights if relevant
            if self._needs_relationship_context(query):
                relationship_insights = self._generate_relationship_insights(context_docs)
                if relationship_insights:
                    enhanced_parts.append("\nSheet Relationship Insights:")
                    enhanced_parts.extend(relationship_insights)
            
            return "\n".join(enhanced_parts)
        except Exception as e:
            logger.error(f"Error enhancing Excel response: {str(e)}")
            return response  # Fall back to original response if enhancement fails

    def _needs_sheet_specific_insights(self, query: str) -> bool:
        """Determine if query needs sheet-specific insights"""
        insight_indicators = [
            'analyze', 'summarize', 'tell me about', 'what is in',
            'show me', 'describe', 'explain', 'give me details'
        ]
        
        query_lower = query.lower()
        return any(indicator in query_lower for indicator in insight_indicators)

    def _generate_sheet_insights(self, query: str, context_docs: List[str]) -> List[str]:
        """Generate insights for relevant sheets"""
        insights = []
        relevant_sheets = set()
        
        # Collect relevant sheets from context
        for doc in context_docs:
            if 'Sheet:' in doc:
                sheets = self._identify_relevant_sheets(doc, query)
                relevant_sheets.update(sheets)
        
        # Generate insights for each relevant sheet
        for sheet in relevant_sheets:
            sheet_insights = self._generate_single_sheet_insights(sheet, context_docs)
            if sheet_insights:
                insights.extend(sheet_insights)
        
        return insights

    def _generate_single_sheet_insights(
        self,
        sheet_name: str,
        context_docs: List[str]
    ) -> List[str]:
        """Generate insights for a single sheet"""
        insights = []
        sheet_context = None
        
        # Find context for this sheet
        for doc in context_docs:
            if f'Sheet: {sheet_name}' in doc:
                sheet_context = self._extract_sheet_context(doc, sheet_name)
                break
        
        if not sheet_context:
            return insights
        
        # Extract and summarize key information
        if 'Numeric Columns Summary:' in sheet_context:
            numeric_insights = self._extract_numeric_insights(sheet_context)
            if numeric_insights:
                insights.extend(numeric_insights)
                
        if 'Categorical Columns Summary:' in sheet_context:
            categorical_insights = self._extract_categorical_insights(sheet_context)
            if categorical_insights:
                insights.extend(categorical_insights)
        
        return insights

    def _generate_relationship_insights(self, context_docs: List[str]) -> List[str]:
        """Generate insights about relationships between sheets"""
        insights = []
        
        for doc in context_docs:
            relationship_context = self._extract_relationship_context(doc)
            if relationship_context:
                # Process and format relationship information
                relationships = relationship_context.split('\n')[1:]  # Skip header
                for rel in relationships:
                    if rel.strip():
                        insights.append(f"- {rel.strip()}")
        
        return insights