Update blahblahtron_1_1B.py
Browse files- blahblahtron_1_1B.py +8 -4
blahblahtron_1_1B.py
CHANGED
|
@@ -131,7 +131,7 @@ def apply_rope(q: torch.Tensor, k: torch.Tensor, rope_cos: torch.Tensor, rope_si
|
|
| 131 |
k_embed = (k * rope_cos) + (rotate_half(k) * rope_sin)
|
| 132 |
return q_embed, k_embed
|
| 133 |
|
| 134 |
-
|
| 135 |
class Attention(nn.Module):
|
| 136 |
def __init__(self, cfg: HFWrapperConfig):
|
| 137 |
super().__init__()
|
|
@@ -313,12 +313,16 @@ class MyCustomModelForCausalLM(PreTrainedModel):
|
|
| 313 |
rope_cos = self.rope_cos[T_past:T].to(x.dtype)
|
| 314 |
rope_sin = self.rope_sin[T_past:T].to(x.dtype)
|
| 315 |
|
| 316 |
-
|
| 317 |
present_key_values = [] if use_cache else None
|
| 318 |
|
| 319 |
for i, block in enumerate(self.blocks):
|
| 320 |
-
past_kv = past_key_values[i] if past_key_values is not None else None
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
if self.training and self._gradient_checkpointing:
|
| 323 |
def block_only_x(x_, rc, rs):
|
| 324 |
out_x, _ = block(x_, rc, rs, past_kv=None, use_cache=False)
|
|
|
|
| 131 |
k_embed = (k * rope_cos) + (rotate_half(k) * rope_sin)
|
| 132 |
return q_embed, k_embed
|
| 133 |
|
| 134 |
+
|
| 135 |
class Attention(nn.Module):
|
| 136 |
def __init__(self, cfg: HFWrapperConfig):
|
| 137 |
super().__init__()
|
|
|
|
| 313 |
rope_cos = self.rope_cos[T_past:T].to(x.dtype)
|
| 314 |
rope_sin = self.rope_sin[T_past:T].to(x.dtype)
|
| 315 |
|
| 316 |
+
|
| 317 |
present_key_values = [] if use_cache else None
|
| 318 |
|
| 319 |
for i, block in enumerate(self.blocks):
|
| 320 |
+
#past_kv = past_key_values[i] if past_key_values is not None else None
|
| 321 |
+
if past_key_values is not None and i < len(past_key_values):
|
| 322 |
+
past_kv = past_key_values[i]
|
| 323 |
+
else:
|
| 324 |
+
past_kv = None
|
| 325 |
+
|
| 326 |
if self.training and self._gradient_checkpointing:
|
| 327 |
def block_only_x(x_, rc, rs):
|
| 328 |
out_x, _ = block(x_, rc, rs, past_kv=None, use_cache=False)
|