DrGM-FastViT-Multimodal-FER
Model Description
This is a State-of-the-Art (SOTA) Multimodal Facial Emotion Recognition (FER) model designed for extreme efficiency on edge devices.
It combines two powerful architectures:
- Vision Branch:
fastvit_sa24(Apple's SOTA efficient Transformer/CNN hybrid) for visual texture analysis. - Geometric Branch: A lightweight MLP processing 478 3D Facial Landmarks extracted via MediaPipe.
This fusion allows the model to achieve >91% Accuracy while maintaining real-time inference speeds (30-100+ FPS) on weak devices.
⚠️ License & Usage
This model is released under the CC-BY-NC-4.0 license.
- Personal Use: ✅ Allowed. You can use this for personal projects, research, and education.
- Commercial Use: ❌ Forbidden without prior permission.
- Commissions: If you wish to use this model for commercial applications or commissions, please contact the author for licensing.
Performance
The model achieves exceptional performance on the Facial Emotion Expressions dataset.
Classification Report (Test Set)
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| Angry | 0.89 | 0.95 | 0.92 | 1797 |
| Disgust | 1.00 | 1.00 | 1.00 | 1798 |
| Fear | 0.88 | 0.93 | 0.90 | 1798 |
| Happy | 0.92 | 0.89 | 0.91 | 1798 |
| Neutral | 0.83 | 0.84 | 0.84 | 1798 |
| Sad | 0.88 | 0.76 | 0.82 | 1798 |
| Surprise | 0.95 | 0.97 | 0.96 | 1798 |
| Accuracy | 0.91 | 12585 |
Confusion Matrix
Usage (Inference)
To use this model, you need timm, mediapipe, and torch.
import torch
import timm
import mediapipe as mp
import numpy as np
from PIL import Image
from torchvision import transforms
# 1. Define Model Architecture (Same as training)
class MultimodalFERModel(torch.nn.Module):
def __init__(self, num_classes=7):
super().__init__()
self.vision_backbone = timm.create_model('fastvit_sa24.apple_in1k', num_classes=0)
self.landmark_encoder = torch.nn.Sequential(
torch.nn.Linear(478*3, 512), torch.nn.BatchNorm1d(512), torch.nn.ReLU(),
torch.nn.Linear(512, 256), torch.nn.BatchNorm1d(256), torch.nn.ReLU()
)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.vision_backbone.num_features + 256, 512),
torch.nn.BatchNorm1d(512), torch.nn.ReLU(),
torch.nn.Linear(512, num_classes)
)
def forward(self, pixel_values, landmarks):
v = self.vision_backbone(pixel_values)
l = self.landmark_encoder(landmarks)
return self.classifier(torch.cat((v, l), dim=1))
# 2. Load Model
model = MultimodalFERModel()
# Download weights from Hub (manually or via API) and load
# model.load_state_dict(torch.hub.load_state_dict_from_url('...'))
model.eval()
# 3. Preprocessing (MediaPipe + Transforms)
mp_face = mp.solutions.face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1)
val_tf = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(image_path):
img = Image.open(image_path).convert('RGB')
# Vision Input
pixel_values = val_tf(img).unsqueeze(0)
# Landmark Input
results = mp_face.process(np.array(img))
if results.multi_face_landmarks:
lm = np.array([[l.x, l.y, l.z] for l in results.multi_face_landmarks[0].landmark]).flatten()
else:
lm = np.zeros(478*3)
landmarks = torch.tensor(lm, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
logits = model(pixel_values, landmarks)
pred = logits.argmax(1).item()
return pred
- Downloads last month
- -
Model tree for DrGM/DrGM-FastViT-nano-Multimodal-Facial-Emotion-Recognition
Base model
timm/fastvit_sa24.apple_in1k