from typing import Optional, Union, Dict, Any from transformers import Gemma3TextConfig, SiglipVisionConfig, PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class Gemma3OmniConfig(PretrainedConfig): model_type = "gemma3omni" attribute_map = { "image_token_id": "image_token_index", "audio_token_id": "audio_token_index", "boi_token_id": "boi_token_index", "eoi_token_id": "eoi_token_index", } sub_configs = { "text_config": Gemma3TextConfig, "vision_config": SiglipVisionConfig, } def __init__( self, text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None, vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None, mm_tokens_per_image: int = 256, boi_token_index: int = 255_999, eoi_token_index: int = 256_000, image_token_index: int = 262_152, audio_token_index: int = 262_151, initializer_range: float = 0.02, **kwargs, ): if text_config is None: text_config = Gemma3TextConfig() logger.info("text_config is None, using default Gemma3TextConfig text config.") elif isinstance(text_config, dict): text_config = Gemma3TextConfig(**text_config) if isinstance(vision_config, dict): vision_config = SiglipVisionConfig(**vision_config) elif vision_config is None: vision_config = SiglipVisionConfig() logger.info("vision_config is None, using default SiglipVisionConfig vision config.") self.text_config = text_config self.vision_config = vision_config self.mm_tokens_per_image = mm_tokens_per_image self.boi_token_index = boi_token_index self.eoi_token_index = eoi_token_index self.image_token_index = image_token_index self.audio_token_index = audio_token_index self.initializer_range = initializer_range super().__init__(**kwargs)