SigMamba-V1 / modeling_sigmamba.py
Vinay
Upload folder using huggingface_hub
cb78aa5 verified
"""
SigMamba: Unified Video Anomaly Detection Model
Combines SigLIP vision encoder with Mamba temporal reasoning.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModel, AutoConfig
from .configuration_sigmamba import SigMambaConfig
try:
from mamba_ssm import Mamba
IS_OFFICIAL_MAMBA = True
except ImportError:
IS_OFFICIAL_MAMBA = False
Mamba = None
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return output * self.weight
class MambaSSM(nn.Module):
"""
Pure PyTorch Mamba SSM implementation.
Fallback for systems without official CUDA kernels.
"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=True,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
)
self.x_proj = nn.Linear(self.d_inner, int(self.d_inner // 16) + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(int(self.d_inner // 16), self.d_inner, bias=True)
A = torch.arange(1, self.d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
self._init_weights()
def _init_weights(self):
"""Initialize dt_proj bias for stable training."""
import math
dt_min, dt_max = 0.001, 0.1
dt_init_std = 0.02
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
dt = torch.exp(
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
).clamp(min=1e-4)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
def forward(self, x):
batch, seq_len, _ = x.shape
xz = self.in_proj(x)
x_proj, z = xz.chunk(2, dim=-1)
x_proj = x_proj.transpose(1, 2)
x_proj = self.conv1d(x_proj)[:, :, :seq_len]
x_proj = x_proj.transpose(1, 2)
x_proj = F.silu(x_proj)
x_dbl = self.x_proj(x_proj)
d_rank = int(self.d_inner // 16)
delta, B, C = torch.split(x_dbl, [d_rank, self.d_state, self.d_state], dim=-1)
delta = F.softplus(self.dt_proj(delta))
A = -torch.exp(self.A_log)
y = self.selective_scan_seq(x_proj, delta, A, B, C, self.D)
y = y * F.silu(z)
return self.out_proj(y)
def selective_scan_seq(self, u, delta, A, B, C, D):
"""Sequential selective scan (S6 recurrence)."""
b_size, l, d_in = u.shape
d_state = A.shape[1]
h = torch.zeros(b_size, d_in, d_state, device=u.device)
ys = []
deltaA = torch.exp(torch.einsum('bld,dn->bldn', delta, A))
deltaB_u = torch.einsum('bld,bln,bld->bldn', delta, B, u)
for i in range(l):
h = h * deltaA[:, i] + deltaB_u[:, i]
y = torch.einsum('bdn,bn->bd', h, C[:, i])
ys.append(y)
y = torch.stack(ys, dim=1)
return y + u * D
class MambaBlock(nn.Module):
"""Single Mamba layer with residual connection."""
def __init__(self, config, d_model, depth_idx=0):
super().__init__()
self.norm = RMSNorm(d_model)
if IS_OFFICIAL_MAMBA:
self.mixer = Mamba(
d_model=d_model,
d_state=config.d_state,
d_conv=config.d_conv,
expand=config.expand
)
else:
self.mixer = MambaSSM(
d_model=d_model,
d_state=config.d_state,
d_conv=config.d_conv,
expand=config.expand
)
def forward(self, x):
return x + self.mixer(self.norm(x))
class MambaEncoder(nn.Module):
"""Stacked Mamba blocks for temporal encoding."""
def __init__(self, config):
super().__init__()
self.embedding = nn.Linear(config.feature_dim, config.d_model)
self.layers = nn.ModuleList([
MambaBlock(config, d_model=config.d_model, depth_idx=i)
for i in range(config.depth)
])
self.norm_f = RMSNorm(config.d_model)
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
return self.norm_f(x)
class SigMambaPreTrainedModel(PreTrainedModel):
"""Base class for SigMamba models."""
config_class = SigMambaConfig
base_model_prefix = "sigmamba"
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class SigMambaForVideoClassification(SigMambaPreTrainedModel):
"""
SigMamba model for video anomaly detection.
Supports two input modes:
- features: Pre-extracted embeddings (B, T, 1024)
- pixel_values: Raw video frames (B, T, C, H, W)
"""
def __init__(self, config):
super().__init__(config)
self.config = config
# Vision tower (structure only, weights loaded from checkpoint)
vision_config = AutoConfig.from_pretrained(config.vision_model_id)
self.vision_model = AutoModel.from_config(vision_config)
# Temporal encoder
self.mamba_encoder = MambaEncoder(config)
# Classification head
self.fc_head = nn.Sequential(
nn.Linear(config.d_model, 128),
nn.LeakyReLU(negative_slope=5e-2),
nn.Dropout(0.2),
nn.Linear(128, config.num_classes),
nn.Sigmoid()
)
self.post_init()
def forward(self, features=None, pixel_values=None):
"""
Args:
features: Pre-extracted features (B, T, 1024)
pixel_values: Raw video frames (B, T, C, H, W)
Returns:
scores: Anomaly scores (B, T, 1)
"""
# Path A: Unified mode (pixels -> features -> scores)
if pixel_values is not None and features is None:
if pixel_values.dim() == 5:
b, t, c, h, w = pixel_values.shape
flat_pixels = pixel_values.view(b * t, c, h, w)
else:
flat_pixels = pixel_values
b, t = flat_pixels.shape[0], 1
# Extract and normalize features
if hasattr(self.vision_model, 'get_image_features'):
flat_features = self.vision_model.get_image_features(pixel_values=flat_pixels)
if not isinstance(flat_features, torch.Tensor):
flat_features = getattr(flat_features, 'pooler_output',
getattr(flat_features, 'image_embeds', flat_features[0]))
flat_features = flat_features / flat_features.norm(p=2, dim=-1, keepdim=True)
else:
vision_outputs = self.vision_model(pixel_values=flat_pixels)
flat_features = getattr(vision_outputs, 'pooler_output',
getattr(vision_outputs, 'image_embeds', vision_outputs[0]))
# Reshape to (B, T, D)
features = flat_features.view(b, t, -1) if pixel_values.dim() == 5 else flat_features.unsqueeze(1)
# Path B: Modular mode (features -> scores)
if features is None:
raise ValueError("You must provide either 'features' or 'pixel_values'")
# Encode and classify
x = self.mamba_encoder(features)
scores = self.fc_head(x)
return scores