umm-dev commited on
Commit
1d31d8f
·
verified ·
1 Parent(s): 95546e8

Create modeling_gclm.py

Browse files
Files changed (1) hide show
  1. modeling_gclm.py +195 -0
modeling_gclm.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import CausalLMOutput
8
+
9
+
10
+ # ============================================================
11
+ # Configuration class is assumed to live in configuration_gclm.py
12
+ # ============================================================
13
+
14
+ # Expected fields in GCLMConfig:
15
+ # - vocab_size
16
+ # - d_model
17
+ # - n_layers
18
+ # - max_seq_len
19
+ # - local_kernel_size
20
+ # - global_kernel_size
21
+ # - fft_size
22
+ # - use_global_every_n_layers
23
+ # - layer_norm_eps
24
+
25
+
26
+ # ============================================================
27
+ # Global FFT Convolution
28
+ # ============================================================
29
+
30
+ class GlobalConv1D(nn.Module):
31
+ def __init__(self, d_model, kernel_size, fft_size):
32
+ super().__init__()
33
+ self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
34
+ self.kernel_size = kernel_size
35
+ self.fft_size = fft_size
36
+
37
+ def forward(self, x):
38
+ # x: [B, C, T]
39
+ B, C, T = x.shape
40
+ K = min(self.kernel_size, T)
41
+
42
+ overlap = K - 1
43
+ block = self.fft_size - overlap
44
+
45
+ x = F.pad(x, (overlap, 0))
46
+ k = self.kernel[:, :K]
47
+ k = F.pad(k, (0, self.fft_size - K))
48
+
49
+ k_f = torch.fft.rfft(k, n=self.fft_size)
50
+
51
+ outs = []
52
+ pos = 0
53
+ while pos < T:
54
+ seg = x[..., pos:pos + self.fft_size]
55
+ if seg.shape[-1] < self.fft_size:
56
+ seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
57
+
58
+ y = torch.fft.irfft(
59
+ torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
60
+ n=self.fft_size
61
+ )
62
+ outs.append(y[..., overlap:overlap + block])
63
+ pos += block
64
+
65
+ return torch.cat(outs, dim=-1)[..., :T]
66
+
67
+
68
+ # ============================================================
69
+ # Local Convolution
70
+ # ============================================================
71
+
72
+ class LocalConv1D(nn.Module):
73
+ def __init__(self, d_model, k):
74
+ super().__init__()
75
+ self.k = k
76
+ self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
77
+ self.pw = nn.Conv1d(d_model, d_model, 1)
78
+
79
+ def forward(self, x):
80
+ x = F.pad(x, (self.k - 1, 0))
81
+ return self.pw(F.relu(self.dw(x)))
82
+
83
+
84
+ # ============================================================
85
+ # GCLM Block
86
+ # ============================================================
87
+
88
+ class GCLMBlock(nn.Module):
89
+ def __init__(self, config, use_global):
90
+ super().__init__()
91
+ self.use_global = use_global
92
+
93
+ self.ln1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
94
+ self.local = LocalConv1D(
95
+ config.d_model,
96
+ config.local_kernel_size
97
+ )
98
+
99
+ if use_global:
100
+ self.ln2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
101
+ self.global_conv = GlobalConv1D(
102
+ config.d_model,
103
+ config.global_kernel_size,
104
+ config.fft_size
105
+ )
106
+
107
+ self.ln3 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
108
+ self.ff = nn.Sequential(
109
+ nn.Linear(config.d_model, config.d_model * 4),
110
+ nn.GELU(),
111
+ nn.Linear(config.d_model * 4, config.d_model),
112
+ )
113
+
114
+ def forward(self, x):
115
+ x = x + self.local(self.ln1(x).transpose(1, 2)).transpose(1, 2)
116
+ if self.use_global:
117
+ x = x + self.global_conv(self.ln2(x).transpose(1, 2)).transpose(1, 2)
118
+ return x + self.ff(self.ln3(x))
119
+
120
+
121
+ # ============================================================
122
+ # Base GCLM Model
123
+ # ============================================================
124
+
125
+ class GCLMModel(PreTrainedModel):
126
+ config_class = None # set by AutoConfig
127
+ base_model_prefix = "gclm"
128
+
129
+ def __init__(self, config):
130
+ super().__init__(config)
131
+
132
+ self.emb = nn.Embedding(config.vocab_size, config.d_model)
133
+ self.pos = nn.Embedding(config.max_seq_len, config.d_model)
134
+
135
+ self.layers = nn.ModuleList([
136
+ GCLMBlock(
137
+ config,
138
+ use_global=(i % config.use_global_every_n_layers == 0)
139
+ )
140
+ for i in range(config.n_layers)
141
+ ])
142
+
143
+ self.ln = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
144
+
145
+ self.post_init()
146
+
147
+ def forward(self, input_ids):
148
+ B, T = input_ids.shape
149
+ pos = torch.arange(T, device=input_ids.device)
150
+
151
+ h = self.emb(input_ids) + self.pos(pos)
152
+
153
+ for layer in self.layers:
154
+ h = layer(h)
155
+
156
+ return self.ln(h)
157
+
158
+
159
+ # ============================================================
160
+ # Causal LM Head
161
+ # ============================================================
162
+
163
+ class GCLMForCausalLM(PreTrainedModel):
164
+ config_class = None
165
+ base_model_prefix = "gclm"
166
+
167
+ def __init__(self, config):
168
+ super().__init__(config)
169
+
170
+ self.gclm = GCLMModel(config)
171
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
172
+
173
+ self.post_init()
174
+
175
+ def forward(
176
+ self,
177
+ input_ids,
178
+ labels=None,
179
+ **kwargs
180
+ ):
181
+ hidden = self.gclm(input_ids)
182
+ logits = self.lm_head(hidden)
183
+
184
+ loss = None
185
+ if labels is not None:
186
+ loss = F.cross_entropy(
187
+ logits.view(-1, logits.size(-1)),
188
+ labels.view(-1),
189
+ ignore_index=-100
190
+ )
191
+
192
+ return CausalLMOutput(
193
+ loss=loss,
194
+ logits=logits
195
+ )