Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import json | |
| import time | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Union, Any | |
| from datetime import datetime | |
| import threading | |
| import queue | |
| # ====================== Additional Imports ====================== | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image, ExifTags | |
| from tqdm import tqdm | |
| import gradio as gr | |
| import pandas as pd | |
| # Hugging Face Hub | |
| from huggingface_hub import ( | |
| hf_hub_download, | |
| login, | |
| whoami, | |
| create_repo, | |
| HfApi, | |
| InferenceClient, | |
| ) | |
| # ====================== Configuration & Paths ====================== | |
| HF_USERNAME = os.environ.get("HF_USERNAME", "latterworks") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) # If not provided, use default Spaces token | |
| DATASET_NAME = os.environ.get("DATASET_NAME", "geo-metadata") | |
| DATASET_REPO = f"{HF_USERNAME}/{DATASET_NAME}" | |
| # Relative local paths | |
| LOCAL_STORAGE_PATH = Path("./data") | |
| LOCAL_STORAGE_PATH.mkdir(exist_ok=True, parents=True) | |
| METADATA_FILE = LOCAL_STORAGE_PATH / "metadata.jsonl" | |
| IMAGES_DIR = Path("./images") # place your images here | |
| IMAGES_DIR.mkdir(exist_ok=True, parents=True) | |
| # We’ll store checkpoints here: | |
| CHECKPOINTS_DIR = Path("./checkpoints") | |
| CHECKPOINTS_DIR.mkdir(exist_ok=True, parents=True) | |
| CHECKPOINT_PATH = CHECKPOINTS_DIR / "last_checkpoint.pth" | |
| MAX_BATCH_SIZE = 25 | |
| SUPPORTED_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.heic', '.tiff', '.tif', '.bmp', '.webp'] | |
| # ====================== Queues and Threads ====================== | |
| process_queue = queue.Queue() | |
| upload_queue = queue.Queue() | |
| # ====================== EXIF Extraction Core ====================== | |
| def convert_to_degrees(value): | |
| """Convert GPS coords to decimal degrees.""" | |
| try: | |
| d, m, s = value | |
| return d + (m / 60.0) + (s / 3600.0) | |
| except (TypeError, ValueError): | |
| return value | |
| def extract_gps_info(gps_info): | |
| """Extract and format GPS metadata from EXIF.""" | |
| if not gps_info or not isinstance(gps_info, dict): | |
| return None | |
| gps_data = {} | |
| for key, val in gps_info.items(): | |
| tag_name = ExifTags.GPSTAGS.get(key, key) | |
| gps_data[tag_name] = val | |
| if 'GPSLatitude' in gps_data and 'GPSLongitude' in gps_data: | |
| lat = convert_to_degrees(gps_data['GPSLatitude']) | |
| lon = convert_to_degrees(gps_data['GPSLongitude']) | |
| if gps_data.get('GPSLatitudeRef') == 'S': | |
| lat = -lat | |
| if gps_data.get('GPSLongitudeRef') == 'W': | |
| lon = -lon | |
| gps_data['Latitude'] = lat | |
| gps_data['Longitude'] = lon | |
| return gps_data | |
| def make_serializable(value): | |
| """Convert objects to JSON-serializable.""" | |
| if hasattr(value, 'numerator') and hasattr(value, 'denominator'): | |
| try: | |
| return float(value.numerator) / float(value.denominator) | |
| except: | |
| return str(value) | |
| elif isinstance(value, tuple) and len(value) == 2: | |
| try: | |
| return float(value[0]) / float(value[1]) | |
| except: | |
| return str(value) | |
| elif isinstance(value, (list, tuple)): | |
| return [make_serializable(v) for v in value] | |
| elif isinstance(value, dict): | |
| return {k: make_serializable(v) for k, v in value.items()} | |
| elif isinstance(value, bytes): | |
| try: | |
| return value.decode('utf-8') | |
| except UnicodeDecodeError: | |
| return str(value) | |
| # final fallback | |
| try: | |
| json.dumps(value) | |
| return value | |
| except: | |
| return str(value) | |
| def extract_metadata(image_path_or_obj, original_filename=None): | |
| """ | |
| Extract EXIF & metadata from a file or PIL Image. | |
| """ | |
| try: | |
| if isinstance(image_path_or_obj, Image.Image): | |
| image = image_path_or_obj | |
| file_name = original_filename or "unknown.jpg" | |
| file_size = None | |
| file_extension = os.path.splitext(file_name)[1].lower() | |
| else: | |
| image_path = Path(image_path_or_obj) | |
| image = Image.open(image_path) | |
| file_name = str(image_path.name) | |
| file_size = image_path.stat().st_size | |
| file_extension = image_path.suffix.lower() | |
| metadata = { | |
| "file_name": file_name, | |
| "format": image.format, | |
| "size": list(image.size), | |
| "mode": image.mode, | |
| "extraction_timestamp": datetime.now().isoformat(), | |
| "file_extension": file_extension | |
| } | |
| if file_size: | |
| metadata["file_size"] = file_size | |
| try: | |
| exif_data = image._getexif() | |
| except Exception as e: | |
| metadata["exif_error"] = str(e) | |
| exif_data = None | |
| if exif_data: | |
| for tag_id, value in exif_data.items(): | |
| try: | |
| tag_name = ExifTags.TAGS.get(tag_id, f"tag_{tag_id}") | |
| if tag_name == "GPSInfo": | |
| gps_info = extract_gps_info(value) | |
| if gps_info: | |
| metadata["gps_info"] = make_serializable(gps_info) | |
| else: | |
| metadata[tag_name.lower()] = make_serializable(value) | |
| except Exception as e: | |
| metadata[f"error_tag_{tag_id}"] = str(e) | |
| else: | |
| metadata["exif"] = "No EXIF data available" | |
| # Validate serializability | |
| try: | |
| json.dumps(metadata) | |
| except: | |
| # fallback | |
| basic_metadata = { | |
| "file_name": metadata.get("file_name", "unknown"), | |
| "format": metadata.get("format", None), | |
| "size": metadata.get("size", None), | |
| "mode": metadata.get("mode", None), | |
| "file_extension": metadata.get("file_extension", None), | |
| } | |
| basic_metadata["serialization_error"] = "Some metadata were removed." | |
| return basic_metadata | |
| return metadata | |
| except Exception as e: | |
| return { | |
| "file_name": str(original_filename or "unknown"), | |
| "error": str(e), | |
| "extraction_timestamp": datetime.now().isoformat() | |
| } | |
| # ====================== Save/Load JSONL ====================== | |
| def save_metadata_to_jsonl(metadata_list, append=True): | |
| mode = 'a' if append and METADATA_FILE.exists() else 'w' | |
| success_count = 0 | |
| with open(METADATA_FILE, mode) as f: | |
| for entry in metadata_list: | |
| try: | |
| json_str = json.dumps(entry) | |
| f.write(json_str + '\n') | |
| success_count += 1 | |
| except Exception as e: | |
| print(f"Failed to serialize entry: {e}") | |
| simplified = { | |
| "file_name": entry.get("file_name", "unknown"), | |
| "error": "Serialization failed" | |
| } | |
| f.write(json.dumps(simplified) + '\n') | |
| return success_count, len(metadata_list) | |
| def read_metadata_jsonl(): | |
| if not METADATA_FILE.exists(): | |
| return [] | |
| metadata_list = [] | |
| with open(METADATA_FILE, 'r') as f: | |
| for line in f: | |
| try: | |
| metadata_list.append(json.loads(line)) | |
| except json.JSONDecodeError: | |
| continue | |
| return metadata_list | |
| # ====================== Pushing to HuggingFace Hub ====================== | |
| def push_to_hub(metadata_list=None, create_if_not_exists=True): | |
| api = HfApi(token=HF_TOKEN) | |
| try: | |
| if metadata_list is None: | |
| metadata_list = read_metadata_jsonl() | |
| if not metadata_list: | |
| return "No metadata to push", "warning" | |
| repo_exists = True | |
| try: | |
| api.repo_info(repo_id=DATASET_REPO, repo_type="dataset") | |
| except Exception: | |
| repo_exists = False | |
| if create_if_not_exists: | |
| create_repo(repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, private=False) | |
| else: | |
| return f"Dataset repo {DATASET_REPO} doesn't exist.", "error" | |
| existing_metadata = [] | |
| if repo_exists: | |
| try: | |
| existing_file = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename="metadata.jsonl", | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| with open(existing_file, 'r') as f: | |
| for line in f: | |
| try: | |
| existing_metadata.append(json.loads(line)) | |
| except: | |
| pass | |
| except Exception as e: | |
| print(f"No existing metadata found or error reading: {e}") | |
| if existing_metadata: | |
| existing_filenames = {item.get("file_name") for item in existing_metadata} | |
| unique_new = [item for item in metadata_list | |
| if item.get("file_name") not in existing_filenames] | |
| combined_metadata = existing_metadata + unique_new | |
| else: | |
| combined_metadata = metadata_list | |
| temp_file = Path(tempfile.mktemp(suffix=".jsonl")) | |
| with open(temp_file, 'w') as f: | |
| for entry in combined_metadata: | |
| f.write(json.dumps(entry) + '\n') | |
| api.upload_file( | |
| path_or_fileobj=str(temp_file), | |
| path_in_repo="metadata.jsonl", | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| readme_path = LOCAL_STORAGE_PATH / "README.md" | |
| if not readme_path.exists(): | |
| with open(readme_path, 'w') as f: | |
| f.write( | |
| f"# EXIF Metadata Dataset\n\n" | |
| f"This dataset contains EXIF metadata.\n\n" | |
| f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" | |
| f"Total entries: {len(combined_metadata)}" | |
| ) | |
| try: | |
| with open(readme_path, 'r') as f: | |
| readme_content = f.read() | |
| updated_readme = ( | |
| f"# EXIF Metadata Dataset\n\n" | |
| f"This dataset contains EXIF metadata.\n\n" | |
| f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" | |
| f"Total entries: {len(combined_metadata)}" | |
| ) | |
| with open(readme_path, 'w') as f: | |
| f.write(updated_readme) | |
| api.upload_file( | |
| path_or_fileobj=str(readme_path), | |
| path_in_repo="README.md", | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| except Exception as e: | |
| print(f"Error updating README: {e}") | |
| return f"Successfully pushed {len(metadata_list)} entries to {DATASET_REPO}", "success" | |
| except Exception as e: | |
| return f"Error pushing to Hub: {e}", "error" | |
| # ====================== Background Processing Threads ====================== | |
| def process_worker(): | |
| while True: | |
| try: | |
| task = process_queue.get() | |
| if task is None: | |
| break | |
| file_path, original_filename = task | |
| metadata = extract_metadata(file_path, original_filename) | |
| success, total = save_metadata_to_jsonl([metadata]) | |
| if success: | |
| upload_queue.put(metadata) | |
| process_queue.task_done() | |
| except Exception as e: | |
| print(f"Error in process worker: {e}") | |
| process_queue.task_done() | |
| def upload_worker(): | |
| batch = [] | |
| last_upload_time = time.time() | |
| while True: | |
| try: | |
| try: | |
| metadata = upload_queue.get(timeout=60) | |
| except queue.Empty: | |
| if batch and (time.time() - last_upload_time) > 300: | |
| push_to_hub(batch) | |
| batch = [] | |
| last_upload_time = time.time() | |
| continue | |
| if metadata is None: | |
| break | |
| batch.append(metadata) | |
| upload_queue.task_done() | |
| if len(batch) >= MAX_BATCH_SIZE: | |
| push_to_hub(batch) | |
| batch = [] | |
| last_upload_time = time.time() | |
| except Exception as e: | |
| print(f"Error in upload worker: {e}") | |
| if metadata: | |
| upload_queue.task_done() | |
| process_thread = threading.Thread(target=process_worker, daemon=True) | |
| process_thread.start() | |
| upload_thread = threading.Thread(target=upload_worker, daemon=True) | |
| upload_thread.start() | |
| # ====================== Gradio App ====================== | |
| def process_uploaded_files(files): | |
| if not files: | |
| return "No files uploaded", "warning" | |
| processed = 0 | |
| metadata_list = [] | |
| for file in files: | |
| try: | |
| # If using Gradio 3.x | |
| if hasattr(file, 'name'): | |
| file_path = Path(file.name) | |
| file_name = file_path.name | |
| else: | |
| # If using Gradio 4.x => (path, orig_name) | |
| file_path = Path(file) | |
| file_name = file_path.name | |
| if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS: | |
| continue | |
| metadata = extract_metadata(file_path, file_name) | |
| metadata_list.append(metadata) | |
| processed += 1 | |
| process_queue.put((file_path, file_name)) | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {e}") | |
| if metadata_list: | |
| success, total = save_metadata_to_jsonl(metadata_list) | |
| return (f"Processed {processed} files. " | |
| f"{success}/{total} metadata entries saved."), "success" | |
| else: | |
| return f"No valid image files among the {len(files)} uploaded.", "warning" | |
| def view_metadata(): | |
| metadata_list = read_metadata_jsonl() | |
| if not metadata_list: | |
| return "No metadata available", pd.DataFrame() | |
| display_data = [] | |
| for entry in metadata_list: | |
| row = { | |
| "filename": entry.get("file_name", "unknown"), | |
| "width": None, | |
| "height": None, | |
| "format": entry.get("format"), | |
| "has_gps": "Yes" if entry.get("gps_info") else "No" | |
| } | |
| size = entry.get("size") | |
| if isinstance(size, list) and len(size) == 2: | |
| row["width"], row["height"] = size | |
| if entry.get("gps_info"): | |
| gps = entry["gps_info"] | |
| row["latitude"] = gps.get("Latitude") | |
| row["longitude"] = gps.get("Longitude") | |
| display_data.append(row) | |
| df = pd.DataFrame(display_data) | |
| return f"Found {len(metadata_list)} entries", df | |
| def manual_push_to_hub(): | |
| return push_to_hub() | |
| with gr.Blocks(title="EXIF Extraction Pipeline") as app: | |
| gr.Markdown(f""" | |
| # EXIF Metadata Extraction Pipeline | |
| **Local storage**: `./data` | |
| **Images directory**: `./images` | |
| **Checkpoints**: `./checkpoints` | |
| **Supported formats**: {", ".join(SUPPORTED_EXTENSIONS)} | |
| Upload images to extract EXIF metadata (including GPS) and push to HuggingFace Hub. | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("Upload Images"): | |
| file_input = gr.File(file_count="multiple", label="Upload Images") | |
| submit_btn = gr.Button("Process Images") | |
| output_status = gr.Textbox(label="Status") | |
| submit_btn.click(fn=process_uploaded_files, inputs=[file_input], outputs=[output_status]) | |
| with gr.TabItem("View Metadata"): | |
| refresh_btn = gr.Button("Refresh Metadata") | |
| view_status = gr.Textbox(label="Status") | |
| results_df = gr.DataFrame(label="Metadata Overview") | |
| refresh_btn.click(fn=view_metadata, inputs=[], outputs=[view_status, results_df]) | |
| app.load(fn=view_metadata, inputs=[], outputs=[view_status, results_df]) | |
| with gr.TabItem("Hub Management"): | |
| push_btn = gr.Button("Push to HuggingFace Hub") | |
| push_status = gr.Textbox(label="Status") | |
| push_btn.click(fn=manual_push_to_hub, inputs=[], outputs=[push_status]) | |
| # ====================== PyTorch: Using GPS Data ====================== | |
| def load_exif_gps_metadata(metadata_file=METADATA_FILE): | |
| gps_map = {} | |
| if not os.path.exists(metadata_file): | |
| return gps_map | |
| with open(metadata_file, "r") as f: | |
| for line in f: | |
| try: | |
| entry = json.loads(line) | |
| gps_info = entry.get("gps_info") | |
| if gps_info and "Latitude" in gps_info and "Longitude" in gps_info: | |
| lat = gps_info["Latitude"] | |
| lon = gps_info["Longitude"] | |
| gps_map[entry["file_name"]] = (lat, lon) | |
| except: | |
| pass | |
| return gps_map | |
| class GPSImageDataset(Dataset): | |
| def __init__(self, images_dir, gps_map, transform=None): | |
| self.images_dir = Path(images_dir) | |
| self.transform = transform | |
| self.gps_map = gps_map | |
| # Filter to only files that have GPS data | |
| self.file_names = [] | |
| for fn in os.listdir(self.images_dir): | |
| if fn in gps_map: # ensure we have matching metadata | |
| self.file_names.append(fn) | |
| def __len__(self): | |
| return len(self.file_names) | |
| def __getitem__(self, idx): | |
| file_name = self.file_names[idx] | |
| img_path = self.images_dir / file_name | |
| image = Image.open(img_path).convert("RGB") | |
| if self.transform: | |
| image = self.transform(image) | |
| lat, lon = self.gps_map[file_name] | |
| gps_tensor = torch.tensor([lat, lon], dtype=torch.float) | |
| return image, gps_tensor | |
| def train_one_epoch( | |
| train_dataloader, model, optimizer, epoch, batch_size, device, | |
| scheduler=None, criterion=nn.CrossEntropyLoss() | |
| ): | |
| print(f"\nStarting Epoch {epoch} ...") | |
| bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) | |
| # Create some placeholder targets (for demonstration only). | |
| targets_img_gps = torch.arange(0, batch_size).long().to(device) | |
| for i, (imgs, gps) in bar: | |
| imgs, gps = imgs.to(device), gps.to(device) | |
| gps_queue = model.get_gps_queue() # Hypothetical in your model | |
| optimizer.zero_grad() | |
| gps_all = torch.cat([gps, gps_queue], dim=0) | |
| model.dequeue_and_enqueue(gps) | |
| logits_img_gps = model(imgs, gps_all) | |
| loss = criterion(logits_img_gps, targets_img_gps) | |
| loss.backward() | |
| optimizer.step() | |
| bar.set_description(f"Epoch {epoch} loss: {loss.item():.5f}") | |
| if scheduler: | |
| scheduler.step() | |
| # ====================== Checkpoint Helpers ====================== | |
| def save_checkpoint(model, optimizer, epoch, path=CHECKPOINT_PATH): | |
| """ | |
| Saves model + optimizer state_dict along with current epoch | |
| to `path`. | |
| """ | |
| ckpt = { | |
| "epoch": epoch, | |
| "model_state": model.state_dict(), | |
| "optimizer_state": optimizer.state_dict(), | |
| } | |
| torch.save(ckpt, path) | |
| print(f"[Checkpoint] Saved at epoch={epoch} -> {path}") | |
| def load_checkpoint(model, optimizer, path=CHECKPOINT_PATH, device="cpu"): | |
| """ | |
| Loads checkpoint into model + optimizer, returns the last epoch. | |
| """ | |
| if not os.path.exists(path): | |
| print(f"No checkpoint found at {path}. Starting fresh.") | |
| return 0 | |
| ckpt = torch.load(path, map_location=device) | |
| model.load_state_dict(ckpt["model_state"]) | |
| optimizer.load_state_dict(ckpt["optimizer_state"]) | |
| print(f"[Checkpoint] Loaded from {path} (epoch={ckpt['epoch']})") | |
| return ckpt["epoch"] | |
| # ====================== Continuous Trainer ====================== | |
| def continuous_train( | |
| train_dataloader, | |
| model, | |
| optimizer, | |
| device, | |
| start_epoch=1, | |
| max_epochs=5, | |
| scheduler=None | |
| ): | |
| """ | |
| Loads checkpoint if available, then trains up to `max_epochs`. | |
| Saves new checkpoint at the end of each epoch. | |
| """ | |
| # Attempt to load from existing checkpoint | |
| loaded_epoch = load_checkpoint(model, optimizer, path=CHECKPOINT_PATH, device=device) | |
| # If loaded_epoch=3 and user says max_epochs=5, we continue from epoch 4, 5 | |
| current_epoch = loaded_epoch + 1 | |
| final_epoch = max(loaded_epoch + 1, max_epochs) # ensure we do something | |
| # Example: train from current_epoch -> max_epochs | |
| while current_epoch <= max_epochs: | |
| train_one_epoch( | |
| train_dataloader=train_dataloader, | |
| model=model, | |
| optimizer=optimizer, | |
| epoch=current_epoch, | |
| batch_size=train_dataloader.batch_size, | |
| device=device, | |
| scheduler=scheduler | |
| ) | |
| # Save checkpoint each epoch | |
| save_checkpoint(model, optimizer, current_epoch, CHECKPOINT_PATH) | |
| current_epoch += 1 | |
| class ExampleGPSModel(nn.Module): | |
| def __init__(self, gps_queue_len=10): | |
| super().__init__() | |
| self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1) | |
| self.flatten = nn.Flatten() | |
| self.fc_img = nn.Linear(16 * 224 * 224, 32) | |
| self.fc_gps = nn.Linear(2, 32) | |
| self.fc_out = nn.Linear(64, 10) | |
| self.gps_queue_len = gps_queue_len | |
| self._gps_queue = torch.zeros((gps_queue_len, 2), dtype=torch.float) | |
| def forward(self, imgs, gps_all): | |
| x = self.conv(imgs) | |
| x = F.relu(x) | |
| x = self.flatten(x) | |
| x = self.fc_img(x) | |
| g = self.fc_gps(gps_all) | |
| # Average all GPS embeddings | |
| if g.dim() == 2: | |
| g = g.mean(dim=0, keepdim=True) | |
| combined = torch.cat([x, g.repeat(x.size(0), 1)], dim=1) | |
| out = self.fc_out(combined) | |
| return out | |
| def get_gps_queue(self): | |
| return self._gps_queue | |
| def dequeue_and_enqueue(self, new_gps): | |
| B = new_gps.shape[0] | |
| self._gps_queue = torch.roll(self._gps_queue, shifts=-B, dims=0) | |
| self._gps_queue[-B:] = new_gps | |
| if __name__ == "__main__": | |
| # ========== Example usage: build dataset/dataloader ========== | |
| gps_map = load_exif_gps_metadata(METADATA_FILE) # from ./data/metadata.jsonl | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| train_dataset = GPSImageDataset(IMAGES_DIR, gps_map, transform=transform) | |
| train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True) | |
| # ========== Create model & optimizer ========== | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = ExampleGPSModel().to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
| # ========== Continuous training example (5 epochs) ========== | |
| continuous_train( | |
| train_dataloader=train_dataloader, | |
| model=model, | |
| optimizer=optimizer, | |
| device=device, | |
| start_epoch=1, # not used if there's a checkpoint | |
| max_epochs=5 | |
| ) | |
| print("Done training. Launching Gradio app...") | |
| # ========== Launch Gradio ========== | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |