"""RoFormer model with projection head for classification. This module provides a RoFormer-based model with a projection head for contrastive learning, enabling both classification and embedding-based similarity search for file type detection. """ from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import RoFormerModel, RoFormerPreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput try: from .configuration_roformer_classification import RoFormerClassificationConfig except ImportError: from configuration_roformer_classification import RoFormerClassificationConfig class RoFormerForSequenceClassificationWithProjection(RoFormerPreTrainedModel): """RoFormer with projection head for file type classification. This model extends RoFormer with a projection head that produces L2-normalized embeddings suitable for both classification and similarity search. The architecture is: RoFormer -> CLS pooling -> Projection -> L2 Norm -> Classifier The projection head enables contrastive learning and produces embeddings for similarity-based file type matching. """ config_class = RoFormerClassificationConfig def __init__(self, config: RoFormerClassificationConfig): super().__init__(config) self.num_labels = config.num_labels self.projection_dim = getattr(config, "projection_dim", 256) self.roformer = RoFormerModel(config) # Projection head for contrastive learning embeddings self.projection = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), nn.Linear(config.hidden_size, self.projection_dim), ) # Classifier on pooled output (hidden_size, not projection_dim) # This architecture uses hidden representation for classification # while projection is for embedding similarity search self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]: """Forward pass for classification. Args: input_ids: Input token IDs [batch_size, seq_length] attention_mask: Attention mask [batch_size, seq_length] token_type_ids: Token type IDs (optional) head_mask: Head mask for attention (optional) inputs_embeds: Input embeddings (optional, alternative to input_ids) labels: Labels for computing loss [batch_size] output_attentions: Whether to return attention weights output_hidden_states: Whether to return hidden states return_dict: Whether to return a SequenceClassifierOutput Returns: SequenceClassifierOutput with loss, logits, and optional hidden states """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roformer( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Pool using CLS token sequence_output = outputs[0] pooled_output = sequence_output[:, 0, :] # Classify from pooled output directly logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def get_embeddings( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Get normalized projection embeddings for similarity search. Args: input_ids: Input token IDs [batch_size, seq_length] attention_mask: Attention mask [batch_size, seq_length] token_type_ids: Token type IDs (optional) Returns: L2-normalized embeddings [batch_size, projection_dim] """ outputs = self.roformer( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True, ) pooled_output = outputs.last_hidden_state[:, 0, :] projections = self.projection(pooled_output) return F.normalize(projections, p=2, dim=1)