| | from typing import Any, Callable |
| | from typing import cast as type_cast |
| |
|
| | import torch |
| | from transformers.cache_utils import DynamicCache |
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.generation.utils import GenerateOutput |
| | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
| | Qwen2_5_VisionTransformerPretrainedModel, |
| | ) |
| |
|
| | from .image_encoder import Qwen25VLEncoder |
| | from .configuration_helium1_casa import Helium1CASAConfig |
| | from .language_helium1_casa import ( |
| | CausalHeliumOutput, |
| | Helium1CASAAttention, |
| | Helium1ForCausalLM, |
| | Helium1RMSNorm, |
| | ) |
| |
|
| |
|
| | def meta_project( |
| | logits: torch.Tensor | list[torch.Tensor], |
| | projector: torch.nn.Module, |
| | norm: torch.nn.Module | None = None, |
| | ) -> torch.Tensor | list[torch.Tensor]: |
| | """Projection operation that handles both tensors and list of tensors |
| | |
| | Outputs either a (N, S, D) tensors (same resolution images) or a list of N (S, D) tensors (where |
| | S can be a different sequence length per image) |
| | """ |
| | split_sizes: list[int] | None = None |
| | if not isinstance(logits, torch.Tensor): |
| | split_sizes = [_x.shape[0] for _x in logits] |
| | logits = torch.cat(logits, dim=0)[None, :, :] |
| | logits = type_cast(torch.Tensor, logits) |
| | logits = projector(logits) |
| |
|
| | assert isinstance(logits, torch.Tensor) |
| | if norm is not None: |
| | logits = norm(logits) |
| | if split_sizes is not None: |
| | return list(torch.split(type_cast(torch.Tensor, logits[0]), split_sizes, dim=0)) |
| | return logits |
| |
|
| |
|
| | class ImageProjection(torch.nn.Module): |
| | """Takes in a batch or sequence of images and returns embeddings |
| | which are then fed to the LM. |
| | |
| | :param config: KyuteyeConfig object |
| | :param lm_model_dim: Output dimension (number of channels) for this module |
| | """ |
| |
|
| | def __init__(self, config: PretrainedConfig, lm_model_dim: int) -> None: |
| | super().__init__() |
| | self.config = config |
| | self.out_dim = lm_model_dim |
| | visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) |
| |
|
| | self.enc = Qwen25VLEncoder(visual=visual) |
| | |
| | self.proj_extra = self.init_proj_module() |
| | |
| | self.norm_extra = Helium1RMSNorm(self.out_dim) |
| |
|
| | def init_proj_module(self) -> torch.nn.Module: |
| | """Init the project module for the inserted and/or cross-attended image tokens""" |
| | if self.config.vision_config.out_dim == self.out_dim: |
| | return torch.nn.Identity() |
| | return torch.nn.Linear(self.config.vision_config.out_dim, self.out_dim) |
| |
|
| | def forward( |
| | self, x: torch.Tensor | list[torch.Tensor] |
| | ) -> dict[ |
| | str, |
| | torch.Tensor | list[torch.Tensor], |
| | ]: |
| | """Image embedding mapping |
| | |
| | :param x: Either a tensor with shape (Bi, C, H, W) or a list of Bi tensors |
| | with shape (C, H, W) (or (H, W, C) in the case of Qwen) |
| | |
| | :return: Either a tensor with shape (num_total_image, S, D) or, if images |
| | can have different seq length, a list of `num_total_images` Tensors with shape |
| | (S, D) |
| | """ |
| |
|
| | |
| | og_dtype = x[0].dtype |
| | encoded = self.enc(x)["image_embeds"] |
| | encoded = [_x.to(og_dtype) for _x in encoded] |
| | if all(x.shape[0] == encoded[0].shape[0] for x in encoded): |
| | encoded = torch.stack(encoded, dim=0) |
| |
|
| | |
| | image_embeds = meta_project(encoded, self.proj_extra, self.norm_extra) |
| |
|
| | |
| | return {"image_embeds": image_embeds} |
| |
|
| |
|
| | class V2Helium1(Helium1ForCausalLM): |
| | config_class = Helium1CASAConfig |
| |
|
| | def __init__(self, config: Helium1CASAConfig, **kwargs: Any) -> None: |
| | del kwargs |
| | super().__init__(config) |
| | self.image_prefix = ImageProjection(config=config, lm_model_dim=self.token_dim) |
| |
|
| | def get_device(self) -> str: |
| | """Return the device type of the model""" |
| | return next(self.parameters()).device.type |
| |
|
| | @property |
| | def token_dim(self) -> int: |
| | """Returns the number of dimensions for the token representation""" |
| | return self.config.hidden_size |
| |
|
| | @property |
| | def rotary_embed(self) -> Callable: |
| | """Returns the rotary embedding function of the underlying model""" |
| | return self.model.rotary_emb |
| |
|
| | def _update_model_kwargs_for_generation( |
| | self, |
| | outputs: Any, |
| | model_kwargs: dict[str, Any], |
| | is_encoder_decoder: bool = False, |
| | num_new_tokens: int = 1, |
| | ): |
| | """This is required to handle multiple gen calls for subtitles""" |
| | |
| | model_kwargs = super()._update_model_kwargs_for_generation( |
| | outputs, model_kwargs, is_encoder_decoder, num_new_tokens |
| | ) |
| | |
| | model_kwargs["__is_first_gen_call__"] = False |
| | return model_kwargs |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.Tensor, |
| | past_key_values: DynamicCache | None = None, |
| | **kwargs: Any, |
| | ): |
| | __is_first_gen_call__ = kwargs.get("__is_first_gen_call__", True) |
| | if past_key_values is not None and ( |
| | kwargs.get("cache_position") is None |
| | or type_cast(torch.Tensor, kwargs.get("cache_position")).shape[0] == 0 |
| | ): |
| | |
| | past_length = past_key_values._seen_tokens |
| | kwargs["cache_position"] = torch.arange( |
| | past_length, |
| | past_length + (input_ids.shape[1] if __is_first_gen_call__ else 1), |
| | dtype=torch.long, |
| | device=input_ids.device, |
| | ) |
| |
|
| | return super().prepare_inputs_for_generation( |
| | type_cast(torch.LongTensor, input_ids), |
| | past_key_values=past_key_values, |
| | **kwargs, |
| | ) |
| |
|
| | def prepare_multimodal_inputs( |
| | self, |
| | |
| | input_ids: torch.Tensor | None = None, |
| | inputs_embeds: torch.Tensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | image_embeds_insertion_points: list[torch.Tensor] | None = None, |
| | labels: torch.Tensor | None = None, |
| | |
| | pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
| | pre_image_tokens: list[int] | None = None, |
| | post_image_tokens: list[int] | None = None, |
| | **_kwargs: Any, |
| | ) -> dict: |
| | """Get a batch data mixing text and image data""" |
| | del _kwargs |
| |
|
| | processed_inputs = { |
| | "input_ids": input_ids, |
| | "inputs_embeds": inputs_embeds, |
| | "labels": labels, |
| | "attention_mask": attention_mask, |
| | "image_embeds_insertion_points": image_embeds_insertion_points, |
| | } |
| | if pixel_values is not None: |
| | processed_inputs.update(self.image_prefix(pixel_values)) |
| | assert "image_embeds" in processed_inputs |
| | assert ( |
| | isinstance(processed_inputs["image_embeds"], torch.Tensor) |
| | and processed_inputs["image_embeds"].ndim == 3 |
| | ) or ( |
| | isinstance(processed_inputs["image_embeds"], list) |
| | and all(_x.ndim == 2 for _x in processed_inputs["image_embeds"]) |
| | ) |
| |
|
| | |
| | processed_inputs["casa_windows_info"] = { |
| | "num_post_image_tokens": 0 if post_image_tokens is None else len(post_image_tokens), |
| | "num_pre_image_tokens": 0 if pre_image_tokens is None else len(pre_image_tokens), |
| | } |
| |
|
| | return processed_inputs |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor | None = None, |
| | inputs_embeds: torch.Tensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
| | return_loss: bool = True, |
| | labels: torch.Tensor | None = None, |
| | image_embeds_insertion_points: list[torch.Tensor] | None = None, |
| | pre_image_tokens: list[int] | None = None, |
| | post_image_tokens: list[int] | None = None, |
| | **kwargs: Any, |
| | ) -> CausalHeliumOutput: |
| | """Multi modal forward pass""" |
| | assert input_ids is not None or inputs_embeds is not None |
| |
|
| | if self.training: |
| | assert return_loss is True, ( |
| | "Helium models always compute its own labels/losses in train mode" |
| | ) |
| |
|
| | |
| | if kwargs.get("__is_first_gen_call__", True): |
| | processed_inputs = self.prepare_multimodal_inputs( |
| | input_ids=input_ids, |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | image_embeds_insertion_points=image_embeds_insertion_points, |
| | pixel_values=pixel_values, |
| | labels=labels, |
| | pre_image_tokens=pre_image_tokens, |
| | post_image_tokens=post_image_tokens, |
| | ) |
| | processed_inputs.pop("inputs_embeds", None) |
| | else: |
| | processed_inputs = { |
| | "inputs_embeds": self.model.embed_tokens(input_ids), |
| | "attention_mask": attention_mask, |
| | } |
| |
|
| | |
| | |
| | if ( |
| | not self.config.casa_attention |
| | and (cp := kwargs.get("cache_position", None)) is not None |
| | and pixel_values is not None |
| | ): |
| | start = kwargs["cache_position"][0].item() |
| | num_image_tokens = (pixel_values[0].shape[0] * pixel_values[0].shape[1]) // 4 |
| | num_tokens = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
| | kwargs["cache_position"] = torch.arange( |
| | start + (0 if kwargs.get("__is_first_gen_call__", True) else num_image_tokens), |
| | start + num_tokens + num_image_tokens, |
| | dtype=cp.dtype, |
| | device=cp.device, |
| | ) |
| |
|
| | kwargs.pop("__is_first_gen_call__", True) |
| | out = super().forward( |
| | **processed_inputs, |
| | **kwargs, |
| | ) |
| |
|
| | return out |
| |
|
| | @torch.no_grad() |
| | def generate_from_image( |
| | self, |
| | input_ids: torch.Tensor | None = None, |
| | inputs_embeds: torch.Tensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | image_embeds_insertion_points: list[torch.Tensor] | None = None, |
| | pixel_values: torch.Tensor | list[torch.Tensor] | None = None, |
| | reset_streaming: bool = True, |
| | **kwargs: Any, |
| | ) -> "GenerateOutput | torch.LongTensor": |
| | assert input_ids is not None and inputs_embeds is None, ( |
| | "Input IDs must be provided for generation" |
| | ) |
| |
|
| | |
| | if kwargs.get("past_key_values", None) is None: |
| | kwargs["past_key_values"] = DynamicCache() |
| |
|
| | |
| | if kwargs.get("pad_token_id", None) is None: |
| | kwargs["pad_token_id"] = kwargs.get("eos_token_id", None) |
| | if isinstance(kwargs["pad_token_id"], (list, tuple)): |
| | kwargs["pad_token_id"] = kwargs["pad_token_id"][0] |
| |
|
| | self.start_casa_streaming_states() |
| | outputs = self.generate( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | pixel_values=pixel_values, |
| | image_embeds_insertion_points=image_embeds_insertion_points, |
| | use_cache=True, |
| | **kwargs, |
| | ) |
| | if reset_streaming: |
| | self.reset_casa_streaming_states() |
| | return outputs |
| |
|
| | def reset_casa_streaming_states(self, clean_cache: bool = True) -> None: |
| | def __reset__(m: torch.nn.Module): |
| | if isinstance(m, Helium1CASAAttention): |
| | m._set_streaming(False, ()) |
| | m.reset_streaming() |
| | if clean_cache: |
| | del m.streaming_state.k |
| | del m.streaming_state.v |
| | del m.streaming_state.casa_handler |
| |
|
| | self.apply(__reset__) |
| |
|
| | def start_casa_streaming_states(self) -> None: |
| | def __start__(m: torch.nn.Module): |
| | if isinstance(m, Helium1CASAAttention): |
| | m._set_streaming(True, ()) |
| |
|
| | self.apply(__start__) |
| |
|