infgrad commited on
Commit
06a100f
·
verified ·
1 Parent(s): f47c8f7

Upload modeling_qwen3_jasper.py

Browse files
Files changed (1) hide show
  1. modeling_qwen3_jasper.py +17 -18
modeling_qwen3_jasper.py CHANGED
@@ -9,9 +9,9 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP
9
 
10
  class TokenCompressor(nn.Module):
11
  """
12
- 自适应Token压缩模块
13
- 对于长度超过阈值的序列,使用adaptive_avg_pool1d进行压缩
14
- 压缩后长度 = 阈值 + 超出部分 * compression_ratio
15
  """
16
 
17
  def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3):
@@ -23,28 +23,28 @@ class TokenCompressor(nn.Module):
23
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
24
  ) -> tuple[torch.Tensor, torch.Tensor]:
25
  """
26
- token embeddings进行自适应压缩
27
  Args:
28
  token_embeddings: [batch_size, seq_len, hidden_size]
29
  attention_mask: [batch_size, seq_len]
30
  Returns:
31
- compressed_embeddings: 压缩后的embeddings
32
- compressed_mask: 压缩后的attention mask
33
  """
34
  padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left'
35
 
36
  compressed_embeddings_list = []
37
  compressed_masks_list = []
38
  for text_idx in range(token_embeddings.shape[0]):
39
- # 获取当前样本的有效长度
40
  real_length = int(attention_mask[text_idx].sum().item())
41
  if real_length <= self.length_threshold:
42
- # 根据padding方向提取有效的token embeddings
43
  if padding_side == 'left':
44
- # 左填充:有效tokens在右边
45
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
46
  else:
47
- # 右填充:有效tokens在左边
48
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
49
  compressed_embeddings_list.append(valid_embeddings)
50
  compressed_masks_list.append([1] * real_length)
@@ -52,15 +52,15 @@ class TokenCompressor(nn.Module):
52
  target_length = int(
53
  self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio
54
  )
55
- # 根据padding方向提取有效的token embeddings
56
  if padding_side == 'left':
57
- # 左填充:有效tokens在右边
58
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
59
  else:
60
- # 右填充:有效tokens在左边
61
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
62
 
63
- # 使用adaptive_avg_pool1d进行压缩
64
  compressed_embeddings_list.append(
65
  F.adaptive_avg_pool1d(
66
  valid_embeddings.transpose(1, 2), target_length
@@ -69,7 +69,7 @@ class TokenCompressor(nn.Module):
69
  # print("valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape",valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape)
70
  compressed_masks_list.append([1] * target_length)
71
 
72
- # 重新组合为token_embeddingsattention_mask
73
  new_seq_len = max((len(_mask) for _mask in compressed_masks_list))
74
  new_attention_mask = torch.tensor(
75
  [
@@ -83,7 +83,7 @@ class TokenCompressor(nn.Module):
83
  device=token_embeddings.device
84
  )
85
 
86
- # 生成新的token_embeddings
87
  batch_size = token_embeddings.shape[0]
88
  hidden_size = token_embeddings.shape[2]
89
  new_token_embeddings = torch.zeros(
@@ -103,7 +103,6 @@ class TokenCompressor(nn.Module):
103
  return new_token_embeddings, new_attention_mask
104
 
105
 
106
-
107
  class JasperV2Encoder(Qwen3PreTrainedModel):
108
 
109
  def __init__(self, config: Qwen3Config):
@@ -134,7 +133,7 @@ class JasperV2Encoder(Qwen3PreTrainedModel):
134
  inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask
135
  )["last_hidden_state"]
136
 
137
- # 生成句向量
138
  input_mask_expanded = (
139
  attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to(
140
  compressed_token_embeddings.dtype)
 
9
 
10
  class TokenCompressor(nn.Module):
11
  """
12
+ Adaptive Token Compression Module
13
+ For sequences exceeding the threshold length, use adaptive_avg_pool1d for compression
14
+ Compressed length = threshold + excess_part * compression_ratio
15
  """
16
 
17
  def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3):
 
23
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
24
  ) -> tuple[torch.Tensor, torch.Tensor]:
25
  """
26
+ Perform adaptive compression on token embeddings
27
  Args:
28
  token_embeddings: [batch_size, seq_len, hidden_size]
29
  attention_mask: [batch_size, seq_len]
30
  Returns:
31
+ compressed_embeddings: Compressed embeddings
32
+ compressed_mask: Compressed attention mask
33
  """
34
  padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left'
35
 
36
  compressed_embeddings_list = []
37
  compressed_masks_list = []
38
  for text_idx in range(token_embeddings.shape[0]):
39
+ # Get the effective length of current sample
40
  real_length = int(attention_mask[text_idx].sum().item())
41
  if real_length <= self.length_threshold:
42
+ # Extract valid token embeddings based on padding direction
43
  if padding_side == 'left':
44
+ # Left padding: valid tokens are on the right
45
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
46
  else:
47
+ # Right padding: valid tokens are on the left
48
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
49
  compressed_embeddings_list.append(valid_embeddings)
50
  compressed_masks_list.append([1] * real_length)
 
52
  target_length = int(
53
  self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio
54
  )
55
+ # Extract valid token embeddings based on padding direction
56
  if padding_side == 'left':
57
+ # Left padding: valid tokens are on the right
58
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :]
59
  else:
60
+ # Right padding: valid tokens are on the left
61
  valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :]
62
 
63
+ # Use adaptive_avg_pool1d for compression
64
  compressed_embeddings_list.append(
65
  F.adaptive_avg_pool1d(
66
  valid_embeddings.transpose(1, 2), target_length
 
69
  # print("valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape",valid_embeddings.shape,target_length,compressed_embeddings_list[-1].shape)
70
  compressed_masks_list.append([1] * target_length)
71
 
72
+ # Reassemble token_embeddings and attention_mask
73
  new_seq_len = max((len(_mask) for _mask in compressed_masks_list))
74
  new_attention_mask = torch.tensor(
75
  [
 
83
  device=token_embeddings.device
84
  )
85
 
86
+ # Generate new token_embeddings
87
  batch_size = token_embeddings.shape[0]
88
  hidden_size = token_embeddings.shape[2]
89
  new_token_embeddings = torch.zeros(
 
103
  return new_token_embeddings, new_attention_mask
104
 
105
 
 
106
  class JasperV2Encoder(Qwen3PreTrainedModel):
107
 
108
  def __init__(self, config: Qwen3Config):
 
133
  inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask
134
  )["last_hidden_state"]
135
 
136
+ # Generate sentence vector
137
  input_mask_expanded = (
138
  attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to(
139
  compressed_token_embeddings.dtype)