import gradio as gr import torch import json import os import cv2 import numpy as np import easyocr import keras_ocr from paddleocr import PaddleOCR from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch.nn.functional as F from save_results import save_results_to_repo # Paths MODEL_PATH = "./distilbert_spam_model" RESULTS_JSON = "ocr_results.json" # Ensure model exists if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")): print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...") model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) model.save_pretrained(MODEL_PATH) tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer.save_pretrained(MODEL_PATH) print(f"✅ Model saved at {MODEL_PATH}.") else: model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH) tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH) # Set the model to evaluation mode to disable dropout layers model.eval() # Load OCR Methods def ocr_with_paddle(img): ocr = PaddleOCR(lang='en', use_angle_cls=True) result = ocr.ocr(img) return ' '.join([item[1][0] for item in result[0]]) def ocr_with_keras(img): pipeline = keras_ocr.pipeline.Pipeline() images = [keras_ocr.tools.read(img)] predictions = pipeline.recognize(images) return ' '.join([text for text, _ in predictions[0]]) def ocr_with_easy(img): gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) reader = easyocr.Reader(['en']) results = reader.readtext(gray_image, detail=0) return ' '.join(results) # OCR Function def generate_ocr(method, img): if img is None: raise gr.Error("Please upload an image!") # Convert PIL Image to OpenCV format img = np.array(img) # Select OCR method if method == "PaddleOCR": text_output = ocr_with_paddle(img) elif method == "EasyOCR": text_output = ocr_with_easy(img) else: # KerasOCR text_output = ocr_with_keras(img) # Clean and truncate the extracted text text_output = text_output.strip() if len(text_output) == 0: return "No text detected!", "Cannot classify" # Classify Text as Spam or Not Spam inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities prediction = torch.argmax(probs, dim=1).item() label_map = {0: "Not Spam", 1: "Spam"} label = label_map[prediction] # Save results using the external save function save_results_to_repo(text_output, label) return text_output, label # Save results to JSON file RESULTS_JSON = "ocr_results.json" def save_to_json(text, label): data = {"Extracted Text": text, "Spam Classification": label} # Save to JSON file with open(RESULTS_JSON, "w") as json_file: json.dump(data, json_file, indent=4) return f"Results saved to {RESULTS_JSON}" # Create Gradio Interface image_input = gr.Image() method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR") output_text = gr.Textbox(label="Extracted Text") output_label = gr.Textbox(label="Spam Classification") save_button = gr.Button("Save to JSON") save_output = gr.Textbox(label="Save Status") # Main OCR Interface demo = gr.Interface( fn=generate_ocr, inputs=[method_input, image_input], outputs=[output_text, output_label], title="OCR Spam Classifier", description="Upload an image, extract text, and classify it as Spam or Not Spam.", theme="compact", ) # *Attach Save Button to Function* save_button.click( fn=save_to_json, inputs=[output_text, output_label], outputs=[save_output] ) # Launch App demo.launch()