import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
from PIL import Image
import pytesseract
import io

# Import save function
from save_results import save_results_to_repo  

# Paths
MODEL_PATH = "./distilbert_spam_model"

# Ensure LLM Model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
    print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    model.save_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    tokenizer.save_pretrained(MODEL_PATH)
    print(f"✅ Model saved at {MODEL_PATH}.")
else:
    model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)

# Ensure model is in evaluation mode
model.eval()

# Function to process image for OCR
def preprocess_image(image):
    """Convert PIL image to OpenCV format (NumPy array)"""
    return np.array(image)

# OCR Functions (same as ocr-api)
def ocr_with_paddle(img):
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    extracted_text, confidences = [], []
    for line in result[0]:
        text, confidence = line[1]
        extracted_text.append(text)
        confidences.append(confidence)
    return extracted_text, confidences

def ocr_with_keras(img):
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    extracted_text = [text for text, confidence in predictions[0]]
    confidences = [confidence for text, confidence in predictions[0]]
    return extracted_text, confidences

def ocr_with_easy(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    reader = easyocr.Reader(['en'])
    results = reader.readtext(gray_image)
    extracted_text = [text for _, text, confidence in results]
    confidences = [confidence for _, text, confidence in results]
    return extracted_text, confidences

def ocr_with_tesseract(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    extracted_text = pytesseract.image_to_string(gray_image).split("\n")
    extracted_text = [line.strip() for line in extracted_text if line.strip()]
    confidences = [1.0] * len(extracted_text)  # Tesseract doesn't return confidence scores
    return extracted_text, confidences

# OCR & Classification Function
def ocr_with_paddle(img):
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    return ' '.join([item[1][0] for item in result[0]])

def ocr_with_keras(img):
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    return ' '.join([text for text, _ in predictions[0]])

def ocr_with_easy(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    reader = easyocr.Reader(['en'])
    results = reader.readtext(gray_image, detail=0)
    return ' '.join(results)

# OCR & Classification Function
def generate_ocr(method, img):
    if img is None:
        raise gr.Error("Please upload an image!")

    # Convert PIL Image to OpenCV format
    img = np.array(img)

    # Select OCR method
    if method == "PaddleOCR":
        text_output = ocr_with_paddle(img)
    elif method == "EasyOCR":
        text_output = ocr_with_easy(img)
    else:  # KerasOCR
        text_output = ocr_with_keras(img)

    # Preprocess text properly
    text_output = text_output.strip()
    if len(text_output) == 0:
        return "No text detected!", "Cannot classify"

    # Tokenize text
    inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=1)  # Convert logits to probabilities
        spam_prob = probs[0][1].item()  # Probability of Spam

    # Adjust classification based on threshold (better than argmax)
    label = "Spam" if spam_prob > 0.5 else "Not Spam"

    # Save results using external function
    save_results_to_repo(text_output, label)

    return text_output, label

# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR", "TesseractOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")

demo = gr.Interface(
    generate_ocr,
    inputs=[method_input, image_input],
    outputs=[output_text, output_label],
    title="OCR Spam Classifier",
    description="Upload an image, extract text using OCR, and classify it as Spam or Not Spam.",
    theme="compact",
)

# Launch App
if __name__ == "__main__":
    demo.launch()