sagar007 commited on
Commit
34b253d
·
verified ·
1 Parent(s): 085f1c9

Upload folder using huggingface_hub

Browse files
src/data/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import LLaVADataset, MultimodalCollator
2
+ from .datamodule import LLaVADataModule
3
+ from .processors import ImageProcessor, TextProcessor
4
+
5
+ __all__ = [
6
+ "LLaVADataset",
7
+ "MultimodalCollator",
8
+ "LLaVADataModule",
9
+ "ImageProcessor",
10
+ "TextProcessor"
11
+ ]
src/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (448 Bytes). View file
 
src/data/__pycache__/datamodule.cpython-311.pyc ADDED
Binary file (8.05 kB). View file
 
src/data/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (25.2 kB). View file
 
src/data/__pycache__/processors.cpython-311.pyc ADDED
Binary file (8.77 kB). View file
 
src/data/datamodule.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Lightning DataModule for LLaVA dataset
3
+ """
4
+ import lightning as L
5
+ import torch
6
+ from torch.utils.data import DataLoader, random_split
7
+ from typing import Optional, Dict, Any
8
+ import logging
9
+
10
+ from .dataset import LLaVADataset, MultimodalCollator
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class LLaVADataModule(L.LightningDataModule):
16
+ """Lightning DataModule for LLaVA dataset"""
17
+
18
+ def __init__(
19
+ self,
20
+ tokenizer,
21
+ vision_processor,
22
+ config: Dict[str, Any]
23
+ ):
24
+ super().__init__()
25
+ self.tokenizer = tokenizer
26
+ self.vision_processor = vision_processor
27
+ self.config = config
28
+
29
+ # Data configuration
30
+ data_config = config["data"]
31
+ self.batch_size = config["training"]["batch_size"]
32
+ self.num_workers = data_config.get("num_workers", 4)
33
+ self.pin_memory = data_config.get("pin_memory", True)
34
+ self.persistent_workers = data_config.get("persistent_workers", True)
35
+
36
+ # Dataset splits
37
+ self.train_split = data_config.get("train_split", "train")
38
+ self.val_split = data_config.get("val_split", "train") # LLaVA doesn't have separate val
39
+ self.val_size = data_config.get("val_size", 0.02)
40
+
41
+ # Initialize datasets to None
42
+ self.train_dataset = None
43
+ self.val_dataset = None
44
+
45
+ # Create collator
46
+ self.collator = MultimodalCollator(
47
+ tokenizer=self.tokenizer,
48
+ vision_processor=self.vision_processor,
49
+ config=self.config
50
+ )
51
+
52
+ logger.info("LLaVADataModule initialized")
53
+
54
+ def prepare_data(self) -> None:
55
+ """Download and prepare data (called only on rank 0)"""
56
+ # This will download the dataset if not already cached
57
+ try:
58
+ LLaVADataset(
59
+ config=self.config,
60
+ split=self.train_split
61
+ )
62
+ logger.info("Dataset preparation completed")
63
+ except Exception as e:
64
+ logger.error(f"Failed to prepare dataset: {e}")
65
+ raise
66
+
67
+ def setup(self, stage: Optional[str] = None) -> None:
68
+ """Setup datasets for training/validation/testing"""
69
+
70
+ if stage == "fit" or stage is None:
71
+ # Load full training dataset
72
+ full_dataset = LLaVADataset(
73
+ config=self.config,
74
+ split=self.train_split
75
+ )
76
+
77
+ # Split into train and validation
78
+ total_size = len(full_dataset)
79
+ val_size = int(total_size * self.val_size)
80
+ train_size = total_size - val_size
81
+
82
+ self.train_dataset, self.val_dataset = random_split(
83
+ full_dataset,
84
+ [train_size, val_size],
85
+ generator=torch.Generator().manual_seed(42) # For reproducibility
86
+ )
87
+
88
+ logger.info(f"Dataset split: {train_size} train, {val_size} validation")
89
+
90
+ if stage == "test":
91
+ # For testing, we'll use a small subset of the training data
92
+ self.test_dataset = LLaVADataset(
93
+ config=self.config,
94
+ split=self.train_split
95
+ )
96
+
97
+ if stage == "predict":
98
+ # For prediction, setup can be done dynamically
99
+ pass
100
+
101
+ def train_dataloader(self) -> DataLoader:
102
+ """Create training dataloader"""
103
+ if self.train_dataset is None:
104
+ raise RuntimeError("Train dataset not initialized. Call setup() first.")
105
+
106
+ return DataLoader(
107
+ self.train_dataset,
108
+ batch_size=self.batch_size,
109
+ shuffle=True,
110
+ num_workers=self.num_workers,
111
+ pin_memory=self.pin_memory,
112
+ persistent_workers=self.persistent_workers and self.num_workers > 0,
113
+ collate_fn=self.collator,
114
+ drop_last=True # Drop incomplete batches for consistent training
115
+ )
116
+
117
+ def val_dataloader(self) -> DataLoader:
118
+ """Create validation dataloader"""
119
+ if self.val_dataset is None:
120
+ raise RuntimeError("Validation dataset not initialized. Call setup() first.")
121
+
122
+ return DataLoader(
123
+ self.val_dataset,
124
+ batch_size=self.batch_size,
125
+ shuffle=False,
126
+ num_workers=self.num_workers,
127
+ pin_memory=self.pin_memory,
128
+ persistent_workers=self.persistent_workers and self.num_workers > 0,
129
+ collate_fn=self.collator,
130
+ drop_last=False
131
+ )
132
+
133
+ def test_dataloader(self) -> DataLoader:
134
+ """Create test dataloader"""
135
+ if not hasattr(self, 'test_dataset') or self.test_dataset is None:
136
+ raise RuntimeError("Test dataset not initialized. Call setup() first.")
137
+
138
+ return DataLoader(
139
+ self.test_dataset,
140
+ batch_size=self.batch_size,
141
+ shuffle=False,
142
+ num_workers=self.num_workers,
143
+ pin_memory=self.pin_memory,
144
+ collate_fn=self.collator,
145
+ drop_last=False
146
+ )
147
+
148
+ def predict_dataloader(self) -> DataLoader:
149
+ """Create prediction dataloader"""
150
+ # This can be implemented based on specific prediction needs
151
+ return self.val_dataloader()
152
+
153
+ def teardown(self, stage: Optional[str] = None) -> None:
154
+ """Clean up after training/testing"""
155
+ # Log dataset statistics if available
156
+ if hasattr(self, 'train_dataset') and self.train_dataset is not None:
157
+ if hasattr(self.train_dataset.dataset, 'get_stats'):
158
+ stats = self.train_dataset.dataset.get_stats()
159
+ logger.info(f"Training dataset stats: {stats}")
160
+
161
+ if hasattr(self, 'val_dataset') and self.val_dataset is not None:
162
+ if hasattr(self.val_dataset.dataset, 'get_stats'):
163
+ stats = self.val_dataset.dataset.get_stats()
164
+ logger.info(f"Validation dataset stats: {stats}")
165
+
166
+ def get_dataset_info(self) -> Dict[str, Any]:
167
+ """Get information about the loaded datasets"""
168
+ info = {}
169
+
170
+ if self.train_dataset is not None:
171
+ info["train_size"] = len(self.train_dataset)
172
+
173
+ if self.val_dataset is not None:
174
+ info["val_size"] = len(self.val_dataset)
175
+
176
+ info["batch_size"] = self.batch_size
177
+ info["num_workers"] = self.num_workers
178
+
179
+ return info
src/data/dataset.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset implementation for LLaVA multimodal training
3
+ """
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from datasets import load_dataset
7
+ import requests
8
+ from PIL import Image
9
+ import io
10
+ from typing import Dict, Any, List, Optional, Union
11
+ import logging
12
+ import time
13
+ from pathlib import Path
14
+
15
+ from .processors import ImageProcessor, TextProcessor
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class LLaVADataset(Dataset):
21
+ """LLaVA dataset for multimodal training"""
22
+
23
+ def __init__(
24
+ self,
25
+ config: Dict[str, Any],
26
+ split: str = "train",
27
+ transform: Optional[Any] = None
28
+ ):
29
+ self.config = config
30
+ self.split = split
31
+ self.transform = transform
32
+
33
+ # Initialize processors
34
+ self.image_processor = ImageProcessor(config)
35
+ self.text_processor = TextProcessor(config)
36
+
37
+ # Dataset configuration
38
+ data_config = config["data"]
39
+ self.cache_dir = data_config.get("cache_dir", "./data/cache")
40
+ self.image_size = data_config["image_size"]
41
+
42
+ # COCO configuration
43
+ coco_config = config.get("coco", {})
44
+ self.coco_base_url = coco_config.get("base_url", "http://images.cocodataset.org/train2017/")
45
+ self.download_timeout = coco_config.get("download_timeout", 30)
46
+ self.retry_attempts = coco_config.get("retry_attempts", 3)
47
+ self.fallback_size = tuple(coco_config.get("fallback_image_size", [224, 224]))
48
+ self.fallback_color = coco_config.get("fallback_image_color", "white")
49
+
50
+ # Load dataset
51
+ self._load_dataset()
52
+
53
+ # Apply filtering optimizations
54
+ if config["data"].get("filter_long_conversations", True):
55
+ self._filter_dataset()
56
+
57
+ # Statistics
58
+ self.successful_images = 0
59
+ self.failed_images = 0
60
+
61
+ logger.info(f"Initialized LLaVADataset with {len(self.dataset)} samples for split '{split}'")
62
+
63
+ def _load_dataset(self):
64
+ """Load the LLaVA dataset from HuggingFace"""
65
+ dataset_name = self.config["data"]["dataset_name"]
66
+
67
+ # Create cache directory
68
+ Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
69
+
70
+ # Try different loading approaches
71
+ loading_strategies = [
72
+ # Strategy 1: Simple loading without problematic parameters
73
+ lambda: load_dataset(
74
+ dataset_name,
75
+ split=self.split,
76
+ cache_dir=self.cache_dir
77
+ ),
78
+
79
+ # Strategy 2: With streaming disabled
80
+ lambda: load_dataset(
81
+ dataset_name,
82
+ split=self.split,
83
+ cache_dir=self.cache_dir,
84
+ streaming=False
85
+ ),
86
+
87
+ # Strategy 3: Different data format approach
88
+ lambda: self._load_alternative_format(dataset_name),
89
+
90
+ # Strategy 4: Load from local files if available
91
+ lambda: self._load_local_dataset(dataset_name)
92
+ ]
93
+
94
+ for i, strategy in enumerate(loading_strategies):
95
+ try:
96
+ logger.info(f"Trying dataset loading strategy {i+1}...")
97
+ self.dataset = strategy()
98
+
99
+ # Validate dataset
100
+ if len(self.dataset) == 0:
101
+ raise ValueError("Dataset is empty")
102
+
103
+ logger.info(f"Successfully loaded {len(self.dataset)} examples from {dataset_name}")
104
+ return
105
+
106
+ except Exception as e:
107
+ logger.warning(f"Strategy {i+1} failed: {e}")
108
+ # Continue to next strategy
109
+
110
+ # If all strategies fail, create a larger dummy dataset for development
111
+ logger.warning("All loading strategies failed, creating larger dummy dataset...")
112
+ self.dataset = self._create_development_dataset()
113
+
114
+ def _load_alternative_format(self, dataset_name):
115
+ """Try alternative loading format for LLaVA dataset"""
116
+ try:
117
+ # Try loading with explicit JSON format
118
+ from datasets import load_dataset, DownloadConfig
119
+
120
+ download_config = DownloadConfig(
121
+ resume_download=True,
122
+ force_download=False,
123
+ use_etag=False
124
+ )
125
+
126
+ return load_dataset(
127
+ "json",
128
+ data_files={
129
+ "train": "hf://datasets/liuhaotian/LLaVA-Instruct-150K/llava_instruct_150k.json"
130
+ },
131
+ split=self.split,
132
+ cache_dir=self.cache_dir,
133
+ download_config=download_config
134
+ )
135
+ except Exception as e:
136
+ logger.warning(f"Alternative format loading failed: {e}")
137
+ raise
138
+
139
+ def _load_local_dataset(self, dataset_name):
140
+ """Try to load dataset from local files or alternative sources"""
141
+ try:
142
+ # Try loading with minimal parameters
143
+ return load_dataset(
144
+ dataset_name,
145
+ split=self.split,
146
+ cache_dir=self.cache_dir
147
+ )
148
+ except Exception:
149
+ # If local loading fails, create dummy data
150
+ logger.warning("Local loading failed, using dummy dataset")
151
+ return self._create_dummy_dataset()
152
+
153
+ def _create_dummy_dataset(self):
154
+ """Create a small dummy dataset for testing"""
155
+ from datasets import Dataset
156
+
157
+ dummy_data = []
158
+ for i in range(100): # Small dataset for testing
159
+ # Use realistic COCO-style filenames that will trigger fallback
160
+ coco_filename = f"{str(i).zfill(12)}.jpg"
161
+ dummy_data.append({
162
+ "id": str(i),
163
+ "image": coco_filename,
164
+ "conversations": [
165
+ {
166
+ "from": "human",
167
+ "value": f"What do you see in image {i}?"
168
+ },
169
+ {
170
+ "from": "gpt",
171
+ "value": f"I can see an image numbered {i}."
172
+ }
173
+ ]
174
+ })
175
+
176
+ return Dataset.from_list(dummy_data)
177
+
178
+ def _create_development_dataset(self):
179
+ """Create a larger dummy dataset for development/testing"""
180
+ from datasets import Dataset
181
+ import random
182
+
183
+ # Create more realistic sample data for development
184
+ dummy_data = []
185
+
186
+ # Common visual questions and responses
187
+ questions = [
188
+ "What do you see in this image?",
189
+ "Describe the main objects in the picture.",
190
+ "What is the person doing?",
191
+ "What colors are prominent in this image?",
192
+ "Can you identify any animals in the picture?",
193
+ "What's the setting or location of this image?",
194
+ "Are there any vehicles visible?",
195
+ "What's the weather like in the image?",
196
+ "How many people are in the picture?",
197
+ "What objects are on the table?",
198
+ ]
199
+
200
+ responses = [
201
+ "I can see a person standing in a park with trees in the background.",
202
+ "The image shows a cat sitting on a windowsill, looking outside.",
203
+ "There's a red car parked on a street with buildings nearby.",
204
+ "I notice several people walking on a busy sidewalk.",
205
+ "The picture contains a bowl of fruit on a wooden table.",
206
+ "I can see a dog playing in a grassy field.",
207
+ "The image shows a bicycle leaning against a wall.",
208
+ "There's a group of children playing in a playground.",
209
+ "I can see mountains in the distance with a clear blue sky.",
210
+ "The picture shows a kitchen with modern appliances.",
211
+ ]
212
+
213
+ # Generate realistic sample size for development
214
+ num_samples = self.config["data"].get("subset_size", 10000) if self.config["data"].get("use_subset", False) else 50000
215
+
216
+ for i in range(num_samples):
217
+ # Use realistic COCO-style filenames
218
+ coco_filename = f"{str(i % 1000).zfill(12)}.jpg"
219
+ question = random.choice(questions)
220
+ response = random.choice(responses)
221
+
222
+ dummy_data.append({
223
+ "id": str(i),
224
+ "image": coco_filename,
225
+ "conversations": [
226
+ {
227
+ "from": "human",
228
+ "value": question
229
+ },
230
+ {
231
+ "from": "gpt",
232
+ "value": response
233
+ }
234
+ ]
235
+ })
236
+
237
+ logger.info(f"Created development dataset with {len(dummy_data)} samples")
238
+ return Dataset.from_list(dummy_data)
239
+
240
+ def _filter_dataset(self):
241
+ """Filter dataset for faster training"""
242
+ logger.info("Applying speed optimization filters...")
243
+
244
+ filtering_config = self.config["data"]["filtering"]
245
+ data_config = self.config["data"]
246
+
247
+ original_size = len(self.dataset)
248
+ filtered_indices = []
249
+
250
+ # Use subset for testing if enabled
251
+ if data_config.get("use_subset", False):
252
+ subset_size = data_config.get("subset_size", 10000)
253
+ indices = list(range(min(subset_size, original_size)))
254
+ logger.info(f"Using subset of {len(indices)} samples for testing")
255
+ else:
256
+ indices = list(range(original_size))
257
+
258
+ max_turns = data_config.get("max_conversation_turns", 6)
259
+ max_tokens = filtering_config.get("max_tokens_per_sample", 256)
260
+ max_length = filtering_config.get("max_length", 800)
261
+
262
+ for idx in indices:
263
+ try:
264
+ item = self.dataset[idx]
265
+ conversations = item.get("conversations", [])
266
+
267
+ # Filter by conversation length
268
+ if len(conversations) > max_turns:
269
+ continue
270
+
271
+ # Estimate token count (rough approximation: 1 token ≈ 4 chars)
272
+ total_text = ""
273
+ for conv in conversations:
274
+ total_text += conv.get("value", "")
275
+
276
+ estimated_tokens = len(total_text) // 4
277
+ if estimated_tokens > max_tokens:
278
+ continue
279
+
280
+ # Check if it's image-related (has visual keywords)
281
+ has_visual_content = any(
282
+ keyword in total_text.lower()
283
+ for keyword in ["see", "image", "picture", "photo", "visual", "look", "show", "appear", "visible"]
284
+ )
285
+
286
+ if filtering_config.get("min_image_questions", 1) > 0 and not has_visual_content:
287
+ continue
288
+
289
+ # Check final text length
290
+ if len(total_text) > max_length:
291
+ continue
292
+
293
+ filtered_indices.append(idx)
294
+
295
+ except Exception as e:
296
+ logger.debug(f"Error filtering item {idx}: {e}")
297
+ continue
298
+
299
+ # Apply filtering
300
+ if filtered_indices:
301
+ self.dataset = self.dataset.select(filtered_indices)
302
+
303
+ filtered_size = len(self.dataset)
304
+ reduction_pct = (1 - filtered_size / original_size) * 100
305
+
306
+ logger.info(f"Dataset filtered: {original_size:,} → {filtered_size:,} samples")
307
+ logger.info(f"Reduction: {reduction_pct:.1f}% (faster training!)")
308
+
309
+ return self.dataset
310
+
311
+ def __len__(self) -> int:
312
+ return len(self.dataset)
313
+
314
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
315
+ """Get a single sample from the dataset with improved error handling"""
316
+ try:
317
+ item = self.dataset[idx]
318
+
319
+ # Load and process image
320
+ image = self._load_image(item.get("image", ""))
321
+
322
+ # Process conversation text with robust handling
323
+ conversations = item.get("conversations", [])
324
+ if not conversations or not isinstance(conversations, list):
325
+ # Fallback if no valid conversations
326
+ conversations = [
327
+ {"from": "human", "value": "What do you see in this image?"},
328
+ {"from": "gpt", "value": "I can see an image that contains various visual elements."}
329
+ ]
330
+
331
+ formatted_text = self.text_processor.format_conversation(conversations)
332
+
333
+ # Add image token if image is present
334
+ formatted_text = self.text_processor.add_image_token(formatted_text, image is not None)
335
+
336
+ # More lenient validation - only reject if truly problematic
337
+ if not self.text_processor.validate_text(formatted_text):
338
+ # Create a better fallback based on original conversations
339
+ try:
340
+ # Try to extract any usable content
341
+ fallback_content = "What do you see in this image?"
342
+ if conversations and len(conversations) > 0:
343
+ first_conv = conversations[0]
344
+ if isinstance(first_conv, dict) and "value" in first_conv:
345
+ user_text = str(first_conv["value"]).strip()
346
+ if user_text and len(user_text) > 5:
347
+ fallback_content = user_text
348
+
349
+ formatted_text = f"<image>\nHuman: {fallback_content}\nAssistant: I can see an image."
350
+ except Exception:
351
+ formatted_text = "<image>\nHuman: What do you see?\nAssistant: I see an image."
352
+
353
+ return {
354
+ "image": image,
355
+ "text": formatted_text,
356
+ "conversations": conversations,
357
+ "id": item.get("id", f"sample_{idx}"),
358
+ "image_filename": item.get("image", ""),
359
+ "has_image": image is not None
360
+ }
361
+
362
+ except Exception as e:
363
+ logger.debug(f"Error processing item {idx}: {e}")
364
+ # Return a fallback sample (reduce logging level to debug)
365
+ return self._get_fallback_sample(idx)
366
+
367
+ def _load_image(self, image_filename: str) -> Optional[Image.Image]:
368
+ """Load image from COCO dataset with retry logic"""
369
+ if not image_filename or not image_filename.strip():
370
+ return None
371
+
372
+ # Check if it's a dummy image (contains "dummy_")
373
+ if "dummy_" in image_filename:
374
+ logger.debug(f"Using placeholder image for {image_filename}")
375
+ return self._create_fallback_image()
376
+
377
+ # For actual dummy filenames from our generated dataset (short numbers), use placeholder
378
+ filename_without_ext = image_filename.replace('.jpg', '').replace('.png', '')
379
+ if image_filename and filename_without_ext.isdigit() and len(filename_without_ext) <= 6:
380
+ logger.debug(f"Using placeholder image for dummy filename: {image_filename}")
381
+ return self._create_fallback_image()
382
+
383
+ # Check cache first
384
+ cache_path = Path(self.cache_dir) / "images" / image_filename
385
+ if cache_path.exists():
386
+ try:
387
+ image = Image.open(cache_path).convert('RGB')
388
+ self.successful_images += 1
389
+ return image
390
+ except Exception:
391
+ cache_path.unlink(missing_ok=True) # Remove corrupted cache
392
+
393
+ image_url = f"{self.coco_base_url}{image_filename}"
394
+
395
+ for attempt in range(self.retry_attempts):
396
+ try:
397
+ response = requests.get(
398
+ image_url,
399
+ timeout=self.download_timeout,
400
+ headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
401
+ )
402
+ response.raise_for_status()
403
+
404
+ # Load and validate image
405
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
406
+
407
+ # Basic validation
408
+ if image.size[0] < 10 or image.size[1] < 10:
409
+ raise ValueError("Image too small")
410
+
411
+ # Cache the image
412
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
413
+ image.save(cache_path, "JPEG", quality=85)
414
+ logger.debug(f"Cached image: {cache_path}")
415
+
416
+ self.successful_images += 1
417
+ return image
418
+
419
+ except Exception as e:
420
+ if attempt == self.retry_attempts - 1:
421
+ logger.debug(f"Failed to load image {image_filename} after {self.retry_attempts} attempts: {e}")
422
+ self.failed_images += 1
423
+ return self._create_fallback_image()
424
+ else:
425
+ time.sleep(0.5) # Brief pause before retry
426
+
427
+ return self._create_fallback_image()
428
+
429
+ def _create_fallback_image(self) -> Image.Image:
430
+ """Create a fallback image when loading fails"""
431
+ return Image.new('RGB', self.fallback_size, color=self.fallback_color)
432
+
433
+ def _get_fallback_sample(self, idx: int) -> Dict[str, Any]:
434
+ """Get a fallback sample when processing fails"""
435
+ fallback_image = self._create_fallback_image()
436
+ fallback_text = "Human: What do you see in this image?\nAssistant: I can see a simple image."
437
+
438
+ return {
439
+ "image": fallback_image,
440
+ "text": fallback_text,
441
+ "conversations": [
442
+ {"from": "human", "value": "What do you see in this image?"},
443
+ {"from": "gpt", "value": "I can see a simple image."}
444
+ ],
445
+ "id": f"fallback_{idx}",
446
+ "image_filename": "",
447
+ "has_image": True
448
+ }
449
+
450
+ def get_stats(self) -> Dict[str, int]:
451
+ """Get dataset statistics"""
452
+ return {
453
+ "total_samples": len(self),
454
+ "successful_images": self.successful_images,
455
+ "failed_images": self.failed_images,
456
+ "success_rate": self.successful_images / (self.successful_images + self.failed_images) * 100
457
+ if (self.successful_images + self.failed_images) > 0 else 0
458
+ }
459
+
460
+
461
+ class MultimodalCollator:
462
+ """Custom collator for multimodal data batching"""
463
+
464
+ def __init__(
465
+ self,
466
+ tokenizer,
467
+ vision_processor,
468
+ config: Dict[str, Any],
469
+ max_length: Optional[int] = None
470
+ ):
471
+ self.tokenizer = tokenizer
472
+ self.vision_processor = vision_processor
473
+ self.config = config
474
+ self.max_length = max_length or config["data"]["max_length"]
475
+
476
+ # Image token for processing
477
+ self.image_token = config.get("special_tokens", {}).get("image_token", "<image>")
478
+
479
+ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
480
+ """Collate a batch of samples"""
481
+
482
+ images = []
483
+ texts = []
484
+ has_images = []
485
+
486
+ for sample in batch:
487
+ # Collect images
488
+ if sample["image"] is not None:
489
+ images.append(sample["image"])
490
+ has_images.append(True)
491
+ else:
492
+ # Create placeholder image for samples without images
493
+ placeholder = Image.new('RGB', (224, 224), color='white')
494
+ images.append(placeholder)
495
+ has_images.append(False)
496
+
497
+ # Collect texts
498
+ texts.append(sample["text"])
499
+
500
+ # Process images using vision processor
501
+ try:
502
+ vision_inputs = self.vision_processor(
503
+ images=images,
504
+ return_tensors="pt"
505
+ )
506
+ pixel_values = vision_inputs["pixel_values"]
507
+ except Exception as e:
508
+ logger.error(f"Error processing images: {e}")
509
+ # Create dummy pixel values
510
+ pixel_values = torch.zeros(len(batch), 3, 224, 224)
511
+
512
+ # Tokenize texts
513
+ try:
514
+ text_inputs = self.tokenizer(
515
+ texts,
516
+ padding=True,
517
+ truncation=True,
518
+ max_length=self.max_length,
519
+ return_tensors="pt"
520
+ )
521
+ except Exception as e:
522
+ logger.error(f"Error tokenizing texts: {e}")
523
+ # Create dummy inputs
524
+ text_inputs = {
525
+ "input_ids": torch.zeros(len(batch), self.max_length, dtype=torch.long),
526
+ "attention_mask": torch.ones(len(batch), self.max_length, dtype=torch.long)
527
+ }
528
+
529
+ # Create labels (same as input_ids for causal LM)
530
+ labels = text_inputs["input_ids"].clone()
531
+
532
+ # Mask padding tokens in labels (-100 is ignored by loss function)
533
+ labels[labels == self.tokenizer.pad_token_id] = -100
534
+
535
+ batch_dict = {
536
+ "input_ids": text_inputs["input_ids"],
537
+ "attention_mask": text_inputs["attention_mask"],
538
+ "labels": labels,
539
+ "images": pixel_values,
540
+ "has_images": torch.tensor(has_images, dtype=torch.bool)
541
+ }
542
+
543
+ return batch_dict
src/data/processors.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data processors for images and text
3
+ """
4
+ import torch
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ from typing import List, Dict, Any, Optional
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ImageProcessor:
14
+ """Image preprocessing for CLIP vision encoder"""
15
+
16
+ def __init__(self, config: Dict[str, Any]):
17
+ self.config = config
18
+ self.image_size = config["data"]["image_size"]
19
+
20
+ # CLIP normalization values
21
+ self.mean = config["data"]["image_mean"]
22
+ self.std = config["data"]["image_std"]
23
+
24
+ # Setup transforms
25
+ self.transform = self._setup_transforms()
26
+
27
+ def _setup_transforms(self):
28
+ """Setup image transformations"""
29
+ transform_list = [
30
+ transforms.Resize((self.image_size, self.image_size)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=self.mean, std=self.std)
33
+ ]
34
+
35
+ # Add augmentations if enabled
36
+ if self.config["data"]["augmentation"]["enabled"]:
37
+ aug_transforms = []
38
+
39
+ # Random resized crop
40
+ if self.config["data"]["augmentation"].get("random_resized_crop"):
41
+ scale = self.config["data"]["augmentation"]["random_resized_crop"]
42
+ aug_transforms.append(
43
+ transforms.RandomResizedCrop(
44
+ self.image_size,
45
+ scale=(scale, 1.0)
46
+ )
47
+ )
48
+
49
+ # Color jitter
50
+ if self.config["data"]["augmentation"].get("color_jitter"):
51
+ brightness = self.config["data"]["augmentation"]["color_jitter"]
52
+ aug_transforms.append(
53
+ transforms.ColorJitter(brightness=brightness)
54
+ )
55
+
56
+ # Horizontal flip
57
+ if self.config["data"]["augmentation"].get("horizontal_flip"):
58
+ prob = self.config["data"]["augmentation"]["horizontal_flip"]
59
+ aug_transforms.append(
60
+ transforms.RandomHorizontalFlip(p=prob)
61
+ )
62
+
63
+ # Insert augmentations before normalization
64
+ transform_list = (
65
+ transform_list[:-2] + # Resize, ToTensor
66
+ aug_transforms +
67
+ transform_list[-2:] # Normalize
68
+ )
69
+
70
+ return transforms.Compose(transform_list)
71
+
72
+ def __call__(self, image: Image.Image) -> torch.Tensor:
73
+ """Process a single image"""
74
+ if not isinstance(image, Image.Image):
75
+ raise ValueError(f"Expected PIL Image, got {type(image)}")
76
+
77
+ return self.transform(image)
78
+
79
+ def process_batch(self, images: List[Image.Image]) -> torch.Tensor:
80
+ """Process a batch of images"""
81
+ processed = []
82
+ for img in images:
83
+ processed.append(self(img))
84
+ return torch.stack(processed)
85
+
86
+
87
+ class TextProcessor:
88
+ """Text preprocessing for conversations"""
89
+
90
+ def __init__(self, config: Dict[str, Any]):
91
+ self.config = config
92
+ self.max_length = config["data"]["max_length"]
93
+
94
+ # Conversation formatting
95
+ conv_config = config["data"]["conversation"]
96
+ self.system_message = conv_config.get("system_message", "")
97
+ self.user_prefix = conv_config.get("user_prefix", "Human: ")
98
+ self.assistant_prefix = conv_config.get("assistant_prefix", "Assistant: ")
99
+ self.turn_separator = conv_config.get("turn_separator", "\n")
100
+
101
+ def format_conversation(self, conversations: List[Dict[str, str]]) -> str:
102
+ """Format conversation into training text with robust error handling"""
103
+ formatted_parts = []
104
+
105
+ # Add system message if present
106
+ if self.system_message:
107
+ formatted_parts.append(self.system_message)
108
+
109
+ # Ensure conversations is a valid list
110
+ if not isinstance(conversations, list):
111
+ conversations = []
112
+
113
+ # Process conversation turns with error handling
114
+ for turn in conversations:
115
+ try:
116
+ if not isinstance(turn, dict):
117
+ continue
118
+
119
+ role = turn.get("from", "").lower().strip()
120
+ content = turn.get("value", "")
121
+
122
+ # Clean and validate content
123
+ if not isinstance(content, str):
124
+ content = str(content) if content else ""
125
+
126
+ content = content.strip()
127
+ if not content:
128
+ continue
129
+
130
+ # Remove problematic characters that might cause issues
131
+ content = content.replace('\x00', '').replace('\n\n\n', '\n\n')
132
+
133
+ if role in ["human", "user"]:
134
+ formatted_parts.append(f"{self.user_prefix}{content}")
135
+ elif role in ["gpt", "assistant", "ai"]:
136
+ formatted_parts.append(f"{self.assistant_prefix}{content}")
137
+ else:
138
+ # Default to human if role is unclear
139
+ formatted_parts.append(f"{self.user_prefix}{content}")
140
+
141
+ except Exception as e:
142
+ logger.debug(f"Error processing conversation turn: {e}")
143
+ continue
144
+
145
+ # Ensure we have at least some content
146
+ if not formatted_parts:
147
+ return f"{self.user_prefix}What do you see in this image?{self.turn_separator}{self.assistant_prefix}I can see an image."
148
+
149
+ return self.turn_separator.join(formatted_parts)
150
+
151
+ def add_image_token(self, text: str, has_image: bool = True) -> str:
152
+ """Add image token to text if image is present"""
153
+ if has_image:
154
+ image_token = self.config.get("special_tokens", {}).get("image_token", "<image>")
155
+ return f"{image_token}\n{text}"
156
+ return text
157
+
158
+ def validate_text(self, text: str) -> bool:
159
+ """Validate text meets filtering criteria - more lenient validation"""
160
+ if not isinstance(text, str):
161
+ return False
162
+
163
+ # Basic cleanup
164
+ text = text.strip()
165
+
166
+ # Check for completely empty content
167
+ if not text:
168
+ return False
169
+
170
+ # More lenient length check - just ensure it's not absurdly long or short
171
+ text_length = len(text)
172
+ if text_length < 5: # Very short
173
+ return False
174
+ if text_length > 2000: # Very long
175
+ return False
176
+
177
+ # Check for basic structure (should have some content)
178
+ if len(text.split()) < 2: # Less than 2 words
179
+ return False
180
+
181
+ return True
src/training/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .callbacks import CustomCallback
2
+ from .utils import TrainingUtils
3
+
4
+ __all__ = [
5
+ "CustomCallback",
6
+ "TrainingUtils"
7
+ ]
src/training/callbacks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Lightning callbacks
3
+ """
4
+ import lightning as L
5
+ from lightning.pytorch.callbacks import Callback
6
+ import torch
7
+ from typing import Any
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class CustomCallback(Callback):
14
+ """Custom callback for monitoring training progress"""
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.start_time = None
19
+
20
+ def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
21
+ """Called when training starts"""
22
+ import time
23
+ self.start_time = time.time()
24
+ logger.info("Training started")
25
+
26
+ # Log model info
27
+ total_params = sum(p.numel() for p in pl_module.parameters())
28
+ trainable_params = sum(p.numel() for p in pl_module.parameters() if p.requires_grad)
29
+
30
+ logger.info(f"Total parameters: {total_params:,}")
31
+ logger.info(f"Trainable parameters: {trainable_params:,}")
32
+ logger.info(f"Trainable ratio: {trainable_params/total_params:.2%}")
33
+
34
+ def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
35
+ """Called when training ends"""
36
+ if self.start_time:
37
+ import time
38
+ duration = time.time() - self.start_time
39
+ logger.info(f"Training completed in {duration:.2f} seconds")
40
+
41
+ def on_train_epoch_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
42
+ """Called at the start of each training epoch"""
43
+ logger.info(f"Starting epoch {trainer.current_epoch + 1}/{trainer.max_epochs}")
44
+
45
+ def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
46
+ """Called at the end of validation epoch"""
47
+ if trainer.logged_metrics:
48
+ val_loss = trainer.logged_metrics.get("val/loss", None)
49
+ if val_loss is not None:
50
+ logger.info(f"Validation loss: {val_loss:.4f}")
51
+
52
+
53
+ class MemoryMonitorCallback(Callback):
54
+ """Monitor GPU memory usage during training"""
55
+
56
+ def __init__(self, log_every_n_steps: int = 100):
57
+ super().__init__()
58
+ self.log_every_n_steps = log_every_n_steps
59
+
60
+ def on_train_batch_end(
61
+ self,
62
+ trainer: L.Trainer,
63
+ pl_module: L.LightningModule,
64
+ outputs: Any,
65
+ batch: Any,
66
+ batch_idx: int
67
+ ) -> None:
68
+ """Log memory usage"""
69
+ if batch_idx % self.log_every_n_steps == 0 and torch.cuda.is_available():
70
+ memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
71
+ memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
72
+
73
+ pl_module.log("train/memory_allocated_gb", memory_allocated, on_step=True)
74
+ pl_module.log("train/memory_reserved_gb", memory_reserved, on_step=True)
src/training/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training utilities
3
+ """
4
+ import torch
5
+ import logging
6
+ from typing import Dict, Any, Optional
7
+ from pathlib import Path
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class TrainingUtils:
13
+ """Utility functions for training"""
14
+
15
+ @staticmethod
16
+ def count_parameters(model: torch.nn.Module) -> Dict[str, int]:
17
+ """Count model parameters"""
18
+ total_params = sum(p.numel() for p in model.parameters())
19
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
20
+ frozen_params = total_params - trainable_params
21
+
22
+ return {
23
+ "total": total_params,
24
+ "trainable": trainable_params,
25
+ "frozen": frozen_params,
26
+ "trainable_percentage": (trainable_params / total_params) * 100 if total_params > 0 else 0
27
+ }
28
+
29
+ @staticmethod
30
+ def print_model_summary(model: torch.nn.Module, model_name: str = "Model") -> None:
31
+ """Print detailed model summary"""
32
+ params = TrainingUtils.count_parameters(model)
33
+
34
+ logger.info(f"\n{model_name} Summary:")
35
+ logger.info(f" Total parameters: {params['total']:,}")
36
+ logger.info(f" Trainable parameters: {params['trainable']:,}")
37
+ logger.info(f" Frozen parameters: {params['frozen']:,}")
38
+ logger.info(f" Trainable percentage: {params['trainable_percentage']:.2f}%")
39
+
40
+ @staticmethod
41
+ def save_model_state(
42
+ model: torch.nn.Module,
43
+ path: str,
44
+ additional_info: Optional[Dict[str, Any]] = None
45
+ ) -> None:
46
+ """Save model state with additional information"""
47
+ save_path = Path(path)
48
+ save_path.parent.mkdir(parents=True, exist_ok=True)
49
+
50
+ state_dict = {
51
+ "model_state_dict": model.state_dict(),
52
+ "model_class": model.__class__.__name__,
53
+ }
54
+
55
+ if additional_info:
56
+ state_dict.update(additional_info)
57
+
58
+ torch.save(state_dict, save_path)
59
+ logger.info(f"Model state saved to: {save_path}")
60
+
61
+ @staticmethod
62
+ def load_model_state(model: torch.nn.Module, path: str, strict: bool = True) -> Dict[str, Any]:
63
+ """Load model state and return additional information"""
64
+ checkpoint = torch.load(path, map_location="cpu")
65
+
66
+ if "model_state_dict" in checkpoint:
67
+ model.load_state_dict(checkpoint["model_state_dict"], strict=strict)
68
+ logger.info(f"Model state loaded from: {path}")
69
+
70
+ # Return additional info
71
+ additional_info = {k: v for k, v in checkpoint.items() if k != "model_state_dict"}
72
+ return additional_info
73
+ else:
74
+ # Assume the checkpoint is just the state dict
75
+ model.load_state_dict(checkpoint, strict=strict)
76
+ logger.info(f"Model state loaded from: {path}")
77
+ return {}
78
+
79
+ @staticmethod
80
+ def get_device_info() -> Dict[str, Any]:
81
+ """Get information about available devices"""
82
+ info = {
83
+ "cuda_available": torch.cuda.is_available(),
84
+ "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
85
+ }
86
+
87
+ if torch.cuda.is_available():
88
+ info["cuda_current_device"] = torch.cuda.current_device()
89
+ info["cuda_device_name"] = torch.cuda.get_device_name()
90
+ info["cuda_memory_total"] = torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
91
+
92
+ return info
93
+
94
+ @staticmethod
95
+ def log_device_info() -> None:
96
+ """Log device information"""
97
+ info = TrainingUtils.get_device_info()
98
+
99
+ logger.info("\nDevice Information:")
100
+ logger.info(f" CUDA Available: {info['cuda_available']}")
101
+
102
+ if info['cuda_available']:
103
+ logger.info(f" CUDA Device Count: {info['cuda_device_count']}")
104
+ logger.info(f" Current Device: {info['cuda_current_device']}")
105
+ logger.info(f" Device Name: {info['cuda_device_name']}")
106
+ logger.info(f" Total Memory: {info['cuda_memory_total']:.2f} GB")
107
+ else:
108
+ logger.info(" Using CPU for training")
109
+
110
+ @staticmethod
111
+ def cleanup_memory() -> None:
112
+ """Clean up GPU memory"""
113
+ if torch.cuda.is_available():
114
+ torch.cuda.empty_cache()
115
+ logger.info("GPU memory cache cleared")
src/utils/logging.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logging utilities
3
+ """
4
+ import logging
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Optional
8
+ from rich.logging import RichHandler
9
+ from rich.console import Console
10
+
11
+
12
+ def setup_logging(
13
+ level: int = logging.INFO,
14
+ log_file: Optional[str] = None,
15
+ use_rich: bool = True
16
+ ) -> None:
17
+ """Setup logging configuration"""
18
+
19
+ # Create logs directory if needed
20
+ if log_file:
21
+ log_path = Path(log_file)
22
+ log_path.parent.mkdir(parents=True, exist_ok=True)
23
+
24
+ # Clear existing handlers
25
+ root_logger = logging.getLogger()
26
+ root_logger.handlers.clear()
27
+
28
+ # Setup formatters
29
+ if use_rich:
30
+ # Rich handler for console output
31
+ console_handler = RichHandler(
32
+ console=Console(stderr=True),
33
+ show_time=True,
34
+ show_path=True,
35
+ rich_tracebacks=True
36
+ )
37
+ console_handler.setLevel(level)
38
+ root_logger.addHandler(console_handler)
39
+ else:
40
+ # Standard console handler
41
+ console_handler = logging.StreamHandler(sys.stdout)
42
+ console_handler.setLevel(level)
43
+ console_formatter = logging.Formatter(
44
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
45
+ )
46
+ console_handler.setFormatter(console_formatter)
47
+ root_logger.addHandler(console_handler)
48
+
49
+ # File handler if specified
50
+ if log_file:
51
+ file_handler = logging.FileHandler(log_file)
52
+ file_handler.setLevel(level)
53
+ file_formatter = logging.Formatter(
54
+ '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
55
+ )
56
+ file_handler.setFormatter(file_formatter)
57
+ root_logger.addHandler(file_handler)
58
+
59
+ # Set root logger level
60
+ root_logger.setLevel(level)
61
+
62
+ # Reduce noise from some libraries
63
+ logging.getLogger("transformers").setLevel(logging.WARNING)
64
+ logging.getLogger("datasets").setLevel(logging.WARNING)
65
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
66
+ logging.getLogger("requests").setLevel(logging.WARNING)
67
+
68
+ logging.info("Logging setup completed")
69
+
70
+
71
+ def get_logger(name: str) -> logging.Logger:
72
+ """Get a logger with the specified name"""
73
+ return logging.getLogger(name)