File size: 5,504 Bytes
9d1a15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""RoFormer model with projection head for classification.

This module provides a RoFormer-based model with a projection head for
contrastive learning, enabling both classification and embedding-based
similarity search for file type detection.
"""

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RoFormerModel, RoFormerPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput

try:
    from .configuration_roformer_classification import RoFormerClassificationConfig
except ImportError:
    from configuration_roformer_classification import RoFormerClassificationConfig


class RoFormerForSequenceClassificationWithProjection(RoFormerPreTrainedModel):
    """RoFormer with projection head for file type classification.

    This model extends RoFormer with a projection head that produces
    L2-normalized embeddings suitable for both classification and
    similarity search. The architecture is:

        RoFormer -> CLS pooling -> Projection -> L2 Norm -> Classifier

    The projection head enables contrastive learning and produces
    embeddings for similarity-based file type matching.
    """

    config_class = RoFormerClassificationConfig

    def __init__(self, config: RoFormerClassificationConfig):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.projection_dim = getattr(config, "projection_dim", 256)

        self.roformer = RoFormerModel(config)

        # Projection head for contrastive learning embeddings
        self.projection = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, self.projection_dim),
        )

        # Classifier on pooled output (hidden_size, not projection_dim)
        # This architecture uses hidden representation for classification
        # while projection is for embedding similarity search
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]:
        """Forward pass for classification.

        Args:
            input_ids: Input token IDs [batch_size, seq_length]
            attention_mask: Attention mask [batch_size, seq_length]
            token_type_ids: Token type IDs (optional)
            head_mask: Head mask for attention (optional)
            inputs_embeds: Input embeddings (optional, alternative to input_ids)
            labels: Labels for computing loss [batch_size]
            output_attentions: Whether to return attention weights
            output_hidden_states: Whether to return hidden states
            return_dict: Whether to return a SequenceClassifierOutput

        Returns:
            SequenceClassifierOutput with loss, logits, and optional hidden states
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.roformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Pool using CLS token
        sequence_output = outputs[0]
        pooled_output = sequence_output[:, 0, :]

        # Classify from pooled output directly
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def get_embeddings(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Get normalized projection embeddings for similarity search.

        Args:
            input_ids: Input token IDs [batch_size, seq_length]
            attention_mask: Attention mask [batch_size, seq_length]
            token_type_ids: Token type IDs (optional)

        Returns:
            L2-normalized embeddings [batch_size, projection_dim]
        """
        outputs = self.roformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
        )

        pooled_output = outputs.last_hidden_state[:, 0, :]
        projections = self.projection(pooled_output)
        return F.normalize(projections, p=2, dim=1)