File size: 2,511 Bytes
19736cf
025580f
 
19736cf
 
 
 
025580f
 
 
 
 
 
 
 
 
 
19736cf
 
 
 
 
 
 
c4cf574
19736cf
 
025580f
19736cf
c4cf574
19736cf
 
 
 
 
025580f
 
 
 
 
 
19736cf
c4cf574
19736cf
025580f
19736cf
 
025580f
 
 
 
19736cf
025580f
19736cf
025580f
 
 
 
 
19736cf
025580f
 
 
 
 
 
c4cf574
025580f
 
 
 
c4cf574
025580f
 
 
 
a4bd204
 
025580f
a4bd204
19736cf
025580f
 
 
c4cf574
19736cf
025580f
 
 
 
 
19736cf
025580f
19736cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import keras_ocr
import cv2
import easyocr
from paddleocr import PaddleOCR
import numpy as np

# Load tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("./distilbert_spam_model")

# Load model
model = DistilBertForSequenceClassification.from_pretrained("./distilbert_spam_model")
model.load_state_dict(torch.load("./distilbert_spam_model/model.pth", map_location=torch.device('cpu')))
model.eval()

"""
Paddle OCR
"""
def ocr_with_paddle(img):
    finaltext = ''
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    
    for i in range(len(result[0])):
        text = result[0][i][1][0]
        finaltext += ' ' + text
    return finaltext

"""
Keras OCR
"""
def ocr_with_keras(img):
    output_text = ''
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    
    for text, _ in predictions[0]:
        output_text += ' ' + text
    return output_text

"""
Easy OCR
"""
def ocr_with_easy(img):
    reader = easyocr.Reader(['en'])
    bounds = reader.readtext(img, paragraph=True, detail=0)
    return ' '.join(bounds)

"""
Generate OCR and classify spam
"""
def generate_ocr_and_classify(Method, img):
    if img is None:
        raise gr.Error("Please upload an image!")

    # Perform OCR
    text_output = ''
    if Method == 'EasyOCR':
        text_output = ocr_with_easy(img)
    elif Method == 'KerasOCR':
        text_output = ocr_with_keras(img)
    elif Method == 'PaddleOCR':
        text_output = ocr_with_paddle(img)
    
    # Classify extracted text
    inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    
    prediction = torch.argmax(outputs.logits, dim=1).item()
    classification = "Spam" if prediction == 1 else "Not Spam"

    return text_output, classification

"""
Create user interface
"""
image = gr.Image()
method = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Label(label="Classification")

demo = gr.Interface(
    generate_ocr_and_classify,
    [method, image],
    [output_text, output_label],
    title="OCR & Spam Classification",
    description="Upload an image with text, extract the text using OCR, and classify whether it is spam or not.",
)

demo.launch()