from fastapi import UploadFile, File, Form, HTTPException, APIRouter
from typing import List, Optional, Dict, Tuple
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import pandas as pd
from utils import process_pdf_to_chunks
import hashlib
import uuid
import json
from datetime import datetime
from pydantic import BaseModel
import logging

# Create router
router = APIRouter(
    prefix="/rag",
    tags=["rag"]
)

# Initialize LanceDB and embedding model
db = lancedb.connect("/tmp/db")
model = get_registry().get("sentence-transformers").create(
    name="Snowflake/snowflake-arctic-embed-xs", 
    device="cpu"
)

def get_user_collection(user_id: str, collection_name: str) -> str:
    """Generate user-specific collection name"""
    return f"{user_id}_{collection_name}"

class DocumentChunk(LanceModel):
    text: str = model.SourceField()
    vector: Vector(model.ndims()) = model.VectorField()
    document_id: str
    chunk_index: int
    file_name: str
    file_type: str
    created_date: str
    collection_id: str
    user_id: str
    metadata_json: str
    char_start: int
    char_end: int
    page_numbers: List[int]
    images: List[str]

class QueryInput(BaseModel):
    collection_id: str
    query: str
    top_k: Optional[int] = 3
    user_id: str

class SearchResult(BaseModel):
    text: str
    distance: float
    metadata: Dict  # Added metadata field

class SearchResponse(BaseModel):
    results: List[SearchResult]

async def process_file(file: UploadFile, collection_id: str, user_id: str) -> Tuple[List[dict], str]:
    """Process single file and return chunks with metadata"""
    content = await file.read()
    file_type = file.filename.split('.')[-1].lower()
    
    chunks = []
    doc_id = ""
    if file_type == 'pdf':
        chunks, doc_id = process_pdf_to_chunks(
            pdf_content=content,
            file_name=file.filename
        )
    elif file_type == 'txt':
        doc_id = hashlib.sha256(content).hexdigest()[:4]
        text_content = content.decode('utf-8')
        chunks = [{
            "text": text_content,
            "metadata": {
                "created_date": datetime.now().isoformat(),
                "file_name": file.filename,
                "document_id": doc_id,
                "user_id": user_id,
                "location": {
                    "chunk_index": 0,
                    "char_start": 0,
                    "char_end": len(text_content),
                    "pages": [1],
                    "total_chunks": 1
                },
                "images": []
            }
        }]
    
    return chunks, doc_id

@router.post("/upload_files")
async def upload_files(
    files: List[UploadFile] = File(...),
    collection_name: Optional[str] = Form(None),
    user_id: str = Form(...)
):
    try:
        collection_id = get_user_collection(
            user_id, 
            collection_name if collection_name else f"col_{uuid.uuid4().hex[:8]}"
        )
        all_chunks = []
        doc_ids = {}
        for file in files:
            try:
                chunks, doc_id = await process_file(file, collection_id, user_id)
                for chunk in chunks:
                    chunk_data = {
                        "text": chunk["text"],
                        "document_id": chunk["metadata"]["document_id"],
                        "chunk_index": chunk["metadata"]["location"]["chunk_index"],
                        "file_name": chunk["metadata"]["file_name"],
                        "file_type": file.filename.split('.')[-1].lower(),
                        "created_date": chunk["metadata"]["created_date"],
                        "collection_id": collection_id,
                        "user_id": user_id,
                        "metadata_json": json.dumps(chunk["metadata"]),
                        "char_start": chunk["metadata"]["location"]["char_start"],
                        "char_end": chunk["metadata"]["location"]["char_end"],
                        "page_numbers": chunk["metadata"]["location"]["pages"],
                        "images": chunk["metadata"].get("images", [])
                    }
                    all_chunks.append(chunk_data)
                doc_ids[doc_id] = file.filename
            except Exception as e:
                logging.error(f"Error processing file {file.filename}: {str(e)}")
                raise HTTPException(
                    status_code=400,
                    detail=f"Error processing file {file.filename}: {str(e)}"
                )

        try:
            table = db.open_table(collection_id)
        except Exception as e:
            logging.error(f"Error opening table: {str(e)}")
            try:
                table = db.create_table(
                    collection_id,
                    schema=DocumentChunk,
                    mode="create"
                )
                # Create FTS index on the text column for hybrid search support

                # table.create_fts_index(
                #     field_names="text",
                #     replace=True,
                #     tokenizer_name="en_stem",  # Use English stemming
                #     lower_case=True,  # Convert text to lowercase
                #     remove_stop_words=True,  # Remove common words like "the", "is", "at"
                #     writer_heap_size=1024 * 1024 * 1024  # 1GB heap size
                # )

            except Exception as e:
                logging.error(f"Error creating table: {str(e)}")
                raise HTTPException(
                    status_code=500,
                    detail=f"Error creating database table: {str(e)}"
                )
            
        try:
            df = pd.DataFrame(all_chunks)
            table.add(data=df)
        except Exception as e:
            logging.error(f"Error adding data to table: {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Error adding data to database: {str(e)}"
            )
        
        return {
            "message": f"Successfully processed {len(files)} files",
            "collection_id": collection_id,
            "total_chunks": len(all_chunks),
            "user_id": user_id,
            "document_ids": doc_ids
        }
        
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Unexpected error during file upload: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"Unexpected error: {str(e)}"
        )

