Commit
·
a07de8e
1
Parent(s):
cfd45a6
Patched it to work again. https://github.com/Q-Future/Q-Align/issues/31#issuecomment-2561704943
Browse files- README.md +7 -0
- config.json +1 -0
- modeling_llama2.py +23 -16
README.md
CHANGED
|
@@ -3,6 +3,13 @@ license: mit
|
|
| 3 |
pipeline_tag: zero-shot-image-classification
|
| 4 |
---
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
The model that corresponds to Q-Align (ICML2024).
|
| 7 |
|
| 8 |
## Quick Start with AutoModel
|
|
|
|
| 3 |
pipeline_tag: zero-shot-image-classification
|
| 4 |
---
|
| 5 |
|
| 6 |
+
## This fork
|
| 7 |
+
This fork fixes some issues with the newest versions of transformers.
|
| 8 |
+
[https://github.com/Q-Future/Q-Align/issues/31](https://github.com/Q-Future/Q-Align/issues/31#issuecomment-2561704943)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
## Upstream
|
| 12 |
+
|
| 13 |
The model that corresponds to Q-Align (ICML2024).
|
| 14 |
|
| 15 |
## Quick Start with AutoModel
|
config.json
CHANGED
|
@@ -33,6 +33,7 @@
|
|
| 33 |
"transformers_version": "4.31.0",
|
| 34 |
"tune_visual_abstractor": true,
|
| 35 |
"use_cache": true,
|
|
|
|
| 36 |
"visual_abstractor_lr": null,
|
| 37 |
"visual_config": {
|
| 38 |
"visual_abstractor": {
|
|
|
|
| 33 |
"transformers_version": "4.31.0",
|
| 34 |
"tune_visual_abstractor": true,
|
| 35 |
"use_cache": true,
|
| 36 |
+
"mlp_bias": false,
|
| 37 |
"visual_abstractor_lr": null,
|
| 38 |
"visual_config": {
|
| 39 |
"visual_abstractor": {
|
modeling_llama2.py
CHANGED
|
@@ -22,8 +22,12 @@ from transformers.models.llama.modeling_llama import *
|
|
| 22 |
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
from transformers.utils import logging
|
| 24 |
|
| 25 |
-
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
| 26 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
class MultiwayNetwork(nn.Module):
|
| 29 |
|
|
@@ -31,14 +35,14 @@ class MultiwayNetwork(nn.Module):
|
|
| 31 |
super(MultiwayNetwork, self).__init__()
|
| 32 |
|
| 33 |
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
| 34 |
-
|
| 35 |
def forward(self, hidden_states, multiway_indices):
|
| 36 |
|
| 37 |
if len(self.multiway) == 1:
|
| 38 |
return self.multiway[0](hidden_states)
|
| 39 |
|
| 40 |
output_hidden_states = torch.empty_like(hidden_states)
|
| 41 |
-
|
| 42 |
for idx, subway in enumerate(self.multiway):
|
| 43 |
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
| 44 |
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
|
@@ -48,9 +52,9 @@ class MultiwayNetwork(nn.Module):
|
|
| 48 |
output = output[0]
|
| 49 |
output = output.squeeze(1)
|
| 50 |
output_hidden_states[local_indices] = output
|
| 51 |
-
|
| 52 |
return output_hidden_states.contiguous()
|
| 53 |
-
|
| 54 |
|
| 55 |
class LlamaAttention(nn.Module):
|
| 56 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
@@ -65,7 +69,7 @@ class LlamaAttention(nn.Module):
|
|
| 65 |
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 66 |
"when creating this class."
|
| 67 |
)
|
| 68 |
-
|
| 69 |
self.attention_dropout = config.attention_dropout
|
| 70 |
self.hidden_size = config.hidden_size
|
| 71 |
self.num_heads = config.num_attention_heads
|
|
@@ -145,7 +149,8 @@ class LlamaAttention(nn.Module):
|
|
| 145 |
kv_seq_len = key_states.shape[-2]
|
| 146 |
if past_key_value is not None:
|
| 147 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 148 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
| 149 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 150 |
|
| 151 |
if past_key_value is not None:
|
|
@@ -193,7 +198,7 @@ class LlamaAttention(nn.Module):
|
|
| 193 |
attn_weights = None
|
| 194 |
|
| 195 |
return attn_output, attn_weights, past_key_value
|
| 196 |
-
|
| 197 |
|
| 198 |
class LlamaFlashAttention2(LlamaAttention):
|
| 199 |
"""
|
|
@@ -248,7 +253,8 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
| 248 |
kv_seq_len = key_states.shape[-2]
|
| 249 |
if past_key_value is not None:
|
| 250 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 251 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
| 252 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 253 |
|
| 254 |
if past_key_value is not None:
|
|
@@ -446,7 +452,8 @@ class LlamaSdpaAttention(LlamaAttention):
|
|
| 446 |
kv_seq_len = key_states.shape[-2]
|
| 447 |
if past_key_value is not None:
|
| 448 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 449 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
| 450 |
|
| 451 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 452 |
|
|
@@ -596,7 +603,7 @@ def model_forward(
|
|
| 596 |
batch_size, seq_length, _ = inputs_embeds.shape
|
| 597 |
else:
|
| 598 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 599 |
-
|
| 600 |
seq_length_with_past = seq_length
|
| 601 |
past_key_values_length = 0
|
| 602 |
|
|
@@ -620,11 +627,11 @@ def model_forward(
|
|
| 620 |
attention_mask = torch.ones(
|
| 621 |
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 622 |
)
|
| 623 |
-
|
| 624 |
-
if self._use_flash_attention_2:
|
| 625 |
# 2d mask is passed through the layers
|
| 626 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 627 |
-
elif self._use_sdpa and not output_attentions:
|
| 628 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 629 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 630 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
@@ -814,7 +821,7 @@ def replace_llama_modality_adaptive():
|
|
| 814 |
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
|
| 815 |
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
|
| 816 |
|
| 817 |
-
|
| 818 |
if __name__ == "__main__":
|
| 819 |
replace_llama_modality_adaptive()
|
| 820 |
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
|
|
|
|
| 22 |
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
from transformers.utils import logging
|
| 24 |
|
| 25 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
| 26 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
| 27 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 28 |
+
|
| 29 |
+
#from .configuration_mplug_owl2 import LlamaConfig
|
| 30 |
+
|
| 31 |
|
| 32 |
class MultiwayNetwork(nn.Module):
|
| 33 |
|
|
|
|
| 35 |
super(MultiwayNetwork, self).__init__()
|
| 36 |
|
| 37 |
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
| 38 |
+
|
| 39 |
def forward(self, hidden_states, multiway_indices):
|
| 40 |
|
| 41 |
if len(self.multiway) == 1:
|
| 42 |
return self.multiway[0](hidden_states)
|
| 43 |
|
| 44 |
output_hidden_states = torch.empty_like(hidden_states)
|
| 45 |
+
|
| 46 |
for idx, subway in enumerate(self.multiway):
|
| 47 |
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
| 48 |
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
|
|
|
| 52 |
output = output[0]
|
| 53 |
output = output.squeeze(1)
|
| 54 |
output_hidden_states[local_indices] = output
|
| 55 |
+
|
| 56 |
return output_hidden_states.contiguous()
|
| 57 |
+
|
| 58 |
|
| 59 |
class LlamaAttention(nn.Module):
|
| 60 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
| 69 |
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 70 |
"when creating this class."
|
| 71 |
)
|
| 72 |
+
|
| 73 |
self.attention_dropout = config.attention_dropout
|
| 74 |
self.hidden_size = config.hidden_size
|
| 75 |
self.num_heads = config.num_attention_heads
|
|
|
|
| 149 |
kv_seq_len = key_states.shape[-2]
|
| 150 |
if past_key_value is not None:
|
| 151 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 152 |
+
#cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 153 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 154 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 155 |
|
| 156 |
if past_key_value is not None:
|
|
|
|
| 198 |
attn_weights = None
|
| 199 |
|
| 200 |
return attn_output, attn_weights, past_key_value
|
| 201 |
+
|
| 202 |
|
| 203 |
class LlamaFlashAttention2(LlamaAttention):
|
| 204 |
"""
|
|
|
|
| 253 |
kv_seq_len = key_states.shape[-2]
|
| 254 |
if past_key_value is not None:
|
| 255 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 256 |
+
#cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 257 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 258 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 259 |
|
| 260 |
if past_key_value is not None:
|
|
|
|
| 452 |
kv_seq_len = key_states.shape[-2]
|
| 453 |
if past_key_value is not None:
|
| 454 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 455 |
+
#cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 456 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 457 |
|
| 458 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 459 |
|
|
|
|
| 603 |
batch_size, seq_length, _ = inputs_embeds.shape
|
| 604 |
else:
|
| 605 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 606 |
+
|
| 607 |
seq_length_with_past = seq_length
|
| 608 |
past_key_values_length = 0
|
| 609 |
|
|
|
|
| 627 |
attention_mask = torch.ones(
|
| 628 |
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 629 |
)
|
| 630 |
+
|
| 631 |
+
if False: #self._use_flash_attention_2:
|
| 632 |
# 2d mask is passed through the layers
|
| 633 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 634 |
+
elif False: #self._use_sdpa and not output_attentions:
|
| 635 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 636 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 637 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
|
|
| 821 |
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
|
| 822 |
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
|
| 823 |
|
| 824 |
+
|
| 825 |
if __name__ == "__main__":
|
| 826 |
replace_llama_modality_adaptive()
|
| 827 |
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
|