simple-gclm-implementation / modeling_gclm.py
umm-dev's picture
Create modeling_gclm.py
1d31d8f verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
# ============================================================
# Configuration class is assumed to live in configuration_gclm.py
# ============================================================
# Expected fields in GCLMConfig:
# - vocab_size
# - d_model
# - n_layers
# - max_seq_len
# - local_kernel_size
# - global_kernel_size
# - fft_size
# - use_global_every_n_layers
# - layer_norm_eps
# ============================================================
# Global FFT Convolution
# ============================================================
class GlobalConv1D(nn.Module):
def __init__(self, d_model, kernel_size, fft_size):
super().__init__()
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
self.kernel_size = kernel_size
self.fft_size = fft_size
def forward(self, x):
# x: [B, C, T]
B, C, T = x.shape
K = min(self.kernel_size, T)
overlap = K - 1
block = self.fft_size - overlap
x = F.pad(x, (overlap, 0))
k = self.kernel[:, :K]
k = F.pad(k, (0, self.fft_size - K))
k_f = torch.fft.rfft(k, n=self.fft_size)
outs = []
pos = 0
while pos < T:
seg = x[..., pos:pos + self.fft_size]
if seg.shape[-1] < self.fft_size:
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
y = torch.fft.irfft(
torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
n=self.fft_size
)
outs.append(y[..., overlap:overlap + block])
pos += block
return torch.cat(outs, dim=-1)[..., :T]
# ============================================================
# Local Convolution
# ============================================================
class LocalConv1D(nn.Module):
def __init__(self, d_model, k):
super().__init__()
self.k = k
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
self.pw = nn.Conv1d(d_model, d_model, 1)
def forward(self, x):
x = F.pad(x, (self.k - 1, 0))
return self.pw(F.relu(self.dw(x)))
# ============================================================
# GCLM Block
# ============================================================
class GCLMBlock(nn.Module):
def __init__(self, config, use_global):
super().__init__()
self.use_global = use_global
self.ln1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.local = LocalConv1D(
config.d_model,
config.local_kernel_size
)
if use_global:
self.ln2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.global_conv = GlobalConv1D(
config.d_model,
config.global_kernel_size,
config.fft_size
)
self.ln3 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.ff = nn.Sequential(
nn.Linear(config.d_model, config.d_model * 4),
nn.GELU(),
nn.Linear(config.d_model * 4, config.d_model),
)
def forward(self, x):
x = x + self.local(self.ln1(x).transpose(1, 2)).transpose(1, 2)
if self.use_global:
x = x + self.global_conv(self.ln2(x).transpose(1, 2)).transpose(1, 2)
return x + self.ff(self.ln3(x))
# ============================================================
# Base GCLM Model
# ============================================================
class GCLMModel(PreTrainedModel):
config_class = None # set by AutoConfig
base_model_prefix = "gclm"
def __init__(self, config):
super().__init__(config)
self.emb = nn.Embedding(config.vocab_size, config.d_model)
self.pos = nn.Embedding(config.max_seq_len, config.d_model)
self.layers = nn.ModuleList([
GCLMBlock(
config,
use_global=(i % config.use_global_every_n_layers == 0)
)
for i in range(config.n_layers)
])
self.ln = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.post_init()
def forward(self, input_ids):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
h = self.emb(input_ids) + self.pos(pos)
for layer in self.layers:
h = layer(h)
return self.ln(h)
# ============================================================
# Causal LM Head
# ============================================================
class GCLMForCausalLM(PreTrainedModel):
config_class = None
base_model_prefix = "gclm"
def __init__(self, config):
super().__init__(config)
self.gclm = GCLMModel(config)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def forward(
self,
input_ids,
labels=None,
**kwargs
):
hidden = self.gclm(input_ids)
logits = self.lm_head(hidden)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
return CausalLMOutput(
loss=loss,
logits=logits
)