# -*- coding: utf-8 -*- from __future__ import annotations import torch from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding as _OrigEmb def _patched_forward(self: _OrigEmb, input_ids: torch.LongTensor): return super(_OrigEmb, self).forward(input_ids).clone() _OrigEmb.forward = _patched_forward from typing import List, Optional, Tuple, Union, Callable from transformers import ( AutoModel, Cache, PreTrainedModel, PretrainedConfig, ) from transformers.generation import GenerationMixin from transformers.masking_utils import ( create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask, ) from transformers.models.gemma3.modeling_gemma3 import ( Gemma3CausalLMOutputWithPast, Gemma3RMSNorm, Gemma3PreTrainedModel, Gemma3ModelOutputWithPast, ) from transformers.utils import ( is_torchdynamo_compiling, logging, is_torch_flex_attn_available, ) try: from liger_kernel.transformers.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyLoss, ) except Exception: LigerFusedLinearCrossEntropyLoss = None from .configuration_gemma3_omni import Gemma3OmniConfig from .speech_conformer_encoder import ConformerEncoder logger = logging.get_logger(__name__) if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask class Gemma3AudioProjectorConfig(PretrainedConfig): model_type = "gemma3_audio" def __init__( self, hidden_size: int = 1024, num_hidden_layers: int = 24, sample_rate: int = 16_000, n_mels: int = 80, image_token_index: int = 0, n_fft: int = 400, hop_length: int = 160, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.sample_rate = sample_rate self.n_mels = n_mels self.image_token_index = image_token_index self.n_fft = n_fft self.hop_length = hop_length from torch import nn class LayerWiseWeightedSum(nn.Module): def __init__(self, num_layers: int, learnable: bool = True): super().__init__() self.num_layers = num_layers if learnable: self.scalar_weights = nn.Parameter(torch.zeros(num_layers)) else: self.register_buffer("scalar_weights", torch.zeros(num_layers)) def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor: if len(hidden_states) != self.num_layers: raise ValueError(f"Expected {self.num_layers} hidden states, but got {len(hidden_states)}") norm_weights = torch.softmax(self.scalar_weights, dim=0).view(-1, 1, 1, 1) stacked_states = torch.stack(hidden_states, dim=0) weighted_sum = (norm_weights * stacked_states).sum(dim=0) return weighted_sum class Gemma3AudioProjector(PreTrainedModel): config_class = Gemma3AudioProjectorConfig base_model_prefix = "audio_projector" def __init__(self, config: Gemma3AudioProjectorConfig): super().__init__(config) encoder_config = { "activation": "swish", "activation_checkpointing": "", "attention_dim": 1024, "attention_heads": 16, "batch_norm": False, "bias_in_glu": True, "causal": True, "chunk_size": -1, "conv_activation": "swish", "conv_glu_type": "swish", "depthwise_multiplier": 1, "depthwise_seperable_out_channel": 1024, "dropout_rate": 0.0, "encoder_embedding_config": { "input_size": config.n_mels }, "ext_pw_kernel_size": 1, "ext_pw_out_channel": 1024, "input_layer": "nemo_conv", "input_size": config.n_mels, "kernel_size": 3, "left_chunk": 18, "linear_units": 1536, "nemo_conv_settings": { "conv_channels": 1024, }, "num_blocks": 24, "relative_attention_bias_args": { "t5_bias_max_distance": 500, "type": "t5" }, "time_reduction": 8 } self.encoder = ConformerEncoder(**encoder_config) self.layer_weighter = LayerWiseWeightedSum( num_layers=encoder_config["num_blocks"] ) self.norm = Gemma3RMSNorm(encoder_config['attention_dim'], eps=1e-6) self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False) def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: _, out_mask, hidden_list = self.encoder(mel, mel_mask) features = self.layer_weighter(hidden_list) normalized_features = self.norm(features) projected_features = self.proj(normalized_features) return projected_features, out_mask class Gemma3VisionProjector(nn.Module): def __init__(self, config): super().__init__() self.mm_input_projection_weight = nn.Parameter( torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) ) self.mm_soft_emb_norm = Gemma3RMSNorm( config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps ) self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): b, _, seq_len = vision_outputs.shape x = vision_outputs.transpose(1, 2).reshape( b, seq_len, self.patches_per_image, self.patches_per_image ) x = self.avg_pool(x).flatten(2).transpose(1, 2) x = self.mm_soft_emb_norm(x) return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs) def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: if token_type_ids is None: return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return token_type_ids[batch_idx, kv_idx] != 0 return inner_mask class Gemma3OmniModel(Gemma3PreTrainedModel): config_class = Gemma3OmniConfig def __init__(self, config): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = Gemma3VisionProjector(config) self.audio_projector = Gemma3AudioProjector( Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size, n_mels=config.audio_config.n_mels, num_hidden_layers=config.audio_config.num_hidden_layers) ) self.vocab_size = config.text_config.vocab_size language_model = AutoModel.from_config(config=config.text_config) self.language_model = language_model self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, input_audio_embeds: Optional[torch.FloatTensor] = None, audio_attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **lm_kwargs, ) -> Union[Tuple, Gemma3ModelOutputWithPast]: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: llm_input_ids = input_ids if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(llm_input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if pixel_values is not None and past_key_values is None: vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) image_features = self.multi_modal_projector(vision_outputs.hidden_states[-1]) if input_ids is None: raise ValueError("`input_ids` are required when `pixel_values` are provided.") special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] raise ValueError( f"Number of images does not match number of special image tokens in the input text. " f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " "tokens from image embeddings." ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = torch.where(special_image_mask, image_features, inputs_embeds).contiguous() if input_audio_embeds is not None and past_key_values is None: audio_features, audio_feat_mask = self.audio_projector( input_audio_embeds, audio_attention_mask ) if input_ids is None: raise ValueError("`input_ids` are required when `input_audio_embeds` are provided.") special_audio_mask = ( input_ids == self.config.audio_token_index ).unsqueeze(-1) special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to( inputs_embeds.device ) if ( not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel() ): audio_tokens_in_text = special_audio_mask.sum(dim=1).sum(dim=0)[0] raise ValueError( f"Number of audio tokens in the text ({audio_tokens_in_text}) " f"≠ number of tokens from audio embeddings " f"({audio_features.shape[0] * audio_features.shape[1]})." ) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features.flatten(0, 1)) if not isinstance(attention_mask, dict): mask_kwargs = { "config": self.config.get_text_config(), "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, } if token_type_ids is not None and inputs_embeds.shape[1] != 1: mask_kwargs["or_mask_function"] = token_type_ids_mask_function( token_type_ids.to(cache_position.device) ) attention_mask = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **lm_kwargs, ) image_hidden_states = None if 'vision_outputs' in locals(): image_hidden_states = vision_outputs.hidden_states[-1] return Gemma3ModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values if use_cache else None, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_hidden_states, ) class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): config_class = Gemma3OmniConfig _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^vision_tower": "model.vision_tower", "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = Gemma3OmniModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, input_audio_embeds: Optional[torch.FloatTensor] = None, audio_attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, input_audio_embeds=input_audio_embeds, audio_attention_mask=audio_attention_mask, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **lm_kwargs, ) hidden_states = outputs[0] slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: if LigerFusedLinearCrossEntropyLoss is not None: shift_hidden_states = hidden_states[..., :-1, :] shift_labels = labels[..., 1:] hidden_device = shift_hidden_states.device if attention_mask is not None: shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1]:].to(hidden_device) shift_hidden_states = shift_hidden_states[shift_attention_mask != 0].contiguous() shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() else: shift_hidden_states = shift_hidden_states.contiguous() shift_labels = shift_labels.contiguous() shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size) shift_labels = shift_labels.view(-1).to(hidden_device) loss_fct = LigerFusedLinearCrossEntropyLoss() loss = loss_fct(self.lm_head.weight, shift_hidden_states, shift_labels) else: logits = logits.float() shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] if attention_mask is not None: shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device) shift_logits = shift_logits[shift_attention_mask != 0].contiguous() shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() else: shift_logits = shift_logits.contiguous() shift_labels = shift_labels.contiguous() flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(flat_logits, flat_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output image_hidden_states = outputs.image_hidden_states if return_dict else outputs[4] return Gemma3CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_hidden_states, ) __all__ = [ "Gemma3AudioProjectorConfig", "Gemma3AudioProjector", "Gemma3VisionProjector", "Gemma3OmniModel", "Gemma3OmniForConditionalGeneration", ]