SakibRumu commited on
Commit
2894e20
·
verified ·
1 Parent(s): e432586

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -56
app.py CHANGED
@@ -1,50 +1,72 @@
1
  import torch
2
- import timm
3
  import torch.nn as nn
4
  import gradio as gr
 
5
  from PIL import Image
6
- from torchvision import transforms
7
 
8
- # Define your custom model architecture (HybridCNNTransformer in this case)
9
  class HybridCNNTransformer(nn.Module):
10
  def __init__(self, num_classes=7):
11
  super(HybridCNNTransformer, self).__init__()
12
-
13
- # Example: Using ResNet50 from timm as a CNN feature extractor
14
- self.backbone = timm.create_model('resnet50', pretrained=True)
15
-
16
- # Example Transformer part (modify according to your model)
17
- self.transformer = nn.Transformer(d_model=2048, nhead=8, num_encoder_layers=6)
18
-
19
- # Final fully connected layer (7 classes for emotion recognition)
20
- self.fc = nn.Linear(2048, num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def forward(self, x):
23
- # CNN feature extraction
24
- cnn_features = self.backbone(x)
25
-
26
- # Transformer encoding (if applicable, you might not need this part)
27
- transformer_features = self.transformer(cnn_features, cnn_features)
28
-
29
- # Final classification layer
30
- output = self.fc(transformer_features)
31
  return output
32
 
33
- # Load the model
34
  model = HybridCNNTransformer(num_classes=7)
 
 
 
35
 
36
- # Load the weights from the .pth file
37
- model_path = "transformer_emotion_recognition_model.pth" # Replace with the path to your .pth file
38
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # For CPU; change 'cpu' to 'cuda' for GPU
39
-
40
- # Move the model to the appropriate device (CUDA or CPU)
41
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- model.to(device)
43
 
44
- # Set the model to evaluation mode
45
- model.eval()
 
 
 
 
 
 
 
 
 
46
 
47
- # Custom CSS for layout styling
48
  css = """
49
  body {
50
  background-color: #1e1e1e;
@@ -76,29 +98,6 @@ body {
76
  }
77
  """
78
 
79
- # Image Preprocessing for the model (assuming the model was trained with resized and normalized images)
80
- preprocess = transforms.Compose([
81
- transforms.Resize((224, 224)), # Adjust according to your model's input size
82
- transforms.ToTensor(),
83
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
84
- ])
85
-
86
- # Prediction function
87
- def predict_emotion(image):
88
- # Preprocess the image
89
- image_tensor = preprocess(image).unsqueeze(0).to(device) # Add batch dimension and move to device
90
-
91
- # Make prediction
92
- with torch.no_grad():
93
- output = model(image_tensor)
94
- _, predicted = torch.max(output, 1) # Get the predicted class
95
- confidence = torch.nn.functional.softmax(output, dim=1).max().item() # Confidence score
96
-
97
- # Return the predicted emotion label and confidence score
98
- emotions = ["Anger", "Disgust", "Fear", "Happiness", "Sadness", "Surprise", "Neutral"] # Modify labels as per your model
99
- predicted_emotion = emotions[predicted.item()]
100
- return predicted_emotion, confidence
101
-
102
  # Gradio Interface
103
  iface = gr.Interface(
104
  fn=predict_emotion,
@@ -106,10 +105,10 @@ iface = gr.Interface(
106
  outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
107
  live=True,
108
  title="Emotion Classification",
109
- description="Upload an image to predict the emotion expressed in the image using a fine-tuned SE-ResNet50 model.",
110
  css=css
111
  )
112
 
113
  # Launch the app
114
  if __name__ == "__main__":
115
- iface.launch()
 
1
  import torch
 
2
  import torch.nn as nn
3
  import gradio as gr
4
+ from torchvision import models, transforms
5
  from PIL import Image
6
+ from transformers import ViTModel
7
 
8
+ # Define HybridCNNTransformer Model
9
  class HybridCNNTransformer(nn.Module):
10
  def __init__(self, num_classes=7):
11
  super(HybridCNNTransformer, self).__init__()
12
+
13
+ # CNN Feature Extractor (ResNet50)
14
+ self.cnn = models.resnet50(pretrained=True)
15
+ self.cnn = nn.Sequential(*list(self.cnn.children())[:-2]) # Remove FC layers
16
+
17
+ # Reduce channels (2048 64)
18
+ self.channel_reduction = nn.Conv2d(in_channels=2048, out_channels=64, kernel_size=1)
19
+
20
+ # Convert to 3 channels for ViT
21
+ self.to_rgb = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
22
+
23
+ # Vision Transformer
24
+ self.transformer = ViTModel.from_pretrained("google/vit-base-patch16-224")
25
+
26
+ # Fully Connected Layers (Classifier Head)
27
+ self.fc = nn.Sequential(
28
+ nn.Linear(768, 512),
29
+ nn.ReLU(),
30
+ nn.Dropout(0.3),
31
+ nn.Linear(512, num_classes)
32
+ )
33
 
34
  def forward(self, x):
35
+ cnn_features = self.cnn(x)
36
+ reduced_features = self.channel_reduction(cnn_features)
37
+ rgb_features = self.to_rgb(reduced_features)
38
+ resized_features = nn.functional.interpolate(rgb_features, size=(224, 224), mode="bilinear", align_corners=False)
39
+
40
+ transformer_output = self.transformer(pixel_values=resized_features).last_hidden_state[:, 0, :]
41
+ output = self.fc(transformer_output)
 
42
  return output
43
 
44
+ # Load Model
45
  model = HybridCNNTransformer(num_classes=7)
46
+ state_dict = torch.load("transformer_emotion_recognition_model.pth", map_location=torch.device('cpu'))
47
+ model.load_state_dict(state_dict, strict=False)
48
+ model.eval()
49
 
50
+ # Define Preprocessing Transform
51
+ transform = transforms.Compose([
52
+ transforms.Resize((224, 224)),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
+ ])
 
56
 
57
+ # Define Prediction Function
58
+ def predict_emotion(image):
59
+ image = transform(image).unsqueeze(0) # Add batch dimension
60
+ with torch.no_grad():
61
+ output = model(image)
62
+ probabilities = torch.nn.functional.softmax(output, dim=1)
63
+ confidence, predicted_class = torch.max(probabilities, 1)
64
+
65
+ class_labels = ["Angry", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
66
+ predicted_emotion = class_labels[predicted_class.item()]
67
+ return predicted_emotion, f"{confidence.item() * 100:.2f}%"
68
 
69
+ # Custom CSS for UI Styling
70
  css = """
71
  body {
72
  background-color: #1e1e1e;
 
98
  }
99
  """
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Gradio Interface
102
  iface = gr.Interface(
103
  fn=predict_emotion,
 
105
  outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
106
  live=True,
107
  title="Emotion Classification",
108
+ description="Upload an image to predict the emotion expressed in the image using a fine-tuned ResNet50 + Vision Transformer model.",
109
  css=css
110
  )
111
 
112
  # Launch the app
113
  if __name__ == "__main__":
114
+ iface.launch()