abeat commited on
Commit
cd28c9d
·
verified ·
1 Parent(s): a9b6b1a

Update blahblahtron_1_1B.py

Browse files
Files changed (1) hide show
  1. blahblahtron_1_1B.py +8 -4
blahblahtron_1_1B.py CHANGED
@@ -131,7 +131,7 @@ def apply_rope(q: torch.Tensor, k: torch.Tensor, rope_cos: torch.Tensor, rope_si
131
  k_embed = (k * rope_cos) + (rotate_half(k) * rope_sin)
132
  return q_embed, k_embed
133
 
134
-
135
  class Attention(nn.Module):
136
  def __init__(self, cfg: HFWrapperConfig):
137
  super().__init__()
@@ -313,12 +313,16 @@ class MyCustomModelForCausalLM(PreTrainedModel):
313
  rope_cos = self.rope_cos[T_past:T].to(x.dtype)
314
  rope_sin = self.rope_sin[T_past:T].to(x.dtype)
315
 
316
- print("RECEIVED USE CACHE: ",use_cache)
317
  present_key_values = [] if use_cache else None
318
 
319
  for i, block in enumerate(self.blocks):
320
- past_kv = past_key_values[i] if past_key_values is not None else None
321
-
 
 
 
 
322
  if self.training and self._gradient_checkpointing:
323
  def block_only_x(x_, rc, rs):
324
  out_x, _ = block(x_, rc, rs, past_kv=None, use_cache=False)
 
131
  k_embed = (k * rope_cos) + (rotate_half(k) * rope_sin)
132
  return q_embed, k_embed
133
 
134
+
135
  class Attention(nn.Module):
136
  def __init__(self, cfg: HFWrapperConfig):
137
  super().__init__()
 
313
  rope_cos = self.rope_cos[T_past:T].to(x.dtype)
314
  rope_sin = self.rope_sin[T_past:T].to(x.dtype)
315
 
316
+
317
  present_key_values = [] if use_cache else None
318
 
319
  for i, block in enumerate(self.blocks):
320
+ #past_kv = past_key_values[i] if past_key_values is not None else None
321
+ if past_key_values is not None and i < len(past_key_values):
322
+ past_kv = past_key_values[i]
323
+ else:
324
+ past_kv = None
325
+
326
  if self.training and self._gradient_checkpointing:
327
  def block_only_x(x_, rc, rs):
328
  out_x, _ = block(x_, rc, rs, past_kv=None, use_cache=False)