@router.get("/get_document/{collection_id}/{document_id}")
async def get_document(
    collection_id: str,
    document_id: str,
    user_id: str
):
    try:
        table = db.open_table(f"{user_id}_{collection_id}")
    except Exception as e:
        logging.error(f"Error opening table: {str(e)}")
        raise HTTPException(
            status_code=404,
            detail=f"Collection not found: {str(e)}"
        )

    try:
        chunks = table.to_pandas()
        doc_chunks = chunks[
            (chunks['document_id'] == document_id) & 
            (chunks['user_id'] == user_id)
        ].sort_values('chunk_index')
        
        if len(doc_chunks) == 0:
            raise HTTPException(
                status_code=404,
                detail=f"Document {document_id} not found in collection {collection_id}"
            )
            
        return {
            "document_id": document_id,
            "file_name": doc_chunks.iloc[0]['file_name'],
            "chunks": [
                {
                    "text": row['text'],
                    "metadata": json.loads(row['metadata_json'])
                }
                for _, row in doc_chunks.iterrows()
            ]
        }
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Error retrieving document: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"Error retrieving document: {str(e)}"
        )

@router.post("/query_collection", response_model=SearchResponse)
async def query_collection(input_data: QueryInput):
    try:
        collection_id = get_user_collection(input_data.user_id, input_data.collection_id)
        
        try:
            table = db.open_table(collection_id)
        except Exception as e:
            logging.error(f"Error opening table: {str(e)}")
            raise HTTPException(
                status_code=404,
                detail=f"Collection not found: {str(e)}"
            )

        try:
            results = (
                table.search(input_data.query)
                .where(f"user_id = '{input_data.user_id}'")
                .limit(input_data.top_k)
                .to_list()
            )
        except Exception as e:
            logging.error(f"Error searching collection: {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Error searching collection: {str(e)}"
            )
        
        return SearchResponse(results=[
            SearchResult(
                text=r['text'],
                distance=float(r['_distance']),
                metadata=json.loads(r['metadata_json'])
            )
            for r in results
        ])
    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Unexpected error during query: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"Unexpected error: {str(e)}"
        )



@router.get("/list_collections")
async def list_collections(user_id: str):
    try:
        all_collections = db.table_names()
        user_collections = [
            c for c in all_collections 
            if c.startswith(f"{user_id}_")
        ]
        
        # Get documents for each collection
        collections_info = []
        for collection_name in user_collections:
            try:
                table = db.open_table(collection_name)
                df = table.to_pandas()
                
                # Group by document_id to get unique documents
                documents = df.groupby('document_id').agg({
                    'file_name': 'first',
                    'created_date': 'first'
                }).reset_index()
                
                collections_info.append({
                    "collection_id": collection_name.replace(f"{user_id}_", ""),
                    "documents": [
                        {
                            "document_id": row['document_id'],
                            "file_name": row['file_name'],
                            "created_date": row['created_date']
                        }
                        for _, row in documents.iterrows()
                    ]
                })
            except Exception as e:
                logging.error(f"Error processing collection {collection_name}: {str(e)}")
                continue
                
        return {"collections": collections_info}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.delete("/delete_collection/{collection_id}")
async def delete_collection(collection_id: str, user_id: str):
    try:
        full_collection_id = f"{user_id}_{collection_id}"
        
        # Check if collection exists
        try:
            table = db.open_table(full_collection_id)
        except Exception as e:
            logging.error(f"Collection not found: {str(e)}")
            raise HTTPException(
                status_code=404,
                detail=f"Collection {collection_id} not found"
            )

        # Verify ownership
        if not full_collection_id.startswith(f"{user_id}_"):
            logging.error(f"Unauthorized deletion attempt for collection {collection_id} by user {user_id}")
            raise HTTPException(
                status_code=403,
                detail="Not authorized to delete this collection"
            )

        try:
            db.drop_table(full_collection_id)
        except Exception as e:
            logging.error(f"Error deleting collection {collection_id}: {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Error deleting collection: {str(e)}"
            )

        return {
            "message": f"Collection {collection_id} deleted successfully",
            "collection_id": collection_id
        }

    except HTTPException:
        raise
    except Exception as e:
        logging.error(f"Unexpected error deleting collection {collection_id}: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"Unexpected error: {str(e)}"
        )

@router.post("/get_collection_files")
def get_collection_files(collection_id: str, user_id: str) -> str:
    """Get list of files in the specified collection"""
    try:
        # Get the full collection name
        collection_name = f"{user_id}_{collection_id}"
        
        # Open the table and convert to pandas
        table = db.open_table(collection_name)
        df = table.to_pandas()
        logging.info(f"fetched chunks {str(df.head())}")
        
        # Get unique file names
        unique_files = df['file_name'].unique()
        
        # Join the file names into a string
        return ", ".join(unique_files)
    except Exception as e:
        logging.error(f"Error getting collection files: {str(e)}")
        return f"Error getting files: {str(e)}"


@router.post("/query_collection_tool")
async def query_collection_tool(input_data: QueryInput):
    try:
        response = await query_collection(input_data)
        results = []
        
        # Access response directly since it's a Pydantic model
        for r in response.results:
            result_dict = {
                "text": r.text,
                "distance": r.distance,
                "metadata": {
                    "document_id": r.metadata.get("document_id"),
                    "chunk_index": r.metadata.get("location", {}).get("chunk_index")
                }
            }
            results.append(result_dict)
            
        return str(results)
        
    except Exception as e:
        logging.error(f"Unexpected error during query: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail=f"Unexpected error: {str(e)}"
        )