Commit ·
2711812
1
Parent(s): 724a033
Add calibration support to WebUI.
Browse files
app.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
from threading import Lock
|
| 3 |
|
| 4 |
import numpy as np
|
|
@@ -31,8 +33,11 @@ model_lock = Lock()
|
|
| 31 |
model, tag_list = load_model("models/jtp-3-hydra.safetensors", device=device)
|
| 32 |
model.requires_grad_(False)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
tags = {
|
| 35 |
-
|
| 36 |
for idx, tag in enumerate(tag_list)
|
| 37 |
}
|
| 38 |
tag_list = list(tags.keys())
|
|
@@ -169,12 +174,19 @@ def cam_composite(image: Image.Image, cam: np.ndarray):
|
|
| 169 |
|
| 170 |
return image
|
| 171 |
|
| 172 |
-
def filter_tags(predictions: dict[str, float], threshold: float):
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
tag_str = ", ".join(predictions.keys())
|
| 180 |
return tag_str, predictions
|
|
@@ -224,9 +236,9 @@ def url_submit(url: str):
|
|
| 224 |
processed_image,
|
| 225 |
)
|
| 226 |
|
| 227 |
-
def image_changed(image: Image.Image, threshold: float, cam_depth: int):
|
| 228 |
features, predictions = run_classifier(image, cam_depth)
|
| 229 |
-
return *filter_tags(predictions, threshold), features, predictions
|
| 230 |
|
| 231 |
def image_clear():
|
| 232 |
return (
|
|
@@ -235,6 +247,44 @@ def image_clear():
|
|
| 235 |
None, None, {},
|
| 236 |
)
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def cam_changed(
|
| 239 |
display_image: Image.Image,
|
| 240 |
image: Image.Image, features: dict[str, Tensor],
|
|
@@ -258,6 +308,11 @@ custom_css = """
|
|
| 258 |
) !important;
|
| 259 |
background-size: 100% 100% !important;
|
| 260 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
#image_container-image {
|
| 262 |
width: 100%;
|
| 263 |
aspect-ratio: 1 / 1;
|
|
@@ -280,6 +335,7 @@ with gr.Blocks(
|
|
| 280 |
image_state = gr.State()
|
| 281 |
features_state = gr.State()
|
| 282 |
predictions_state = gr.State(value={})
|
|
|
|
| 283 |
|
| 284 |
gr.HTML(
|
| 285 |
"<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
|
|
@@ -321,8 +377,24 @@ with gr.Blocks(
|
|
| 321 |
)
|
| 322 |
|
| 323 |
with gr.Column():
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
|
| 327 |
|
| 328 |
image.upload(
|
|
@@ -337,7 +409,7 @@ with gr.Blocks(
|
|
| 337 |
show_progress_on=[image]
|
| 338 |
).then(
|
| 339 |
fn=image_changed,
|
| 340 |
-
inputs=[image_state, threshold_slider, cam_depth],
|
| 341 |
outputs=[
|
| 342 |
tag_string, tag_box,
|
| 343 |
features_state, predictions_state,
|
|
@@ -358,7 +430,7 @@ with gr.Blocks(
|
|
| 358 |
show_progress_on=[url]
|
| 359 |
).then(
|
| 360 |
fn=image_changed,
|
| 361 |
-
inputs=[image_state, threshold_slider, cam_depth],
|
| 362 |
outputs=[
|
| 363 |
tag_string, tag_box,
|
| 364 |
features_state, predictions_state,
|
|
@@ -379,13 +451,29 @@ with gr.Blocks(
|
|
| 379 |
)
|
| 380 |
|
| 381 |
threshold_slider.input(
|
| 382 |
-
fn=
|
| 383 |
inputs=[predictions_state, threshold_slider],
|
| 384 |
-
outputs=[tag_string, tag_box],
|
| 385 |
trigger_mode='always_last',
|
| 386 |
show_progress='hidden'
|
| 387 |
)
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
cam_tag.input(
|
| 390 |
fn=cam_changed,
|
| 391 |
inputs=[
|
|
@@ -430,5 +518,12 @@ with gr.Blocks(
|
|
| 430 |
show_progress_on=[cam_tag]
|
| 431 |
)
|
| 432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
if __name__ == "__main__":
|
| 434 |
demo.launch()
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
from io import BytesIO, StringIO
|
| 4 |
from threading import Lock
|
| 5 |
|
| 6 |
import numpy as np
|
|
|
|
| 33 |
model, tag_list = load_model("models/jtp-3-hydra.safetensors", device=device)
|
| 34 |
model.requires_grad_(False)
|
| 35 |
|
| 36 |
+
def rewrite_tag(tag: str) -> str:
|
| 37 |
+
return tag.replace("_", " ").replace("vulva", "pussy")
|
| 38 |
+
|
| 39 |
tags = {
|
| 40 |
+
rewrite_tag(tag): idx
|
| 41 |
for idx, tag in enumerate(tag_list)
|
| 42 |
}
|
| 43 |
tag_list = list(tags.keys())
|
|
|
|
| 174 |
|
| 175 |
return image
|
| 176 |
|
| 177 |
+
def filter_tags(predictions: dict[str, float], threshold: float, calibration: dict[str, float] | None):
|
| 178 |
+
if calibration is None:
|
| 179 |
+
predictions = {
|
| 180 |
+
key: value
|
| 181 |
+
for key, value in predictions.items()
|
| 182 |
+
if value >= threshold
|
| 183 |
+
}
|
| 184 |
+
else:
|
| 185 |
+
predictions = {
|
| 186 |
+
key: value
|
| 187 |
+
for key, value in predictions.items()
|
| 188 |
+
if value >= calibration.get(key, float("inf"))
|
| 189 |
+
}
|
| 190 |
|
| 191 |
tag_str = ", ".join(predictions.keys())
|
| 192 |
return tag_str, predictions
|
|
|
|
| 236 |
processed_image,
|
| 237 |
)
|
| 238 |
|
| 239 |
+
def image_changed(image: Image.Image, threshold: float, calibration: dict[str, float] | None, cam_depth: int):
|
| 240 |
features, predictions = run_classifier(image, cam_depth)
|
| 241 |
+
return *filter_tags(predictions, threshold, calibration), features, predictions
|
| 242 |
|
| 243 |
def image_clear():
|
| 244 |
return (
|
|
|
|
| 247 |
None, None, {},
|
| 248 |
)
|
| 249 |
|
| 250 |
+
def threshold_input(predictions: dict[str, float], threshold: float):
|
| 251 |
+
return (
|
| 252 |
+
*filter_tags(predictions, threshold, None), None,
|
| 253 |
+
gr.Slider(label="Tag Threshold", elem_classes=[]),
|
| 254 |
+
gr.Textbox(label="Upload Calibration")
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def parse_calibration(data) -> dict[str, float]:
|
| 258 |
+
return {
|
| 259 |
+
rewrite_tag(row["tag"]): float(row["threshold"])
|
| 260 |
+
for row in csv.DictReader(data)
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
def calibration_load(predictions: dict[str, float]):
|
| 264 |
+
try:
|
| 265 |
+
with open("calibration.csv", "r", encoding="utf-8", newline="") as csv:
|
| 266 |
+
calibration = parse_calibration(csv)
|
| 267 |
+
except Exception:
|
| 268 |
+
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File")
|
| 269 |
+
|
| 270 |
+
return (
|
| 271 |
+
*filter_tags(predictions, 0.0, calibration), calibration,
|
| 272 |
+
gr.Slider(label="Using Default Calibration", elem_classes=["inactive-slider"]),
|
| 273 |
+
gr.Textbox(label="Change Calibration")
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def calibration_changed(predictions: dict[str, float], calibration_data: bytes):
|
| 277 |
+
try:
|
| 278 |
+
calibration = parse_calibration(StringIO(calibration_data.decode("utf-8")))
|
| 279 |
+
except Exception:
|
| 280 |
+
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File")
|
| 281 |
+
|
| 282 |
+
return (
|
| 283 |
+
*filter_tags(predictions, 0.0, calibration), calibration,
|
| 284 |
+
gr.Slider(label="Using Uploaded Calibration", elem_classes=["inactive-slider"]),
|
| 285 |
+
gr.Textbox(label="Change Calibration")
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
def cam_changed(
|
| 289 |
display_image: Image.Image,
|
| 290 |
image: Image.Image, features: dict[str, Tensor],
|
|
|
|
| 308 |
) !important;
|
| 309 |
background-size: 100% 100% !important;
|
| 310 |
}
|
| 311 |
+
|
| 312 |
+
.inactive-slider input[type=range] {
|
| 313 |
+
--slider-color: grey !important;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
#image_container-image {
|
| 317 |
width: 100%;
|
| 318 |
aspect-ratio: 1 / 1;
|
|
|
|
| 335 |
image_state = gr.State()
|
| 336 |
features_state = gr.State()
|
| 337 |
predictions_state = gr.State(value={})
|
| 338 |
+
calibration_state = gr.State()
|
| 339 |
|
| 340 |
gr.HTML(
|
| 341 |
"<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
|
|
|
|
| 377 |
)
|
| 378 |
|
| 379 |
with gr.Column():
|
| 380 |
+
with gr.Row(variant="panel"):
|
| 381 |
+
threshold_slider = gr.Slider(
|
| 382 |
+
minimum=0.00, maximum=1.00, step=0.01, value=0.30,
|
| 383 |
+
label="Tag Threshold", scale=4
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
with gr.Column(), gr.Group():
|
| 387 |
+
calibration_default = gr.Button(
|
| 388 |
+
interactive=os.path.exists("calibration.csv"),
|
| 389 |
+
value="Default Calibration", size="lg",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
calibration_upload = gr.UploadButton(
|
| 393 |
+
file_count="single", file_types=["text"], type="binary",
|
| 394 |
+
label="Upload Calibration", size="md", variant="secondary",
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
tag_string = gr.Textbox(lines=3, label="Tags", show_copy_button=True)
|
| 398 |
tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
|
| 399 |
|
| 400 |
image.upload(
|
|
|
|
| 409 |
show_progress_on=[image]
|
| 410 |
).then(
|
| 411 |
fn=image_changed,
|
| 412 |
+
inputs=[image_state, threshold_slider, calibration_state, cam_depth],
|
| 413 |
outputs=[
|
| 414 |
tag_string, tag_box,
|
| 415 |
features_state, predictions_state,
|
|
|
|
| 430 |
show_progress_on=[url]
|
| 431 |
).then(
|
| 432 |
fn=image_changed,
|
| 433 |
+
inputs=[image_state, threshold_slider, calibration_state, cam_depth],
|
| 434 |
outputs=[
|
| 435 |
tag_string, tag_box,
|
| 436 |
features_state, predictions_state,
|
|
|
|
| 451 |
)
|
| 452 |
|
| 453 |
threshold_slider.input(
|
| 454 |
+
fn=threshold_input,
|
| 455 |
inputs=[predictions_state, threshold_slider],
|
| 456 |
+
outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
|
| 457 |
trigger_mode='always_last',
|
| 458 |
show_progress='hidden'
|
| 459 |
)
|
| 460 |
|
| 461 |
+
calibration_default.click(
|
| 462 |
+
fn=calibration_load,
|
| 463 |
+
inputs=[predictions_state],
|
| 464 |
+
outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
|
| 465 |
+
show_progress='hidden'
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
calibration_upload.upload(
|
| 469 |
+
fn=calibration_changed,
|
| 470 |
+
inputs=[predictions_state, calibration_upload],
|
| 471 |
+
outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
|
| 472 |
+
trigger_mode='always_last',
|
| 473 |
+
show_progress='minimal',
|
| 474 |
+
show_progress_on=[calibration_upload],
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
cam_tag.input(
|
| 478 |
fn=cam_changed,
|
| 479 |
inputs=[
|
|
|
|
| 518 |
show_progress_on=[cam_tag]
|
| 519 |
)
|
| 520 |
|
| 521 |
+
scan_timer = gr.Timer()
|
| 522 |
+
scan_timer.tick(
|
| 523 |
+
fn=lambda: gr.Button(interactive=os.path.exists("calibration.csv")),
|
| 524 |
+
outputs=[calibration_default],
|
| 525 |
+
show_progress='hidden'
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
if __name__ == "__main__":
|
| 529 |
demo.launch()
|