import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput # ============================================================ # Configuration class is assumed to live in configuration_gclm.py # ============================================================ # Expected fields in GCLMConfig: # - vocab_size # - d_model # - n_layers # - max_seq_len # - local_kernel_size # - global_kernel_size # - fft_size # - use_global_every_n_layers # - layer_norm_eps # ============================================================ # Global FFT Convolution # ============================================================ class GlobalConv1D(nn.Module): def __init__(self, d_model, kernel_size, fft_size): super().__init__() self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) self.kernel_size = kernel_size self.fft_size = fft_size def forward(self, x): # x: [B, C, T] B, C, T = x.shape K = min(self.kernel_size, T) overlap = K - 1 block = self.fft_size - overlap x = F.pad(x, (overlap, 0)) k = self.kernel[:, :K] k = F.pad(k, (0, self.fft_size - K)) k_f = torch.fft.rfft(k, n=self.fft_size) outs = [] pos = 0 while pos < T: seg = x[..., pos:pos + self.fft_size] if seg.shape[-1] < self.fft_size: seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) y = torch.fft.irfft( torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size ) outs.append(y[..., overlap:overlap + block]) pos += block return torch.cat(outs, dim=-1)[..., :T] # ============================================================ # Local Convolution # ============================================================ class LocalConv1D(nn.Module): def __init__(self, d_model, k): super().__init__() self.k = k self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) self.pw = nn.Conv1d(d_model, d_model, 1) def forward(self, x): x = F.pad(x, (self.k - 1, 0)) return self.pw(F.relu(self.dw(x))) # ============================================================ # GCLM Block # ============================================================ class GCLMBlock(nn.Module): def __init__(self, config, use_global): super().__init__() self.use_global = use_global self.ln1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) self.local = LocalConv1D( config.d_model, config.local_kernel_size ) if use_global: self.ln2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) self.global_conv = GlobalConv1D( config.d_model, config.global_kernel_size, config.fft_size ) self.ln3 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) self.ff = nn.Sequential( nn.Linear(config.d_model, config.d_model * 4), nn.GELU(), nn.Linear(config.d_model * 4, config.d_model), ) def forward(self, x): x = x + self.local(self.ln1(x).transpose(1, 2)).transpose(1, 2) if self.use_global: x = x + self.global_conv(self.ln2(x).transpose(1, 2)).transpose(1, 2) return x + self.ff(self.ln3(x)) # ============================================================ # Base GCLM Model # ============================================================ class GCLMModel(PreTrainedModel): config_class = None # set by AutoConfig base_model_prefix = "gclm" def __init__(self, config): super().__init__(config) self.emb = nn.Embedding(config.vocab_size, config.d_model) self.pos = nn.Embedding(config.max_seq_len, config.d_model) self.layers = nn.ModuleList([ GCLMBlock( config, use_global=(i % config.use_global_every_n_layers == 0) ) for i in range(config.n_layers) ]) self.ln = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) self.post_init() def forward(self, input_ids): B, T = input_ids.shape pos = torch.arange(T, device=input_ids.device) h = self.emb(input_ids) + self.pos(pos) for layer in self.layers: h = layer(h) return self.ln(h) # ============================================================ # Causal LM Head # ============================================================ class GCLMForCausalLM(PreTrainedModel): config_class = None base_model_prefix = "gclm" def __init__(self, config): super().__init__(config) self.gclm = GCLMModel(config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() def forward( self, input_ids, labels=None, **kwargs ): hidden = self.gclm(input_ids) logits = self.lm_head(hidden) loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100 ) return CausalLMOutput( loss=loss, logits=logits )