Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms, models | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import os | |
# Class Mapping for RAF-DB Dataset (7 classes) | |
class_mapping = { | |
0: "Surprise", | |
1: "Fear", | |
2: "Disgust", | |
3: "Happiness", | |
4: "Sadness", | |
5: "Anger", | |
6: "Neutral" | |
} | |
# Transformations for inference (same as test transform) | |
transform = transforms.Compose([ | |
transforms.Resize((112, 112)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Feature Extraction Backbone | |
class IR50(nn.Module): | |
def __init__(self): | |
super(IR50, self).__init__() | |
resnet = models.resnet50(weights='IMAGENET1K_V1') | |
self.conv1 = resnet.conv1 | |
self.bn1 = resnet.bn1 | |
self.relu = resnet.relu | |
self.maxpool = resnet.maxpool | |
self.layer1 = resnet.layer1 | |
self.layer2 = resnet.layer2 | |
self.downsample = nn.Conv2d(512, 256, 1, stride=2) | |
self.bn_downsample = nn.BatchNorm2d(256, eps=1e-5) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.maxpool(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.downsample(x) | |
x = self.bn_downsample(x) | |
return x | |
# HLA Stream | |
class HLA(nn.Module): | |
def __init__(self, in_channels=256, reduction=4): | |
super(HLA, self).__init__() | |
reduced_channels = in_channels // reduction | |
self.spatial_branch1 = nn.Conv2d(in_channels, reduced_channels, 1) | |
self.spatial_branch2 = nn.Conv2d(in_channels, reduced_channels, 1) | |
self.sigmoid = nn.Sigmoid() | |
self.channel_restore = nn.Conv2d(reduced_channels, in_channels, 1) | |
self.channel_attention = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), | |
nn.ReLU(), | |
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False), | |
nn.Sigmoid() | |
) | |
self.bn = nn.BatchNorm2d(in_channels, eps=1e-5) | |
def forward(self, x): | |
b1 = self.spatial_branch1(x) | |
b2 = self.spatial_branch2(x) | |
spatial_attn = self.sigmoid(torch.max(b1, b2)) | |
spatial_attn = self.channel_restore(spatial_attn) | |
spatial_out = x * spatial_attn | |
channel_attn = self.channel_attention(spatial_out) | |
out = spatial_out * channel_attn | |
out = self.bn(out) | |
return out | |
# ViT Stream | |
class ViT(nn.Module): | |
def __init__(self, in_channels=256, patch_size=1, embed_dim=768, num_layers=8, num_heads=12): # 8 layers as in the 82.93% version | |
super(ViT, self).__init__() | |
self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
num_patches = (7 // patch_size) * (7 // patch_size) | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) | |
self.transformer = nn.ModuleList([ | |
nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1536, activation="gelu") | |
for _ in range(num_layers) | |
]) | |
self.ln = nn.LayerNorm(embed_dim) | |
self.bn = nn.BatchNorm1d(embed_dim, eps=1e-5) | |
# Initialize weights | |
nn.init.xavier_uniform_(self.patch_embed.weight) | |
nn.init.zeros_(self.patch_embed.bias) | |
nn.init.normal_(self.cls_token, std=0.02) | |
nn.init.normal_(self.pos_embed, std=0.02) | |
def forward(self, x): | |
x = self.patch_embed(x) | |
x = x.flatten(2).transpose(1, 2) | |
cls_tokens = self.cls_token.expand(x.size(0), -1, -1) | |
x = torch.cat([cls_tokens, x], dim=1) | |
x = x + self.pos_embed | |
for layer in self.transformer: | |
x = layer(x) | |
x = x[:, 0] | |
x = self.ln(x) | |
x = self.bn(x) | |
return x | |
# Intensity Stream | |
class IntensityStream(nn.Module): | |
def __init__(self, in_channels=256): | |
super(IntensityStream, self).__init__() | |
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32) | |
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32) | |
self.sobel_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels) | |
self.sobel_y = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels) | |
self.sobel_x.weight.data = sobel_x.repeat(in_channels, 1, 1, 1) | |
self.sobel_y.weight.data = sobel_y.repeat(in_channels, 1, 1, 1) | |
self.conv = nn.Conv2d(in_channels, 128, 3, padding=1) | |
self.bn = nn.BatchNorm2d(128, eps=1e-5) | |
self.pool = nn.AdaptiveAvgPool2d(1) | |
self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=1) | |
# Initialize weights | |
nn.init.xavier_uniform_(self.conv.weight) | |
nn.init.zeros_(self.conv.bias) | |
def forward(self, x): | |
gx = self.sobel_x(x) | |
gy = self.sobel_y(x) | |
grad_magnitude = torch.sqrt(gx**2 + gy**2 + 1e-8) | |
variance = ((x - x.mean(dim=1, keepdim=True))**2).mean(dim=1).flatten(1) | |
cnn_out = F.relu(self.conv(grad_magnitude)) | |
cnn_out = self.bn(cnn_out) | |
texture_out = self.pool(cnn_out).squeeze(-1).squeeze(-1) | |
attn_in = cnn_out.flatten(2).permute(2, 0, 1) | |
attn_in = attn_in / (attn_in.norm(dim=-1, keepdim=True) + 1e-8) | |
attn_out, _ = self.attention(attn_in, attn_in, attn_in) | |
context_out = attn_out.mean(dim=0) | |
out = torch.cat([texture_out, context_out], dim=1) | |
return out, grad_magnitude, variance | |
# Full Model (Single-Label Prediction) | |
class TripleStreamHLAViT(nn.Module): | |
def __init__(self, num_classes=7): | |
super(TripleStreamHLAViT, self).__init__() | |
self.backbone = IR50() | |
self.hla = HLA() | |
self.vit = ViT() | |
self.intensity = IntensityStream() | |
self.fc_hla = nn.Linear(256*7*7, 768) | |
self.fc_intensity = nn.Linear(256, 768) | |
self.fusion_fc = nn.Linear(768*3, 512) | |
self.bn_fusion = nn.BatchNorm1d(512, eps=1e-5) | |
self.dropout = nn.Dropout(0.5) | |
self.classifier = nn.Linear(512, num_classes) | |
# Initialize weights | |
nn.init.xavier_uniform_(self.fc_hla.weight) | |
nn.init.zeros_(self.fc_hla.bias) | |
nn.init.xavier_uniform_(self.fc_intensity.weight) | |
nn.init.zeros_(self.fc_intensity.bias) | |
nn.init.xavier_uniform_(self.fusion_fc.weight) | |
nn.init.zeros_(self.fusion_fc.bias) | |
nn.init.xavier_uniform_(self.classifier.weight) | |
nn.init.zeros_(self.classifier.bias) | |
def forward(self, x): | |
features = self.backbone(x) | |
hla_out = self.hla(features) | |
vit_out = self.vit(features) | |
intensity_out, grad_magnitude, variance = self.intensity(features) | |
hla_flat = self.fc_hla(hla_out.view(-1, 256*7*7)) | |
intensity_flat = self.fc_intensity(intensity_out) | |
fused = torch.cat([hla_flat, vit_out, intensity_flat], dim=1) | |
fused = F.relu(self.fusion_fc(fused)) | |
fused = self.bn_fusion(fused) | |
fused = self.dropout(fused) | |
logits = self.classifier(fused) | |
return logits, hla_out, vit_out, grad_magnitude, variance | |
# Load the model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
model = TripleStreamHLAViT(num_classes=7).to(device) | |
model_path = "triple_stream_model_rafdb.pth" # Ensure this file is in the Hugging Face Space repository | |
try: | |
# Map the weights to the appropriate device | |
map_location = torch.device('cpu') if not torch.cuda.is_available() else None | |
model.load_state_dict(torch.load(model_path, map_location=map_location, weights_only=True)) | |
model.eval() | |
print("Model loaded successfully") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
# Inference and Visualization Function | |
def predict_emotion(image): | |
# Convert the input image (from Gradio) to PIL Image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Preprocess the image | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
# Run inference | |
with torch.no_grad(): | |
outputs, hla_out, _, grad_magnitude, _ = model(image_tensor) | |
probs = F.softmax(outputs, dim=1) | |
pred_label = torch.argmax(probs, dim=1).item() | |
pred_label_name = class_mapping[pred_label] | |
probabilities = probs.cpu().numpy()[0] | |
# Create probability dictionary | |
prob_dict = {class_mapping[i]: float(prob) for i, prob in enumerate(probabilities)} | |
# Generate HLA heatmap | |
heatmap = hla_out[0].mean(dim=0).detach().cpu().numpy() | |
# Denormalize the image for visualization | |
img = image_tensor[0].permute(1, 2, 0).detach().cpu().numpy() | |
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) | |
img = np.clip(img, 0, 1) | |
# Plot the input image and heatmap | |
fig, axs = plt.subplots(1, 2, figsize=(8, 4)) | |
axs[0].imshow(img) | |
axs[0].set_title(f"Input Image\nPredicted: {pred_label_name}") | |
axs[0].axis("off") | |
axs[1].imshow(heatmap, cmap="jet") | |
axs[1].set_title("HLA Heatmap") | |
axs[1].axis("off") | |
plt.tight_layout() | |
# Save the plot to a temporary file | |
plt.savefig("visualization.png") | |
plt.close() | |
return pred_label_name, prob_dict, "visualization.png" | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=predict_emotion, | |
inputs=gr.Image(type="pil", label="Upload an Image"), | |
outputs=[ | |
gr.Textbox(label="Predicted Emotion"), | |
gr.Label(label="Probabilities"), | |
gr.Image(label="Input Image and HLA Heatmap") | |
], | |
title="Facial Emotion Recognition with TripleStreamHLAViT", | |
description="Upload an image to predict the facial emotion (Surprise, Fear, Disgust, Happiness, Sadness, Anger, Neutral). This model achieves 82.93% test accuracy on the RAF-DB dataset. The HLA heatmap shows where the model focuses.", | |
examples=[ | |
["examples/surprise.jpg"], | |
["examples/sadness.jpg"] | |
] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch(share=False) |