| import os |
| import io |
| import base64 |
| import tempfile |
| import zipfile |
| import logging |
| import sys |
| import time |
| from typing import Dict, Any, Optional |
| from pathlib import Path |
| import json |
|
|
| import torch |
| import numpy as np |
| from PIL import Image |
| import cv2 |
|
|
| |
| |
| |
| class Float32Autocast: |
| """No-op autocast that forces float32.""" |
| def __init__(self, device_type, dtype=None, enabled=True): |
| self.device_type = device_type |
| self.dtype = torch.float32 |
| self.enabled = False |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, *args): |
| pass |
|
|
| |
| _ORIGINAL_AUTOCAST = torch.autocast |
| torch.autocast = Float32Autocast |
| if hasattr(torch.cuda, 'amp'): |
| torch.cuda.amp.autocast = Float32Autocast |
| if hasattr(torch, 'amp'): |
| torch.amp.autocast = Float32Autocast |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s [%(levelname)s] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| stream=sys.stdout |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| logger.info("✓ Patched torch.autocast globally before SAM3 import") |
|
|
| |
| |
| from sam3.model_builder import build_sam3_video_predictor |
|
|
| |
| try: |
| from huggingface_hub import HfApi |
| HF_HUB_AVAILABLE = True |
| except ImportError: |
| HF_HUB_AVAILABLE = False |
|
|
|
|
| class EndpointHandler: |
| """ |
| SAM3 Video Segmentation Handler for HuggingFace Inference Endpoints |
| |
| Processes video with text prompts and returns segmentation masks. |
| Uses SAM3 repository code directly from local sam3/ package. |
| """ |
| |
| def __init__(self, path: str = ""): |
| """ |
| Initialize SAM3 video predictor. |
| |
| Args: |
| path: Path to model repository (not used - model loads from HF automatically) |
| """ |
| logger.info("="*80) |
| logger.info("INITIALIZING SAM3 VIDEO SEGMENTATION HANDLER") |
| logger.info("="*80) |
| |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Device detection: {self.device}") |
| |
| if self.device != "cuda": |
| logger.error("FATAL: SAM3 requires GPU acceleration. No CUDA device found.") |
| raise ValueError("SAM3 requires GPU acceleration. No CUDA device found.") |
| |
| |
| if torch.cuda.is_available(): |
| logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}") |
| logger.info(f"CUDA Version: {torch.version.cuda}") |
| logger.info(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
|
|
| |
| |
| try: |
| logger.info("Building SAM3 video predictor...") |
| start_time = time.time() |
|
|
| |
| bpe_path = self._ensure_bpe_file() |
| logger.info(f"BPE tokenizer path: {bpe_path}") |
|
|
| |
| self.predictor = build_sam3_video_predictor( |
| gpus_to_use=[0], |
| bpe_path=bpe_path |
| ) |
|
|
| |
| |
| logger.info("Converting model to float32 to avoid dtype mismatch...") |
|
|
| def convert_model_to_float32(model): |
| """Recursively convert all model components to float32.""" |
| conversion_count = 0 |
|
|
| |
| model.float() |
|
|
| |
| for name, param in model.named_parameters(): |
| if param.dtype != torch.float32: |
| param.data = param.data.float() |
| conversion_count += 1 |
| logger.debug(f" Converted parameter: {name}") |
|
|
| |
| for buffer_name, buffer in model.named_buffers(): |
| if buffer.dtype != torch.float32 and buffer.dtype in [torch.float16, torch.bfloat16]: |
| model.register_buffer(buffer_name, buffer.float()) |
| conversion_count += 1 |
| logger.debug(f" Converted buffer: {buffer_name}") |
|
|
| |
| for name, module in model.named_modules(): |
| if module is not model: |
| try: |
| module.float() |
| except Exception: |
| pass |
|
|
| return conversion_count |
|
|
| total_conversions = 0 |
|
|
| |
| if hasattr(self.predictor, 'model') and self.predictor.model is not None: |
| logger.info(" Converting main model...") |
| total_conversions += convert_model_to_float32(self.predictor.model) |
|
|
| |
| |
| for attr_name in ['detector', 'tracker', 'image_encoder', 'text_encoder']: |
| if hasattr(self.predictor, attr_name): |
| attr = getattr(self.predictor, attr_name) |
| if attr is not None and hasattr(attr, 'float'): |
| logger.info(f" Converting {attr_name}...") |
| try: |
| total_conversions += convert_model_to_float32(attr) |
| except Exception as e: |
| logger.warning(f" Could not convert {attr_name}: {e}") |
|
|
| |
| if hasattr(self.predictor, 'model') and self.predictor.model is not None: |
| model = self.predictor.model |
| for attr_name in dir(model): |
| if not attr_name.startswith('_'): |
| try: |
| attr = getattr(model, attr_name) |
| if hasattr(attr, 'parameters') and hasattr(attr, 'float'): |
| |
| if attr_name not in ['model', 'detector', 'tracker']: |
| logger.debug(f" Found submodel: {attr_name}") |
| try: |
| convert_model_to_float32(attr) |
| except Exception: |
| pass |
| except Exception: |
| pass |
|
|
| if total_conversions > 0: |
| logger.info(f"✓ Model converted to float32 ({total_conversions} tensors converted)") |
| else: |
| logger.warning("⚠ No tensors were converted - dtype fix may not have been applied correctly") |
|
|
| |
| original_handle_request = self.predictor.handle_request |
|
|
| def float32_handle_request(request): |
| """Wrapper to ensure all tensor inputs are float32.""" |
| |
| def ensure_float32(obj): |
| if isinstance(obj, torch.Tensor): |
| if obj.dtype in [torch.float16, torch.bfloat16]: |
| return obj.float() |
| return obj |
| elif isinstance(obj, dict): |
| return {k: ensure_float32(v) for k, v in obj.items()} |
| elif isinstance(obj, (list, tuple)): |
| return type(obj)(ensure_float32(item) for item in obj) |
| return obj |
|
|
| request = ensure_float32(request) |
| return original_handle_request(request) |
|
|
| self.predictor.handle_request = float32_handle_request |
|
|
| |
| if hasattr(self.predictor, 'handle_stream_request'): |
| original_handle_stream_request = self.predictor.handle_stream_request |
|
|
| def float32_handle_stream_request(request): |
| """Wrapper to ensure all tensor inputs are float32.""" |
| def ensure_float32(obj): |
| if isinstance(obj, torch.Tensor): |
| if obj.dtype in [torch.float16, torch.bfloat16]: |
| return obj.float() |
| return obj |
| elif isinstance(obj, dict): |
| return {k: ensure_float32(v) for k, v in obj.items()} |
| elif isinstance(obj, (list, tuple)): |
| return type(obj)(ensure_float32(item) for item in obj) |
| return obj |
|
|
| request = ensure_float32(request) |
| for response in original_handle_stream_request(request): |
| yield response |
|
|
| self.predictor.handle_stream_request = float32_handle_stream_request |
|
|
| logger.info("✓ Added float32 enforcement wrappers to predictor methods") |
|
|
| elapsed = time.time() - start_time |
| logger.info(f"✓ SAM3 video predictor loaded successfully in {elapsed:.2f}s") |
| |
| except Exception as e: |
| logger.error(f"✗ Failed to load SAM3 predictor: {type(e).__name__}: {e}") |
| logger.exception("Full traceback:") |
| raise |
| |
| |
| self.hf_api = None |
| hf_token = os.getenv("HF_TOKEN") |
| |
| if HF_HUB_AVAILABLE and hf_token: |
| try: |
| self.hf_api = HfApi(token=hf_token) |
| logger.info("✓ HuggingFace Hub API initialized") |
| except Exception as e: |
| logger.warning(f"Failed to initialize HF API: {e}") |
| else: |
| reasons = [] |
| if not HF_HUB_AVAILABLE: |
| reasons.append("huggingface_hub not installed") |
| if not hf_token: |
| reasons.append("HF_TOKEN not set") |
| logger.info(f"HuggingFace Hub uploads disabled ({', '.join(reasons)})") |
| |
| logger.info("="*80) |
| logger.info("INITIALIZATION COMPLETE - READY FOR REQUESTS") |
| logger.info("="*80) |
| |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process video segmentation request using SAM3 video predictor API. |
| |
| Expected input format (HuggingFace Inference Toolkit standard): |
| { |
| "inputs": <base64_encoded_video>, |
| "parameters": { |
| "text_prompt": "object to segment", |
| "return_format": "download_url" or "base64" or "metadata_only", # optional |
| "output_repo": "username/dataset-name", # optional, for HF upload |
| } |
| } |
| |
| Returns: |
| { |
| "download_url": "https://...", # if uploaded to HF |
| "frame_count": 120, |
| "video_metadata": {...}, |
| "compressed_size_mb": 15.3, |
| "objects_detected": [1, 2, 3] # object IDs |
| } |
| """ |
| request_start = time.time() |
| |
| logger.info("") |
| logger.info("="*80) |
| logger.info("NEW REQUEST RECEIVED") |
| logger.info("="*80) |
| |
| try: |
| |
| logger.info("Parsing request parameters...") |
| |
| |
| logger.info(f" Received keys: {list(data.keys())}") |
| if "parameters" in data: |
| logger.info(f" parameters dict keys: {list(data['parameters'].keys())}") |
| |
| |
| video_data = data.get("inputs") |
| |
| |
| |
| parameters = data.get("parameters", {}) |
| text_prompt = data.get("text_prompt") or parameters.get("text_prompt", "") |
| output_repo = data.get("output_repo") or parameters.get("output_repo") |
| return_format = data.get("return_format") or parameters.get("return_format", "metadata_only") |
| |
| |
| logger.info(f" Extracted text_prompt: '{text_prompt}'") |
| |
| |
| logger.info(f" text_prompt: '{text_prompt}'") |
| logger.info(f" return_format: {return_format}") |
| logger.info(f" output_repo: {output_repo if output_repo else 'None'}") |
| logger.info(f" video_data: {'Present' if video_data else 'Missing'} ({len(video_data) if video_data else 0} chars)") |
| |
| |
| if not video_data: |
| logger.error("✗ Validation failed: No video data provided") |
| return {"error": "No video data provided. Include video as 'inputs' in request."} |
| |
| if not text_prompt: |
| logger.error("✗ Validation failed: No text prompt provided") |
| return {"error": "No text prompt provided. Include 'text_prompt' in 'parameters'."} |
| |
| if return_format not in ["metadata_only", "base64", "download_url"]: |
| logger.warning(f"Invalid return_format '{return_format}', defaulting to 'metadata_only'") |
| return_format = "metadata_only" |
| |
| if return_format == "download_url" and not output_repo: |
| logger.error("✗ Validation failed: download_url requires output_repo") |
| return {"error": "return_format='download_url' requires 'output_repo' parameter"} |
| |
| logger.info("✓ Request validation passed") |
| |
| |
| with tempfile.TemporaryDirectory() as tmpdir: |
| tmpdir_path = Path(tmpdir) |
| logger.info(f"Created temporary directory: {tmpdir}") |
| |
| |
| logger.info("") |
| logger.info("STEP 1/9: Decoding video data...") |
| step_start = time.time() |
| |
| try: |
| video_path = self._prepare_video(video_data, tmpdir_path) |
| video_size_mb = video_path.stat().st_size / 1e6 |
| |
| logger.info(f" Video saved to: {video_path}") |
| logger.info(f" Video size: {video_size_mb:.2f} MB") |
| logger.info(f"✓ Step 1 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.error(f"✗ Step 1 failed: {type(e).__name__}: {e}") |
| raise |
| |
| |
| logger.info("") |
| logger.info("STEP 2/9: Starting SAM3 session...") |
| step_start = time.time() |
| |
| try: |
| response = self.predictor.handle_request( |
| request=dict( |
| type="start_session", |
| resource_path=str(video_path), |
| ) |
| ) |
| session_id = response["session_id"] |
| |
| logger.info(f" Session ID: {session_id}") |
| logger.info(f"✓ Step 2 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.error(f"✗ Step 2 failed: {type(e).__name__}: {e}") |
| raise |
| |
| |
| logger.info("") |
| logger.info("STEP 3/9: Adding text prompt to first frame...") |
| step_start = time.time() |
| |
| try: |
| response = self.predictor.handle_request( |
| request=dict( |
| type="add_prompt", |
| session_id=session_id, |
| frame_index=0, |
| text=text_prompt, |
| ) |
| ) |
| |
| logger.info(f" Prompt: '{text_prompt}'") |
| logger.info(f" Frame: 0") |
| logger.info(f"✓ Step 3 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.error(f"✗ Step 3 failed: {type(e).__name__}: {e}") |
| raise |
| |
| |
| logger.info("") |
| logger.info("STEP 4/9: Propagating segmentation through video...") |
| step_start = time.time() |
| |
| try: |
| outputs_per_frame = {} |
| last_log_frame = -1 |
| log_interval = 10 |
| |
| for stream_response in self.predictor.handle_stream_request( |
| request=dict( |
| type="propagate_in_video", |
| session_id=session_id, |
| ) |
| ): |
| frame_idx = stream_response["frame_index"] |
| outputs_per_frame[frame_idx] = stream_response["outputs"] |
| |
| |
| if frame_idx - last_log_frame >= log_interval: |
| logger.info(f" Processing frame {frame_idx}...") |
| last_log_frame = frame_idx |
| |
| logger.info(f" Total frames processed: {len(outputs_per_frame)}") |
| logger.info(f"✓ Step 4 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.error(f"✗ Step 4 failed: {type(e).__name__}: {e}") |
| raise |
| |
| |
| logger.info("") |
| logger.info("STEP 5/9: Saving masks to PNG files...") |
| step_start = time.time() |
| |
| try: |
| masks_dir = tmpdir_path / "masks" |
| masks_dir.mkdir() |
| |
| all_object_ids = set() |
| mask_count = 0 |
| |
| for frame_idx, frame_output in outputs_per_frame.items(): |
| frame_masks = self._save_frame_masks(frame_output, masks_dir, frame_idx) |
| mask_count += frame_masks |
| |
| |
| if "object_ids" in frame_output and frame_output["object_ids"] is not None: |
| obj_ids = frame_output["object_ids"] |
| if torch.is_tensor(obj_ids): |
| obj_ids = obj_ids.cpu().tolist() |
| elif isinstance(obj_ids, np.ndarray): |
| obj_ids = obj_ids.tolist() |
| |
| if isinstance(obj_ids, list): |
| all_object_ids.update(obj_ids) |
| else: |
| all_object_ids.add(obj_ids) |
| |
| logger.info(f" Masks directory: {masks_dir}") |
| logger.info(f" Total mask files: {mask_count}") |
| logger.info(f" Unique objects: {sorted(list(all_object_ids))}") |
| logger.info(f"✓ Step 5 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.error(f"✗ Step 5 failed: {type(e).__name__}: {e}") |
| raise |
| |
| |
| logger.info("") |
| logger.info("STEP 6/9: Creating ZIP archive...") |
| step_start = time.time() |
| |
| try: |
| zip_path = tmpdir_path / "masks.zip" |
| self._create_zip(masks_dir, zip_path) |
| |
| zip_size_mb = zip_path.stat().st_size / 1e6 |
| |
| logger.info(f" ZIP path: {zip_path}") |
| logger.info(f" ZIP size: {zip_size_mb:.2f} MB") |
| logger.info(f" Compression ratio: {(1 - zip_size_mb / video_size_mb) * 100:.1f}%") |
| logger.info(f"✓ Step 6 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.error(f"✗ Step 6 failed: {type(e).__name__}: {e}") |
| raise |
| |
| |
| logger.info("") |
| logger.info("STEP 7/9: Extracting video metadata...") |
| step_start = time.time() |
| |
| try: |
| video_metadata = self._get_video_metadata(video_path) |
| |
| for key, value in video_metadata.items(): |
| logger.info(f" {key}: {value}") |
| logger.info(f"✓ Step 7 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.warning(f"Step 7 partial failure: {e}") |
| video_metadata = {} |
| |
| |
| logger.info("") |
| logger.info("STEP 8/9: Preparing response...") |
| step_start = time.time() |
| |
| response = { |
| "frame_count": len(outputs_per_frame), |
| "objects_detected": sorted(list(all_object_ids)) if all_object_ids else [], |
| "compressed_size_mb": round(zip_size_mb, 2), |
| "video_metadata": video_metadata |
| } |
| |
| if return_format == "download_url" and output_repo: |
| logger.info(f" Uploading to HuggingFace dataset: {output_repo}") |
| try: |
| download_url = self._upload_to_hf(zip_path, output_repo) |
| response["download_url"] = download_url |
| logger.info(f" ✓ Upload successful: {download_url}") |
| except Exception as e: |
| logger.error(f" ✗ Upload failed: {e}") |
| raise |
| |
| elif return_format == "base64": |
| logger.info(" Encoding ZIP to base64...") |
| try: |
| with open(zip_path, "rb") as f: |
| zip_bytes = f.read() |
| response["masks_zip_base64"] = base64.b64encode(zip_bytes).decode("utf-8") |
| logger.info(f" ✓ Encoded {len(response['masks_zip_base64'])} characters") |
| except Exception as e: |
| logger.error(f" ✗ Encoding failed: {e}") |
| raise |
| |
| else: |
| logger.info(" Returning metadata only (no mask data)") |
| |
| logger.info(f"✓ Step 8 completed in {time.time() - step_start:.2f}s") |
| |
| |
| logger.info("") |
| logger.info("STEP 9/9: Closing SAM3 session...") |
| step_start = time.time() |
| |
| try: |
| self.predictor.handle_request( |
| request=dict( |
| type="close_session", |
| session_id=session_id, |
| ) |
| ) |
| logger.info(f"✓ Step 9 completed in {time.time() - step_start:.2f}s") |
| |
| except Exception as e: |
| logger.warning(f"Step 9 partial failure (non-critical): {e}") |
| |
| |
| total_time = time.time() - request_start |
| logger.info("") |
| logger.info("="*80) |
| logger.info("REQUEST COMPLETED SUCCESSFULLY") |
| logger.info(f"Total processing time: {total_time:.2f}s") |
| logger.info(f"Frames processed: {len(outputs_per_frame)}") |
| logger.info(f"Objects detected: {len(all_object_ids)}") |
| logger.info("="*80) |
| logger.info("") |
| |
| return response |
| |
| except Exception as e: |
| total_time = time.time() - request_start |
| |
| logger.error("") |
| logger.error("="*80) |
| logger.error("REQUEST FAILED") |
| logger.error(f"Error type: {type(e).__name__}") |
| logger.error(f"Error message: {str(e)}") |
| logger.error(f"Time elapsed: {total_time:.2f}s") |
| logger.error("="*80) |
| logger.exception("Full traceback:") |
| logger.error("") |
| |
| return { |
| "error": str(e), |
| "error_type": type(e).__name__ |
| } |
| |
| def _ensure_bpe_file(self) -> str: |
| """ |
| Ensure BPE tokenizer file exists. Download from HuggingFace if missing. |
| Returns path to the BPE file. |
| """ |
| logger.info("Checking for BPE tokenizer file...") |
|
|
| |
| possible_paths = [ |
| Path("/repository/assets/bpe_simple_vocab_16e6.txt.gz"), |
| Path("./assets/bpe_simple_vocab_16e6.txt.gz"), |
| Path("../assets/bpe_simple_vocab_16e6.txt.gz"), |
| Path("/app/assets/bpe_simple_vocab_16e6.txt.gz"), |
| ] |
|
|
| for bpe_file in possible_paths: |
| if bpe_file.exists(): |
| logger.info(f" ✓ BPE file found: {bpe_file}") |
| return str(bpe_file) |
|
|
| logger.warning(" BPE file not found in any expected location") |
|
|
| |
| assets_dir = Path("/repository/assets") |
| bpe_file = assets_dir / "bpe_simple_vocab_16e6.txt.gz" |
| |
| logger.warning(f" BPE file not found at {bpe_file}") |
| logger.info(" Downloading from HuggingFace...") |
| |
| |
| assets_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| try: |
| from huggingface_hub import hf_hub_download |
| |
| logger.info(" Attempting download via hf_hub_download...") |
| downloaded_path = hf_hub_download( |
| repo_id="facebook/sam3", |
| filename="assets/bpe_simple_vocab_16e6.txt.gz", |
| local_dir="/repository", |
| local_dir_use_symlinks=False |
| ) |
| |
| logger.info(f" ✓ BPE file downloaded: {downloaded_path}") |
| return downloaded_path |
| |
| except Exception as e: |
| logger.warning(f" Primary download failed: {e}") |
| logger.info(" Trying fallback download method...") |
| |
| |
| import urllib.request |
| url = "https://huggingface.co/facebook/sam3/resolve/main/assets/bpe_simple_vocab_16e6.txt.gz" |
| |
| try: |
| logger.info(f" Downloading from: {url}") |
| urllib.request.urlretrieve(url, str(bpe_file)) |
| logger.info(f" ✓ BPE file downloaded: {bpe_file}") |
| return str(bpe_file) |
| |
| except Exception as e2: |
| logger.error(f" ✗ Fallback download failed: {e2}") |
| raise ValueError( |
| f"Could not download BPE tokenizer file. Please add assets/bpe_simple_vocab_16e6.txt.gz " |
| f"to your repository. Download from: {url}" |
| ) |
| |
| def _prepare_video(self, video_data: str, tmpdir: Path) -> Path: |
| """Decode base64 video and save to file.""" |
| try: |
| logger.info(" Decoding base64 data...") |
| video_bytes = base64.b64decode(video_data) |
| logger.info(f" Decoded {len(video_bytes)} bytes") |
| |
| except Exception as e: |
| logger.error(f" Base64 decode failed: {e}") |
| raise ValueError(f"Failed to decode base64 video: {e}") |
| |
| video_path = tmpdir / "input_video.mp4" |
| video_path.write_bytes(video_bytes) |
| |
| return video_path |
| |
| def _save_frame_masks(self, frame_output: Dict, masks_dir: Path, frame_idx: int) -> int: |
| """ |
| Save masks for a frame as PNG files. |
| Each object gets its own mask file: frame_XXXX_obj_Y.png |
| Returns the number of masks saved. |
| """ |
| if "masks" not in frame_output or frame_output["masks"] is None: |
| return 0 |
| |
| masks = frame_output["masks"] |
| object_ids = frame_output.get("object_ids", []) |
| |
| |
| if torch.is_tensor(object_ids): |
| object_ids = object_ids.cpu().tolist() |
| elif isinstance(object_ids, np.ndarray): |
| object_ids = object_ids.tolist() |
| elif not isinstance(object_ids, list): |
| object_ids = list(object_ids) if object_ids is not None else [] |
| |
| |
| if torch.is_tensor(masks): |
| masks = masks.cpu().numpy() |
| |
| |
| if len(masks.shape) == 4: |
| masks = masks[0] |
| |
| |
| saved_count = 0 |
| for i, obj_id in enumerate(object_ids): |
| if i < len(masks): |
| mask = masks[i] |
| |
| |
| mask_binary = (mask > 0.5).astype(np.uint8) * 255 |
| |
| |
| mask_img = Image.fromarray(mask_binary) |
| mask_filename = f"frame_{frame_idx:05d}_obj_{obj_id}.png" |
| mask_img.save(masks_dir / mask_filename, compress_level=9) |
| saved_count += 1 |
| |
| return saved_count |
| |
| def _create_zip(self, masks_dir: Path, zip_path: Path): |
| """Create ZIP archive of all mask PNGs.""" |
| mask_files = sorted(masks_dir.glob("*.png")) |
| logger.info(f" Creating ZIP with {len(mask_files)} files...") |
| |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED, compresslevel=9) as zipf: |
| for mask_file in mask_files: |
| zipf.write(mask_file, mask_file.name) |
| |
| def _get_video_metadata(self, video_path: Path) -> Dict[str, Any]: |
| """Extract video metadata using OpenCV.""" |
| try: |
| cap = cv2.VideoCapture(str(video_path)) |
| |
| if not cap.isOpened(): |
| logger.warning(f" Could not open video file: {video_path}") |
| return {} |
| |
| metadata = { |
| "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), |
| "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), |
| "fps": float(cap.get(cv2.CAP_PROP_FPS)), |
| "frame_count": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), |
| } |
| cap.release() |
| |
| return metadata |
| |
| except Exception as e: |
| logger.warning(f" Could not extract video metadata: {e}") |
| return {} |
| |
| def _upload_to_hf(self, zip_path: Path, repo_id: str) -> str: |
| """Upload ZIP file to HuggingFace dataset repository.""" |
| if not self.hf_api: |
| raise ValueError("HuggingFace Hub API not initialized. Set HF_TOKEN environment variable.") |
| |
| try: |
| |
| import time |
| timestamp = int(time.time()) |
| filename = f"masks_{timestamp}.zip" |
| |
| logger.info(f" Uploading {zip_path.stat().st_size / 1e6:.2f} MB...") |
| |
| |
| url = self.hf_api.upload_file( |
| path_or_fileobj=str(zip_path), |
| path_in_repo=filename, |
| repo_id=repo_id, |
| repo_type="dataset", |
| ) |
| |
| |
| download_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}" |
| return download_url |
| |
| except Exception as e: |
| logger.error(f" Upload error: {e}") |
| raise ValueError(f"Failed to upload to HuggingFace: {e}") |