import os, io
from pathlib import Path

from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
from PyPDF2 import PdfReader
from docx import Document
from PIL import Image
from io import BytesIO

# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
PORT              = int(os.getenv("PORT", 7860))

app = FastAPI(
    title       = "AI-Powered Web-App API",
    description = "Backend for summarisation, captioning & QA",
    version     = "1.2.3",               # <-- bumped
)

app.add_middleware(
    CORSMiddleware,
    allow_origins     = ["*"],
    allow_credentials = True,
    allow_methods     = ["*"],
    allow_headers     = ["*"],
)

# -----------------------------------------------------------------------------
# OPTIONAL STATIC FILES
# -----------------------------------------------------------------------------
static_dir = Path("static")
if static_dir.exists():
    app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")

# -----------------------------------------------------------------------------
# HUGGING FACE INFERENCE CLIENTS
# -----------------------------------------------------------------------------
summary_client        = InferenceClient(
    "facebook/bart-large-cnn",
    token   = HUGGINGFACE_TOKEN,
    timeout = 120,
)

# ➜ Upgraded QA model (higher accuracy than roberta-base)
qa_client             = InferenceClient(
    "deepset/roberta-large-squad2",
    token   = HUGGINGFACE_TOKEN,
    timeout = 120,
)
# If you need multilingual support, swap for:
# qa_client = InferenceClient("deepset/xlm-roberta-large-squad2",
#                             token=HUGGINGFACE_TOKEN, timeout=120)

image_caption_client  = InferenceClient(
    "nlpconnect/vit-gpt2-image-captioning",
    token   = HUGGINGFACE_TOKEN,
    timeout = 60,
)

# -----------------------------------------------------------------------------
# UTILITIES
# -----------------------------------------------------------------------------
def extract_text_from_pdf(content: bytes) -> str:
    reader = PdfReader(io.BytesIO(content))
    return "\n".join(page.extract_text() or "" for page in reader.pages).strip()

def extract_text_from_docx(content: bytes) -> str:
    doc = Document(io.BytesIO(content))
    return "\n".join(p.text for p in doc.paragraphs).strip()

def process_uploaded_file(file: UploadFile) -> str:
    content = file.file.read()
    ext     = file.filename.split(".")[-1].lower()
    if ext == "pdf":
        return extract_text_from_pdf(content)
    if ext == "docx":
        return extract_text_from_docx(content)
    if ext == "txt":
        return content.decode("utf-8").strip()
    raise ValueError("Unsupported file type")

# -----------------------------------------------------------------------------
# ROUTES
# -----------------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
async def serve_index():
    return FileResponse("index.html")

# -------------------- Summarisation ------------------------------------------
@app.post("/api/summarize")
async def summarize_document(file: UploadFile = File(...)):
    try:
        text = process_uploaded_file(file)
        if len(text) < 20:
            return {"result": "Document too short to summarise."}
        summary_raw = summary_client.summarization(text[:3000])
        summary_txt = (
            summary_raw[0].get("summary_text") if isinstance(summary_raw, list) else
            summary_raw.get("summary_text")   if isinstance(summary_raw, dict) else
            str(summary_raw)
        )
        return {"result": summary_txt}
    except Exception as exc:
        return JSONResponse(status_code=500,
                            content={"error": f"Summarisation failure: {exc}"})


# -------------------- Image Caption ------------------------------------------
@app.post("/api/caption")
async def caption_image(image: UploadFile = File(...)):
    """`image` field name matches frontend (was `file` before)."""
    try:
        img_bytes = await image.read()
        img       = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        img.thumbnail((1024, 1024))
        buf = BytesIO(); img.save(buf, format="JPEG")
        result = image_caption_client.image_to_text(buf.getvalue())
        if isinstance(result, dict):
            caption = (result.get("generated_text")
                       or result.get("caption")
                       or "No caption found.")
        elif isinstance(result, list):
            caption = result[0].get("generated_text", "No caption found.")
        else:
            caption = str(result)
        return {"result": caption}
    except Exception as exc:
        return JSONResponse(status_code=500,
                            content={"error": f"Caption failure: {exc}"})


# -------------------- Question Answering -------------------------------------
@app.post("/api/qa")
async def question_answering(file: UploadFile = File(...),
                             question: str = Form(...)):
    try:
        if file.content_type.startswith("image/"):
            img_bytes = await file.read()
            img       = Image.open(io.BytesIO(img_bytes)).convert("RGB")
            img.thumbnail((1024, 1024))
            buf = BytesIO(); img.save(buf, format="JPEG")
            res      = image_caption_client.image_to_text(buf.getvalue())
            context  = (res.get("generated_text") if isinstance(res, dict)
                        else str(res))
        else:
            context = process_uploaded_file(file)[:3000]

        if not context:
            return {"result": "No context – cannot answer."}

        answer = qa_client.question_answering(question=question, context=context)
        return {"result": answer.get("answer", "No answer found.")}
    except Exception as exc:
        return JSONResponse(status_code=500,
                            content={"error": f"QA failure: {exc}"})


# -------------------- Health --------------------------------------------------
@app.get("/api/health")
async def health():
    return {"status": "healthy",
            "hf_token_set": bool(HUGGINGFACE_TOKEN),
            "version": app.version}

# -----------------------------------------------------------------------------
# ENTRYPOINT
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=PORT)