import csv import os from io import BytesIO, StringIO from threading import Lock from typing import cast import numpy as np import torch from torch import Tensor from torch.nn import Parameter from torch.nn.functional import sigmoid import gradio as gr from PIL import Image, ImageDraw, ImageFont import requests from model import discover_extensions, load_model, process_image, patchify_image from image import unpatchify EXT_DIR = "extensions/jtp-3-hydra" PATCH_SIZE = 16 MAX_SEQ_LEN = 1024 device = "cuda" if torch.cuda.is_available() else "cpu" if hasattr(torch.backends, "fp32_precision"): torch.backends.fp32_precision = "tf32" else: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True model_lock = Lock() model, tag_list, ext_info = load_model( "models/jtp-3-hydra.safetensors", extensions=discover_extensions(EXT_DIR) if os.path.isdir(EXT_DIR) else (), device=device ) model.requires_grad_(False) def rewrite_tag(tag: str) -> str: return tag.replace("_", " ").replace("vulva", "pussy") tags = { rewrite_tag(tag): idx for idx, tag in enumerate(tag_list) } tag_list = list(tags.keys()) FONT = ImageFont.load_default(24) @torch.no_grad() def run_classifier(image: Image.Image, cam_depth: int): patches, patch_coords, patch_valid = patchify_image(image, PATCH_SIZE, MAX_SEQ_LEN) patches = patches.unsqueeze(0).to(device=device, non_blocking=True) patch_coords = patch_coords.unsqueeze(0).to(device=device, non_blocking=True) patch_valid = patch_valid.unsqueeze(0).to(device=device, non_blocking=True) patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0) patch_coords = patch_coords.to(dtype=torch.int32) with model_lock: features = cast(dict[str, Tensor], model.forward_intermediates( patches, patch_coord=patch_coords, patch_valid=patch_valid, indices=cam_depth, output_dict=True, output_fmt='NLC' )) logits = model.forward_head(features["image_features"], patch_valid=patch_valid) del features["image_features"] features["patch_coords"] = patch_coords features["patch_valid"] = patch_valid del patches, patch_coords, patch_valid probits = sigmoid(logits[0].to(dtype=torch.float32)) probits.mul_(2.0).sub_(1.0) # scale to -1 to 1 values, indices = probits.cpu().topk(250) predictions = { tag_list[idx.item()]: val.item() for idx, val in sorted( zip(indices, values), key=lambda item: item[1].item(), reverse=True ) } return features, predictions @torch.no_grad() def run_cam( display_image: Image.Image, image: Image.Image, features: dict[str, Tensor], tag_idx: int, cam_depth: int ): intermediates = features["image_intermediates"] if len(intermediates) < cam_depth: features, _ = run_classifier(image, cam_depth) intermediates = features["image_intermediates"] elif len(intermediates) > cam_depth: intermediates = intermediates[-cam_depth:] patch_coords = features["patch_coords"] patch_valid = features["patch_valid"] with model_lock: saved_q = model.attn_pool.q saved_p = model.attn_pool.out_proj.weight try: model.attn_pool.q = Parameter(saved_q[:, [tag_idx], :], requires_grad=False) model.attn_pool.out_proj.weight = Parameter(saved_p[[tag_idx], :, :], requires_grad=False) with torch.enable_grad(): for intermediate in intermediates: intermediate.requires_grad_(True).retain_grad() model.forward_head(intermediate, patch_valid=patch_valid)[0, 0].backward() finally: model.attn_pool.q = saved_q model.attn_pool.out_proj.weight = saved_p cam_1d: Tensor | None = None for intermediate in intermediates: patch_grad = (intermediate.grad.float() * intermediate.sign()).sum(dim=(0, 2)) intermediate.grad = None if cam_1d is None: cam_1d = patch_grad else: cam_1d.add_(patch_grad) assert cam_1d is not None cam_2d = unpatchify(cam_1d, patch_coords, patch_valid).cpu().numpy() return cam_composite(display_image, cam_2d), features def cam_composite(image: Image.Image, cam: np.ndarray): """ Overlays CAM on image and returns a PIL image. Args: image_pil: PIL Image (RGB) cam: 2D numpy array (activation map) Returns: PIL.Image.Image with overlay """ cam_abs = np.abs(cam) cam_scale = cam_abs.max() cam_rgba = np.dstack(( (cam < 0).astype(np.float32), (cam > 0).astype(np.float32), np.zeros_like(cam, dtype=np.float32), cam_abs * (0.5 / cam_scale), )) # Shape: (H, W, 4) cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8)) cam_pil = cam_pil.resize(image.size, resample=Image.Resampling.NEAREST) image = Image.blend( image.convert('RGBA'), image.convert('L').convert('RGBA'), 0.33 ) image = Image.alpha_composite(image, cam_pil) draw = ImageDraw.Draw(image) draw.text( (image.width - 7, image.height - 7), f"{cam_scale.item():.4g}", anchor="rd", font=FONT, fill=(32, 32, 255, 255) ) return image def filter_tags(predictions: dict[str, float], threshold: float, calibration: dict[str, float] | None): if calibration is None: predictions = { key: value for key, value in predictions.items() if value >= threshold } else: predictions = { key: value for key, value in predictions.items() if value >= calibration.get(key, float("inf")) } tag_str = ", ".join(predictions.keys()) return tag_str, predictions def resize_image(image: Image.Image) -> Image.Image: longest_side = max(image.height, image.width) if longest_side < 1080: return image scale = 1080 / longest_side return image.resize( ( int(round(image.width * scale)), int(round(image.height * scale)), ), resample=Image.Resampling.LANCZOS, reducing_gap=3.0 ) def image_upload(image: Image.Image): display_image = resize_image(image) processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN) if display_image is not image and processed_image is not image: image.close() return ( "", {}, "None", "", gr.skip() if display_image is image else display_image, display_image, processed_image, ) def url_submit(url: str): resp = requests.get(url, timeout=10) resp.raise_for_status() image = Image.open(BytesIO(resp.content)) display_image = resize_image(image) processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN) if display_image is not image and processed_image is not image: image.close() return ( "", {}, "None", display_image, display_image, processed_image, ) def image_changed(image: Image.Image, threshold: float, calibration: dict[str, float] | None, cam_depth: int): features, predictions = run_classifier(image, cam_depth) return *filter_tags(predictions, threshold, calibration), features, predictions def image_clear(): return ( "", {}, "None", "", None, None, None, None, {}, ) def threshold_input(predictions: dict[str, float], threshold: float): return ( *filter_tags(predictions, threshold, None), None, gr.Slider(label="Tag Threshold", elem_classes=[]), gr.Textbox(label="Upload Calibration") ) def parse_calibration(data) -> dict[str, float]: return { rewrite_tag(row["tag"]): float(row["threshold"]) for row in csv.DictReader(data) } def calibration_load(predictions: dict[str, float]): try: with open("calibration.csv", "r", encoding="utf-8", newline="") as csv: calibration = parse_calibration(csv) except Exception: return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File") return ( *filter_tags(predictions, 0.0, calibration), calibration, gr.Slider(label="Using Default Calibration", elem_classes=["inactive-slider"]), gr.Textbox(label="Change Calibration") ) def calibration_changed(predictions: dict[str, float], calibration_data: bytes): try: calibration = parse_calibration(StringIO(calibration_data.decode("utf-8"))) except Exception: return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File") return ( *filter_tags(predictions, 0.0, calibration), calibration, gr.Slider(label="Using Uploaded Calibration", elem_classes=["inactive-slider"]), gr.Textbox(label="Change Calibration") ) def cam_changed( display_image: Image.Image, image: Image.Image, features: dict[str, Tensor], tag: str, cam_depth: int ): if tag == "None": return display_image, features return run_cam(display_image, image, features, tags[tag], cam_depth) def tag_box_select(evt: gr.SelectData): return evt.value custom_css = """ .output-class { display: none; } .inferno-slider input[type=range] { background: linear-gradient(to right, #000004, #1b0c41, #4a0c6b, #781c6d, #a52c60, #cf4446, #ed6925, #fb9b06, #f7d13d, #fcffa4 ) !important; background-size: 100% 100% !important; } .inactive-slider input[type=range] { --slider-color: grey !important; } #image_container-image { width: 100%; aspect-ratio: 1 / 1; max-height: 100%; } #image_container img { object-fit: contain !important; } .show-api, .show-api-divider { display: none !important; } """ with gr.Blocks( title="RedRocket JTP-3 Hydra", css=custom_css, analytics_enabled=False, ) as demo: display_image_state = gr.State() image_state = gr.State() features_state = gr.State() predictions_state = gr.State(value={}) calibration_state = gr.State() gr.HTML( "
"
""
""
"RedRocket – JTP-3 Hydra"
""
"