File size: 5,504 Bytes
9d1a15a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
"""RoFormer model with projection head for classification.
This module provides a RoFormer-based model with a projection head for
contrastive learning, enabling both classification and embedding-based
similarity search for file type detection.
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RoFormerModel, RoFormerPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
try:
from .configuration_roformer_classification import RoFormerClassificationConfig
except ImportError:
from configuration_roformer_classification import RoFormerClassificationConfig
class RoFormerForSequenceClassificationWithProjection(RoFormerPreTrainedModel):
"""RoFormer with projection head for file type classification.
This model extends RoFormer with a projection head that produces
L2-normalized embeddings suitable for both classification and
similarity search. The architecture is:
RoFormer -> CLS pooling -> Projection -> L2 Norm -> Classifier
The projection head enables contrastive learning and produces
embeddings for similarity-based file type matching.
"""
config_class = RoFormerClassificationConfig
def __init__(self, config: RoFormerClassificationConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.projection_dim = getattr(config, "projection_dim", 256)
self.roformer = RoFormerModel(config)
# Projection head for contrastive learning embeddings
self.projection = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, self.projection_dim),
)
# Classifier on pooled output (hidden_size, not projection_dim)
# This architecture uses hidden representation for classification
# while projection is for embedding similarity search
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]:
"""Forward pass for classification.
Args:
input_ids: Input token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
head_mask: Head mask for attention (optional)
inputs_embeds: Input embeddings (optional, alternative to input_ids)
labels: Labels for computing loss [batch_size]
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return hidden states
return_dict: Whether to return a SequenceClassifierOutput
Returns:
SequenceClassifierOutput with loss, logits, and optional hidden states
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Pool using CLS token
sequence_output = outputs[0]
pooled_output = sequence_output[:, 0, :]
# Classify from pooled output directly
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_embeddings(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get normalized projection embeddings for similarity search.
Args:
input_ids: Input token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
Returns:
L2-normalized embeddings [batch_size, projection_dim]
"""
outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
)
pooled_output = outputs.last_hidden_state[:, 0, :]
projections = self.projection(pooled_output)
return F.normalize(projections, p=2, dim=1)
|