Upload modeling_qwen3_jasper.py
Browse files- modeling_qwen3_jasper.py +17 -18
modeling_qwen3_jasper.py
CHANGED
|
@@ -9,9 +9,9 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP
|
|
| 9 |
|
| 10 |
class TokenCompressor(nn.Module):
|
| 11 |
"""
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
"""
|
| 16 |
|
| 17 |
def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3):
|
|
@@ -23,28 +23,28 @@ class TokenCompressor(nn.Module):
|
|
| 23 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 24 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
Args:
|
| 28 |
token_embeddings: [batch_size, seq_len, hidden_size]
|
| 29 |
attention_mask: [batch_size, seq_len]
|
| 30 |
Returns:
|
| 31 |
-
compressed_embeddings:
|
| 32 |
-
compressed_mask:
|
| 33 |
"""
|
| 34 |
padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left'
|
| 35 |
|
| 36 |
compressed_embeddings_list = []
|
| 37 |
compressed_masks_list = []
|
| 38 |
for text_idx in range(token_embeddings.shape[0]):
|
| 39 |
-
#
|
| 40 |
real_length = int(attention_mask[text_idx].sum().item())
|
| 41 |
if real_length <= self.length_threshold:
|
| 42 |
-
#
|
| 43 |
if padding_side == 'left':
|
| 44 |
-
#
|
| 45 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
|
| 46 |
else:
|
| 47 |
-
#
|
| 48 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
|
| 49 |
compressed_embeddings_list.append(valid_embeddings)
|
| 50 |
compressed_masks_list.append([1] * real_length)
|
|
@@ -52,15 +52,15 @@ class TokenCompressor(nn.Module):
|
|
| 52 |
target_length = int(
|
| 53 |
self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio
|
| 54 |
)
|
| 55 |
-
#
|
| 56 |
if padding_side == 'left':
|
| 57 |
-
#
|
| 58 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
|
| 59 |
else:
|
| 60 |
-
#
|
| 61 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
|
| 62 |
|
| 63 |
-
#
|
| 64 |
compressed_embeddings_list.append(
|
| 65 |
F.adaptive_avg_pool1d(
|
| 66 |
valid_embeddings.transpose(1, 2), target_length
|
|
@@ -69,7 +69,7 @@ class TokenCompressor(nn.Module):
|
|
| 69 |
# print("valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape",valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape)
|
| 70 |
compressed_masks_list.append([1] * target_length)
|
| 71 |
|
| 72 |
-
#
|
| 73 |
new_seq_len = max((len(_mask) for _mask in compressed_masks_list))
|
| 74 |
new_attention_mask = torch.tensor(
|
| 75 |
[
|
|
@@ -83,7 +83,7 @@ class TokenCompressor(nn.Module):
|
|
| 83 |
device=token_embeddings.device
|
| 84 |
)
|
| 85 |
|
| 86 |
-
#
|
| 87 |
batch_size = token_embeddings.shape[0]
|
| 88 |
hidden_size = token_embeddings.shape[2]
|
| 89 |
new_token_embeddings = torch.zeros(
|
|
@@ -103,7 +103,6 @@ class TokenCompressor(nn.Module):
|
|
| 103 |
return new_token_embeddings, new_attention_mask
|
| 104 |
|
| 105 |
|
| 106 |
-
|
| 107 |
class JasperV2Encoder(Qwen3PreTrainedModel):
|
| 108 |
|
| 109 |
def __init__(self, config: Qwen3Config):
|
|
@@ -134,7 +133,7 @@ class JasperV2Encoder(Qwen3PreTrainedModel):
|
|
| 134 |
inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask
|
| 135 |
)["last_hidden_state"]
|
| 136 |
|
| 137 |
-
#
|
| 138 |
input_mask_expanded = (
|
| 139 |
attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to(
|
| 140 |
compressed_token_embeddings.dtype)
|
|
|
|
| 9 |
|
| 10 |
class TokenCompressor(nn.Module):
|
| 11 |
"""
|
| 12 |
+
Adaptive Token Compression Module
|
| 13 |
+
For sequences exceeding the threshold length, use adaptive_avg_pool1d for compression
|
| 14 |
+
Compressed length = threshold + excess_part * compression_ratio
|
| 15 |
"""
|
| 16 |
|
| 17 |
def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3):
|
|
|
|
| 23 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 24 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 25 |
"""
|
| 26 |
+
Perform adaptive compression on token embeddings
|
| 27 |
Args:
|
| 28 |
token_embeddings: [batch_size, seq_len, hidden_size]
|
| 29 |
attention_mask: [batch_size, seq_len]
|
| 30 |
Returns:
|
| 31 |
+
compressed_embeddings: Compressed embeddings
|
| 32 |
+
compressed_mask: Compressed attention mask
|
| 33 |
"""
|
| 34 |
padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left'
|
| 35 |
|
| 36 |
compressed_embeddings_list = []
|
| 37 |
compressed_masks_list = []
|
| 38 |
for text_idx in range(token_embeddings.shape[0]):
|
| 39 |
+
# Get the effective length of current sample
|
| 40 |
real_length = int(attention_mask[text_idx].sum().item())
|
| 41 |
if real_length <= self.length_threshold:
|
| 42 |
+
# Extract valid token embeddings based on padding direction
|
| 43 |
if padding_side == 'left':
|
| 44 |
+
# Left padding: valid tokens are on the right
|
| 45 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
|
| 46 |
else:
|
| 47 |
+
# Right padding: valid tokens are on the left
|
| 48 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
|
| 49 |
compressed_embeddings_list.append(valid_embeddings)
|
| 50 |
compressed_masks_list.append([1] * real_length)
|
|
|
|
| 52 |
target_length = int(
|
| 53 |
self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio
|
| 54 |
)
|
| 55 |
+
# Extract valid token embeddings based on padding direction
|
| 56 |
if padding_side == 'left':
|
| 57 |
+
# Left padding: valid tokens are on the right
|
| 58 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
|
| 59 |
else:
|
| 60 |
+
# Right padding: valid tokens are on the left
|
| 61 |
valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
|
| 62 |
|
| 63 |
+
# Use adaptive_avg_pool1d for compression
|
| 64 |
compressed_embeddings_list.append(
|
| 65 |
F.adaptive_avg_pool1d(
|
| 66 |
valid_embeddings.transpose(1, 2), target_length
|
|
|
|
| 69 |
# print("valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape",valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape)
|
| 70 |
compressed_masks_list.append([1] * target_length)
|
| 71 |
|
| 72 |
+
# Reassemble token_embeddings and attention_mask
|
| 73 |
new_seq_len = max((len(_mask) for _mask in compressed_masks_list))
|
| 74 |
new_attention_mask = torch.tensor(
|
| 75 |
[
|
|
|
|
| 83 |
device=token_embeddings.device
|
| 84 |
)
|
| 85 |
|
| 86 |
+
# Generate new token_embeddings
|
| 87 |
batch_size = token_embeddings.shape[0]
|
| 88 |
hidden_size = token_embeddings.shape[2]
|
| 89 |
new_token_embeddings = torch.zeros(
|
|
|
|
| 103 |
return new_token_embeddings, new_attention_mask
|
| 104 |
|
| 105 |
|
|
|
|
| 106 |
class JasperV2Encoder(Qwen3PreTrainedModel):
|
| 107 |
|
| 108 |
def __init__(self, config: Qwen3Config):
|
|
|
|
| 133 |
inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask
|
| 134 |
)["last_hidden_state"]
|
| 135 |
|
| 136 |
+
# Generate sentence vector
|
| 137 |
input_mask_expanded = (
|
| 138 |
attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to(
|
| 139 |
compressed_token_embeddings.dtype)
|