import os import io import json import time import tempfile from pathlib import Path from typing import Dict, List, Optional, Union, Any from datetime import datetime import threading import queue # ====================== Additional Imports ====================== import torch from torch import nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image, ExifTags from tqdm import tqdm import gradio as gr import pandas as pd # Hugging Face Hub from huggingface_hub import ( hf_hub_download, login, whoami, create_repo, HfApi, InferenceClient, ) # ====================== Configuration & Paths ====================== HF_USERNAME = os.environ.get("HF_USERNAME", "latterworks") HF_TOKEN = os.environ.get("HF_TOKEN", None) # If not provided, use default Spaces token DATASET_NAME = os.environ.get("DATASET_NAME", "geo-metadata") DATASET_REPO = f"{HF_USERNAME}/{DATASET_NAME}" # Relative local paths LOCAL_STORAGE_PATH = Path("./data") LOCAL_STORAGE_PATH.mkdir(exist_ok=True, parents=True) METADATA_FILE = LOCAL_STORAGE_PATH / "metadata.jsonl" IMAGES_DIR = Path("./images") # place your images here IMAGES_DIR.mkdir(exist_ok=True, parents=True) # We’ll store checkpoints here: CHECKPOINTS_DIR = Path("./checkpoints") CHECKPOINTS_DIR.mkdir(exist_ok=True, parents=True) CHECKPOINT_PATH = CHECKPOINTS_DIR / "last_checkpoint.pth" MAX_BATCH_SIZE = 25 SUPPORTED_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.heic', '.tiff', '.tif', '.bmp', '.webp'] # ====================== Queues and Threads ====================== process_queue = queue.Queue() upload_queue = queue.Queue() # ====================== EXIF Extraction Core ====================== def convert_to_degrees(value): """Convert GPS coords to decimal degrees.""" try: d, m, s = value return d + (m / 60.0) + (s / 3600.0) except (TypeError, ValueError): return value def extract_gps_info(gps_info): """Extract and format GPS metadata from EXIF.""" if not gps_info or not isinstance(gps_info, dict): return None gps_data = {} for key, val in gps_info.items(): tag_name = ExifTags.GPSTAGS.get(key, key) gps_data[tag_name] = val if 'GPSLatitude' in gps_data and 'GPSLongitude' in gps_data: lat = convert_to_degrees(gps_data['GPSLatitude']) lon = convert_to_degrees(gps_data['GPSLongitude']) if gps_data.get('GPSLatitudeRef') == 'S': lat = -lat if gps_data.get('GPSLongitudeRef') == 'W': lon = -lon gps_data['Latitude'] = lat gps_data['Longitude'] = lon return gps_data def make_serializable(value): """Convert objects to JSON-serializable.""" if hasattr(value, 'numerator') and hasattr(value, 'denominator'): try: return float(value.numerator) / float(value.denominator) except: return str(value) elif isinstance(value, tuple) and len(value) == 2: try: return float(value[0]) / float(value[1]) except: return str(value) elif isinstance(value, (list, tuple)): return [make_serializable(v) for v in value] elif isinstance(value, dict): return {k: make_serializable(v) for k, v in value.items()} elif isinstance(value, bytes): try: return value.decode('utf-8') except UnicodeDecodeError: return str(value) # final fallback try: json.dumps(value) return value except: return str(value) def extract_metadata(image_path_or_obj, original_filename=None): """ Extract EXIF & metadata from a file or PIL Image. """ try: if isinstance(image_path_or_obj, Image.Image): image = image_path_or_obj file_name = original_filename or "unknown.jpg" file_size = None file_extension = os.path.splitext(file_name)[1].lower() else: image_path = Path(image_path_or_obj) image = Image.open(image_path) file_name = str(image_path.name) file_size = image_path.stat().st_size file_extension = image_path.suffix.lower() metadata = { "file_name": file_name, "format": image.format, "size": list(image.size), "mode": image.mode, "extraction_timestamp": datetime.now().isoformat(), "file_extension": file_extension } if file_size: metadata["file_size"] = file_size try: exif_data = image._getexif() except Exception as e: metadata["exif_error"] = str(e) exif_data = None if exif_data: for tag_id, value in exif_data.items(): try: tag_name = ExifTags.TAGS.get(tag_id, f"tag_{tag_id}") if tag_name == "GPSInfo": gps_info = extract_gps_info(value) if gps_info: metadata["gps_info"] = make_serializable(gps_info) else: metadata[tag_name.lower()] = make_serializable(value) except Exception as e: metadata[f"error_tag_{tag_id}"] = str(e) else: metadata["exif"] = "No EXIF data available" # Validate serializability try: json.dumps(metadata) except: # fallback basic_metadata = { "file_name": metadata.get("file_name", "unknown"), "format": metadata.get("format", None), "size": metadata.get("size", None), "mode": metadata.get("mode", None), "file_extension": metadata.get("file_extension", None), } basic_metadata["serialization_error"] = "Some metadata were removed." return basic_metadata return metadata except Exception as e: return { "file_name": str(original_filename or "unknown"), "error": str(e), "extraction_timestamp": datetime.now().isoformat() } # ====================== Save/Load JSONL ====================== def save_metadata_to_jsonl(metadata_list, append=True): mode = 'a' if append and METADATA_FILE.exists() else 'w' success_count = 0 with open(METADATA_FILE, mode) as f: for entry in metadata_list: try: json_str = json.dumps(entry) f.write(json_str + '\n') success_count += 1 except Exception as e: print(f"Failed to serialize entry: {e}") simplified = { "file_name": entry.get("file_name", "unknown"), "error": "Serialization failed" } f.write(json.dumps(simplified) + '\n') return success_count, len(metadata_list) def read_metadata_jsonl(): if not METADATA_FILE.exists(): return [] metadata_list = [] with open(METADATA_FILE, 'r') as f: for line in f: try: metadata_list.append(json.loads(line)) except json.JSONDecodeError: continue return metadata_list # ====================== Pushing to HuggingFace Hub ====================== def push_to_hub(metadata_list=None, create_if_not_exists=True): api = HfApi(token=HF_TOKEN) try: if metadata_list is None: metadata_list = read_metadata_jsonl() if not metadata_list: return "No metadata to push", "warning" repo_exists = True try: api.repo_info(repo_id=DATASET_REPO, repo_type="dataset") except Exception: repo_exists = False if create_if_not_exists: create_repo(repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, private=False) else: return f"Dataset repo {DATASET_REPO} doesn't exist.", "error" existing_metadata = [] if repo_exists: try: existing_file = hf_hub_download( repo_id=DATASET_REPO, filename="metadata.jsonl", repo_type="dataset", token=HF_TOKEN ) with open(existing_file, 'r') as f: for line in f: try: existing_metadata.append(json.loads(line)) except: pass except Exception as e: print(f"No existing metadata found or error reading: {e}") if existing_metadata: existing_filenames = {item.get("file_name") for item in existing_metadata} unique_new = [item for item in metadata_list if item.get("file_name") not in existing_filenames] combined_metadata = existing_metadata + unique_new else: combined_metadata = metadata_list temp_file = Path(tempfile.mktemp(suffix=".jsonl")) with open(temp_file, 'w') as f: for entry in combined_metadata: f.write(json.dumps(entry) + '\n') api.upload_file( path_or_fileobj=str(temp_file), path_in_repo="metadata.jsonl", repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) readme_path = LOCAL_STORAGE_PATH / "README.md" if not readme_path.exists(): with open(readme_path, 'w') as f: f.write( f"# EXIF Metadata Dataset\n\n" f"This dataset contains EXIF metadata.\n\n" f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" f"Total entries: {len(combined_metadata)}" ) try: with open(readme_path, 'r') as f: readme_content = f.read() updated_readme = ( f"# EXIF Metadata Dataset\n\n" f"This dataset contains EXIF metadata.\n\n" f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" f"Total entries: {len(combined_metadata)}" ) with open(readme_path, 'w') as f: f.write(updated_readme) api.upload_file( path_or_fileobj=str(readme_path), path_in_repo="README.md", repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) except Exception as e: print(f"Error updating README: {e}") return f"Successfully pushed {len(metadata_list)} entries to {DATASET_REPO}", "success" except Exception as e: return f"Error pushing to Hub: {e}", "error" # ====================== Background Processing Threads ====================== def process_worker(): while True: try: task = process_queue.get() if task is None: break file_path, original_filename = task metadata = extract_metadata(file_path, original_filename) success, total = save_metadata_to_jsonl([metadata]) if success: upload_queue.put(metadata) process_queue.task_done() except Exception as e: print(f"Error in process worker: {e}") process_queue.task_done() def upload_worker(): batch = [] last_upload_time = time.time() while True: try: try: metadata = upload_queue.get(timeout=60) except queue.Empty: if batch and (time.time() - last_upload_time) > 300: push_to_hub(batch) batch = [] last_upload_time = time.time() continue if metadata is None: break batch.append(metadata) upload_queue.task_done() if len(batch) >= MAX_BATCH_SIZE: push_to_hub(batch) batch = [] last_upload_time = time.time() except Exception as e: print(f"Error in upload worker: {e}") if metadata: upload_queue.task_done() process_thread = threading.Thread(target=process_worker, daemon=True) process_thread.start() upload_thread = threading.Thread(target=upload_worker, daemon=True) upload_thread.start() # ====================== Gradio App ====================== def process_uploaded_files(files): if not files: return "No files uploaded", "warning" processed = 0 metadata_list = [] for file in files: try: # If using Gradio 3.x if hasattr(file, 'name'): file_path = Path(file.name) file_name = file_path.name else: # If using Gradio 4.x => (path, orig_name) file_path = Path(file) file_name = file_path.name if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS: continue metadata = extract_metadata(file_path, file_name) metadata_list.append(metadata) processed += 1 process_queue.put((file_path, file_name)) except Exception as e: print(f"Error processing {file_path}: {e}") if metadata_list: success, total = save_metadata_to_jsonl(metadata_list) return (f"Processed {processed} files. " f"{success}/{total} metadata entries saved."), "success" else: return f"No valid image files among the {len(files)} uploaded.", "warning" def view_metadata(): metadata_list = read_metadata_jsonl() if not metadata_list: return "No metadata available", pd.DataFrame() display_data = [] for entry in metadata_list: row = { "filename": entry.get("file_name", "unknown"), "width": None, "height": None, "format": entry.get("format"), "has_gps": "Yes" if entry.get("gps_info") else "No" } size = entry.get("size") if isinstance(size, list) and len(size) == 2: row["width"], row["height"] = size if entry.get("gps_info"): gps = entry["gps_info"] row["latitude"] = gps.get("Latitude") row["longitude"] = gps.get("Longitude") display_data.append(row) df = pd.DataFrame(display_data) return f"Found {len(metadata_list)} entries", df def manual_push_to_hub(): return push_to_hub() with gr.Blocks(title="EXIF Extraction Pipeline") as app: gr.Markdown(f""" # EXIF Metadata Extraction Pipeline **Local storage**: `./data` **Images directory**: `./images` **Checkpoints**: `./checkpoints` **Supported formats**: {", ".join(SUPPORTED_EXTENSIONS)} Upload images to extract EXIF metadata (including GPS) and push to HuggingFace Hub. """) with gr.Tabs(): with gr.TabItem("Upload Images"): file_input = gr.File(file_count="multiple", label="Upload Images") submit_btn = gr.Button("Process Images") output_status = gr.Textbox(label="Status") submit_btn.click(fn=process_uploaded_files, inputs=[file_input], outputs=[output_status]) with gr.TabItem("View Metadata"): refresh_btn = gr.Button("Refresh Metadata") view_status = gr.Textbox(label="Status") results_df = gr.DataFrame(label="Metadata Overview") refresh_btn.click(fn=view_metadata, inputs=[], outputs=[view_status, results_df]) app.load(fn=view_metadata, inputs=[], outputs=[view_status, results_df]) with gr.TabItem("Hub Management"): push_btn = gr.Button("Push to HuggingFace Hub") push_status = gr.Textbox(label="Status") push_btn.click(fn=manual_push_to_hub, inputs=[], outputs=[push_status]) # ====================== PyTorch: Using GPS Data ====================== def load_exif_gps_metadata(metadata_file=METADATA_FILE): gps_map = {} if not os.path.exists(metadata_file): return gps_map with open(metadata_file, "r") as f: for line in f: try: entry = json.loads(line) gps_info = entry.get("gps_info") if gps_info and "Latitude" in gps_info and "Longitude" in gps_info: lat = gps_info["Latitude"] lon = gps_info["Longitude"] gps_map[entry["file_name"]] = (lat, lon) except: pass return gps_map class GPSImageDataset(Dataset): def __init__(self, images_dir, gps_map, transform=None): self.images_dir = Path(images_dir) self.transform = transform self.gps_map = gps_map # Filter to only files that have GPS data self.file_names = [] for fn in os.listdir(self.images_dir): if fn in gps_map: # ensure we have matching metadata self.file_names.append(fn) def __len__(self): return len(self.file_names) def __getitem__(self, idx): file_name = self.file_names[idx] img_path = self.images_dir / file_name image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) lat, lon = self.gps_map[file_name] gps_tensor = torch.tensor([lat, lon], dtype=torch.float) return image, gps_tensor def train_one_epoch( train_dataloader, model, optimizer, epoch, batch_size, device, scheduler=None, criterion=nn.CrossEntropyLoss() ): print(f"\nStarting Epoch {epoch} ...") bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) # Create some placeholder targets (for demonstration only). targets_img_gps = torch.arange(0, batch_size).long().to(device) for i, (imgs, gps) in bar: imgs, gps = imgs.to(device), gps.to(device) gps_queue = model.get_gps_queue() # Hypothetical in your model optimizer.zero_grad() gps_all = torch.cat([gps, gps_queue], dim=0) model.dequeue_and_enqueue(gps) logits_img_gps = model(imgs, gps_all) loss = criterion(logits_img_gps, targets_img_gps) loss.backward() optimizer.step() bar.set_description(f"Epoch {epoch} loss: {loss.item():.5f}") if scheduler: scheduler.step() # ====================== Checkpoint Helpers ====================== def save_checkpoint(model, optimizer, epoch, path=CHECKPOINT_PATH): """ Saves model + optimizer state_dict along with current epoch to `path`. """ ckpt = { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), } torch.save(ckpt, path) print(f"[Checkpoint] Saved at epoch={epoch} -> {path}") def load_checkpoint(model, optimizer, path=CHECKPOINT_PATH, device="cpu"): """ Loads checkpoint into model + optimizer, returns the last epoch. """ if not os.path.exists(path): print(f"No checkpoint found at {path}. Starting fresh.") return 0 ckpt = torch.load(path, map_location=device) model.load_state_dict(ckpt["model_state"]) optimizer.load_state_dict(ckpt["optimizer_state"]) print(f"[Checkpoint] Loaded from {path} (epoch={ckpt['epoch']})") return ckpt["epoch"] # ====================== Continuous Trainer ====================== def continuous_train( train_dataloader, model, optimizer, device, start_epoch=1, max_epochs=5, scheduler=None ): """ Loads checkpoint if available, then trains up to `max_epochs`. Saves new checkpoint at the end of each epoch. """ # Attempt to load from existing checkpoint loaded_epoch = load_checkpoint(model, optimizer, path=CHECKPOINT_PATH, device=device) # If loaded_epoch=3 and user says max_epochs=5, we continue from epoch 4, 5 current_epoch = loaded_epoch + 1 final_epoch = max(loaded_epoch + 1, max_epochs) # ensure we do something # Example: train from current_epoch -> max_epochs while current_epoch <= max_epochs: train_one_epoch( train_dataloader=train_dataloader, model=model, optimizer=optimizer, epoch=current_epoch, batch_size=train_dataloader.batch_size, device=device, scheduler=scheduler ) # Save checkpoint each epoch save_checkpoint(model, optimizer, current_epoch, CHECKPOINT_PATH) current_epoch += 1 class ExampleGPSModel(nn.Module): def __init__(self, gps_queue_len=10): super().__init__() self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1) self.flatten = nn.Flatten() self.fc_img = nn.Linear(16 * 224 * 224, 32) self.fc_gps = nn.Linear(2, 32) self.fc_out = nn.Linear(64, 10) self.gps_queue_len = gps_queue_len self._gps_queue = torch.zeros((gps_queue_len, 2), dtype=torch.float) def forward(self, imgs, gps_all): x = self.conv(imgs) x = F.relu(x) x = self.flatten(x) x = self.fc_img(x) g = self.fc_gps(gps_all) # Average all GPS embeddings if g.dim() == 2: g = g.mean(dim=0, keepdim=True) combined = torch.cat([x, g.repeat(x.size(0), 1)], dim=1) out = self.fc_out(combined) return out def get_gps_queue(self): return self._gps_queue def dequeue_and_enqueue(self, new_gps): B = new_gps.shape[0] self._gps_queue = torch.roll(self._gps_queue, shifts=-B, dims=0) self._gps_queue[-B:] = new_gps if __name__ == "__main__": # ========== Example usage: build dataset/dataloader ========== gps_map = load_exif_gps_metadata(METADATA_FILE) # from ./data/metadata.jsonl transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) train_dataset = GPSImageDataset(IMAGES_DIR, gps_map, transform=transform) train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True) # ========== Create model & optimizer ========== device = "cuda" if torch.cuda.is_available() else "cpu" model = ExampleGPSModel().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # ========== Continuous training example (5 epochs) ========== continuous_train( train_dataloader=train_dataloader, model=model, optimizer=optimizer, device=device, start_epoch=1, # not used if there's a checkpoint max_epochs=5 ) print("Done training. Launching Gradio app...") # ========== Launch Gradio ========== app.launch(server_name="0.0.0.0", server_port=7860)