ocr-llm-test / app.py
winamnd's picture
Update app.py
025580f verified
raw
history blame
2.51 kB
import gradio as gr
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import keras_ocr
import cv2
import easyocr
from paddleocr import PaddleOCR
import numpy as np
# Load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("./distilbert_spam_model")
# Load model
model = DistilBertForSequenceClassification.from_pretrained("./distilbert_spam_model")
model.load_state_dict(torch.load("./distilbert_spam_model/model.pth", map_location=torch.device('cpu')))
model.eval()
"""
Paddle OCR
"""
def ocr_with_paddle(img):
finaltext = ''
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
for i in range(len(result[0])):
text = result[0][i][1][0]
finaltext += ' ' + text
return finaltext
"""
Keras OCR
"""
def ocr_with_keras(img):
output_text = ''
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
for text, _ in predictions[0]:
output_text += ' ' + text
return output_text
"""
Easy OCR
"""
def ocr_with_easy(img):
reader = easyocr.Reader(['en'])
bounds = reader.readtext(img, paragraph=True, detail=0)
return ' '.join(bounds)
"""
Generate OCR and classify spam
"""
def generate_ocr_and_classify(Method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Perform OCR
text_output = ''
if Method == 'EasyOCR':
text_output = ocr_with_easy(img)
elif Method == 'KerasOCR':
text_output = ocr_with_keras(img)
elif Method == 'PaddleOCR':
text_output = ocr_with_paddle(img)
# Classify extracted text
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
prediction = torch.argmax(outputs.logits, dim=1).item()
classification = "Spam" if prediction == 1 else "Not Spam"
return text_output, classification
"""
Create user interface
"""
image = gr.Image()
method = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Label(label="Classification")
demo = gr.Interface(
generate_ocr_and_classify,
[method, image],
[output_text, output_label],
title="OCR & Spam Classification",
description="Upload an image with text, extract the text using OCR, and classify whether it is spam or not.",
)
demo.launch()