liubangwei commited on
Commit
d37eb96
·
1 Parent(s): 1855cc2

init IDMR demo

Browse files
src/vlm_backbone/llava_next/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_llava_next import LlavaNextForConditionalGeneration
src/vlm_backbone/llava_next/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (255 Bytes). View file
 
src/vlm_backbone/llava_next/__pycache__/modeling_llava_next.cpython-310.pyc ADDED
Binary file (35.2 kB). View file
 
src/vlm_backbone/llava_next/modeling_llava_next.py ADDED
@@ -0,0 +1,987 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Llava-NeXT model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.generation import GenerationMixin
28
+ from transformers.image_processing_utils import select_best_resolution
29
+ from transformers.modeling_outputs import ModelOutput
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import (
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ replace_return_docstrings,
36
+ )
37
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
38
+ from transformers.models.llava_next.configuration_llava_next import LlavaNextConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CONFIG_FOR_DOC = "LlavaNextConfig"
44
+
45
+
46
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
47
+ """
48
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
49
+
50
+ Args:
51
+ image_size (`tuple`):
52
+ The size of the input image in the format (width, height).
53
+ grid_pinpoints (`List`):
54
+ A list containing possible resolutions. Each item in the list should be a tuple or list
55
+ of the form `(height, width)`.
56
+ patch_size (`int`):
57
+ The size of each image patch.
58
+
59
+ Returns:
60
+ tuple: The shape of the image patch grid in the format (width, height).
61
+ """
62
+ if not isinstance(grid_pinpoints, list):
63
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
64
+
65
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
66
+ if not isinstance(image_size, (list, tuple)):
67
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
68
+ raise TypeError(
69
+ f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor"
70
+ )
71
+ image_size = image_size.tolist()
72
+
73
+ height, width = select_best_resolution(image_size, grid_pinpoints)
74
+ return height // patch_size, width // patch_size
75
+
76
+
77
+ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
78
+ """
79
+ Calculate the number of patches after the preprocessing for images of any resolution.
80
+
81
+ Args:
82
+ image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
83
+ The size of the input image in the format (height, width). ?
84
+ grid_pinpoints (`List`):
85
+ A list containing possible resolutions. Each item in the list should be a tuple or list
86
+ of the form `(height, width)`.
87
+ patch_size (`int`):
88
+ The size of each image patch.
89
+
90
+ Returns:
91
+ int: the number of patches
92
+ """
93
+ if not isinstance(grid_pinpoints, list):
94
+ raise TypeError("grid_pinpoints should be a list of tuples or lists")
95
+
96
+ # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate
97
+ if not isinstance(image_size, (list, tuple)):
98
+ if not isinstance(image_size, (torch.Tensor, np.ndarray)):
99
+ raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}")
100
+ image_size = image_size.tolist()
101
+
102
+ best_resolution = select_best_resolution(image_size, grid_pinpoints)
103
+ height, width = best_resolution
104
+ num_patches = 0
105
+ # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
106
+ for i in range(0, height, patch_size):
107
+ for j in range(0, width, patch_size):
108
+ num_patches += 1
109
+ # add the base patch
110
+ num_patches += 1
111
+ return num_patches
112
+
113
+
114
+ def unpad_image(tensor, original_size):
115
+ """
116
+ Unpads a PyTorch tensor of a padded and resized image.
117
+
118
+ Args:
119
+ tensor (`torch.Tensor`):
120
+ The image tensor, assumed to be of shape (num_channels, height, width).
121
+ original_size (`tuple`):
122
+ The original size of the image (height, width).
123
+
124
+ Returns:
125
+ `torch.Tensor`: The unpadded image tensor.
126
+ """
127
+ if not isinstance(original_size, (list, tuple)):
128
+ if not isinstance(original_size, (torch.Tensor, np.ndarray)):
129
+ raise TypeError(
130
+ f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
131
+ )
132
+ original_size = original_size.tolist()
133
+ original_height, original_width = original_size
134
+ current_height, current_width = tensor.shape[1:]
135
+
136
+ original_aspect_ratio = original_width / original_height
137
+ current_aspect_ratio = current_width / current_height
138
+
139
+ if original_aspect_ratio > current_aspect_ratio:
140
+ scale_factor = current_width / original_width
141
+ new_height = int(original_height * scale_factor)
142
+ padding = (current_height - new_height) // 2
143
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
144
+ else:
145
+ scale_factor = current_height / original_height
146
+ new_width = int(original_width * scale_factor)
147
+ padding = (current_width - new_width) // 2
148
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
149
+
150
+ return unpadded_tensor
151
+
152
+
153
+ @dataclass
154
+ class LlavaNextCausalLMOutputWithPast(ModelOutput):
155
+ """
156
+ Base class for LlavaNext causal language model (or autoregressive) outputs.
157
+
158
+ Args:
159
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
160
+ Language modeling loss (for next-token prediction).
161
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
162
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
163
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
164
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
165
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
166
+
167
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
168
+ `past_key_values` input) to speed up sequential decoding.
169
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
170
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
171
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
172
+
173
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
174
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
175
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
176
+ sequence_length)`.
177
+
178
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
179
+ heads.
180
+ image_hidden_states (`torch.FloatTensor`, *optional*):
181
+ A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
182
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
183
+ """
184
+
185
+ loss: Optional[torch.FloatTensor] = None
186
+ logits: torch.FloatTensor = None
187
+ past_key_values: Optional[List[torch.FloatTensor]] = None
188
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
189
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
190
+ image_hidden_states: Optional[torch.FloatTensor] = None
191
+
192
+
193
+ # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
194
+ class LlavaNextMultiModalProjector(nn.Module):
195
+ def __init__(self, config: LlavaNextConfig):
196
+ super().__init__()
197
+
198
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
199
+ self.act = ACT2FN[config.projector_hidden_act]
200
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
201
+
202
+ def forward(self, image_features):
203
+ hidden_states = self.linear_1(image_features)
204
+ hidden_states = self.act(hidden_states)
205
+ hidden_states = self.linear_2(hidden_states)
206
+ return hidden_states
207
+
208
+
209
+ LLAVA_NEXT_START_DOCSTRING = r"""
210
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
211
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
212
+ etc.)
213
+
214
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
215
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
216
+ and behavior.
217
+
218
+ Parameters:
219
+ config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]):
220
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
221
+ load the weights associated with the model, only the configuration. Check out the
222
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
223
+ """
224
+
225
+
226
+ @add_start_docstrings(
227
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
228
+ LLAVA_NEXT_START_DOCSTRING,
229
+ )
230
+ # Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNext,llava->llava_next
231
+ class LlavaNextPreTrainedModel(PreTrainedModel):
232
+ config_class = LlavaNextConfig
233
+ base_model_prefix = "model"
234
+ supports_gradient_checkpointing = True
235
+ _no_split_modules = ["LlavaNextVisionAttention"]
236
+ _skip_keys_device_placement = "past_key_values"
237
+ _supports_flash_attn_2 = True
238
+ _supports_cache_class = True
239
+
240
+ def _init_weights(self, module):
241
+ # important: this ported version of LlavaNext isn't meant for training from scratch - only
242
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
243
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
244
+ std = (
245
+ self.config.initializer_range
246
+ if hasattr(self.config, "initializer_range")
247
+ else self.config.text_config.initializer_range
248
+ )
249
+
250
+ if hasattr(module, "class_embedding"):
251
+ module.class_embedding.data.normal_(mean=0.0, std=std)
252
+
253
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
254
+ module.weight.data.normal_(mean=0.0, std=std)
255
+ if module.bias is not None:
256
+ module.bias.data.zero_()
257
+ elif isinstance(module, nn.Embedding):
258
+ module.weight.data.normal_(mean=0.0, std=std)
259
+ if module.padding_idx is not None:
260
+ module.weight.data[module.padding_idx].zero_()
261
+
262
+ @property
263
+ def _supports_sdpa(self):
264
+ """
265
+ Retrieve language_model's attribute to check whether the model supports
266
+ SDPA or not.
267
+ """
268
+ return self.language_model._supports_sdpa
269
+
270
+
271
+ LLAVA_NEXT_INPUTS_DOCSTRING = r"""
272
+ Args:
273
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
274
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
275
+ it.
276
+
277
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
278
+ [`PreTrainedTokenizer.__call__`] for details.
279
+
280
+ [What are input IDs?](../glossary#input-ids)
281
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
282
+ The tensors corresponding to the input images. Pixel values can be obtained using
283
+ [`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses
284
+ [`LlavaNextImageProcessor`] for processing images.
285
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
286
+ The sizes of the images in the batch, being (height, width) for each image.
287
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
288
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
289
+
290
+ - 1 for tokens that are **not masked**,
291
+ - 0 for tokens that are **masked**.
292
+
293
+ [What are attention masks?](../glossary#attention-mask)
294
+
295
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
296
+ [`PreTrainedTokenizer.__call__`] for details.
297
+
298
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
299
+ `past_key_values`).
300
+
301
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
302
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
303
+ information on the default strategy.
304
+
305
+ - 1 indicates the head is **not masked**,
306
+ - 0 indicates the head is **masked**.
307
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
308
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
309
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
310
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
311
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
312
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
313
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
314
+
315
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
316
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
317
+
318
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
319
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
320
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
321
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
322
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
323
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
324
+ model's internal embedding lookup matrix.
325
+ vision_feature_layer (`int`, *optional*, defaults to -2):
326
+ The index of the layer to select the vision feature.
327
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
328
+ The feature selection strategy used to select the vision feature from the vision backbone.
329
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
330
+ If `"full"`, the full vision features are used.
331
+ use_cache (`bool`, *optional*):
332
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
333
+ `past_key_values`).
334
+ output_attentions (`bool`, *optional*):
335
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
336
+ tensors for more detail.
337
+ output_hidden_states (`bool`, *optional*):
338
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
339
+ more detail.
340
+ return_dict (`bool`, *optional*):
341
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
342
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
343
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
344
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
345
+ the complete sequence length.
346
+ """
347
+
348
+
349
+ @add_start_docstrings(
350
+ """The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
351
+ LLAVA_NEXT_START_DOCSTRING,
352
+ )
353
+ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
354
+ def __init__(self, config: LlavaNextConfig):
355
+ super().__init__(config)
356
+ self.vision_tower = AutoModel.from_config(config.vision_config)
357
+
358
+ self.multi_modal_projector = LlavaNextMultiModalProjector(config)
359
+ embed_std = 1 / math.sqrt(config.text_config.hidden_size)
360
+ self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
361
+
362
+ self.vocab_size = config.text_config.vocab_size
363
+ self.language_model = AutoModelForCausalLM.from_config(
364
+ config.text_config, attn_implementation=config._attn_implementation
365
+ )
366
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
367
+ self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
368
+ self.post_init()
369
+
370
+ @property
371
+ def padding_side(self):
372
+ return self._padding_side
373
+
374
+ @padding_side.setter
375
+ def padding_side(self, padding_side: str):
376
+ if padding_side not in ["left", "right"]:
377
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
378
+ self._padding_side = padding_side
379
+
380
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
381
+ def get_input_embeddings(self):
382
+ return self.language_model.get_input_embeddings()
383
+
384
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
385
+ def set_input_embeddings(self, value):
386
+ self.language_model.set_input_embeddings(value)
387
+
388
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
389
+ def get_output_embeddings(self):
390
+ return self.language_model.get_output_embeddings()
391
+
392
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
393
+ def set_output_embeddings(self, new_embeddings):
394
+ self.language_model.set_output_embeddings(new_embeddings)
395
+
396
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
397
+ def set_decoder(self, decoder):
398
+ self.language_model.set_decoder(decoder)
399
+
400
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
401
+ def get_decoder(self):
402
+ return self.language_model.get_decoder()
403
+
404
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
405
+ def tie_weights(self):
406
+ return self.language_model.tie_weights()
407
+
408
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
409
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
410
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
411
+ # update vocab size
412
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
413
+ self.vocab_size = model_embeds.num_embeddings
414
+ return model_embeds
415
+
416
+ def _merge_input_ids_with_image_features(
417
+ self,
418
+ image_features,
419
+ feature_lens,
420
+ inputs_embeds,
421
+ input_ids,
422
+ attention_mask,
423
+ position_ids=None,
424
+ labels=None,
425
+ image_token_index=None,
426
+ ignore_index=-100,
427
+ ):
428
+ """
429
+ Merge input_ids with with image features into final embeddings
430
+
431
+ Args:
432
+ image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
433
+ All vision vectors of all images in the batch
434
+ feature_lens (`torch.LongTensor` of shape `(num_images)`):
435
+ The length of visual embeddings of each image as stacked in `image_features`
436
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
437
+ Token embeddings before merging with visual embeddings
438
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
439
+ Input_ids of tokens, possibly filled with image token
440
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
441
+ Mask to avoid performing attention on padding token indices.
442
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
443
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
444
+ config.n_positions - 1]`.
445
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
446
+ :abels need to be recalculated to support training (if provided)
447
+ image_token_index (`int`, *optional*)
448
+ Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
449
+ ignore_index (`int`, *optional*)
450
+ Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
451
+ Returns:
452
+ final_embedding, final_attention_mask, position_ids, final_labels
453
+
454
+ Explanation:
455
+ each image has variable length embeddings, with length specified by feature_lens
456
+ image_features is concatenation of all visual embed vectors
457
+ task: fill each <image> with the correct number of visual embeddings
458
+ Example:
459
+ X (5 patches), Y (3 patches), Z (8)
460
+ X, Y are in the same sequence (in-context learning)
461
+ if right padding
462
+ input_ids: [
463
+ a b c d e f X g h i j k Y l m
464
+ o p q r Z s t u v _ _ _ _ _ _
465
+ ]
466
+ input_ids should be: [
467
+ a b c d e f X X X X X g h i j k Y Y Y l m
468
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
469
+ ]
470
+ labels should be: [
471
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
472
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
473
+ ]
474
+ elif left padding
475
+ input_ids: [
476
+ a b c d e f X g h i j k Y l m
477
+ _ _ _ _ _ _ o p q r Z s t u v
478
+ ]
479
+ input_ids should be: [
480
+ a b c d e f X X X X X g h i j k Y Y Y l m
481
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
482
+ ]
483
+ labels should be: [
484
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
485
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
486
+ ]
487
+ Edge cases:
488
+ * If tokens are same but image token sizes are different, then cannot infer left or right padding
489
+ ```python
490
+ cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
491
+ chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
492
+ prompts = [
493
+ "[INST] <image>\nWhat is shown in this image? [/INST]",
494
+ "[INST] <image>\nWhat is shown in this image? [/INST]",
495
+ ]
496
+ inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
497
+ chart_img has 2634 tokens, while cat_img has 2340 tokens
498
+ ```
499
+
500
+ input_ids: [
501
+ a b c d X g h
502
+ i j Y k l m n
503
+ ]
504
+ where X is 3 tokens while Y is 5, this mean after merge
505
+ if left-padding (batched generation)
506
+ input_ids should be: [
507
+ _ _ a b c d X X X g h
508
+ i j Y Y Y Y Y k l m n
509
+ ]
510
+ elif (right padding) (training)
511
+ input_ids should be: [
512
+ a b c d X X X g h _ _
513
+ i j Y Y Y Y Y k l m n
514
+ ]
515
+ """
516
+ image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
517
+ ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
518
+
519
+ if self.training and self.padding_side == "left":
520
+ logger.warning_once(
521
+ "Padding side is set to 'left' but the model is in training mode. For training "
522
+ "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. "
523
+ "If that's intended, ignore this warning"
524
+ )
525
+ if not self.training and self.padding_side == "right":
526
+ logger.warning_once(
527
+ "Padding side is set to 'right' but the model is in inference mode. For correct "
528
+ "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. "
529
+ "If that's intended, ignore this warning"
530
+ )
531
+
532
+ with torch.no_grad():
533
+ # ! in llava 1.6, number of patches is variable
534
+ num_images = feature_lens.size(0)
535
+ num_image_features, embed_dim = image_features.shape
536
+ if feature_lens.sum() != num_image_features:
537
+ raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
538
+ batch_size = input_ids.shape[0]
539
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
540
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
541
+
542
+ left_padding = self.padding_side == "left"
543
+ if batch_size > 1:
544
+ if _left_padding and _right_padding:
545
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
546
+ elif _right_padding and left_padding:
547
+ left_padding = False
548
+ elif _left_padding and not left_padding:
549
+ left_padding = True
550
+
551
+ # Whether to turn off right padding
552
+ # 1. Create a mask to know where special image tokens are
553
+ special_image_token_mask = input_ids == image_token_index
554
+ # special_image_token_mask: [bsz, seqlen]
555
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
556
+ # num_special_image_tokens: [bsz]
557
+ # Reserve for padding of num_images
558
+ total_num_special_image_tokens = torch.sum(special_image_token_mask)
559
+
560
+ # we have dummy images, so skip this assert
561
+ # if total_num_special_image_tokens != num_images:
562
+ # raise ValueError(
563
+ # f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
564
+ # )
565
+
566
+ # Compute the maximum embed dimension
567
+ # max_image_feature_lens is max_feature_lens per batch
568
+ feature_lens = feature_lens.to(input_ids.device)
569
+ feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
570
+ feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
571
+ embed_sequence_lengths = (
572
+ (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
573
+ )
574
+ max_embed_dim = embed_sequence_lengths.max()
575
+
576
+ batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
577
+ # 2. Compute the positions where text should be written
578
+ # Calculate new positions for text tokens in merged image-text sequence.
579
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
580
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
581
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
582
+ # ! instead of special_image_token_mask * (num_image_patches - 1)
583
+ # special_image_token_mask * (num_feature_len - 1)
584
+ special_image_token_mask = special_image_token_mask.long()
585
+ special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
586
+ new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
587
+ if left_padding:
588
+ # shift right token positions so that they are ending at the same number
589
+ # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
590
+ new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
591
+
592
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
593
+
594
+ # 3. Create the full embedding, already padded to the maximum position
595
+ final_embedding = torch.zeros(
596
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
597
+ )
598
+ final_attention_mask = torch.zeros(
599
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
600
+ )
601
+ final_input_ids = torch.full(
602
+ (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
603
+ )
604
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
605
+ # set the corresponding tensors into their correct target device.
606
+ target_device = inputs_embeds.device
607
+ batch_indices, non_image_indices, text_to_overwrite = (
608
+ batch_indices.to(target_device),
609
+ non_image_indices.to(target_device),
610
+ text_to_overwrite.to(target_device),
611
+ )
612
+ attention_mask = attention_mask.to(target_device)
613
+ input_ids = input_ids.to(target_device)
614
+
615
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
616
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
617
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
618
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
619
+ final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
620
+ final_labels = None
621
+ if labels is not None:
622
+ labels = labels.to(target_device)
623
+ final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
624
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
625
+
626
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
627
+ with torch.no_grad():
628
+ image_to_overwrite = torch.full(
629
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
630
+ )
631
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
632
+ embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
633
+ embed_indices = embed_indices.expand(batch_size, max_embed_dim)
634
+ embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
635
+
636
+ if left_padding:
637
+ # exclude padding on the left
638
+ max_embed_dim = max_embed_dim.to(target_device)
639
+ val = (max_embed_dim - embed_indices) <= embed_seq_lens
640
+ else:
641
+ # exclude padding on the right
642
+ val = embed_indices < embed_seq_lens
643
+ image_to_overwrite &= val
644
+
645
+ if image_to_overwrite.sum() != num_image_features:
646
+ raise ValueError(
647
+ f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
648
+ f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
649
+ f" the number of image given to the model is {num_images}. "
650
+ f"This prevents correct indexing and breaks batch generation."
651
+ )
652
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
653
+ final_attention_mask |= image_to_overwrite
654
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
655
+ return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
656
+
657
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
658
+ """
659
+ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
660
+
661
+ Args:
662
+ image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
663
+ List of image feature tensor, each contains all the visual feature of all patches.
664
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
665
+ Actual image size of each images (H, W).
666
+ vision_feature_select_strategy (`str`)
667
+ The feature selection strategy used to select the vision feature from the vision backbone.
668
+ image_newline (`torch.Tensor` of shape `(embed_dim)`)
669
+ New line embedding vector.
670
+ Returns:
671
+ image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
672
+ feature_lens (`List[int]`)
673
+ token length of each image in image_features
674
+ """
675
+ new_image_features = []
676
+ feature_lens = []
677
+ for image_idx, image_feature in enumerate(image_features):
678
+ if image_feature.shape[0] > 1:
679
+ base_image_feature = image_feature[0]
680
+ image_feature = image_feature[1:]
681
+ height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
682
+
683
+ if vision_feature_select_strategy == "default":
684
+ expected_num_patches = height * width
685
+ elif vision_feature_select_strategy == "full":
686
+ expected_num_patches = height * width + 1
687
+ if expected_num_patches != base_image_feature.shape[0]:
688
+ raise ValueError("The number of patches is not consistent with the image size.")
689
+
690
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
691
+ image_sizes[image_idx],
692
+ self.config.image_grid_pinpoints,
693
+ self.config.vision_config.image_size,
694
+ )
695
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
696
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
697
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
698
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
699
+ if image_newline is not None:
700
+ image_feature = torch.cat(
701
+ (
702
+ image_feature,
703
+ image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
704
+ ),
705
+ dim=-1,
706
+ )
707
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
708
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
709
+ else:
710
+ image_feature = image_feature[0]
711
+ if image_newline is not None:
712
+ image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
713
+ new_image_features.append(image_feature)
714
+ feature_lens.append(image_feature.size(0))
715
+ image_features = torch.cat(new_image_features, dim=0)
716
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
717
+ return image_features, feature_lens
718
+
719
+ @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
720
+ @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
721
+ def forward(
722
+ self,
723
+ input_ids: torch.LongTensor = None,
724
+ pixel_values: torch.FloatTensor = None,
725
+ image_sizes: Optional[torch.LongTensor] = None,
726
+ attention_mask: Optional[torch.Tensor] = None,
727
+ position_ids: Optional[torch.LongTensor] = None,
728
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
729
+ inputs_embeds: Optional[torch.FloatTensor] = None,
730
+ vision_feature_layer: Optional[int] = None,
731
+ vision_feature_select_strategy: Optional[str] = None,
732
+ labels: Optional[torch.LongTensor] = None,
733
+ use_cache: Optional[bool] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ cache_position: Optional[torch.LongTensor] = None,
738
+ num_logits_to_keep: int = 0,
739
+ ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
740
+ r"""
741
+ Args:
742
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
743
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
744
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
745
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
746
+
747
+ num_logits_to_keep (`int`, *optional*):
748
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
749
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
750
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
751
+
752
+ Returns:
753
+
754
+ Example:
755
+
756
+ ```python
757
+ >>> from PIL import Image
758
+ >>> import requests
759
+ >>> from transformers import AutoProcessor, LlavaNextForConditionalGeneration
760
+
761
+ >>> model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
762
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
763
+
764
+ >>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
765
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
766
+ >>> image = Image.open(requests.get(url, stream=True).raw)
767
+
768
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
769
+
770
+ >>> # Generate
771
+ >>> generate_ids = model.generate(**inputs, max_length=30)
772
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
773
+ "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
774
+ ```"""
775
+
776
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
777
+ output_hidden_states = (
778
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
779
+ )
780
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
781
+ vision_feature_layer = (
782
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
783
+ )
784
+ vision_feature_select_strategy = (
785
+ vision_feature_select_strategy
786
+ if vision_feature_select_strategy is not None
787
+ else self.config.vision_feature_select_strategy
788
+ )
789
+
790
+ if (input_ids is None) ^ (inputs_embeds is not None):
791
+ raise ValueError(
792
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
793
+ )
794
+
795
+ if pixel_values is not None and inputs_embeds is not None:
796
+ raise ValueError(
797
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
798
+ )
799
+
800
+ legacy_processing = False
801
+ has_image_input = pixel_values is not None and pixel_values.size(0) > 0 and pixel_values.norm() != 0
802
+ if inputs_embeds is None:
803
+ inputs_embeds = self.get_input_embeddings()(input_ids)
804
+
805
+ # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
806
+ # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
807
+ # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
808
+ # legacy_processing = (
809
+ # (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
810
+ # ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
811
+
812
+ legacy_processing = False # @ruimeng hardcode to False
813
+
814
+ if has_image_input:
815
+ # ! infer image_num_patches from image_sizes
816
+ image_num_patches = [
817
+ image_size_to_num_patches(
818
+ image_size=imsize,
819
+ grid_pinpoints=self.config.image_grid_pinpoints,
820
+ patch_size=self.config.vision_config.image_size,
821
+ )
822
+ for imsize in image_sizes
823
+ ]
824
+ # figure out if pixel_values is concatenated or stacked
825
+ if pixel_values.dim() == 5:
826
+ # stacking when input is (batch_size, num_patches, num_channels, height, width)
827
+ _pixel_values_list = [
828
+ pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
829
+ ]
830
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
831
+ elif pixel_values.dim() != 4:
832
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
833
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
834
+
835
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
836
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
837
+ if vision_feature_select_strategy == "default":
838
+ selected_image_feature = selected_image_feature[:, 1:]
839
+ elif vision_feature_select_strategy == "full":
840
+ selected_image_feature = selected_image_feature
841
+ image_features = self.multi_modal_projector(selected_image_feature)
842
+ image_features = torch.split(image_features, image_num_patches, dim=0)
843
+
844
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
845
+ image_features, feature_lens = self.pack_image_features(
846
+ image_features,
847
+ image_sizes,
848
+ vision_feature_select_strategy=vision_feature_select_strategy,
849
+ image_newline=self.image_newline,
850
+ )
851
+ if legacy_processing:
852
+ logger.warning_once(
853
+ "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
854
+ "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
855
+ "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
856
+ "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
857
+ )
858
+ if input_ids.shape[1] != 1:
859
+ inputs_embeds = inputs_embeds.to(image_features.dtype)
860
+ inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
861
+ image_features,
862
+ feature_lens,
863
+ inputs_embeds,
864
+ input_ids,
865
+ attention_mask,
866
+ position_ids,
867
+ labels=labels,
868
+ )
869
+ cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
870
+ else:
871
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
872
+ # that are set to 0
873
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
874
+
875
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
876
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
877
+
878
+ # Get the target length
879
+ target_length = input_ids.shape[1]
880
+ past_length = first_layer_past_key_value.shape[-1]
881
+
882
+ extended_attention_mask = torch.ones(
883
+ (attention_mask.shape[0], past_length),
884
+ dtype=attention_mask.dtype,
885
+ device=attention_mask.device,
886
+ )
887
+
888
+ # Filter out only the tokens that can be un-attended, this can happen
889
+ # if one uses Llava + Fused modules where the cache on the
890
+ # first iteration is already big enough, or if one passes custom cache
891
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
892
+ new_batch_index = batch_index[valid_indices]
893
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
894
+
895
+ # Zero-out the places where we don't need to attend
896
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
897
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
898
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
899
+ cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
900
+ -target_length:
901
+ ]
902
+
903
+ # TODO: @raushan retain only the new behavior after v4.47
904
+ else:
905
+ special_image_mask = (
906
+ (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
907
+ )
908
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
909
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
910
+ outputs = self.language_model(
911
+ attention_mask=attention_mask,
912
+ position_ids=position_ids,
913
+ past_key_values=past_key_values,
914
+ inputs_embeds=inputs_embeds,
915
+ use_cache=use_cache,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ return_dict=return_dict,
919
+ cache_position=cache_position,
920
+ num_logits_to_keep=num_logits_to_keep,
921
+ )
922
+
923
+ logits = outputs[0]
924
+
925
+ loss = None
926
+ if labels is not None:
927
+ # Shift so that tokens < n predict n
928
+ if attention_mask is not None:
929
+ shift_attention_mask = attention_mask[..., 1:]
930
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
931
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
932
+ else:
933
+ shift_logits = logits[..., :-1, :].contiguous()
934
+ shift_labels = labels[..., 1:].contiguous()
935
+ # Flatten the tokens
936
+ loss_fct = nn.CrossEntropyLoss()
937
+ loss = loss_fct(
938
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
939
+ )
940
+
941
+ if not return_dict:
942
+ output = (logits,) + outputs[1:]
943
+ return (loss,) + output if loss is not None else output
944
+
945
+ return LlavaNextCausalLMOutputWithPast(
946
+ loss=loss,
947
+ logits=logits,
948
+ past_key_values=outputs.past_key_values,
949
+ hidden_states=outputs.hidden_states,
950
+ attentions=outputs.attentions,
951
+ image_hidden_states=image_features if has_image_input else None,
952
+ )
953
+
954
+ def prepare_inputs_for_generation(
955
+ self,
956
+ input_ids,
957
+ past_key_values=None,
958
+ inputs_embeds=None,
959
+ pixel_values=None,
960
+ image_sizes=None,
961
+ attention_mask=None,
962
+ cache_position=None,
963
+ num_logits_to_keep=None,
964
+ **kwargs,
965
+ ):
966
+ legacy_processing = (
967
+ input_ids is not None
968
+ and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
969
+ )
970
+
971
+ model_inputs = self.language_model.prepare_inputs_for_generation(
972
+ input_ids,
973
+ past_key_values=past_key_values,
974
+ inputs_embeds=inputs_embeds,
975
+ attention_mask=attention_mask,
976
+ cache_position=cache_position,
977
+ num_logits_to_keep=num_logits_to_keep,
978
+ **kwargs,
979
+ )
980
+
981
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
982
+ # Otherwise we need pixel values to be passed to model
983
+ if legacy_processing or cache_position[0] == 0:
984
+ model_inputs["pixel_values"] = pixel_values
985
+ model_inputs["image_sizes"] = image_sizes
986
+
987
+ return model_inputs
src/vlm_backbone/llava_next/processing_llava_next.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for LLaVa-NeXT.
17
+ """
18
+
19
+ from typing import List, Union
20
+
21
+ from transformers.feature_extraction_utils import BatchFeature
22
+ from transformers.image_processing_utils import select_best_resolution
23
+ from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
24
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
25
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
33
+ _defaults = {
34
+ "text_kwargs": {
35
+ "padding": False,
36
+ },
37
+ "images_kwargs": {
38
+ "do_pad": True,
39
+ },
40
+ }
41
+
42
+
43
+ class LlavaNextProcessor(ProcessorMixin):
44
+ r"""
45
+ Constructs a LLaVa-NeXT processor which wraps a LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.
46
+
47
+ [`LlavaNextProcessor`] offers all the functionalities of [`LlavaNextImageProcessor`] and [`LlamaTokenizerFast`]. See the
48
+ [`~LlavaNextProcessor.__call__`] and [`~LlavaNextProcessor.decode`] for more information.
49
+
50
+ Args:
51
+ image_processor ([`LlavaNextImageProcessor`], *optional*):
52
+ The image processor is a required input.
53
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
54
+ The tokenizer is a required input.
55
+ patch_size (`int`, *optional*):
56
+ Patch size from the vision tower.
57
+ vision_feature_select_strategy (`str`, *optional*):
58
+ The feature selection strategy used to select the vision feature from the vision backbone.
59
+ Shoudl be same as in model's config
60
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
61
+ in a chat into a tokenizable string.
62
+ image_token (`str`, *optional*, defaults to `"<image>"`):
63
+ Special token used to denote image location.
64
+ num_additional_image_tokens (`int`, *optional*, defaults to 0):
65
+ Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
66
+ extra tokens appended, no need to set this arg.
67
+ """
68
+
69
+ attributes = ["image_processor", "tokenizer"]
70
+ valid_kwargs = [
71
+ "chat_template",
72
+ "patch_size",
73
+ "vision_feature_select_strategy",
74
+ "image_token",
75
+ "num_additional_image_tokens",
76
+ ]
77
+ image_processor_class = "AutoImageProcessor"
78
+ tokenizer_class = "AutoTokenizer"
79
+
80
+ def __init__(
81
+ self,
82
+ image_processor=None,
83
+ tokenizer=None,
84
+ patch_size=None,
85
+ vision_feature_select_strategy=None,
86
+ chat_template=None,
87
+ image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
88
+ num_additional_image_tokens=0,
89
+ **kwargs,
90
+ ):
91
+ self.patch_size = patch_size
92
+ self.num_additional_image_tokens = num_additional_image_tokens
93
+ self.vision_feature_select_strategy = vision_feature_select_strategy
94
+ self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
95
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
96
+
97
+ def __call__(
98
+ self,
99
+ images: ImageInput = None,
100
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
101
+ audio=None,
102
+ videos=None,
103
+ **kwargs: Unpack[LlavaNextProcessorKwargs],
104
+ ) -> BatchFeature:
105
+ """
106
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
107
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
108
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
109
+ LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
110
+ of the above two methods for more information.
111
+
112
+ Args:
113
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
114
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
115
+ tensor. Both channels-first and channels-last formats are supported.
116
+ text (`str`, `List[str]`, `List[List[str]]`):
117
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
118
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
119
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
120
+
121
+ Returns:
122
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
123
+
124
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
125
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
126
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
127
+ `None`).
128
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
129
+ """
130
+ if images is None and text is None:
131
+ raise ValueError("You have to specify at least images or text.")
132
+ # check if images and text inputs are reversed for BC
133
+ images, text = _validate_images_text_input_order(images, text)
134
+
135
+ output_kwargs = self._merge_kwargs(
136
+ LlavaNextProcessorKwargs,
137
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
138
+ **kwargs,
139
+ )
140
+ if images is not None:
141
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
142
+ else:
143
+ image_inputs = {}
144
+
145
+ if isinstance(text, str):
146
+ text = [text]
147
+ elif not isinstance(text, list) and not isinstance(text[0], str):
148
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
149
+
150
+ prompt_strings = text
151
+ if image_inputs:
152
+ if self.patch_size is None or self.vision_feature_select_strategy is None:
153
+ logger.warning_once(
154
+ "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
155
+ "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
156
+ "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
157
+ "Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
158
+ )
159
+ else:
160
+ image_sizes = iter(image_inputs["image_sizes"])
161
+ height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
162
+ prompt_strings = []
163
+ for sample in text:
164
+ while self.image_token in sample:
165
+ image_size = next(image_sizes)
166
+ if not isinstance(image_size, (list, tuple)):
167
+ # cast to list to avoid numerical precision errors when calculating unpadding
168
+ image_size = image_size.tolist()
169
+ orig_height, orig_width = image_size
170
+ num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
171
+ if self.vision_feature_select_strategy == "default":
172
+ num_image_tokens -= self.num_additional_image_tokens
173
+ sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
174
+ prompt_strings.append(sample)
175
+ prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
176
+
177
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
178
+
179
+ return BatchFeature(data={**text_inputs, **image_inputs})
180
+
181
+ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
182
+ image_grid_pinpoints = self.image_processor.image_grid_pinpoints
183
+
184
+ height_best_resolution, width_best_resolution = select_best_resolution(
185
+ [orig_height, orig_width], image_grid_pinpoints
186
+ )
187
+ scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
188
+
189
+ patches_height = height // self.patch_size
190
+ patches_width = width // self.patch_size
191
+ unpadded_features, newline_features = self._get_unpadded_features(
192
+ orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
193
+ )
194
+ # The base patch covers the entire image (+1 for the CLS)
195
+ base_features = patches_height * patches_width + self.num_additional_image_tokens
196
+ num_image_tokens = unpadded_features + newline_features + base_features
197
+ return num_image_tokens
198
+
199
+ def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
200
+ """
201
+ Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
202
+ because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
203
+ patches an image is divided into and get the number of features from that.
204
+ """
205
+ current_height = patches_height * scale_height
206
+ current_width = patches_width * scale_width
207
+
208
+ original_aspect_ratio = width / height
209
+ current_aspect_ratio = current_width / current_height
210
+ if original_aspect_ratio > current_aspect_ratio:
211
+ new_height = (height * current_width) // width
212
+ padding = (current_height - new_height) // 2
213
+ current_height -= padding * 2
214
+ else:
215
+ new_width = (width * current_height) // height
216
+ padding = (current_width - new_width) // 2
217
+ current_width -= padding * 2
218
+
219
+ unpadded_features = current_height * current_width
220
+ newline_features = current_height
221
+ return (unpadded_features, newline_features)
222
+
223
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
224
+ def batch_decode(self, *args, **kwargs):
225
+ """
226
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
227
+ refer to the docstring of this method for more information.
228
+ """
229
+ return self.tokenizer.batch_decode(*args, **kwargs)
230
+
231
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
232
+ def decode(self, *args, **kwargs):
233
+ """
234
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
235
+ the docstring of this method for more information.
236
+ """
237
+ return self.tokenizer.decode(*args, **kwargs)
238
+
239
+ @property
240
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
241
+ def model_input_names(self):
242
+ tokenizer_input_names = self.tokenizer.model_input_names
243
+ image_processor_input_names = self.image_processor.model_input_names
244
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))