Spaces:
Runtime error
Runtime error
| """ | |
| Data processors for images and text | |
| """ | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ImageProcessor: | |
| """Image preprocessing for CLIP vision encoder""" | |
| def __init__(self, config: Dict[str, Any]): | |
| self.config = config | |
| self.image_size = config["data"]["image_size"] | |
| # CLIP normalization values | |
| self.mean = config["data"]["image_mean"] | |
| self.std = config["data"]["image_std"] | |
| # Setup transforms | |
| self.transform = self._setup_transforms() | |
| def _setup_transforms(self): | |
| """Setup image transformations""" | |
| transform_list = [ | |
| transforms.Resize((self.image_size, self.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=self.mean, std=self.std) | |
| ] | |
| # Add augmentations if enabled | |
| if self.config["data"]["augmentation"]["enabled"]: | |
| aug_transforms = [] | |
| # Random resized crop | |
| if self.config["data"]["augmentation"].get("random_resized_crop"): | |
| scale = self.config["data"]["augmentation"]["random_resized_crop"] | |
| aug_transforms.append( | |
| transforms.RandomResizedCrop( | |
| self.image_size, | |
| scale=(scale, 1.0) | |
| ) | |
| ) | |
| # Color jitter | |
| if self.config["data"]["augmentation"].get("color_jitter"): | |
| brightness = self.config["data"]["augmentation"]["color_jitter"] | |
| aug_transforms.append( | |
| transforms.ColorJitter(brightness=brightness) | |
| ) | |
| # Horizontal flip | |
| if self.config["data"]["augmentation"].get("horizontal_flip"): | |
| prob = self.config["data"]["augmentation"]["horizontal_flip"] | |
| aug_transforms.append( | |
| transforms.RandomHorizontalFlip(p=prob) | |
| ) | |
| # Insert augmentations before normalization | |
| transform_list = ( | |
| transform_list[:-2] + # Resize, ToTensor | |
| aug_transforms + | |
| transform_list[-2:] # Normalize | |
| ) | |
| return transforms.Compose(transform_list) | |
| def __call__(self, image: Image.Image) -> torch.Tensor: | |
| """Process a single image""" | |
| if not isinstance(image, Image.Image): | |
| raise ValueError(f"Expected PIL Image, got {type(image)}") | |
| return self.transform(image) | |
| def process_batch(self, images: List[Image.Image]) -> torch.Tensor: | |
| """Process a batch of images""" | |
| processed = [] | |
| for img in images: | |
| processed.append(self(img)) | |
| return torch.stack(processed) | |
| class TextProcessor: | |
| """Text preprocessing for conversations""" | |
| def __init__(self, config: Dict[str, Any]): | |
| self.config = config | |
| self.max_length = config["data"]["max_length"] | |
| # Conversation formatting | |
| conv_config = config["data"]["conversation"] | |
| self.system_message = conv_config.get("system_message", "") | |
| self.user_prefix = conv_config.get("user_prefix", "Human: ") | |
| self.assistant_prefix = conv_config.get("assistant_prefix", "Assistant: ") | |
| self.turn_separator = conv_config.get("turn_separator", "\n") | |
| def format_conversation(self, conversations: List[Dict[str, str]]) -> str: | |
| """Format conversation into training text with robust error handling""" | |
| formatted_parts = [] | |
| # Add system message if present | |
| if self.system_message: | |
| formatted_parts.append(self.system_message) | |
| # Ensure conversations is a valid list | |
| if not isinstance(conversations, list): | |
| conversations = [] | |
| # Process conversation turns with error handling | |
| for turn in conversations: | |
| try: | |
| if not isinstance(turn, dict): | |
| continue | |
| role = turn.get("from", "").lower().strip() | |
| content = turn.get("value", "") | |
| # Clean and validate content | |
| if not isinstance(content, str): | |
| content = str(content) if content else "" | |
| content = content.strip() | |
| if not content: | |
| continue | |
| # Remove problematic characters that might cause issues | |
| content = content.replace('\x00', '').replace('\n\n\n', '\n\n') | |
| if role in ["human", "user"]: | |
| formatted_parts.append(f"{self.user_prefix}{content}") | |
| elif role in ["gpt", "assistant", "ai"]: | |
| formatted_parts.append(f"{self.assistant_prefix}{content}") | |
| else: | |
| # Default to human if role is unclear | |
| formatted_parts.append(f"{self.user_prefix}{content}") | |
| except Exception as e: | |
| logger.debug(f"Error processing conversation turn: {e}") | |
| continue | |
| # Ensure we have at least some content | |
| if not formatted_parts: | |
| return f"{self.user_prefix}What do you see in this image?{self.turn_separator}{self.assistant_prefix}I can see an image." | |
| return self.turn_separator.join(formatted_parts) | |
| def add_image_token(self, text: str, has_image: bool = True) -> str: | |
| """Add image token to text if image is present""" | |
| if has_image: | |
| image_token = self.config.get("special_tokens", {}).get("image_token", "<image>") | |
| return f"{image_token}\n{text}" | |
| return text | |
| def validate_text(self, text: str) -> bool: | |
| """Validate text meets filtering criteria - more lenient validation""" | |
| if not isinstance(text, str): | |
| return False | |
| # Basic cleanup | |
| text = text.strip() | |
| # Check for completely empty content | |
| if not text: | |
| return False | |
| # More lenient length check - just ensure it's not absurdly long or short | |
| text_length = len(text) | |
| if text_length < 5: # Very short | |
| return False | |
| if text_length > 2000: # Very long | |
| return False | |
| # Check for basic structure (should have some content) | |
| if len(text.split()) < 2: # Less than 2 words | |
| return False | |
| return True | |