# src/utils/document_processor.py
from typing import List, Dict, Optional, Union
import PyPDF2
import docx
import pandas as pd
import json
from pathlib import Path
import hashlib
import mimetypes  # Add this instead
from bs4 import BeautifulSoup
import csv
from datetime import datetime
import threading
from queue import Queue
import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging
from bs4.element import ProcessingInstruction

from config.config import Settings
from .enhanced_excel_processor import EnhancedExcelProcessor


class DocumentProcessor:
    def __init__(
        self,
        chunk_size: Optional[int] = None,
        chunk_overlap: Optional[int] = None,
        max_file_size: Optional[int] = None,
        supported_formats: Optional[List[str]] = None
    ):
        """
        Initialize DocumentProcessor with configurable parameters

        Args:
            chunk_size (Optional[int]): Size of text chunks
            chunk_overlap (Optional[int]): Overlap between chunks
            max_file_size (Optional[int]): Maximum file size in bytes
            supported_formats (Optional[List[str]]): List of supported file extensions
        """

        logging.basicConfig(
            level=logging.DEBUG,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )

        # Get settings with validation
        default_settings = Settings.get_document_processor_settings()

        # Use provided values or defaults from settings
        self.chunk_size = chunk_size if chunk_size is not None else default_settings[
            'chunk_size']
        self.chunk_overlap = chunk_overlap if chunk_overlap is not None else default_settings[
            'chunk_overlap']
        self.max_file_size = max_file_size if max_file_size is not None else default_settings[
            'max_file_size']
        self.supported_formats = supported_formats if supported_formats is not None else default_settings[
            'supported_formats']

        # Validate settings
        self._validate_settings()

        # Initialize existing components
        self.processing_queue = Queue()
        self.processed_docs = {}
        self._initialize_text_splitter()
        self.excel_processor = EnhancedExcelProcessor()

        # Check for required packages (keep existing functionality)
        try:
            import striprtf.striprtf
        except ImportError:
            logging.warning(
                "Warning: striprtf package not found. RTF support will be limited.")

        try:
            from bs4 import BeautifulSoup
            import lxml
        except ImportError:
            logging.warning(
                "Warning: beautifulsoup4 or lxml package not found. XML support will be limited.")

    def _validate_settings(self):
        """Validate and adjust settings if necessary"""
        # Ensure chunk_size is positive and reasonable
        self.chunk_size = max(100, self.chunk_size)

        # Ensure chunk_overlap is less than chunk_size
        self.chunk_overlap = min(self.chunk_overlap, self.chunk_size - 50)

        # Ensure max_file_size is reasonable (minimum 1MB)
        self.max_file_size = max(1024 * 1024, self.max_file_size)

        # Ensure supported_formats contains valid extensions
        if not self.supported_formats:
            # Fallback to default supported formats if empty
            self.supported_formats = Settings.DOCUMENT_PROCESSOR['supported_formats']

        # Ensure all formats start with a dot
        self.supported_formats = [
            f".{fmt.lower().lstrip('.')}" if not fmt.startswith(
                '.') else fmt.lower()
            for fmt in self.supported_formats
        ]

    def _initialize_text_splitter(self):
        """Initialize the text splitter with custom settings"""
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            length_function=len,
            # Modify separators to better handle markdown while maintaining overlap
            separators=["\n\n", "\n", " ", ""],
            keep_separator=True,
            add_start_index=True,
            strip_whitespace=False  # Keep whitespace to maintain markdown formatting
        )

    def _find_break_point(self, text: str, prev_chunk: str) -> int:
        """
    Find suitable breaking point that maintains document structure

    Args:
        text (str): Text to find break point in (the overlap portion)
        prev_chunk (str): The complete previous chunk for context

    Returns:
        int: Position of suitable break point
    """
        # Get the context of how the previous chunk ends
        prev_chunk_lines = prev_chunk.split('\n')

        # Special handling for markdown tables
        if '|' in prev_chunk:
            # Check if we're in the middle of a table
            table_rows = [
                line for line in prev_chunk_lines if line.strip().startswith('|')]
            if table_rows:
                # Find where the current table starts in the text
                table_start = text.find('|')
                if table_start >= 0:
                    # Find the next row boundary
                    next_row = text.find('\n', table_start)
                    if next_row >= 0:
                        return next_row + 1  # Include the newline

        # Define break point markers in order of preference
        break_markers = [
            ('\n\n', True),   # Paragraph breaks (keep marker)
            ('\n', True),     # Line breaks (keep marker)
            ('. ', True),     # Sentence endings (keep marker)
            (', ', True),     # Clause breaks (keep marker)
            (' ', False)      # Word breaks (don't keep marker)
        ]

        # Check the structure of the previous chunk end
        last_line = prev_chunk_lines[-1] if prev_chunk_lines else ""

        # Look for each type of break point
        for marker, keep_marker in break_markers:
            if marker in text:
                # Try to find a break point that maintains document structure
                marker_positions = [i for i in range(
                    len(text)) if text[i:i+len(marker)] == marker]

                for pos in reversed(marker_positions):
                    # Check if this break point would maintain document structure
                    if self._is_valid_break_point(text, pos, last_line):
                        return pos + (len(marker) if keep_marker else 0)

        # If no suitable break point found, default to exact position
        return min(len(text), self.chunk_overlap)

    def _is_valid_break_point(self, text: str, position: int, last_line: str) -> bool:
        """
    Check if a break point would maintain document structure

    Args:
        text (str): Text being checked
        position (int): Potential break position
        last_line (str): Last line of previous chunk

    Returns:
        bool: True if break point is valid
    """
        # Don't break in the middle of markdown formatting
        markdown_markers = ['*', '_', '`', '[', ']', '(', ')', '#']
        if position > 0 and position < len(text) - 1:
            if text[position-1] in markdown_markers or text[position+1] in markdown_markers:
                return False

        # Don't break in the middle of a table cell
        if '|' in last_line:
            cell_count = last_line.count('|')
            text_before_break = text[:position]
            if text_before_break.count('|') % cell_count != 0:
                return False

        # Don't break URLs or code blocks
        url_patterns = ['http://', 'https://', '```', '`']
        for pattern in url_patterns:
            if pattern in text[:position] and pattern not in text[position:]:
                return False

        return True

    def _validate_chunks(self, original_text: str, chunks: List[str]) -> bool:
        """Validate that chunks maintain document integrity"""
        try:
            # Remove overlap to check content
            reconstructed = chunks[0]
            for chunk in chunks[1:]:
                if len(chunk) > self.chunk_overlap:
                    reconstructed += chunk[self.chunk_overlap:]

            # Clean both texts for comparison (remove extra whitespace)
            clean_original = ' '.join(original_text.split())
            clean_reconstructed = ' '.join(reconstructed.split())

            return clean_original == clean_reconstructed
        except Exception as e:
            logging.error(f"Error validating chunks: {str(e)}")
            return False

    def _extract_content(self, file_path: Path) -> str:
        """Extract content from different file formats"""
        suffix = file_path.suffix.lower()

        try:
            if suffix == '.pdf':
                return self._extract_pdf(file_path)
            elif suffix == '.docx':
                return self._extract_docx(file_path)
            elif suffix == '.csv':
                return self._extract_csv(file_path)
            elif suffix == '.json':
                return self._extract_json(file_path)
            elif suffix == '.html':
                return self._extract_html(file_path)
            elif suffix == '.txt' or suffix == '.md':
                return self._extract_text(file_path)
            elif suffix == '.xml':
                return self._extract_xml(file_path)
            elif suffix == '.rtf':
                return self._extract_rtf(file_path)
            elif suffix in ['.xlsx', '.xls']:
                return self._extract_excel(file_path)
            else:
                raise ValueError(f"Unsupported format: {suffix}")
        except Exception as e:
            raise Exception(
                f"Error extracting content from {file_path}: {str(e)}")

    def _extract_text(self, file_path: Path) -> str:
        """Extract content from text-based files"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return f.read()
        except UnicodeDecodeError:
            with open(file_path, 'r', encoding='latin-1') as f:
                return f.read()

    def _extract_pdf(self, file_path: Path) -> str:
        """Extract text from PDF with advanced features"""
        text = ""
        with open(file_path, 'rb') as file:
            reader = PyPDF2.PdfReader(file)
            metadata = reader.metadata

            for page in reader.pages:
                text += page.extract_text() + "\n\n"

                # Extract images if available
                if '/XObject' in page['/Resources']:
                    for obj in page['/Resources']['/XObject'].get_object():
                        if page['/Resources']['/XObject'][obj]['/Subtype'] == '/Image':
                            pass

        return text.strip()

    def _extract_docx(self, file_path: Path) -> str:
        """Extract text from DOCX with formatting"""
        doc = docx.Document(file_path)
        full_text = []

        for para in doc.paragraphs:
            full_text.append(para.text)

        for table in doc.tables:
            for row in table.rows:
                row_text = [cell.text for cell in row.cells]
                full_text.append(" | ".join(row_text))

        return "\n\n".join(full_text)

    def _extract_csv(self, file_path: Path) -> str:
        """Convert CSV to structured text"""
        df = pd.read_csv(file_path)
        return df.to_string()

    def _extract_json(self, file_path: Path) -> str:
        """Convert JSON to readable text"""
        with open(file_path) as f:
            data = json.load(f)
        return json.dumps(data, indent=2)

    def _extract_html(self, file_path: Path) -> str:
        """Extract text from HTML with structure preservation"""
        with open(file_path) as f:
            soup = BeautifulSoup(f, 'html.parser')

        for script in soup(["script", "style"]):
            script.decompose()

        text = soup.get_text(separator='\n')
        lines = [line.strip() for line in text.splitlines() if line.strip()]
        return "\n\n".join(lines)

    def _extract_xml(self, file_path: Path) -> str:
        """Extract text from XML with structure preservation"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                soup = BeautifulSoup(f, 'xml')

            for pi in soup.find_all(text=lambda text: isinstance(text, ProcessingInstruction)):
                pi.extract()

            text = soup.get_text(separator='\n')
            lines = [line.strip()
                     for line in text.splitlines() if line.strip()]
            return "\n\n".join(lines)
        except Exception as e:
            raise Exception(f"Error processing XML file: {str(e)}")

    def _extract_rtf(self, file_path: Path) -> str:
        """Extract text from RTF files"""
        try:
            import striprtf.striprtf as striprtf

            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                rtf_text = f.read()

            plain_text = striprtf.rtf_to_text(rtf_text)
            lines = [line.strip()
                     for line in plain_text.splitlines() if line.strip()]
            return "\n\n".join(lines)
        except ImportError:
            raise ImportError("striprtf package is required for RTF support.")
        except Exception as e:
            raise Exception(f"Error processing RTF file: {str(e)}")

    def _extract_excel(self, file_path: Path) -> str:
        """Extract content from Excel files with enhanced processing"""
        try:
            # Use enhanced Excel processor
            processed_content = self.excel_processor.process_excel(file_path)

            # If processing fails, fall back to basic processing
            if not processed_content:
                logging.warning(
                    f"Enhanced Excel processing failed for {file_path}, falling back to basic processing")
                return self._basic_excel_extract(file_path)

            return processed_content

        except Exception as e:
            logging.error(f"Error in enhanced Excel processing: {str(e)}")
            # Fall back to basic Excel processing
            return self._basic_excel_extract(file_path)

    def _basic_excel_extract(self, file_path: Path) -> str:
        """Basic Excel extraction as fallback"""
        try:
            excel_file = pd.ExcelFile(file_path)
            sheets_data = []

            for sheet_name in excel_file.sheet_names:
                df = pd.read_excel(excel_file, sheet_name=sheet_name)
                sheet_content = f"\nSheet: {sheet_name}\n"
                sheet_content += "=" * (len(sheet_name) + 7) + "\n"

                if df.empty:
                    sheet_content += "Empty Sheet\n"
                else:
                    sheet_content += df.fillna('').to_string(
                        index=False,
                        max_rows=None,
                        max_cols=None,
                        line_width=120
                    ) + "\n"

                sheets_data.append(sheet_content)

            return "\n\n".join(sheets_data)

        except Exception as e:
            raise Exception(f"Error in basic Excel processing: {str(e)}")

    def _get_mime_type(self, file_path: Path) -> str:
        """
        Get MIME type for a file based on its extension

        Args:
            file_path (Path): Path to the file

        Returns:
            str: MIME type of the file
        """
        # Standard MIME mappings for supported formats
        MIME_MAPPINGS = {
            '.pdf': 'application/pdf',
            '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
            '.doc': 'application/msword',
            '.csv': 'text/csv',
            '.json': 'application/json',
            '.html': 'text/html',
            '.txt': 'text/plain',
            '.md': 'text/markdown',
            '.xml': 'text/xml',
            '.rtf': 'application/rtf',
            '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
            '.xls': 'application/vnd.ms-excel'
        }

        suffix = file_path.suffix.lower()

        # Verify the file format is supported
        if suffix not in self.supported_formats:
            logging.warning(f"Unsupported file format: {suffix}")
            return 'application/octet-stream'

        # Return known MIME type or fall back to mimetypes module
        if suffix in MIME_MAPPINGS:
            return MIME_MAPPINGS[suffix]

        mime_type = mimetypes.guess_type(str(file_path))[0]
        return mime_type if mime_type else 'application/octet-stream'

    def _generate_metadata(
        self,
        file_path: Path,
        content: str,
        additional_metadata: Optional[Dict] = None
    ) -> Dict:
        """Generate comprehensive metadata"""
        file_stat = file_path.stat()

        metadata = {
            'filename': file_path.name,
            'file_type': file_path.suffix,
            'file_size': file_stat.st_size,
            'created_at': datetime.fromtimestamp(file_stat.st_ctime),
            'modified_at': datetime.fromtimestamp(file_stat.st_mtime),
            'content_hash': self._calculate_hash(content),
            'mime_type': self._get_mime_type(file_path),
            'word_count': len(content.split()),
            'character_count': len(content),
            'processing_timestamp': datetime.utcnow().isoformat()
        }

        # Add Excel-specific metadata if applicable
        if file_path.suffix.lower() in ['.xlsx', '.xls']:
            try:
                if hasattr(self.excel_processor, 'get_metadata'):
                    excel_metadata = self.excel_processor.get_metadata()
                    metadata.update({'excel_metadata': excel_metadata})
            except Exception as e:
                logging.warning(f"Could not extract Excel metadata: {str(e)}")

        if additional_metadata:
            metadata.update(additional_metadata)

        return metadata

    # def _generate_metadata(
    #     self,
    #     file_path: Path,
    #     content: str,
    #     additional_metadata: Optional[Dict] = None
    # ) -> Dict:
    #     """Generate comprehensive metadata"""
    #     file_stat = file_path.stat()

    #     metadata = {
    #         'filename': file_path.name,
    #         'file_type': file_path.suffix,
    #         'file_size': file_stat.st_size,
    #         'created_at': datetime.fromtimestamp(file_stat.st_ctime),
    #         'modified_at': datetime.fromtimestamp(file_stat.st_mtime),
    #         'content_hash': self._calculate_hash(content),
    #         'mime_type': magic.from_file(str(file_path), mime=True),
    #         'word_count': len(content.split()),
    #         'character_count': len(content),
    #         'processing_timestamp': datetime.now().isoformat()
    #     }

    #     # Add Excel-specific metadata if applicable
    #     if file_path.suffix.lower() in ['.xlsx', '.xls']:
    #         try:
    #             if hasattr(self.excel_processor, 'get_metadata'):
    #                 excel_metadata = self.excel_processor.get_metadata()
    #                 metadata.update({'excel_metadata': excel_metadata})
    #         except Exception as e:
    #             logging.warning(f"Could not extract Excel metadata: {str(e)}")

    #     if additional_metadata:
    #         metadata.update(additional_metadata)

    #     return metadata

    def _calculate_hash(self, text: str) -> str:
        """Calculate SHA-256 hash of text"""
        return hashlib.sha256(text.encode()).hexdigest()

    def _process_chunks(self, text: str) -> List[str]:
        """Process text into chunks with proper overlap"""
        chunks = self.text_splitter.split_text(text)

        # Ensure minimum chunk size and handle overlaps
        processed_chunks = []
        for i, chunk in enumerate(chunks):
            if i > 0:
                # Add overlap from previous chunk
                overlap_start = max(
                    0, len(processed_chunks[-1]) - self.chunk_overlap)
                chunk = processed_chunks[-1][overlap_start:] + chunk

            if len(chunk) > self.chunk_size:
                # Split oversized chunks
                sub_chunks = self.text_splitter.split_text(chunk)
                processed_chunks.extend(sub_chunks)
            else:
                processed_chunks.append(chunk)

        return processed_chunks

    async def process_document(self, file_path: Union[str, Path]) -> Dict:
        """Process document with chunk overlapping"""
        file_path = Path(file_path)

        if not self._validate_file(file_path):
            raise ValueError(f"Invalid file: {file_path}")

        content = self._extract_content(file_path)
        chunks = self._process_chunks(content)

        return {
            'content': content,
            'chunks': chunks,
            'metadata': self._generate_metadata(file_path, content)
        }

    def _calculate_overlap_size(self, chunk1: str, chunk2: str) -> int:
        """Calculate the size of overlap between two chunks"""
        min_len = min(len(chunk1), len(chunk2))
        for i in range(min_len, 0, -1):
            if chunk1[-i:] == chunk2[:i]:
                return i
        return 0

    def _validate_file(self, file_path: Path) -> bool:
        """Validate file type, size, and content"""
        if not file_path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")

        if file_path.suffix.lower() not in self.supported_formats:
            raise ValueError(f"Unsupported file format: {file_path.suffix}")

        if file_path.stat().st_size > self.max_file_size:
            raise ValueError(f"File too large: {file_path}")

        if file_path.stat().st_size == 0:
            raise ValueError(f"Empty file: {file_path}")

        return True

    def _generate_statistics(self, content: str, chunks: List[str]) -> Dict:
        """Generate document statistics"""
        return {
            'total_chunks': len(chunks),
            'average_chunk_size': sum(len(chunk) for chunk in chunks) / len(chunks),
            'token_estimate': len(content.split()),
            'unique_words': len(set(content.lower().split())),
            'sentences': len([s for s in content.split('.') if s.strip()]),
        }

    async def batch_process(
        self,
        file_paths: List[Union[str, Path]],
        parallel: bool = True
    ) -> Dict[str, Dict]:
        """Process multiple documents in parallel"""
        results = {}

        if parallel:
            threads = []
            for file_path in file_paths:
                thread = threading.Thread(
                    target=self._process_and_store,
                    args=(file_path, results)
                )
                threads.append(thread)
                thread.start()

            for thread in threads:
                thread.join()
        else:
            for file_path in file_paths:
                await self._process_and_store(file_path, results)

        return results

    async def _process_and_store(
        self,
        file_path: Union[str, Path],
        results: Dict
    ):
        """Process a single document and store results"""
        try:
            result = await self.process_document(file_path)
            results[str(file_path)] = result
        except Exception as e:
            results[str(file_path)] = {'error': str(e)}