File size: 3,635 Bytes
a6db851
 
 
 
20b1d79
 
 
a6db851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b1d79
a6db851
 
 
 
 
20b1d79
 
 
 
 
a6db851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b1d79
a6db851
 
 
 
 
b3c3163
a6db851
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import torch.nn as nn
# import pytesseract
import easyocr
import numpy as np

class CLIPMultimodalClassifier(nn.Module):
    def __init__(self, clip_model):
        super(CLIPMultimodalClassifier, self).__init__()
        self.clip_model = clip_model
        # CLIP's embedding dimension (512 for ViT-B/32)
        embed_dim = clip_model.config.projection_dim
        # Classification layer: from (embed_dim_image + embed_dim_text) to 2 classes
        self.classifier = nn.Linear(embed_dim * 2, 2)
    def forward(self, input_ids, attention_mask, pixel_values):
        # Get image and text embeddings from CLIP
        outputs = self.clip_model(input_ids=input_ids, 
                                   attention_mask=attention_mask, 
                                   pixel_values=pixel_values)
        # CLIPModel outputs have image_embeds and text_embeds
        img_embeds = outputs.image_embeds    # shape: (batch, embed_dim)
        text_embeds = outputs.text_embeds    # shape: (batch, embed_dim)
        # Concatenate the embeddings
        fused_embeds = torch.cat([img_embeds, text_embeds], dim=1)  # shape: (batch, 2*embed_dim)
        # Feed to classifier to get logits
        logits = self.classifier(fused_embeds)
        return logits

MODEL_PATH = "hateful_meme_clip_model.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the saved model weights for inference
model_infer = CLIPMultimodalClassifier(CLIPModel.from_pretrained("openai/clip-vit-base-patch16"))
model_infer.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model_infer.to(device)
model_infer.eval()

# Load the saved processor
proc = CLIPProcessor.from_pretrained("clip_processor")
reader = easyocr.Reader(['en'])

# Define inference function
def classify_meme(image):
    # image is a PIL Image input from Gradio
    # 1. Extract text from image using OCR
    # text = pytesseract.image_to_string(image)
    text = reader.readtext(np.array(image))
    text = [text for _, text, _ in text]
    text = " ".join(text)

    if text.strip() == "":
        text = "<no text found>"  # handle cases with no detected text
    # 2. Preprocess image and text
    inputs = proc(text=[text], images=image, return_tensors="pt", padding=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    pixel_values = inputs["pixel_values"].to(device)
    # 3. Get model prediction
    with torch.no_grad():
        logits = model_infer(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
        probs = logits.softmax(dim=1).cpu().numpy()[0]
    # 4. Interpret the result
    confidence = probs[0]  # probability of class 'Not Hateful'
    label = "Not Hateful" if confidence < 0.5 else "Hateful"
    # Return label and confidence
    return f"Test Extracted: {text}.\nDecision: Meme is {label}.\nConfidence of Decision: {confidence:.2f}."

# Create Gradio interface
iface = gr.Interface(fn=classify_meme, 
                     inputs=gr.Image(type="pil"), 
                     outputs=gr.Textbox(label="Prediction"),
                     title="Hateful Memes Classifier",
                     description="Upload a meme image to check if it's hateful or not. The model will analyze both the image and text in the meme. The model's decision threshold is 0.5. i.e., if the confidance is less than 0.5, the meme is not hateful, else it is hateful.")

# Launch interface for local testing (if running locally, this will start a web server)
iface.launch(debug=True)