sagar007 commited on
Commit
f4f545d
·
verified ·
1 Parent(s): 18b63c5

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -12,7 +12,19 @@ from PIL import Image
12
  import io
13
  import time
14
  import logging
15
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Model imports
18
  from src.models import MultimodalGemmaLightning
@@ -39,103 +51,38 @@ def download_and_load_model():
39
  cache_dir="./model_cache"
40
  )
41
 
42
- print("📁 Loading model from checkpoint...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Load checkpoint data to inspect what's inside
45
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
46
- print(f"Checkpoint keys: {list(checkpoint.keys())}")
47
-
48
- # Extract the saved hyperparameters if they exist
49
- if "hyper_parameters" in checkpoint:
50
- saved_config = checkpoint["hyper_parameters"].get("config", {})
51
- print("Found saved config in checkpoint")
52
- # Override any gated models in the saved config
53
- if "model" in saved_config and "gemma_model_name" in saved_config["model"]:
54
- if "google/gemma" in saved_config["model"]["gemma_model_name"]:
55
- print("Replacing gated Gemma model with accessible alternative")
56
- saved_config["model"]["gemma_model_name"] = "microsoft/DialoGPT-medium"
57
- saved_config["model"]["use_4bit"] = False # Disable quantization for compatibility
58
- config = saved_config
59
- else:
60
- print("No saved config found, creating minimal config")
61
- # Create minimal config for loading
62
- config = {
63
- "model": {
64
- "gemma_model_name": "microsoft/DialoGPT-medium", # Use non-gated model
65
- "vision_model_name": "openai/clip-vit-large-patch14",
66
- "use_4bit": False, # Disable quantization for loading
67
- "projector_hidden_dim": 2048,
68
- "lora": {"r": 16, "alpha": 32, "dropout": 0.1}
69
- },
70
- "special_tokens": {"image_token": "<image>"},
71
- "training": {"projector_lr": 1e-3, "lora_lr": 1e-4}
72
- }
73
-
74
- try:
75
- # First try: Use the checkpoint's config if available
76
- model = MultimodalGemmaLightning.load_from_checkpoint(
77
- checkpoint_path,
78
- config=config,
79
- strict=False,
80
- map_location="cuda" if torch.cuda.is_available() else "cpu"
81
- )
82
- print("✅ Loaded with checkpoint config")
83
- except Exception as e1:
84
- print(f"Failed with checkpoint config: {e1}")
85
- try:
86
- # Second try: Minimal config with no quantization
87
- minimal_config = {
88
- "model": {
89
- "gemma_model_name": "microsoft/DialoGPT-small", # Even smaller model
90
- "vision_model_name": "openai/clip-vit-base-patch32", # Smaller CLIP
91
- "use_4bit": False, # No quantization
92
- "projector_hidden_dim": 512,
93
- "lora": {"r": 8, "alpha": 16, "dropout": 0.1, "target_modules": ["q_proj", "v_proj"]}
94
- },
95
- "special_tokens": {"image_token": "<image>"},
96
- "training": {"projector_lr": 1e-3, "lora_lr": 1e-4}
97
- }
98
- model = MultimodalGemmaLightning.load_from_checkpoint(
99
- checkpoint_path,
100
- config=minimal_config,
101
- strict=False,
102
- map_location="cuda" if torch.cuda.is_available() else "cpu"
103
- )
104
- print("✅ Loaded with minimal config")
105
- except Exception as e2:
106
- print(f"Failed with minimal config: {e2}")
107
- try:
108
- # Third try: Direct state dict loading
109
- print("Attempting direct state dict loading...")
110
- # Create a dummy model just to get the structure
111
- dummy_config = {
112
- "model": {
113
- "gemma_model_name": "microsoft/DialoGPT-small",
114
- "vision_model_name": "openai/clip-vit-base-patch32",
115
- "use_4bit": False,
116
- "projector_hidden_dim": 512,
117
- },
118
- "special_tokens": {"image_token": "<image>"},
119
- "training": {"projector_lr": 1e-3, "lora_lr": 1e-4}
120
- }
121
- model = MultimodalGemmaLightning(dummy_config)
122
-
123
- # Load only compatible weights
124
- checkpoint_state = checkpoint['state_dict']
125
- model_state = model.state_dict()
126
-
127
- # Filter and load compatible weights
128
- compatible_weights = {}
129
- for key, value in checkpoint_state.items():
130
- if key in model_state and model_state[key].shape == value.shape:
131
- compatible_weights[key] = value
132
-
133
- model.load_state_dict(compatible_weights, strict=False)
134
- print(f"✅ Loaded {len(compatible_weights)} compatible weights")
135
-
136
- except Exception as e3:
137
- print(f"All loading methods failed: {e3}")
138
- return f"❌ Model loading failed - checkpoint incompatible. Last error: {str(e3)}"
139
  model.eval()
140
 
141
  # Move to appropriate device
 
12
  import io
13
  import time
14
  import logging
15
+ import os
16
+ from huggingface_hub import hf_hub_download, login
17
+
18
+ # Try to login with HF token if available (for Spaces with secrets)
19
+ try:
20
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
21
+ if hf_token:
22
+ login(token=hf_token)
23
+ print("✅ Logged in to Hugging Face")
24
+ else:
25
+ print("⚠️ No HF token found - will try to load anyway")
26
+ except Exception as e:
27
+ print(f"⚠️ HF login failed: {e}")
28
 
29
  # Model imports
30
  from src.models import MultimodalGemmaLightning
 
51
  cache_dir="./model_cache"
52
  )
