Muzammal Shafique
Updated app.py
b3c3163
raw
history blame contribute delete
3.64 kB
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
import torch.nn as nn
# import pytesseract
import easyocr
import numpy as np
class CLIPMultimodalClassifier(nn.Module):
def __init__(self, clip_model):
super(CLIPMultimodalClassifier, self).__init__()
self.clip_model = clip_model
# CLIP's embedding dimension (512 for ViT-B/32)
embed_dim = clip_model.config.projection_dim
# Classification layer: from (embed_dim_image + embed_dim_text) to 2 classes
self.classifier = nn.Linear(embed_dim * 2, 2)
def forward(self, input_ids, attention_mask, pixel_values):
# Get image and text embeddings from CLIP
outputs = self.clip_model(input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values)
# CLIPModel outputs have image_embeds and text_embeds
img_embeds = outputs.image_embeds # shape: (batch, embed_dim)
text_embeds = outputs.text_embeds # shape: (batch, embed_dim)
# Concatenate the embeddings
fused_embeds = torch.cat([img_embeds, text_embeds], dim=1) # shape: (batch, 2*embed_dim)
# Feed to classifier to get logits
logits = self.classifier(fused_embeds)
return logits
MODEL_PATH = "hateful_meme_clip_model.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the saved model weights for inference
model_infer = CLIPMultimodalClassifier(CLIPModel.from_pretrained("openai/clip-vit-base-patch16"))
model_infer.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model_infer.to(device)
model_infer.eval()
# Load the saved processor
proc = CLIPProcessor.from_pretrained("clip_processor")
reader = easyocr.Reader(['en'])
# Define inference function
def classify_meme(image):
# image is a PIL Image input from Gradio
# 1. Extract text from image using OCR
# text = pytesseract.image_to_string(image)
text = reader.readtext(np.array(image))
text = [text for _, text, _ in text]
text = " ".join(text)
if text.strip() == "":
text = "<no text found>" # handle cases with no detected text
# 2. Preprocess image and text
inputs = proc(text=[text], images=image, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
pixel_values = inputs["pixel_values"].to(device)
# 3. Get model prediction
with torch.no_grad():
logits = model_infer(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
probs = logits.softmax(dim=1).cpu().numpy()[0]
# 4. Interpret the result
confidence = probs[0] # probability of class 'Not Hateful'
label = "Not Hateful" if confidence < 0.5 else "Hateful"
# Return label and confidence
return f"Test Extracted: {text}.\nDecision: Meme is {label}.\nConfidence of Decision: {confidence:.2f}."
# Create Gradio interface
iface = gr.Interface(fn=classify_meme,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Prediction"),
title="Hateful Memes Classifier",
description="Upload a meme image to check if it's hateful or not. The model will analyze both the image and text in the meme. The model's decision threshold is 0.5. i.e., if the confidance is less than 0.5, the meme is not hateful, else it is hateful.")
# Launch interface for local testing (if running locally, this will start a web server)
iface.launch(debug=True)