Image Classification
English
furry
e621
Not-For-All-Audiences
JTP-3 / app.py
RedHotTensors's picture
Extension Training
dd7d624
raw
history blame
16.5 kB
import csv
import os
from io import BytesIO, StringIO
from threading import Lock
from typing import cast
import numpy as np
import torch
from torch import Tensor
from torch.nn import Parameter
from torch.nn.functional import sigmoid
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import requests
from model import discover_extensions, load_model, process_image, patchify_image
from image import unpatchify
EXT_DIR = "extensions/jtp-3-hydra"
PATCH_SIZE = 16
MAX_SEQ_LEN = 1024
device = "cuda" if torch.cuda.is_available() else "cpu"
if hasattr(torch.backends, "fp32_precision"):
torch.backends.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model_lock = Lock()
model, tag_list, ext_info = load_model(
"models/jtp-3-hydra.safetensors",
extensions=discover_extensions(EXT_DIR) if os.path.isdir(EXT_DIR) else (),
device=device
)
model.requires_grad_(False)
def rewrite_tag(tag: str) -> str:
return tag.replace("_", " ").replace("vulva", "pussy")
tags = {
rewrite_tag(tag): idx
for idx, tag in enumerate(tag_list)
}
tag_list = list(tags.keys())
FONT = ImageFont.load_default(24)
@torch.no_grad()
def run_classifier(image: Image.Image, cam_depth: int):
patches, patch_coords, patch_valid = patchify_image(image, PATCH_SIZE, MAX_SEQ_LEN)
patches = patches.unsqueeze(0).to(device=device, non_blocking=True)
patch_coords = patch_coords.unsqueeze(0).to(device=device, non_blocking=True)
patch_valid = patch_valid.unsqueeze(0).to(device=device, non_blocking=True)
patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
patch_coords = patch_coords.to(dtype=torch.int32)
with model_lock:
features = cast(dict[str, Tensor], model.forward_intermediates(
patches,
patch_coord=patch_coords,
patch_valid=patch_valid,
indices=cam_depth,
output_dict=True,
output_fmt='NLC'
))
logits = model.forward_head(features["image_features"], patch_valid=patch_valid)
del features["image_features"]
features["patch_coords"] = patch_coords
features["patch_valid"] = patch_valid
del patches, patch_coords, patch_valid
probits = sigmoid(logits[0].to(dtype=torch.float32))
probits.mul_(2.0).sub_(1.0) # scale to -1 to 1
values, indices = probits.cpu().topk(250)
predictions = {
tag_list[idx.item()]: val.item()
for idx, val in sorted(
zip(indices, values),
key=lambda item: item[1].item(),
reverse=True
)
}
return features, predictions
@torch.no_grad()
def run_cam(
display_image: Image.Image,
image: Image.Image, features: dict[str, Tensor],
tag_idx: int, cam_depth: int
):
intermediates = features["image_intermediates"]
if len(intermediates) < cam_depth:
features, _ = run_classifier(image, cam_depth)
intermediates = features["image_intermediates"]
elif len(intermediates) > cam_depth:
intermediates = intermediates[-cam_depth:]
patch_coords = features["patch_coords"]
patch_valid = features["patch_valid"]
with model_lock:
saved_q = model.attn_pool.q
saved_p = model.attn_pool.out_proj.weight
try:
model.attn_pool.q = Parameter(saved_q[:, [tag_idx], :], requires_grad=False)
model.attn_pool.out_proj.weight = Parameter(saved_p[[tag_idx], :, :], requires_grad=False)
with torch.enable_grad():
for intermediate in intermediates:
intermediate.requires_grad_(True).retain_grad()
model.forward_head(intermediate, patch_valid=patch_valid)[0, 0].backward()
finally:
model.attn_pool.q = saved_q
model.attn_pool.out_proj.weight = saved_p
cam_1d: Tensor | None = None
for intermediate in intermediates:
patch_grad = (intermediate.grad.float() * intermediate.sign()).sum(dim=(0, 2))
intermediate.grad = None
if cam_1d is None:
cam_1d = patch_grad
else:
cam_1d.add_(patch_grad)
assert cam_1d is not None
cam_2d = unpatchify(cam_1d, patch_coords, patch_valid).cpu().numpy()
return cam_composite(display_image, cam_2d), features
def cam_composite(image: Image.Image, cam: np.ndarray):
"""
Overlays CAM on image and returns a PIL image.
Args:
image_pil: PIL Image (RGB)
cam: 2D numpy array (activation map)
Returns:
PIL.Image.Image with overlay
"""
cam_abs = np.abs(cam)
cam_scale = cam_abs.max()
cam_rgba = np.dstack((
(cam < 0).astype(np.float32),
(cam > 0).astype(np.float32),
np.zeros_like(cam, dtype=np.float32),
cam_abs * (0.5 / cam_scale),
)) # Shape: (H, W, 4)
cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8))
cam_pil = cam_pil.resize(image.size, resample=Image.Resampling.NEAREST)
image = Image.blend(
image.convert('RGBA'),
image.convert('L').convert('RGBA'),
0.33
)
image = Image.alpha_composite(image, cam_pil)
draw = ImageDraw.Draw(image)
draw.text(
(image.width - 7, image.height - 7),
f"{cam_scale.item():.4g}",
anchor="rd", font=FONT, fill=(32, 32, 255, 255)
)
return image
def filter_tags(predictions: dict[str, float], threshold: float, calibration: dict[str, float] | None):
if calibration is None:
predictions = {
key: value
for key, value in predictions.items()
if value >= threshold
}
else:
predictions = {
key: value
for key, value in predictions.items()
if value >= calibration.get(key, float("inf"))
}
tag_str = ", ".join(predictions.keys())
return tag_str, predictions
def resize_image(image: Image.Image) -> Image.Image:
longest_side = max(image.height, image.width)
if longest_side < 1080:
return image
scale = 1080 / longest_side
return image.resize(
(
int(round(image.width * scale)),
int(round(image.height * scale)),
),
resample=Image.Resampling.LANCZOS,
reducing_gap=3.0
)
def image_upload(image: Image.Image):
display_image = resize_image(image)
processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
if display_image is not image and processed_image is not image:
image.close()
return (
"", {}, "None", "",
gr.skip() if display_image is image else display_image, display_image,
processed_image,
)
def url_submit(url: str):
resp = requests.get(url, timeout=10)
resp.raise_for_status()
image = Image.open(BytesIO(resp.content))
display_image = resize_image(image)
processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
if display_image is not image and processed_image is not image:
image.close()
return (
"", {}, "None",
display_image, display_image,
processed_image,
)
def image_changed(image: Image.Image, threshold: float, calibration: dict[str, float] | None, cam_depth: int):
features, predictions = run_classifier(image, cam_depth)
return *filter_tags(predictions, threshold, calibration), features, predictions
def image_clear():
return (
"", {}, "None", "",
None, None,
None, None, {},
)
def threshold_input(predictions: dict[str, float], threshold: float):
return (
*filter_tags(predictions, threshold, None), None,
gr.Slider(label="Tag Threshold", elem_classes=[]),
gr.Textbox(label="Upload Calibration")
)
def parse_calibration(data) -> dict[str, float]:
return {
rewrite_tag(row["tag"]): float(row["threshold"])
for row in csv.DictReader(data)
}
def calibration_load(predictions: dict[str, float]):
try:
with open("calibration.csv", "r", encoding="utf-8", newline="") as csv:
calibration = parse_calibration(csv)
except Exception:
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File")
return (
*filter_tags(predictions, 0.0, calibration), calibration,
gr.Slider(label="Using Default Calibration", elem_classes=["inactive-slider"]),
gr.Textbox(label="Change Calibration")
)
def calibration_changed(predictions: dict[str, float], calibration_data: bytes):
try:
calibration = parse_calibration(StringIO(calibration_data.decode("utf-8")))
except Exception:
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File")
return (
*filter_tags(predictions, 0.0, calibration), calibration,
gr.Slider(label="Using Uploaded Calibration", elem_classes=["inactive-slider"]),
gr.Textbox(label="Change Calibration")
)
def cam_changed(
display_image: Image.Image,
image: Image.Image, features: dict[str, Tensor],
tag: str, cam_depth: int
):
if tag == "None":
return display_image, features
return run_cam(display_image, image, features, tags[tag], cam_depth)
def tag_box_select(evt: gr.SelectData):
return evt.value
custom_css = """
.output-class { display: none; }
.inferno-slider input[type=range] {
background: linear-gradient(to right,
#000004, #1b0c41, #4a0c6b, #781c6d,
#a52c60, #cf4446, #ed6925, #fb9b06,
#f7d13d, #fcffa4
) !important;
background-size: 100% 100% !important;
}
.inactive-slider input[type=range] {
--slider-color: grey !important;
}
#image_container-image {
width: 100%;
aspect-ratio: 1 / 1;
max-height: 100%;
}
#image_container img {
object-fit: contain !important;
}
.show-api, .show-api-divider {
display: none !important;
}
"""
with gr.Blocks(
title="RedRocket JTP-3 Hydra",
css=custom_css,
analytics_enabled=False,
) as demo:
display_image_state = gr.State()
image_state = gr.State()
features_state = gr.State()
predictions_state = gr.State(value={})
calibration_state = gr.State()
gr.HTML(
"<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
"<a href='https://huggingface.co/RedRocket' target='_blank'>"
"<img src='https://huggingface.co/spaces/RedRocket/README/resolve/main/RedRocket.png' style='width: 2em; margin-right: 0.5em;'>"
"</a>"
"<span>"
"<a href='https://huggingface.co/RedRocket' target='_blank'>RedRocket</a> &ndash; JTP-3 Hydra"
"</span>"
"</h1>"
)
with gr.Row():
with gr.Column():
with gr.Column():
image = gr.Image(
sources=['upload', 'clipboard'], type='pil',
show_label=False,
show_download_button=False,
show_share_button=False,
elem_id="image_container"
)
url = gr.Textbox(
label="Upload Image via Url:",
placeholder="https://example.com/image.jpg",
max_lines=1,
submit_btn="⮝",
)
with gr.Column():
cam_tag = gr.Dropdown(
value="None", choices=["None"] + tag_list,
label="CAM Attention Overlay (You can also click a tag on the right.)", show_label=True,
)
cam_depth = gr.Slider(
minimum=1, maximum=27, step=1, value=1,
label="CAM Depth (1=fastest, more precise; 27=slowest, more general)"
)
with gr.Column():
gr.HTML(f"<div style=\"text-align: center;\">{len(ext_info)} extension{'s' if len(ext_info) != 1 else ''} loaded.</div>")
with gr.Column():
with gr.Row(variant="panel"):
threshold_slider = gr.Slider(
minimum=0.00, maximum=1.00, step=0.01, value=0.30,
label="Tag Threshold", scale=4
)
with gr.Column(), gr.Group():
calibration_default = gr.Button(
interactive=os.path.exists("calibration.csv"),
value="Default Calibration", size="lg",
)
calibration_upload = gr.UploadButton(
file_count="single", file_types=["text"], type="binary",
label="Upload Calibration", size="md", variant="secondary",
)
tag_string = gr.Textbox(lines=3, label="Tags", show_copy_button=True)
tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
image.upload(
fn=image_upload,
inputs=[image],
outputs=[
tag_string, tag_box, cam_tag, url,
image, display_image_state,
image_state,
],
show_progress='minimal',
show_progress_on=[image]
).then(
fn=image_changed,
inputs=[image_state, threshold_slider, calibration_state, cam_depth],
outputs=[
tag_string, tag_box,
features_state, predictions_state,
],
show_progress='minimal',
show_progress_on=[tag_box]
)
url.submit(
fn=url_submit,
inputs=[url],
outputs=[
tag_string, tag_box, cam_tag,
image, display_image_state,
image_state,
],
show_progress='minimal',
show_progress_on=[url]
).then(
fn=image_changed,
inputs=[image_state, threshold_slider, calibration_state, cam_depth],
outputs=[
tag_string, tag_box,
features_state, predictions_state,
],
show_progress='minimal',
show_progress_on=[tag_box]
)
image.clear(
fn=image_clear,
inputs=[],
outputs=[
tag_string, tag_box, cam_tag, url,
image, display_image_state,
image_state, features_state, predictions_state,
],
show_progress='hidden'
)
threshold_slider.input(
fn=threshold_input,
inputs=[predictions_state, threshold_slider],
outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
trigger_mode='always_last',
show_progress='hidden'
)
calibration_default.click(
fn=calibration_load,
inputs=[predictions_state],
outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
show_progress='hidden'
)
calibration_upload.upload(
fn=calibration_changed,
inputs=[predictions_state, calibration_upload],
outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
trigger_mode='always_last',
show_progress='minimal',
show_progress_on=[calibration_upload],
)
cam_tag.input(
fn=cam_changed,
inputs=[
display_image_state,
image_state, features_state,
cam_tag, cam_depth,
],
outputs=[image, features_state],
trigger_mode='always_last',
show_progress='minimal',
show_progress_on=[cam_tag]
)
cam_depth.input(
fn=cam_changed,
inputs=[
display_image_state,
image_state, features_state,
cam_tag, cam_depth,
],
outputs=[image, features_state],
trigger_mode='always_last',
show_progress='minimal',
show_progress_on=[cam_depth]
)
tag_box.select(
fn=tag_box_select,
inputs=[],
outputs=[cam_tag],
trigger_mode='always_last',
show_progress='hidden',
).then(
fn=cam_changed,
inputs=[
display_image_state,
image_state, features_state,
cam_tag, cam_depth,
],
outputs=[image, features_state],
show_progress='minimal',
show_progress_on=[cam_tag]
)
scan_timer = gr.Timer()
scan_timer.tick(
fn=lambda: gr.Button(interactive=os.path.exists("calibration.csv")),
outputs=[calibration_default],
show_progress='hidden'
)
if __name__ == "__main__":
demo.launch()