sagar007 commited on
Commit
ffb63d1
·
verified ·
1 Parent(s): 9fc75bd

Upload projectors.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. projectors.py +93 -0
projectors.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Projection layers for multimodal fusion
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+
8
+
9
+ class VisionProjector(nn.Module):
10
+ """Projects vision features to language model embedding space"""
11
+
12
+ def __init__(
13
+ self,
14
+ vision_dim: int,
15
+ language_dim: int,
16
+ hidden_dim: Optional[int] = None,
17
+ dropout: float = 0.1
18
+ ):
19
+ super().__init__()
20
+ hidden_dim = hidden_dim or language_dim
21
+
22
+ self.projector = nn.Sequential(
23
+ nn.Linear(vision_dim, hidden_dim),
24
+ nn.GELU(),
25
+ nn.Dropout(dropout),
26
+ nn.Linear(hidden_dim, language_dim),
27
+ nn.LayerNorm(language_dim)
28
+ )
29
+
30
+ # Initialize weights
31
+ self._init_weights()
32
+
33
+ def _init_weights(self):
34
+ """Initialize projection weights"""
35
+ for module in self.modules():
36
+ if isinstance(module, nn.Linear):
37
+ nn.init.normal_(module.weight, std=0.02)
38
+ if module.bias is not None:
39
+ nn.init.zeros_(module.bias)
40
+ elif isinstance(module, nn.LayerNorm):
41
+ nn.init.ones_(module.weight)
42
+ nn.init.zeros_(module.bias)
43
+
44
+ def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Args:
47
+ vision_features: [batch_size, vision_dim]
48
+ Returns:
49
+ projected_features: [batch_size, language_dim]
50
+ """
51
+ return self.projector(vision_features)
52
+
53
+
54
+ class AudioProjector(nn.Module):
55
+ """Projects audio features to language model embedding space"""
56
+
57
+ def __init__(
58
+ self,
59
+ audio_dim: int,
60
+ language_dim: int,
61
+ dropout: float = 0.1
62
+ ):
63
+ super().__init__()
64
+
65
+ self.projector = nn.Sequential(
66
+ nn.Linear(audio_dim, language_dim),
67
+ nn.GELU(),
68
+ nn.Dropout(dropout),
69
+ nn.LayerNorm(language_dim)
70
+ )
71
+
72
+ # Initialize weights
73
+ self._init_weights()
74
+
75
+ def _init_weights(self):
76
+ """Initialize projection weights"""
77
+ for module in self.modules():
78
+ if isinstance(module, nn.Linear):
79
+ nn.init.normal_(module.weight, std=0.02)
80
+ if module.bias is not None:
81
+ nn.init.zeros_(module.bias)
82
+ elif isinstance(module, nn.LayerNorm):
83
+ nn.init.ones_(module.weight)
84
+ nn.init.zeros_(module.bias)
85
+
86
+ def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
87
+ """
88
+ Args:
89
+ audio_features: [batch_size, audio_dim]
90
+ Returns:
91
+ projected_features: [batch_size, language_dim]
92
+ """
93
+ return self.projector(audio_features)