File size: 6,613 Bytes
34b253d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Data processors for images and text
"""
import torch
from PIL import Image
import torchvision.transforms as transforms
from typing import List, Dict, Any, Optional
import logging

logger = logging.getLogger(__name__)


class ImageProcessor:
    """Image preprocessing for CLIP vision encoder"""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.image_size = config["data"]["image_size"]
        
        # CLIP normalization values
        self.mean = config["data"]["image_mean"]
        self.std = config["data"]["image_std"]
        
        # Setup transforms
        self.transform = self._setup_transforms()
        
    def _setup_transforms(self):
        """Setup image transformations"""
        transform_list = [
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std)
        ]
        
        # Add augmentations if enabled
        if self.config["data"]["augmentation"]["enabled"]:
            aug_transforms = []
            
            # Random resized crop
            if self.config["data"]["augmentation"].get("random_resized_crop"):
                scale = self.config["data"]["augmentation"]["random_resized_crop"]
                aug_transforms.append(
                    transforms.RandomResizedCrop(
                        self.image_size, 
                        scale=(scale, 1.0)
                    )
                )
            
            # Color jitter
            if self.config["data"]["augmentation"].get("color_jitter"):
                brightness = self.config["data"]["augmentation"]["color_jitter"]
                aug_transforms.append(
                    transforms.ColorJitter(brightness=brightness)
                )
            
            # Horizontal flip
            if self.config["data"]["augmentation"].get("horizontal_flip"):
                prob = self.config["data"]["augmentation"]["horizontal_flip"]
                aug_transforms.append(
                    transforms.RandomHorizontalFlip(p=prob)
                )
            
            # Insert augmentations before normalization
            transform_list = (
                transform_list[:-2] +  # Resize, ToTensor
                aug_transforms +
                transform_list[-2:]    # Normalize
            )
        
        return transforms.Compose(transform_list)
    
    def __call__(self, image: Image.Image) -> torch.Tensor:
        """Process a single image"""
        if not isinstance(image, Image.Image):
            raise ValueError(f"Expected PIL Image, got {type(image)}")
        
        return self.transform(image)
    
    def process_batch(self, images: List[Image.Image]) -> torch.Tensor:
        """Process a batch of images"""
        processed = []
        for img in images:
            processed.append(self(img))
        return torch.stack(processed)


class TextProcessor:
    """Text preprocessing for conversations"""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.max_length = config["data"]["max_length"]
        
        # Conversation formatting
        conv_config = config["data"]["conversation"]
        self.system_message = conv_config.get("system_message", "")
        self.user_prefix = conv_config.get("user_prefix", "Human: ")
        self.assistant_prefix = conv_config.get("assistant_prefix", "Assistant: ")
        self.turn_separator = conv_config.get("turn_separator", "\n")
        
    def format_conversation(self, conversations: List[Dict[str, str]]) -> str:
        """Format conversation into training text with robust error handling"""
        formatted_parts = []

        # Add system message if present
        if self.system_message:
            formatted_parts.append(self.system_message)

        # Ensure conversations is a valid list
        if not isinstance(conversations, list):
            conversations = []

        # Process conversation turns with error handling
        for turn in conversations:
            try:
                if not isinstance(turn, dict):
                    continue

                role = turn.get("from", "").lower().strip()
                content = turn.get("value", "")

                # Clean and validate content
                if not isinstance(content, str):
                    content = str(content) if content else ""

                content = content.strip()
                if not content:
                    continue

                # Remove problematic characters that might cause issues
                content = content.replace('\x00', '').replace('\n\n\n', '\n\n')

                if role in ["human", "user"]:
                    formatted_parts.append(f"{self.user_prefix}{content}")
                elif role in ["gpt", "assistant", "ai"]:
                    formatted_parts.append(f"{self.assistant_prefix}{content}")
                else:
                    # Default to human if role is unclear
                    formatted_parts.append(f"{self.user_prefix}{content}")

            except Exception as e:
                logger.debug(f"Error processing conversation turn: {e}")
                continue

        # Ensure we have at least some content
        if not formatted_parts:
            return f"{self.user_prefix}What do you see in this image?{self.turn_separator}{self.assistant_prefix}I can see an image."

        return self.turn_separator.join(formatted_parts)
    
    def add_image_token(self, text: str, has_image: bool = True) -> str:
        """Add image token to text if image is present"""
        if has_image:
            image_token = self.config.get("special_tokens", {}).get("image_token", "<image>")
            return f"{image_token}\n{text}"
        return text
    
    def validate_text(self, text: str) -> bool:
        """Validate text meets filtering criteria - more lenient validation"""
        if not isinstance(text, str):
            return False

        # Basic cleanup
        text = text.strip()

        # Check for completely empty content
        if not text:
            return False

        # More lenient length check - just ensure it's not absurdly long or short
        text_length = len(text)
        if text_length < 5:  # Very short
            return False
        if text_length > 2000:  # Very long
            return False

        # Check for basic structure (should have some content)
        if len(text.split()) < 2:  # Less than 2 words
            return False

        return True