duzx16
commited on
Commit
·
f81daa3
1
Parent(s):
7bcdc71
Fix batch generation for vision model
Browse files- modeling_chatglm.py +20 -5
modeling_chatglm.py
CHANGED
|
@@ -692,16 +692,16 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 692 |
"""Initialize the weights."""
|
| 693 |
return
|
| 694 |
|
| 695 |
-
def get_masks(self,
|
| 696 |
-
batch_size, seq_length =
|
| 697 |
-
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=
|
| 698 |
full_attention_mask.tril_()
|
| 699 |
past_length = 0
|
| 700 |
if past_key_values:
|
| 701 |
past_length = past_key_values[0][0].shape[2]
|
| 702 |
if past_length:
|
| 703 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
| 704 |
-
device=
|
| 705 |
if padding_mask is not None:
|
| 706 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
| 707 |
if not past_length and padding_mask is not None:
|
|
@@ -887,7 +887,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 887 |
|
| 888 |
if full_attention_mask is None:
|
| 889 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
| 890 |
-
full_attention_mask = self.get_masks(
|
| 891 |
|
| 892 |
# Rotary positional embeddings
|
| 893 |
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
|
@@ -976,6 +976,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 976 |
# only last token for input_ids if past is not None
|
| 977 |
if position_ids is None:
|
| 978 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 979 |
if not is_first_forward:
|
| 980 |
if past_key_values is not None:
|
| 981 |
position_ids = position_ids[..., -1:]
|
|
|
|
| 692 |
"""Initialize the weights."""
|
| 693 |
return
|
| 694 |
|
| 695 |
+
def get_masks(self, input_embeds, past_key_values, padding_mask=None):
|
| 696 |
+
batch_size, seq_length, embed_size = input_embeds.shape
|
| 697 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
|
| 698 |
full_attention_mask.tril_()
|
| 699 |
past_length = 0
|
| 700 |
if past_key_values:
|
| 701 |
past_length = past_key_values[0][0].shape[2]
|
| 702 |
if past_length:
|
| 703 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
| 704 |
+
device=input_embeds.device), full_attention_mask), dim=-1)
|
| 705 |
if padding_mask is not None:
|
| 706 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
| 707 |
if not past_length and padding_mask is not None:
|
|
|
|
| 887 |
|
| 888 |
if full_attention_mask is None:
|
| 889 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
| 890 |
+
full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
|
| 891 |
|
| 892 |
# Rotary positional embeddings
|
| 893 |
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
|
|
|
| 976 |
# only last token for input_ids if past is not None
|
| 977 |
if position_ids is None:
|
| 978 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
| 979 |
+
if attention_mask is not None:
|
| 980 |
+
image_size: int = self.config.vision_config['image_size']
|
| 981 |
+
patch_size: int = self.config.vision_config['patch_size']
|
| 982 |
+
num_patches = (image_size // patch_size // 2) ** 2
|
| 983 |
+
new_attention_masks = []
|
| 984 |
+
for i in range(len(input_ids)):
|
| 985 |
+
input_id = input_ids[i].tolist()
|
| 986 |
+
boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
|
| 987 |
+
self.config.eoi_token_id)
|
| 988 |
+
assert eoi_token_pos - boi_token_pos == 2
|
| 989 |
+
new_attention_masks.append(torch.cat(
|
| 990 |
+
(attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches),
|
| 991 |
+
attention_mask[i, eoi_token_pos:])
|
| 992 |
+
))
|
| 993 |
+
attention_mask = torch.stack(new_attention_masks, dim=0)
|
| 994 |
if not is_first_forward:
|
| 995 |
if past_key_values is not None:
|
| 996 |
position_ids = position_ids[..., -1:]
|