""" HunyuanOCR Model Wrapper Provides an easy-to-use interface for text detection and recognition """ import re import os import torch from typing import Dict, List, Tuple, Optional from PIL import Image from transformers import AutoProcessor, HunYuanVLForConditionalGeneration from transformers.modeling_outputs import CausalLMOutputWithPast import requests from io import BytesIO # Monkey-patch HunYuanVLForConditionalGeneration.generate to fix dtype issue def patched_generate( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, imgs: Optional[list[torch.FloatTensor]] = None, imgs_pos: Optional[list[int]] = None, token_type_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[list[int]] = None, **kwargs, ) -> CausalLMOutputWithPast: if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") inputs_embeds = self.model.embed_tokens(input_ids) if self.vit is not None and pixel_values is not None: # PATCH: Use model's dtype instead of forcing bfloat16 pixel_values = pixel_values.to(self.dtype) image_embeds = self.vit(pixel_values, image_grid_thw) # ViT may be deployed on different GPUs from those used by LLMs, due to auto-mapping of accelerate. image_embeds = image_embeds.to(input_ids.device, non_blocking=True) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) return super(HunYuanVLForConditionalGeneration, self).generate( inputs=input_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs, ) HunYuanVLForConditionalGeneration.generate = patched_generate class HunyuanOCR: """Wrapper class for HunyuanOCR model for text spotting tasks""" def __init__(self, model_path: str = "tencent/HunyuanOCR", device: Optional[str] = None): """ Initialize the HunyuanOCR model Args: model_path: Path or name of the model (default: "tencent/HunyuanOCR") device: Device to load model on (cuda/cpu). Auto-detected if None. """ # Check if local model exists when using default path if model_path == "tencent/HunyuanOCR" and os.path.exists("HunyuanOCR"): print("Found local HunyuanOCR model, using it instead of downloading...") model_path = "HunyuanOCR" self.model_path = model_path # Auto-detect device if not specified if device is None: if torch.cuda.is_available(): self.device = "cuda" elif torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" else: self.device = device print(f"Loading HunyuanOCR model on {self.device}...") # Load processor self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False) # Determine dtype based on device if self.device == "cuda": torch_dtype = torch.bfloat16 elif self.device == "mps": torch_dtype = torch.float16 else: torch_dtype = torch.float32 # Load model self.model = HunYuanVLForConditionalGeneration.from_pretrained( model_path, attn_implementation="eager", torch_dtype=torch_dtype, device_map="auto" if self.device == "cuda" else None ) if self.device != "cuda": self.model = self.model.to(self.device) print("Model loaded successfully!") def clean_repeated_substrings(self, text: str) -> str: """ Clean repeated substrings in text output Args: text: Input text to clean Returns: Cleaned text """ n = len(text) if n < 8000: return text for length in range(2, n // 10 + 1): candidate = text[-length:] count = 0 i = n - length while i >= 0 and text[i:i + length] == candidate: count += 1 i -= length if count >= 10: return text[:n - length * (count - 1)] return text def load_image(self, image_source: str) -> Image.Image: """ Load image from URL or file path Args: image_source: URL or file path to image Returns: PIL Image object """ if image_source.startswith(('http://', 'https://')): response = requests.get(image_source) response.raise_for_status() return Image.open(BytesIO(response.content)) else: return Image.open(image_source) def detect_text(self, image: Image.Image, prompt: Optional[str] = None) -> str: """ Detect and recognize text in image with bounding boxes Args: image: PIL Image object prompt: Custom prompt (default: text spotting prompt in Chinese) Returns: Model response with detected text and coordinates """ # Default prompt for text spotting if prompt is None: prompt = "检测并识别图片中的文字,将文本内容与坐标格式化输出。" # Prepare messages messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] # Apply chat template text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process inputs inputs = self.processor( text=[text], images=[image], padding=True, return_tensors="pt", ) # Generate with torch.no_grad(): # Get model's dtype model_dtype = next(self.model.parameters()).dtype if self.device == "cuda": device = next(self.model.parameters()).device inputs = inputs.to(device) else: # Move to device and cast floating point tensors to model's dtype new_inputs = {} for k, v in inputs.items(): if torch.is_tensor(v): v = v.to(self.device) if v.dtype in [torch.float16, torch.bfloat16, torch.float32]: v = v.to(dtype=model_dtype) new_inputs[k] = v else: new_inputs[k] = v inputs = new_inputs generated_ids = self.model.generate( **inputs, max_new_tokens=2048, do_sample=False ) # Decode output if "input_ids" in inputs: input_ids = inputs["input_ids"] else: input_ids = inputs["inputs"] generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Clean repeated substrings output_text = self.clean_repeated_substrings(output_text) return output_text def parse_detection_results(self, response: str, image_width: int, image_height: int) -> List[Dict]: """ Parse detection response into structured format with denormalized coordinates Args: response: Model output text image_width: Image width in pixels image_height: Image height in pixels Returns: List of dictionaries with 'text', 'x1', 'y1', 'x2', 'y2' keys """ results = [] # Pattern to match text and coordinates: text(x1,y1),(x2,y2) pattern = r'([^()]+?)(\(\d+,\d+\),\(\d+,\d+\))' matches = re.finditer(pattern, response) for match in matches: try: text = match.group(1).strip() coords = match.group(2) # Parse coordinates coord_pattern = r'\((\d+),(\d+)\)' coord_matches = re.findall(coord_pattern, coords) if len(coord_matches) == 2: # Coordinates are normalized to [0, 1000], denormalize them x1_norm, y1_norm = float(coord_matches[0][0]), float(coord_matches[0][1]) x2_norm, y2_norm = float(coord_matches[1][0]), float(coord_matches[1][1]) # Denormalize to image dimensions x1 = int(x1_norm * image_width / 1000) y1 = int(y1_norm * image_height / 1000) x2 = int(x2_norm * image_width / 1000) y2 = int(y2_norm * image_height / 1000) results.append({ 'text': text, 'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2 }) except Exception as e: print(f"Error parsing detection result: {str(e)}") continue return results def process_image(self, image_source: str, prompt: Optional[str] = None) -> Tuple[str, List[Dict]]: """ Complete pipeline: load image, detect text, parse results Args: image_source: Path or URL to image prompt: Custom prompt for detection Returns: Tuple of (raw_response, parsed_results) """ # Load image image = self.load_image(image_source) image_width, image_height = image.size # Detect text response = self.detect_text(image, prompt) # Parse results parsed_results = self.parse_detection_results(response, image_width, image_height) return response, parsed_results, image