Spaces:
Sleeping
Sleeping
SakibRumu
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,50 +1,72 @@
|
|
1 |
import torch
|
2 |
-
import timm
|
3 |
import torch.nn as nn
|
4 |
import gradio as gr
|
|
|
5 |
from PIL import Image
|
6 |
-
from
|
7 |
|
8 |
-
# Define
|
9 |
class HybridCNNTransformer(nn.Module):
|
10 |
def __init__(self, num_classes=7):
|
11 |
super(HybridCNNTransformer, self).__init__()
|
12 |
-
|
13 |
-
#
|
14 |
-
self.
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def forward(self, x):
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
output = self.fc(transformer_features)
|
31 |
return output
|
32 |
|
33 |
-
# Load
|
34 |
model = HybridCNNTransformer(num_classes=7)
|
|
|
|
|
|
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
model.to(device)
|
43 |
|
44 |
-
#
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
# Custom CSS for
|
48 |
css = """
|
49 |
body {
|
50 |
background-color: #1e1e1e;
|
@@ -76,29 +98,6 @@ body {
|
|
76 |
}
|
77 |
"""
|
78 |
|
79 |
-
# Image Preprocessing for the model (assuming the model was trained with resized and normalized images)
|
80 |
-
preprocess = transforms.Compose([
|
81 |
-
transforms.Resize((224, 224)), # Adjust according to your model's input size
|
82 |
-
transforms.ToTensor(),
|
83 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
|
84 |
-
])
|
85 |
-
|
86 |
-
# Prediction function
|
87 |
-
def predict_emotion(image):
|
88 |
-
# Preprocess the image
|
89 |
-
image_tensor = preprocess(image).unsqueeze(0).to(device) # Add batch dimension and move to device
|
90 |
-
|
91 |
-
# Make prediction
|
92 |
-
with torch.no_grad():
|
93 |
-
output = model(image_tensor)
|
94 |
-
_, predicted = torch.max(output, 1) # Get the predicted class
|
95 |
-
confidence = torch.nn.functional.softmax(output, dim=1).max().item() # Confidence score
|
96 |
-
|
97 |
-
# Return the predicted emotion label and confidence score
|
98 |
-
emotions = ["Anger", "Disgust", "Fear", "Happiness", "Sadness", "Surprise", "Neutral"] # Modify labels as per your model
|
99 |
-
predicted_emotion = emotions[predicted.item()]
|
100 |
-
return predicted_emotion, confidence
|
101 |
-
|
102 |
# Gradio Interface
|
103 |
iface = gr.Interface(
|
104 |
fn=predict_emotion,
|
@@ -106,10 +105,10 @@ iface = gr.Interface(
|
|
106 |
outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
|
107 |
live=True,
|
108 |
title="Emotion Classification",
|
109 |
-
description="Upload an image to predict the emotion expressed in the image using a fine-tuned
|
110 |
css=css
|
111 |
)
|
112 |
|
113 |
# Launch the app
|
114 |
if __name__ == "__main__":
|
115 |
-
iface.launch()
|
|
|
1 |
import torch
|
|
|
2 |
import torch.nn as nn
|
3 |
import gradio as gr
|
4 |
+
from torchvision import models, transforms
|
5 |
from PIL import Image
|
6 |
+
from transformers import ViTModel
|
7 |
|
8 |
+
# Define HybridCNNTransformer Model
|
9 |
class HybridCNNTransformer(nn.Module):
|
10 |
def __init__(self, num_classes=7):
|
11 |
super(HybridCNNTransformer, self).__init__()
|
12 |
+
|
13 |
+
# CNN Feature Extractor (ResNet50)
|
14 |
+
self.cnn = models.resnet50(pretrained=True)
|
15 |
+
self.cnn = nn.Sequential(*list(self.cnn.children())[:-2]) # Remove FC layers
|
16 |
+
|
17 |
+
# Reduce channels (2048 → 64)
|
18 |
+
self.channel_reduction = nn.Conv2d(in_channels=2048, out_channels=64, kernel_size=1)
|
19 |
+
|
20 |
+
# Convert to 3 channels for ViT
|
21 |
+
self.to_rgb = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
|
22 |
+
|
23 |
+
# Vision Transformer
|
24 |
+
self.transformer = ViTModel.from_pretrained("google/vit-base-patch16-224")
|
25 |
+
|
26 |
+
# Fully Connected Layers (Classifier Head)
|
27 |
+
self.fc = nn.Sequential(
|
28 |
+
nn.Linear(768, 512),
|
29 |
+
nn.ReLU(),
|
30 |
+
nn.Dropout(0.3),
|
31 |
+
nn.Linear(512, num_classes)
|
32 |
+
)
|
33 |
|
34 |
def forward(self, x):
|
35 |
+
cnn_features = self.cnn(x)
|
36 |
+
reduced_features = self.channel_reduction(cnn_features)
|
37 |
+
rgb_features = self.to_rgb(reduced_features)
|
38 |
+
resized_features = nn.functional.interpolate(rgb_features, size=(224, 224), mode="bilinear", align_corners=False)
|
39 |
+
|
40 |
+
transformer_output = self.transformer(pixel_values=resized_features).last_hidden_state[:, 0, :]
|
41 |
+
output = self.fc(transformer_output)
|
|
|
42 |
return output
|
43 |
|
44 |
+
# Load Model
|
45 |
model = HybridCNNTransformer(num_classes=7)
|
46 |
+
state_dict = torch.load("transformer_emotion_recognition_model.pth", map_location=torch.device('cpu'))
|
47 |
+
model.load_state_dict(state_dict, strict=False)
|
48 |
+
model.eval()
|
49 |
|
50 |
+
# Define Preprocessing Transform
|
51 |
+
transform = transforms.Compose([
|
52 |
+
transforms.Resize((224, 224)),
|
53 |
+
transforms.ToTensor(),
|
54 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
55 |
+
])
|
|
|
56 |
|
57 |
+
# Define Prediction Function
|
58 |
+
def predict_emotion(image):
|
59 |
+
image = transform(image).unsqueeze(0) # Add batch dimension
|
60 |
+
with torch.no_grad():
|
61 |
+
output = model(image)
|
62 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
63 |
+
confidence, predicted_class = torch.max(probabilities, 1)
|
64 |
+
|
65 |
+
class_labels = ["Angry", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
|
66 |
+
predicted_emotion = class_labels[predicted_class.item()]
|
67 |
+
return predicted_emotion, f"{confidence.item() * 100:.2f}%"
|
68 |
|
69 |
+
# Custom CSS for UI Styling
|
70 |
css = """
|
71 |
body {
|
72 |
background-color: #1e1e1e;
|
|
|
98 |
}
|
99 |
"""
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
# Gradio Interface
|
102 |
iface = gr.Interface(
|
103 |
fn=predict_emotion,
|
|
|
105 |
outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
|
106 |
live=True,
|
107 |
title="Emotion Classification",
|
108 |
+
description="Upload an image to predict the emotion expressed in the image using a fine-tuned ResNet50 + Vision Transformer model.",
|
109 |
css=css
|
110 |
)
|
111 |
|
112 |
# Launch the app
|
113 |
if __name__ == "__main__":
|
114 |
+
iface.launch()
|