53
 
54
+ # Download config files (same as local setup)
55
+ model_config_path = hf_hub_download(
56
+ repo_id="sagar007/multimodal-gemma-270m-llava",
57
+ filename="configs/model_config.yaml",
58
+ cache_dir="./model_cache"
59
+ )
60
+ training_config_path = hf_hub_download(
61
+ repo_id="sagar007/multimodal-gemma-270m-llava",
62
+ filename="configs/training_config.yaml",
63
+ cache_dir="./model_cache"
64
+ )
65
+ data_config_path = hf_hub_download(
66
+ repo_id="sagar007/multimodal-gemma-270m-llava",
67
+ filename="configs/data_config.yaml",
68
+ cache_dir="./model_cache"
69
+ )
70
 
71
+ # Load configs exactly like local gradio_app.py
72
+ print("📁 Loading configs...")
73
+ model_config = load_config(model_config_path)
74
+ training_config = load_config(training_config_path)
75
+ data_config = load_config(data_config_path)
76
+ config = merge_configs([model_config, training_config, data_config])
77
+
78
+ print("📁 Loading model from checkpoint...")
79
+ # Load model exactly like local gradio_app.py
80
+ model = MultimodalGemmaLightning.load_from_checkpoint(
81
+ checkpoint_path,
82
+ config=config,
83
+ strict=False,
84
+ map_location="cuda" if torch.cuda.is_available() else "cpu"
85
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  model.eval()
87
 
88
  # Move to appropriate device
configs/config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Main Hydra Configuration
2
+ defaults:
3
+ - model_config
4
+ - training_config
5
+ - data_config
6
+
7
+ # Override settings
8
+ hydra:
9
+ run:
10
+ dir: ./logs/hydra/${now:%Y-%m-%d}/${now:%H-%M-%S}
11
+ job:
12
+ name: multimodal_gemma_training
configs/data_config.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Configuration
2
+ data:
3
+ # Dataset settings - using more accessible multimodal dataset
4
+ dataset_name: "liuhaotian/LLaVA-Instruct-150K"
5
+ cache_dir: "./data/cache"
6
+ num_workers: 8 # Increased for faster loading
7
+ pin_memory: true
8
+ persistent_workers: true
9
+
10
+ # Data splits
11
+ train_split: "train"
12
+ val_split: "train" # LLaVA doesn't have a separate val split
13
+ val_size: 0.02 # Use 2% of train data for validation
14
+
15
+ # Text processing - optimized for speed
16
+ max_length: 256 # Reduced from 512 for faster training
17
+ truncation: true
18
+ padding: true
19
+
20
+ # Speed optimizations
21
+ filter_long_conversations: true
22
+ max_conversation_turns: 6 # Limit to 6 turns (3 human + 3 assistant)
23
+ use_subset: false # Set to true for quick testing
24
+ subset_size: 10000 # Use only 10K samples for testing
25
+
26
+ # Image processing
27
+ image_size: 224
28
+ image_mean: [0.48145466, 0.4578275, 0.40821073] # CLIP normalization
29
+ image_std: [0.26862954, 0.26130258, 0.27577711]
30
+
31
+ # Data augmentation (for images)
32
+ augmentation:
33
+ enabled: false # Start without augmentation
34
+ random_resized_crop: 0.9
35
+ color_jitter: 0.1
36
+ horizontal_flip: 0.5
37
+
38
+ # Conversation formatting
39
+ conversation:
40
+ system_message: ""
41
+ user_prefix: "Human: "
42
+ assistant_prefix: "Assistant: "
43
+ turn_separator: "\n"
44
+
45
+ # Data filtering - enhanced for speed
46
+ filtering:
47
+ min_length: 10 # Minimum text length
48
+ max_length: 800 # Reduced from 1000 for faster training
49
+ filter_empty_images: true
50
+ filter_corrupt_images: true
51
+ filter_long_conversations: true
52
+ max_tokens_per_sample: 256 # Skip samples that would exceed max_length
53
+ min_image_questions: 1 # Skip samples without image-related questions
54
+
55
+ # Preprocessing
56
+ preprocessing:
57
+ cache_processed_data: true
58
+ precompute_image_features: false # Set to true to cache CLIP features
59
+
60
+ # COCO Images
61
+ coco:
62
+ base_url: "http://images.cocodataset.org/train2017/"
63
+ download_timeout: 30
64
+ retry_attempts: 3
65
+ fallback_image_size: [224, 224]
66
+ fallback_image_color: "white"
configs/gemma_270m_a100.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimized configuration for Gemma-270M on A100 GPU
2
+ # This configuration maximizes the potential of the smaller 270M model
3
+
4
+ # Model Configuration
5
+ model:
6
+ gemma_model_name: "google/gemma-3-270m" # 270M parameter model
7
+ vision_model_name: "openai/clip-vit-large-patch14"
8
+ audio_model_name: "openai/whisper-small"
9
+
10
+ enable_audio: false
11
+ projector_hidden_dim: 1024 # Optimized for 270M
12
+ audio_hidden_dim: 512
13
+
14
+ # LoRA configuration - can be more aggressive with smaller model
15
+ lora:
16
+ r: 32 # Higher rank for 270M model
17
+ alpha: 64 # Higher alpha for better learning
18
+ dropout: 0.1
19
+ target_modules:
20
+ - "q_proj"
21
+ - "v_proj"
22
+ - "k_proj"
23
+ - "o_proj"
24
+ - "gate_proj"
25
+ - "up_proj"
26
+ - "down_proj"
27
+
28
+ # Quantization (optional for 270M - could train in full precision)
29
+ use_4bit: false # 270M is small enough for full precision
30
+ bnb_4bit_compute_dtype: "bfloat16"
31
+ bnb_4bit_quant_type: "nf4"
32
+ use_nested_quant: false
33
+
34
+ # Training Configuration - Optimized for A100 + 270M
35
+ training:
36
+ max_epochs: 5 # More epochs for smaller model
37
+ batch_size: 16 # Large batch size for 270M on A100
38
+ accumulate_grad_batches: 2 # Effective batch size = 16 * 2 = 32
39
+ gradient_clip_val: 1.0
40
+
41
+ # Learning rates - can be higher for smaller model
42
+ lora_lr: 5e-4 # Higher learning rate
43
+ projector_lr: 2e-3 # Higher learning rate
44
+ weight_decay: 0.01
45
+ warmup_ratio: 0.05 # More warmup
46
+
47
+ # Validation
48
+ val_check_interval: 0.25 # Check more frequently
49
+ limit_val_batches: 50
50
+
51
+ # Checkpointing
52
+ save_top_k: 5
53
+ monitor: "val/loss"
54
+ mode: "min"
55
+
56
+ # Precision
57
+ precision: "bf16-mixed" # A100 optimized
58
+ strategy: "auto"
59
+
60
+ # Early stopping
61
+ patience: 3
62
+ min_delta: 0.0005
63
+
64
+ # Data Configuration
65
+ data:
66
+ dataset_name: "liuhaotian/LLaVA-Instruct-150K"
67
+ cache_dir: "./data/cache"
68
+ num_workers: 8 # More workers for A100
69
+ pin_memory: true
70
+ persistent_workers: true
71
+
72
+ train_split: "train"
73
+ val_split: "train"
74
+ val_size: 0.02
75
+
76
+ max_length: 512
77
+ truncation: true
78
+ padding: true
79
+
80
+ image_size: 224
81
+ image_mean: [0.48145466, 0.4578275, 0.40821073]
82
+ image_std: [0.26862954, 0.26130258, 0.27577711]
83
+
84
+ # No augmentation for initial training
85
+ augmentation:
86
+ enabled: false
87
+
88
+ conversation:
89
+ system_message: ""
90
+ user_prefix: "Human: "
91
+ assistant_prefix: "Assistant: "
92
+ turn_separator: "\n"
93
+
94
+ filtering:
95
+ min_length: 10
96
+ max_length: 1000
97
+ filter_empty_images: true
98
+ filter_corrupt_images: true
99
+
100
+ preprocessing:
101
+ cache_processed_data: true
102
+ precompute_image_features: false
103
+
104
+ # Trainer settings
105
+ trainer:
106
+ accelerator: "gpu"
107
+ devices: 1
108
+ num_nodes: 1
109
+ log_every_n_steps: 25
110
+ enable_checkpointing: true
111
+ enable_progress_bar: true
112
+ enable_model_summary: true
113
+
114
+ fast_dev_run: false
115
+ overfit_batches: 0
116
+ detect_anomaly: false
117
+
118
+ deterministic: false
119
+ benchmark: true
120
+
121
+ # Optimization
122
+ optimization:
123
+ compile_model: true # Enable for PyTorch 2.0+ speedup
124
+ use_fused_adamw: true
125
+
126
+ # Logging
127
+ logging:
128
+ use_wandb: true
129
+ wandb_project: "multimodal-gemma-270m"
130
+ wandb_name: "gemma-270m-llava-a100-optimized"
131
+ log_model: true
132
+
133
+ use_tensorboard: true
134
+ tb_log_dir: "logs/tensorboard"
135
+
136
+ # Special tokens
137
+ special_tokens:
138
+ image_token: "<image>"
139
+ audio_token: "<audio>"
140
+ pad_token: "<pad>"
141
+
142
+ # Tokenizer settings
143
+ tokenizer:
144
+ padding_side: "right"
145
+ truncation: true
146
+ max_length: 512
147
+ add_special_tokens: true
configs/model_config.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Configuration
2
+ model:
3
+ # Base models
4
+ gemma_model_name: "google/gemma-3-270m" # 270M parameter model - fast training on A100
5
+ vision_model_name: "openai/clip-vit-large-patch14"
6
+
7
+ # Model settings - vision-language only
8
+ projector_hidden_dim: 2048 # Larger projection for better alignment
9
+
10
+ # LoRA configuration - optimized for multimodal
11
+ lora:
12
+ r: 64 # Higher rank for better multimodal understanding
13
+ alpha: 128 # Higher alpha for better learning
14
+ dropout: 0.1 # Slightly higher dropout for regularization
15
+ target_modules:
16
+ - "q_proj"
17
+ - "v_proj"
18
+ - "k_proj"
19
+ - "o_proj"
20
+ - "gate_proj"
21
+ - "up_proj"
22
+ - "down_proj"
23
+
24
+ # Quantization
25
+ use_4bit: true
26
+ bnb_4bit_compute_dtype: "bfloat16"
27
+ bnb_4bit_quant_type: "nf4"
28
+ use_nested_quant: false
29
+
30
+ # Tokenizer settings
31
+ tokenizer:
32
+ padding_side: "right"
33
+ truncation: true
34
+ max_length: 512
35
+ add_special_tokens: true
36
+
37
+ # Special tokens
38
+ special_tokens:
39
+ image_token: "<image>"
40
+ pad_token: "<pad>"
configs/multimodal_optimized.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimized configuration for Gemma-270M multimodal training
2
+ # This addresses previous inference quality issues
3
+
4
+ # Model Configuration - Optimized for Multimodal
5
+ model:
6
+ gemma_model_name: "google/gemma-3-270m"
7
+ vision_model_name: "openai/clip-vit-large-patch14"
8
+ audio_model_name: "openai/whisper-small"
9
+
10
+ enable_audio: false
11
+ projector_hidden_dim: 2048 # Larger for better vision-language alignment
12
+ audio_hidden_dim: 512
13
+
14
+ # LoRA configuration - Higher capacity for multimodal
15
+ lora:
16
+ r: 128 # Much higher rank for complex multimodal relationships
17
+ alpha: 256 # Higher alpha for better adaptation
18
+ dropout: 0.1 # Regularization for better generalization
19
+ target_modules:
20
+ - "q_proj"
21
+ - "v_proj"
22
+ - "k_proj"
23
+ - "o_proj"
24
+ - "gate_proj"
25
+ - "up_proj"
26
+ - "down_proj"
27
+
28
+ # Keep 4-bit quantization for memory efficiency
29
+ use_4bit: true
30
+ bnb_4bit_compute_dtype: "bfloat16"
31
+ bnb_4bit_quant_type: "nf4"
32
+ use_nested_quant: false
33
+
34
+ # Training Configuration - Optimized for Multimodal Quality
35
+ training:
36
+ max_epochs: 15 # More epochs for better convergence
37
+ batch_size: 6 # Slightly smaller for stability
38
+ accumulate_grad_batches: 8 # Effective batch size = 6 * 8 = 48
39
+ gradient_clip_val: 1.0
40
+
41
+ # Better learning rate balance
42
+ lora_lr: 1e-3 # Higher for language adaptation
43
+ projector_lr: 5e-3 # Much higher for vision-language alignment
44
+ weight_decay: 0.01
45
+ warmup_ratio: 0.15 # More warmup for stable training
46
+
47
+ # Validation
48
+ val_check_interval: 0.33 # Check validation more frequently
49
+ limit_val_batches: 50
50
+
51
+ # Checkpointing
52
+ save_top_k: 5
53
+ monitor: "val/loss"
54
+ mode: "min"
55
+
56
+ # Precision and optimization
57
+ precision: "bf16-mixed"
58
+ strategy: "auto"
59
+
60
+ # Early stopping
61
+ patience: 5
62
+ min_delta: 0.0001
63
+
64
+ # Data Configuration - Focus on quality
65
+ data:
66
+ dataset_name: "liuhaotian/LLaVA-Instruct-150K"
67
+ cache_dir: "./data/cache"
68
+ num_workers: 6 # More workers for better data loading
69
+ pin_memory: true
70
+ persistent_workers: true
71
+
72
+ train_split: "train"
73
+ val_split: "train"
74
+ val_size: 0.05 # Larger validation set
75
+
76
+ max_length: 512
77
+ truncation: true
78
+ padding: true
79
+
80
+ image_size: 224
81
+ image_mean: [0.48145466, 0.4578275, 0.40821073]
82
+ image_std: [0.26862954, 0.26130258, 0.27577711]
83
+
84
+ augmentation:
85
+ enabled: true # Enable augmentation for better generalization
86
+ random_resized_crop: 0.9
87
+ color_jitter: 0.2
88
+ horizontal_flip: 0.3
89
+
90
+ conversation:
91
+ system_message: ""
92
+ user_prefix: "Human: "
93
+ assistant_prefix: "Assistant: "
94
+ turn_separator: "\n"
95
+
96
+ filtering:
97
+ min_length: 20 # Filter very short conversations
98
+ max_length: 800 # Allow longer conversations
99
+ filter_empty_images: true
100
+ filter_corrupt_images: true
101
+
102
+ preprocessing:
103
+ cache_processed_data: true
104
+ precompute_image_features: false
105
+
106
+ # Trainer settings
107
+ trainer:
108
+ accelerator: "gpu"
109
+ devices: 1
110
+ num_nodes: 1
111
+ log_every_n_steps: 10
112
+ enable_checkpointing: true
113
+ enable_progress_bar: true
114
+ enable_model_summary: true
115
+
116
+ fast_dev_run: false
117
+ overfit_batches: 0
118
+ detect_anomaly: false
119
+
120
+ deterministic: false
121
+ benchmark: true
122
+
123
+ # Optimization
124
+ optimization:
125
+ compile_model: false
126
+ use_fused_adamw: true
127
+
128
+ # Logging - Enable for monitoring
129
+ logging:
130
+ use_wandb: true
131
+ wandb_project: "gemma-270m-multimodal-optimized"
132
+ wandb_name: "gemma-270m-llava-quality-training"
133
+ log_model: true
134
+
135
+ use_tensorboard: true
136
+ tb_log_dir: "logs/tensorboard"
137
+
138
+ # Special tokens
139
+ special_tokens:
140
+ image_token: "<image>"
141
+ audio_token: "<audio>"
142
+ pad_token: "<pad>"
143
+
144
+ # Tokenizer settings
145
+ tokenizer:
146
+ padding_side: "right"
147
+ truncation: true
148
+ max_length: 512
149
+ add_special_tokens: true
configs/training_config.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Configuration
2
+ training:
3
+ # Training hyperparameters - optimized for speed
4
+ max_epochs: 8 # Reduced epochs (shorter sequences = faster convergence)
5
+ batch_size: 16 # Increased batch size (shorter sequences = more GPU memory)
6
+ accumulate_grad_batches: 2 # Effective batch size = 16 * 2 = 32
7
+ gradient_clip_val: 1.0
8
+
9
+ # Learning rates - better balance for multimodal
10
+ lora_lr: 5e-4 # Higher for better adaptation
11
+ projector_lr: 2e-3 # Higher for vision-language alignment
12
+ weight_decay: 0.01
13
+ warmup_ratio: 0.1 # More warmup for stability
14
+
15
+ # Validation
16
+ val_check_interval: 0.5 # Check validation every half epoch
17
+ limit_val_batches: 100 # Limit validation batches for speed
18
+
19
+ # Checkpointing
20
+ save_top_k: 3
21
+ monitor: "val/loss"
22
+ mode: "min"
23
+
24
+ # Precision and optimization
25
+ precision: "bf16-mixed" # Use mixed precision for A100
26
+ strategy: "auto" # Let Lightning choose the best strategy
27
+
28
+ # Early stopping
29
+ patience: 2
30
+ min_delta: 0.001
31
+
32
+ # Lightning Trainer settings
33
+ trainer:
34
+ accelerator: "gpu"
35
+ devices: 1 # Single GPU training
36
+ num_nodes: 1
37
+ log_every_n_steps: 10
38
+ enable_checkpointing: true
39
+ enable_progress_bar: true
40
+ enable_model_summary: true
41
+
42
+ # Debugging and profiling
43
+ fast_dev_run: false
44
+ overfit_batches: 0
45
+ detect_anomaly: false
46
+
47
+ # Reproducibility
48
+ deterministic: false # Set to true for reproducible results (slower)
49
+ benchmark: true # Optimize for consistent input sizes
50
+
51
+ # Optimization settings
52
+ optimization:
53
+ compile_model: false # Set to true for PyTorch 2.0+ compilation
54
+ use_fused_adamw: true # Use fused AdamW for better performance
55
+
56
+ # Logging and monitoring
57
+ logging:
58
+ use_wandb: false # Disable for now - needs API key
59
+ wandb_project: "multimodal-gemma"
60
+ wandb_name: "gemma-270m-llava-training"
61
+ log_model: false
62
+
63
+ # TensorBoard
64
+ use_tensorboard: true # Use TensorBoard instead
65
+ tb_log_dir: "logs/tensorboard"