Commit ·
ad95cec
1
Parent(s): 47b17c7
Upload 7 files
Browse files- README.md +69 -3
- modeling_rwkv5.py +56 -37
- tokenization_rwkv_world.py +142 -12
README.md
CHANGED
|
@@ -85,7 +85,7 @@ Assistant:"""
|
|
| 85 |
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True, torch_dtype=torch.float16).to(0)
|
| 86 |
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True)
|
| 87 |
|
| 88 |
-
text = "
|
| 89 |
prompt = generate_prompt(text)
|
| 90 |
|
| 91 |
inputs = tokenizer(prompt, return_tensors="pt").to(0)
|
|
@@ -100,8 +100,74 @@ User: hi
|
|
| 100 |
|
| 101 |
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
| 102 |
|
| 103 |
-
User:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
Assistant: 乌兰察布市是中国新疆维吾尔自治区的一个地级市,位于新疆维吾尔自治区西南部,毗邻青海省。乌兰察布市是新疆维吾尔自治区的重要城市之一,也是新疆维吾尔自治区的第二大城市。乌兰察布市是新疆的重要经济中心之一,拥有丰富的自然资源和人口密度,是新疆的重要交通枢纽和商
|
| 106 |
```
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True, torch_dtype=torch.float16).to(0)
|
| 86 |
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True)
|
| 87 |
|
| 88 |
+
text = "介绍一下大熊猫"
|
| 89 |
prompt = generate_prompt(text)
|
| 90 |
|
| 91 |
inputs = tokenizer(prompt, return_tensors="pt").to(0)
|
|
|
|
| 100 |
|
| 101 |
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
| 102 |
|
| 103 |
+
User: 介绍一下大熊猫
|
| 104 |
+
|
| 105 |
+
Assistant: 大熊猫是一种中国特有的哺乳动物,也是中国的国宝之一。它们的外貌特征是圆形的黑白相间的身体,有着黑色的毛发和白色的耳朵。大熊猫的食物主要是竹子,它们会在竹林中寻找竹子,并且会将竹子放在竹笼中进行储存。大熊猫的寿命约为20至30年,但由于栖息地的丧失和人类活动的
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
#### Batch Inference
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
import torch
|
| 112 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 113 |
+
|
| 114 |
+
def generate_prompt(instruction, input=""):
|
| 115 |
+
instruction = instruction.strip().replace('\r\n', '\n').replace('\n\n', '\n')
|
| 116 |
+
input = input.strip().replace('\r\n', '\n').replace('\n\n', '\n')
|
| 117 |
+
if input:
|
| 118 |
+
return f"""Instruction: {instruction}
|
| 119 |
+
|
| 120 |
+
Input: {input}
|
| 121 |
+
|
| 122 |
+
Response:"""
|
| 123 |
+
else:
|
| 124 |
+
return f"""User: hi
|
| 125 |
+
|
| 126 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
| 127 |
+
|
| 128 |
+
User: {instruction}
|
| 129 |
+
|
| 130 |
+
Assistant:"""
|
| 131 |
+
|
| 132 |
+
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True).to(torch.float32)
|
| 133 |
+
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True)
|
| 134 |
+
|
| 135 |
+
texts = ["请介绍北京的旅游景点", "介绍一下大熊猫", "乌兰察布"]
|
| 136 |
+
prompts = [generate_prompt(text) for text in texts]
|
| 137 |
+
|
| 138 |
+
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
|
| 139 |
+
outputs = model.generate(inputs["input_ids"], max_new_tokens=128, do_sample=True, temperature=1.0, top_p=0.3, top_k=0, )
|
| 140 |
+
|
| 141 |
+
for output in outputs:
|
| 142 |
+
print(tokenizer.decode(output.tolist(), skip_special_tokens=True))
|
| 143 |
|
|
|
|
| 144 |
```
|
| 145 |
|
| 146 |
+
output:
|
| 147 |
+
|
| 148 |
+
```shell
|
| 149 |
+
User: hi
|
| 150 |
+
|
| 151 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
| 152 |
+
|
| 153 |
+
User: 请介绍北京的旅游景点
|
| 154 |
+
|
| 155 |
+
Assistant: 北京是中国的首都,拥有丰富的旅游资源和历史文化遗产。以下是一些北京的旅游景点:
|
| 156 |
+
1. 故宫:位于北京市中心,是明清两代的皇宫,是中国最大的古代宫殿建筑群之一。
|
| 157 |
+
2. 天安门广场:位于北京市中心,是中国最著名的城市广场之一,也是中国最大的城市广场。
|
| 158 |
+
3. 颐和
|
| 159 |
+
User: hi
|
| 160 |
+
|
| 161 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
| 162 |
+
|
| 163 |
+
User: 介绍一下大熊猫
|
| 164 |
+
|
| 165 |
+
Assistant: 大熊猫是一种生活在中国中部地区的哺乳动物,也是中国的国宝之一。它们的外貌特征是圆形的黑白相间的身体,有着黑色的毛发和圆圆的眼睛。大熊猫是一种濒危物种,目前只有在野外的几个保护区才能看到它们的身影。大熊猫的食物主要是竹子,它们会在竹子上寻找食物,并且可以通
|
| 166 |
+
User: hi
|
| 167 |
+
|
| 168 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
| 169 |
+
|
| 170 |
+
User: 乌兰察布
|
| 171 |
+
|
| 172 |
+
Assistant: 乌兰察布是中国新疆维吾尔自治区的一个县级市,位于新疆维吾尔自治区中部,是新疆的第二大城市。乌兰察布市是新疆的第一大城市,也是新疆的重要城市之一。乌兰察布市是新疆的经济中心,也是新疆的重要交通枢纽之一。乌兰察布市的人口约为2.5万人,其中汉族占绝大多数。乌
|
| 173 |
+
```
|
modeling_rwkv5.py
CHANGED
|
@@ -85,33 +85,46 @@ def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptan
|
|
| 85 |
|
| 86 |
return out, state
|
| 87 |
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
|
| 90 |
time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
|
| 91 |
lxw = lxw.float()
|
| 92 |
lxb = lxb.float()
|
| 93 |
-
if seq_mode:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
else:
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
return out, state
|
| 117 |
|
|
@@ -153,7 +166,7 @@ class RwkvSelfAttention(nn.Module):
|
|
| 153 |
self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
|
| 154 |
|
| 155 |
# TODO: maybe jit, otherwise move inside forward
|
| 156 |
-
def extract_key_value(self, H, S, T, hidden, state=None):
|
| 157 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
| 158 |
if hidden.size(1) == 1 and state is not None:
|
| 159 |
shifted = state[0][:, :, self.layer_id]
|
|
@@ -161,25 +174,27 @@ class RwkvSelfAttention(nn.Module):
|
|
| 161 |
shifted = self.time_shift(hidden)
|
| 162 |
if state is not None:
|
| 163 |
shifted[:, 0] = state[0][:, :, self.layer_id]
|
|
|
|
|
|
|
| 164 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 165 |
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
|
| 166 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 167 |
if self.config.model_version == "5_2":
|
| 168 |
gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
| 169 |
|
| 170 |
-
if hidden.size(1) == 1 and state is not None:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
else:
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
|
| 180 |
if self.config.model_version == "5_2":
|
| 181 |
gate = F.silu(self.gate(gate))
|
| 182 |
-
|
| 183 |
if state is not None:
|
| 184 |
state[0][:, :, self.layer_id] = hidden[:, -1]
|
| 185 |
|
|
@@ -188,17 +203,19 @@ class RwkvSelfAttention(nn.Module):
|
|
| 188 |
return receptance, key, value, state
|
| 189 |
|
| 190 |
def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
|
|
|
|
| 191 |
H = self.time_decay.shape[0]
|
| 192 |
S = hidden.shape[-1] // H
|
| 193 |
T = hidden.shape[1]
|
| 194 |
|
| 195 |
if self.config.model_version == "5_2":
|
| 196 |
-
receptance, key, value, gate, state = self.extract_key_value(H, S, T, hidden, state=state)
|
| 197 |
else:
|
| 198 |
receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
|
| 199 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
| 200 |
if self.config.model_version == "5_2":
|
| 201 |
rwkv, layer_state = rwkv_linear_attention_v5_2(
|
|
|
|
| 202 |
H,
|
| 203 |
S,
|
| 204 |
T,
|
|
@@ -273,6 +290,8 @@ class RwkvFeedForward(nn.Module):
|
|
| 273 |
shifted = self.time_shift(hidden)
|
| 274 |
if state is not None:
|
| 275 |
shifted[:, 0] = state[2][:, :, self.layer_id]
|
|
|
|
|
|
|
| 276 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 277 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 278 |
|
|
@@ -594,7 +613,8 @@ class RwkvModel(RwkvPreTrainedModel):
|
|
| 594 |
|
| 595 |
|
| 596 |
hidden_states = inputs_embeds
|
| 597 |
-
|
|
|
|
| 598 |
all_self_attentions = () if output_attentions else None
|
| 599 |
all_hidden_states = () if output_hidden_states else None
|
| 600 |
for idx, block in enumerate(self.blocks):
|
|
@@ -645,7 +665,6 @@ class RwkvModel(RwkvPreTrainedModel):
|
|
| 645 |
|
| 646 |
self.layers_are_rescaled = not self.training
|
| 647 |
|
| 648 |
-
|
| 649 |
@add_start_docstrings(
|
| 650 |
"""
|
| 651 |
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
|
|
| 85 |
|
| 86 |
return out, state
|
| 87 |
|
| 88 |
+
cnt = 0
|
| 89 |
+
|
| 90 |
+
def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_first, receptance, key, value, gate, lxw, lxb, ow, state, return_state=False, seq_mode=True):
|
| 91 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
|
| 92 |
time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
|
| 93 |
lxw = lxw.float()
|
| 94 |
lxb = lxb.float()
|
| 95 |
+
# if seq_mode:
|
| 96 |
+
out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
|
| 97 |
+
for t in range(T):
|
| 98 |
+
rt = receptance[:,:,t:t+1,:]
|
| 99 |
+
kt = key[:,:,:,t:t+1]
|
| 100 |
+
vt = value[:,:,t:t+1,:]
|
| 101 |
+
at = kt @ vt
|
| 102 |
+
out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
|
| 103 |
+
state = at + time_decay * state
|
| 104 |
+
|
| 105 |
+
out = out.reshape(B*T, H*S)
|
| 106 |
+
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
|
| 107 |
+
out = out.to(dtype=hidden.dtype) * gate
|
| 108 |
+
out = out @ ow
|
| 109 |
+
# else:
|
| 110 |
+
# a = key @ value
|
| 111 |
+
# # print('key.shape: ', key.shape)
|
| 112 |
+
# # print('value.shape: ', value.shape)
|
| 113 |
+
# # print('receptance.shape: ', receptance.shape)
|
| 114 |
+
# # print('a.shape: ', a.shape)
|
| 115 |
+
# # print('time_first.shape: ', time_first.shape)
|
| 116 |
+
# # print('(time_first * a).shape: ', (time_first * a).shape)
|
| 117 |
+
# # print('time_decay.shape: ', time_decay.shape)
|
| 118 |
+
# # print('state.shape: ', state.shape)
|
| 119 |
+
# out = receptance @ (time_first * a + state)
|
| 120 |
+
# # print('out.shape: ', out.shape)
|
| 121 |
+
# state = a + time_decay * state
|
| 122 |
+
# # print('state.shape: ', state.shape)
|
| 123 |
+
# out = out.reshape(B, H*S)
|
| 124 |
+
# out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, 1, H*S)
|
| 125 |
+
# out = out.to(dtype=hidden.dtype) * gate
|
| 126 |
+
# out = out @ ow
|
| 127 |
+
|
| 128 |
|
| 129 |
return out, state
|
| 130 |
|
|
|
|
| 166 |
self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
|
| 167 |
|
| 168 |
# TODO: maybe jit, otherwise move inside forward
|
| 169 |
+
def extract_key_value(self, B, H, S, T, hidden, state=None):
|
| 170 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
| 171 |
if hidden.size(1) == 1 and state is not None:
|
| 172 |
shifted = state[0][:, :, self.layer_id]
|
|
|
|
| 174 |
shifted = self.time_shift(hidden)
|
| 175 |
if state is not None:
|
| 176 |
shifted[:, 0] = state[0][:, :, self.layer_id]
|
| 177 |
+
if len(shifted.size()) == 2:
|
| 178 |
+
shifted = shifted.unsqueeze(1)
|
| 179 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 180 |
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
|
| 181 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 182 |
if self.config.model_version == "5_2":
|
| 183 |
gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
| 184 |
|
| 185 |
+
# if hidden.size(1) == 1 and state is not None:
|
| 186 |
+
# receptance = self.receptance(receptance).to(torch.float32).view(B, H, 1, S)
|
| 187 |
+
# key = self.key(key).to(torch.float32).view(B, H, S, 1)
|
| 188 |
+
# value = self.value(value).to(torch.float32).view(B, H, 1, S)
|
| 189 |
+
# else:
|
| 190 |
+
# https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
|
| 191 |
+
key = self.key(key).to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
|
| 192 |
+
value = self.value(value).to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
| 193 |
+
receptance = self.receptance(receptance).to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
| 194 |
|
| 195 |
if self.config.model_version == "5_2":
|
| 196 |
gate = F.silu(self.gate(gate))
|
| 197 |
+
|
| 198 |
if state is not None:
|
| 199 |
state[0][:, :, self.layer_id] = hidden[:, -1]
|
| 200 |
|
|
|
|
| 203 |
return receptance, key, value, state
|
| 204 |
|
| 205 |
def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
|
| 206 |
+
B = hidden.shape[0]
|
| 207 |
H = self.time_decay.shape[0]
|
| 208 |
S = hidden.shape[-1] // H
|
| 209 |
T = hidden.shape[1]
|
| 210 |
|
| 211 |
if self.config.model_version == "5_2":
|
| 212 |
+
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
|
| 213 |
else:
|
| 214 |
receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
|
| 215 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
| 216 |
if self.config.model_version == "5_2":
|
| 217 |
rwkv, layer_state = rwkv_linear_attention_v5_2(
|
| 218 |
+
B,
|
| 219 |
H,
|
| 220 |
S,
|
| 221 |
T,
|
|
|
|
| 290 |
shifted = self.time_shift(hidden)
|
| 291 |
if state is not None:
|
| 292 |
shifted[:, 0] = state[2][:, :, self.layer_id]
|
| 293 |
+
if len(shifted.size()) == 2:
|
| 294 |
+
shifted = shifted.unsqueeze(1)
|
| 295 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
| 296 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
| 297 |
|
|
|
|
| 613 |
|
| 614 |
|
| 615 |
hidden_states = inputs_embeds
|
| 616 |
+
global cnt
|
| 617 |
+
cnt += 1
|
| 618 |
all_self_attentions = () if output_attentions else None
|
| 619 |
all_hidden_states = () if output_hidden_states else None
|
| 620 |
for idx, block in enumerate(self.blocks):
|
|
|
|
| 665 |
|
| 666 |
self.layers_are_rescaled = not self.training
|
| 667 |
|
|
|
|
| 668 |
@add_start_docstrings(
|
| 669 |
"""
|
| 670 |
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
tokenization_rwkv_world.py
CHANGED
|
@@ -107,6 +107,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 107 |
self,
|
| 108 |
vocab_file,
|
| 109 |
errors="replace",
|
|
|
|
| 110 |
**kwargs
|
| 111 |
):
|
| 112 |
self.add_bos_token = False
|
|
@@ -122,11 +123,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 122 |
assert len(x) == int(l[l.rindex(' '):])
|
| 123 |
sorted += [x]
|
| 124 |
self.encoder[idx] = x
|
| 125 |
-
|
| 126 |
-
super().__init__(
|
| 127 |
-
errors=errors,
|
| 128 |
-
**kwargs,
|
| 129 |
-
)
|
| 130 |
self.decoder = {}
|
| 131 |
for k,v in self.encoder.items():
|
| 132 |
self.decoder[v] = int(k)
|
|
@@ -136,6 +133,14 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 136 |
_ = self.trie.add(t, val=(t, i))
|
| 137 |
self.errors = errors # how to handle errors in decoding
|
| 138 |
self.cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
@property
|
| 141 |
def vocab_size(self):
|
|
@@ -143,6 +148,22 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 143 |
|
| 144 |
def get_vocab(self):
|
| 145 |
return dict(self.encoder, **self.added_tokens_encoder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 148 |
if self.add_bos_token:
|
|
@@ -219,14 +240,21 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 219 |
skip_special_tokens: bool = False,
|
| 220 |
**kwargs
|
| 221 |
) -> str:
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
# Convert inputs to python lists
|
| 224 |
token_ids = to_py_obj(token_ids)
|
|
|
|
| 225 |
if isinstance(token_ids, int):
|
| 226 |
if token_ids in self.all_special_ids and skip_special_tokens:
|
| 227 |
return ""
|
| 228 |
return self.encoder.get(token_ids, self.unk_token)
|
| 229 |
elif isinstance(token_ids, list):
|
|
|
|
| 230 |
out_str = ""
|
| 231 |
out_last = 0
|
| 232 |
out_tokens = []
|
|
@@ -268,6 +296,11 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 268 |
def prepare_for_tokenization(self, text, **kwargs):
|
| 269 |
return (text, kwargs)
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
def _encode_plus(
|
| 272 |
self,
|
| 273 |
text: Union[TextInput, EncodedInput],
|
|
@@ -352,19 +385,33 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 352 |
verbose: bool = True,
|
| 353 |
**kwargs
|
| 354 |
) -> BatchEncoding:
|
| 355 |
-
def get_input_ids(text):
|
|
|
|
|
|
|
|
|
|
| 356 |
if isinstance(text, str):
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
| 359 |
elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
|
|
|
|
|
|
| 362 |
return text
|
|
|
|
| 363 |
else:
|
| 364 |
raise ValueError(
|
| 365 |
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
| 366 |
)
|
| 367 |
|
|
|
|
| 368 |
if return_offsets_mapping:
|
| 369 |
raise NotImplementedError(
|
| 370 |
"return_offset_mapping is not available when using Python tokenizers. "
|
|
@@ -372,15 +419,29 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 372 |
"transformers.PreTrainedTokenizerFast."
|
| 373 |
)
|
| 374 |
|
| 375 |
-
|
|
|
|
| 376 |
for ids_or_pair_ids in batch_text_or_text_pairs:
|
| 377 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
| 378 |
ids, pair_ids = ids_or_pair_ids, None
|
| 379 |
else:
|
| 380 |
ids, pair_ids = ids_or_pair_ids
|
| 381 |
-
|
| 382 |
first_ids = get_input_ids(ids)
|
| 383 |
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
input_ids.append((first_ids, second_ids))
|
| 385 |
|
| 386 |
batch_outputs = self._batch_prepare_for_model(
|
|
@@ -401,6 +462,75 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
| 401 |
)
|
| 402 |
|
| 403 |
return BatchEncoding(batch_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
| 406 |
input_ids = []
|
|
|
|
| 107 |
self,
|
| 108 |
vocab_file,
|
| 109 |
errors="replace",
|
| 110 |
+
pad_token="0",
|
| 111 |
**kwargs
|
| 112 |
):
|
| 113 |
self.add_bos_token = False
|
|
|
|
| 123 |
assert len(x) == int(l[l.rindex(' '):])
|
| 124 |
sorted += [x]
|
| 125 |
self.encoder[idx] = x
|
| 126 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
self.decoder = {}
|
| 128 |
for k,v in self.encoder.items():
|
| 129 |
self.decoder[v] = int(k)
|
|
|
|
| 133 |
_ = self.trie.add(t, val=(t, i))
|
| 134 |
self.errors = errors # how to handle errors in decoding
|
| 135 |
self.cache = {}
|
| 136 |
+
self.first_max_length = 0
|
| 137 |
+
|
| 138 |
+
# pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
| 139 |
+
super().__init__(
|
| 140 |
+
errors=errors,
|
| 141 |
+
# pad_token=pad_token,
|
| 142 |
+
**kwargs,
|
| 143 |
+
)
|
| 144 |
|
| 145 |
@property
|
| 146 |
def vocab_size(self):
|
|
|
|
| 148 |
|
| 149 |
def get_vocab(self):
|
| 150 |
return dict(self.encoder, **self.added_tokens_encoder)
|
| 151 |
+
|
| 152 |
+
def add_tokens(self, new_tokens, special_tokens: bool = False):
|
| 153 |
+
for token in new_tokens:
|
| 154 |
+
token_id = self.convert_tokens_to_ids(token)
|
| 155 |
+
self.added_tokens_decoder[token_id] = token
|
| 156 |
+
|
| 157 |
+
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
| 158 |
+
if isinstance(ids, int):
|
| 159 |
+
ids = [ids]
|
| 160 |
+
tokens = []
|
| 161 |
+
for id_ in ids:
|
| 162 |
+
if id_ in self.added_tokens_decoder:
|
| 163 |
+
tokens.append(self.added_tokens_decoder[id_])
|
| 164 |
+
else:
|
| 165 |
+
tokens.append(self._convert_id_to_token(id_))
|
| 166 |
+
return tokens
|
| 167 |
|
| 168 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 169 |
if self.add_bos_token:
|
|
|
|
| 240 |
skip_special_tokens: bool = False,
|
| 241 |
**kwargs
|
| 242 |
) -> str:
|
| 243 |
+
|
| 244 |
+
def remove_zeros_from_first_segment(token_ids, first_max_length):
|
| 245 |
+
first_segment = token_ids[:first_max_length]
|
| 246 |
+
first_segment_cleaned = [token for token in first_segment if token != 0]
|
| 247 |
+
return first_segment_cleaned + token_ids[first_max_length:]
|
| 248 |
+
|
| 249 |
# Convert inputs to python lists
|
| 250 |
token_ids = to_py_obj(token_ids)
|
| 251 |
+
token_ids = remove_zeros_from_first_segment(token_ids, self.first_max_length)
|
| 252 |
if isinstance(token_ids, int):
|
| 253 |
if token_ids in self.all_special_ids and skip_special_tokens:
|
| 254 |
return ""
|
| 255 |
return self.encoder.get(token_ids, self.unk_token)
|
| 256 |
elif isinstance(token_ids, list):
|
| 257 |
+
self.first_max_length
|
| 258 |
out_str = ""
|
| 259 |
out_last = 0
|
| 260 |
out_tokens = []
|
|
|
|
| 296 |
def prepare_for_tokenization(self, text, **kwargs):
|
| 297 |
return (text, kwargs)
|
| 298 |
|
| 299 |
+
def _get_padding_truncation_strategies(
|
| 300 |
+
self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
|
| 301 |
+
):
|
| 302 |
+
return PaddingStrategy.LONGEST, TruncationStrategy.DO_NOT_TRUNCATE, -1, kwargs
|
| 303 |
+
|
| 304 |
def _encode_plus(
|
| 305 |
self,
|
| 306 |
text: Union[TextInput, EncodedInput],
|
|
|
|
| 385 |
verbose: bool = True,
|
| 386 |
**kwargs
|
| 387 |
) -> BatchEncoding:
|
| 388 |
+
def get_input_ids(text, max_length=None, pad_token_id=0):
|
| 389 |
+
def pad_sequence(seq, max_len, pad_tok):
|
| 390 |
+
return [pad_tok] * (max_len - len(seq)) + seq
|
| 391 |
+
|
| 392 |
if isinstance(text, str):
|
| 393 |
+
tokens = self._tokenize(text)
|
| 394 |
+
if max_length is not None:
|
| 395 |
+
tokens = pad_sequence(tokens, max_length, pad_token_id)
|
| 396 |
+
return tokens
|
| 397 |
+
|
| 398 |
elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
|
| 399 |
+
tokenized_texts = [self._tokenize(t) for t in text]
|
| 400 |
+
if max_length is None:
|
| 401 |
+
max_length = max(len(t) for t in tokenized_texts)
|
| 402 |
+
return [pad_sequence(t, max_length, pad_token_id) for t in tokenized_texts]
|
| 403 |
+
|
| 404 |
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
| 405 |
+
if max_length is not None and len(text) < max_length:
|
| 406 |
+
return pad_sequence(text, max_length, pad_token_id)
|
| 407 |
return text
|
| 408 |
+
|
| 409 |
else:
|
| 410 |
raise ValueError(
|
| 411 |
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
| 412 |
)
|
| 413 |
|
| 414 |
+
|
| 415 |
if return_offsets_mapping:
|
| 416 |
raise NotImplementedError(
|
| 417 |
"return_offset_mapping is not available when using Python tokenizers. "
|
|
|
|
| 419 |
"transformers.PreTrainedTokenizerFast."
|
| 420 |
)
|
| 421 |
|
| 422 |
+
first_max_length = 0
|
| 423 |
+
second_max_length = 0
|
| 424 |
for ids_or_pair_ids in batch_text_or_text_pairs:
|
| 425 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
| 426 |
ids, pair_ids = ids_or_pair_ids, None
|
| 427 |
else:
|
| 428 |
ids, pair_ids = ids_or_pair_ids
|
|
|
|
| 429 |
first_ids = get_input_ids(ids)
|
| 430 |
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
|
| 431 |
+
first_max_length = max(first_max_length, len(first_ids))
|
| 432 |
+
if second_ids is not None:
|
| 433 |
+
second_max_length = max(second_max_length, len(second_ids))
|
| 434 |
+
|
| 435 |
+
self.first_max_length = first_max_length
|
| 436 |
+
input_ids = []
|
| 437 |
+
for ids_or_pair_ids in batch_text_or_text_pairs:
|
| 438 |
+
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
| 439 |
+
ids, pair_ids = ids_or_pair_ids, None
|
| 440 |
+
else:
|
| 441 |
+
ids, pair_ids = ids_or_pair_ids
|
| 442 |
+
|
| 443 |
+
first_ids = get_input_ids(ids, max_length=first_max_length)
|
| 444 |
+
second_ids = get_input_ids(pair_ids, max_length=second_max_length) if pair_ids is not None else None
|
| 445 |
input_ids.append((first_ids, second_ids))
|
| 446 |
|
| 447 |
batch_outputs = self._batch_prepare_for_model(
|
|
|
|
| 462 |
)
|
| 463 |
|
| 464 |
return BatchEncoding(batch_outputs)
|
| 465 |
+
|
| 466 |
+
def decode(
|
| 467 |
+
self,
|
| 468 |
+
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
| 469 |
+
skip_special_tokens: bool = False,
|
| 470 |
+
clean_up_tokenization_spaces: bool = None,
|
| 471 |
+
**kwargs,
|
| 472 |
+
) -> str:
|
| 473 |
+
"""
|
| 474 |
+
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
| 475 |
+
tokens and clean up tokenization spaces.
|
| 476 |
+
|
| 477 |
+
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
| 478 |
+
|
| 479 |
+
Args:
|
| 480 |
+
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
| 481 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
| 482 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 483 |
+
Whether or not to remove special tokens in the decoding.
|
| 484 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
| 485 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
| 486 |
+
`self.clean_up_tokenization_spaces`.
|
| 487 |
+
kwargs (additional keyword arguments, *optional*):
|
| 488 |
+
Will be passed to the underlying model specific decode method.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
`str`: The decoded sentence.
|
| 492 |
+
"""
|
| 493 |
+
# Convert inputs to python lists
|
| 494 |
+
return self._decode(
|
| 495 |
+
token_ids=token_ids,
|
| 496 |
+
skip_special_tokens=skip_special_tokens,
|
| 497 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 498 |
+
**kwargs,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
def batch_decode(
|
| 502 |
+
self,
|
| 503 |
+
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
| 504 |
+
skip_special_tokens: bool = False,
|
| 505 |
+
clean_up_tokenization_spaces: bool = None,
|
| 506 |
+
**kwargs,
|
| 507 |
+
) -> List[str]:
|
| 508 |
+
"""
|
| 509 |
+
Convert a list of lists of token ids into a list of strings by calling decode.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
| 513 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
| 514 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 515 |
+
Whether or not to remove special tokens in the decoding.
|
| 516 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
| 517 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
| 518 |
+
`self.clean_up_tokenization_spaces`.
|
| 519 |
+
kwargs (additional keyword arguments, *optional*):
|
| 520 |
+
Will be passed to the underlying model specific decode method.
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
`List[str]`: The list of decoded sentences.
|
| 524 |
+
"""
|
| 525 |
+
return [
|
| 526 |
+
self.decode(
|
| 527 |
+
seq,
|
| 528 |
+
skip_special_tokens=skip_special_tokens,
|
| 529 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 530 |
+
**kwargs,
|
| 531 |
+
)
|
| 532 |
+
for seq in sequences
|
| 533 |
+
]
|
| 534 |
|
| 535 |
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
| 536 |
input_ids = []
|