|
import gradio as gr |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
import torch.nn as nn |
|
|
|
import easyocr |
|
import numpy as np |
|
|
|
class CLIPMultimodalClassifier(nn.Module): |
|
def __init__(self, clip_model): |
|
super(CLIPMultimodalClassifier, self).__init__() |
|
self.clip_model = clip_model |
|
|
|
embed_dim = clip_model.config.projection_dim |
|
|
|
self.classifier = nn.Linear(embed_dim * 2, 2) |
|
def forward(self, input_ids, attention_mask, pixel_values): |
|
|
|
outputs = self.clip_model(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
pixel_values=pixel_values) |
|
|
|
img_embeds = outputs.image_embeds |
|
text_embeds = outputs.text_embeds |
|
|
|
fused_embeds = torch.cat([img_embeds, text_embeds], dim=1) |
|
|
|
logits = self.classifier(fused_embeds) |
|
return logits |
|
|
|
MODEL_PATH = "hateful_meme_clip_model.pth" |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
model_infer = CLIPMultimodalClassifier(CLIPModel.from_pretrained("openai/clip-vit-base-patch16")) |
|
model_infer.load_state_dict(torch.load(MODEL_PATH, map_location=device)) |
|
model_infer.to(device) |
|
model_infer.eval() |
|
|
|
|
|
proc = CLIPProcessor.from_pretrained("clip_processor") |
|
reader = easyocr.Reader(['en']) |
|
|
|
|
|
def classify_meme(image): |
|
|
|
|
|
|
|
text = reader.readtext(np.array(image)) |
|
text = [text for _, text, _ in text] |
|
text = " ".join(text) |
|
|
|
if text.strip() == "": |
|
text = "<no text found>" |
|
|
|
inputs = proc(text=[text], images=image, return_tensors="pt", padding=True) |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
pixel_values = inputs["pixel_values"].to(device) |
|
|
|
with torch.no_grad(): |
|
logits = model_infer(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values) |
|
probs = logits.softmax(dim=1).cpu().numpy()[0] |
|
|
|
confidence = probs[0] |
|
label = "Not Hateful" if confidence < 0.5 else "Hateful" |
|
|
|
return f"Test Extracted: {text}.\nDecision: Meme is {label}.\nConfidence of Decision: {confidence:.2f}." |
|
|
|
|
|
iface = gr.Interface(fn=classify_meme, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Textbox(label="Prediction"), |
|
title="Hateful Memes Classifier", |
|
description="Upload a meme image to check if it's hateful or not. The model will analyze both the image and text in the meme. The model's decision threshold is 0.5. i.e., if the confidance is less than 0.5, the meme is not hateful, else it is hateful.") |
|
|
|
|
|
iface.launch(debug=True) |