SakibRumu
Update app.py
2b183b5 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
# Class Mapping for RAF-DB Dataset (7 classes)
class_mapping = {
0: "Surprise",
1: "Fear",
2: "Disgust",
3: "Happiness",
4: "Sadness",
5: "Anger",
6: "Neutral"
}
# Transformations for inference (same as test transform)
transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Feature Extraction Backbone
class IR50(nn.Module):
def __init__(self):
super(IR50, self).__init__()
resnet = models.resnet50(weights='IMAGENET1K_V1')
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.downsample = nn.Conv2d(512, 256, 1, stride=2)
self.bn_downsample = nn.BatchNorm2d(256, eps=1e-5)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.downsample(x)
x = self.bn_downsample(x)
return x
# HLA Stream
class HLA(nn.Module):
def __init__(self, in_channels=256, reduction=4):
super(HLA, self).__init__()
reduced_channels = in_channels // reduction
self.spatial_branch1 = nn.Conv2d(in_channels, reduced_channels, 1)
self.spatial_branch2 = nn.Conv2d(in_channels, reduced_channels, 1)
self.sigmoid = nn.Sigmoid()
self.channel_restore = nn.Conv2d(reduced_channels, in_channels, 1)
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False),
nn.Sigmoid()
)
self.bn = nn.BatchNorm2d(in_channels, eps=1e-5)
def forward(self, x):
b1 = self.spatial_branch1(x)
b2 = self.spatial_branch2(x)
spatial_attn = self.sigmoid(torch.max(b1, b2))
spatial_attn = self.channel_restore(spatial_attn)
spatial_out = x * spatial_attn
channel_attn = self.channel_attention(spatial_out)
out = spatial_out * channel_attn
out = self.bn(out)
return out
# ViT Stream
class ViT(nn.Module):
def __init__(self, in_channels=256, patch_size=1, embed_dim=768, num_layers=8, num_heads=12): # 8 layers as in the 82.93% version
super(ViT, self).__init__()
self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
num_patches = (7 // patch_size) * (7 // patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.transformer = nn.ModuleList([
nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1536, activation="gelu")
for _ in range(num_layers)
])
self.ln = nn.LayerNorm(embed_dim)
self.bn = nn.BatchNorm1d(embed_dim, eps=1e-5)
# Initialize weights
nn.init.xavier_uniform_(self.patch_embed.weight)
nn.init.zeros_(self.patch_embed.bias)
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
for layer in self.transformer:
x = layer(x)
x = x[:, 0]
x = self.ln(x)
x = self.bn(x)
return x
# Intensity Stream
class IntensityStream(nn.Module):
def __init__(self, in_channels=256):
super(IntensityStream, self).__init__()
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
self.sobel_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
self.sobel_y = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
self.sobel_x.weight.data = sobel_x.repeat(in_channels, 1, 1, 1)
self.sobel_y.weight.data = sobel_y.repeat(in_channels, 1, 1, 1)
self.conv = nn.Conv2d(in_channels, 128, 3, padding=1)
self.bn = nn.BatchNorm2d(128, eps=1e-5)
self.pool = nn.AdaptiveAvgPool2d(1)
self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=1)
# Initialize weights
nn.init.xavier_uniform_(self.conv.weight)
nn.init.zeros_(self.conv.bias)
def forward(self, x):
gx = self.sobel_x(x)
gy = self.sobel_y(x)
grad_magnitude = torch.sqrt(gx**2 + gy**2 + 1e-8)
variance = ((x - x.mean(dim=1, keepdim=True))**2).mean(dim=1).flatten(1)
cnn_out = F.relu(self.conv(grad_magnitude))
cnn_out = self.bn(cnn_out)
texture_out = self.pool(cnn_out).squeeze(-1).squeeze(-1)
attn_in = cnn_out.flatten(2).permute(2, 0, 1)
attn_in = attn_in / (attn_in.norm(dim=-1, keepdim=True) + 1e-8)
attn_out, _ = self.attention(attn_in, attn_in, attn_in)
context_out = attn_out.mean(dim=0)
out = torch.cat([texture_out, context_out], dim=1)
return out, grad_magnitude, variance
# Full Model (Single-Label Prediction)
class TripleStreamHLAViT(nn.Module):
def __init__(self, num_classes=7):
super(TripleStreamHLAViT, self).__init__()
self.backbone = IR50()
self.hla = HLA()
self.vit = ViT()
self.intensity = IntensityStream()
self.fc_hla = nn.Linear(256*7*7, 768)
self.fc_intensity = nn.Linear(256, 768)
self.fusion_fc = nn.Linear(768*3, 512)
self.bn_fusion = nn.BatchNorm1d(512, eps=1e-5)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(512, num_classes)
# Initialize weights
nn.init.xavier_uniform_(self.fc_hla.weight)
nn.init.zeros_(self.fc_hla.bias)
nn.init.xavier_uniform_(self.fc_intensity.weight)
nn.init.zeros_(self.fc_intensity.bias)
nn.init.xavier_uniform_(self.fusion_fc.weight)
nn.init.zeros_(self.fusion_fc.bias)
nn.init.xavier_uniform_(self.classifier.weight)
nn.init.zeros_(self.classifier.bias)
def forward(self, x):
features = self.backbone(x)
hla_out = self.hla(features)
vit_out = self.vit(features)
intensity_out, grad_magnitude, variance = self.intensity(features)
hla_flat = self.fc_hla(hla_out.view(-1, 256*7*7))
intensity_flat = self.fc_intensity(intensity_out)
fused = torch.cat([hla_flat, vit_out, intensity_flat], dim=1)
fused = F.relu(self.fusion_fc(fused))
fused = self.bn_fusion(fused)
fused = self.dropout(fused)
logits = self.classifier(fused)
return logits, hla_out, vit_out, grad_magnitude, variance
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = TripleStreamHLAViT(num_classes=7).to(device)
model_path = "triple_stream_model_rafdb.pth" # Ensure this file is in the Hugging Face Space repository
try:
# Map the weights to the appropriate device
map_location = torch.device('cpu') if not torch.cuda.is_available() else None
model.load_state_dict(torch.load(model_path, map_location=map_location, weights_only=True))
model.eval()
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
raise
# Inference and Visualization Function
def predict_emotion(image):
# Convert the input image (from Gradio) to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Preprocess the image
image_tensor = transform(image).unsqueeze(0).to(device)
# Run inference
with torch.no_grad():
outputs, hla_out, _, grad_magnitude, _ = model(image_tensor)
probs = F.softmax(outputs, dim=1)
pred_label = torch.argmax(probs, dim=1).item()
pred_label_name = class_mapping[pred_label]
probabilities = probs.cpu().numpy()[0]
# Create probability dictionary
prob_dict = {class_mapping[i]: float(prob) for i, prob in enumerate(probabilities)}
# Generate HLA heatmap
heatmap = hla_out[0].mean(dim=0).detach().cpu().numpy()
# Denormalize the image for visualization
img = image_tensor[0].permute(1, 2, 0).detach().cpu().numpy()
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img = np.clip(img, 0, 1)
# Plot the input image and heatmap
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(img)
axs[0].set_title(f"Input Image\nPredicted: {pred_label_name}")
axs[0].axis("off")
axs[1].imshow(heatmap, cmap="jet")
axs[1].set_title("HLA Heatmap")
axs[1].axis("off")
plt.tight_layout()
# Save the plot to a temporary file
plt.savefig("visualization.png")
plt.close()
return pred_label_name, prob_dict, "visualization.png"
# Gradio Interface
iface = gr.Interface(
fn=predict_emotion,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=[
gr.Textbox(label="Predicted Emotion"),
gr.Label(label="Probabilities"),
gr.Image(label="Input Image and HLA Heatmap")
],
title="Facial Emotion Recognition with TripleStreamHLAViT",
description="Upload an image to predict the facial emotion (Surprise, Fear, Disgust, Happiness, Sadness, Anger, Neutral). This model achieves 82.93% test accuracy on the RAF-DB dataset. The HLA heatmap shows where the model focuses.",
examples=[
["examples/surprise.jpg"],
["examples/sadness.jpg"]
]
)
# Launch the interface
if __name__ == "__main__":
iface.launch(share=False)