Image Classification
English
furry
e621
Not-For-All-Audiences
RedHotTensors commited on
Commit
2711812
·
1 Parent(s): 724a033

Add calibration support to WebUI.

Browse files
Files changed (1) hide show
  1. app.py +111 -16
app.py CHANGED
@@ -1,4 +1,6 @@
1
- from io import BytesIO
 
 
2
  from threading import Lock
3
 
4
  import numpy as np
@@ -31,8 +33,11 @@ model_lock = Lock()
31
  model, tag_list = load_model("models/jtp-3-hydra.safetensors", device=device)
32
  model.requires_grad_(False)
33
 
 
 
 
34
  tags = {
35
- tag.replace("_", " ").replace("vulva", "pussy"): idx
36
  for idx, tag in enumerate(tag_list)
37
  }
38
  tag_list = list(tags.keys())
@@ -169,12 +174,19 @@ def cam_composite(image: Image.Image, cam: np.ndarray):
169
 
170
  return image
171
 
172
- def filter_tags(predictions: dict[str, float], threshold: float):
173
- predictions = {
174
- key: value
175
- for key, value in predictions.items()
176
- if value >= threshold
177
- }
 
 
 
 
 
 
 
178
 
179
  tag_str = ", ".join(predictions.keys())
180
  return tag_str, predictions
@@ -224,9 +236,9 @@ def url_submit(url: str):
224
  processed_image,
225
  )
226
 
227
- def image_changed(image: Image.Image, threshold: float, cam_depth: int):
228
  features, predictions = run_classifier(image, cam_depth)
229
- return *filter_tags(predictions, threshold), features, predictions
230
 
231
  def image_clear():
232
  return (
@@ -235,6 +247,44 @@ def image_clear():
235
  None, None, {},
236
  )
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def cam_changed(
239
  display_image: Image.Image,
240
  image: Image.Image, features: dict[str, Tensor],
@@ -258,6 +308,11 @@ custom_css = """
258
  ) !important;
259
  background-size: 100% 100% !important;
260
  }
 
 
 
 
 
261
  #image_container-image {
262
  width: 100%;
263
  aspect-ratio: 1 / 1;
@@ -280,6 +335,7 @@ with gr.Blocks(
280
  image_state = gr.State()
281
  features_state = gr.State()
282
  predictions_state = gr.State(value={})
 
283
 
284
  gr.HTML(
285
  "<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
@@ -321,8 +377,24 @@ with gr.Blocks(
321
  )
322
 
323
  with gr.Column():
324
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Tag Threshold")
325
- tag_string = gr.Textbox(lines=3, label="Tags", show_label=True, show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
327
 
328
  image.upload(
@@ -337,7 +409,7 @@ with gr.Blocks(
337
  show_progress_on=[image]
338
  ).then(
339
  fn=image_changed,
340
- inputs=[image_state, threshold_slider, cam_depth],
341
  outputs=[
342
  tag_string, tag_box,
343
  features_state, predictions_state,
@@ -358,7 +430,7 @@ with gr.Blocks(
358
  show_progress_on=[url]
359
  ).then(
360
  fn=image_changed,
361
- inputs=[image_state, threshold_slider, cam_depth],
362
  outputs=[
363
  tag_string, tag_box,
364
  features_state, predictions_state,
@@ -379,13 +451,29 @@ with gr.Blocks(
379
  )
380
 
381
  threshold_slider.input(
382
- fn=filter_tags,
383
  inputs=[predictions_state, threshold_slider],
384
- outputs=[tag_string, tag_box],
385
  trigger_mode='always_last',
386
  show_progress='hidden'
387
  )
388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  cam_tag.input(
390
  fn=cam_changed,
391
  inputs=[
@@ -430,5 +518,12 @@ with gr.Blocks(
430
  show_progress_on=[cam_tag]
431
  )
432
 
 
 
 
 
 
 
 
433
  if __name__ == "__main__":
434
  demo.launch()
 
1
+ import csv
2
+ import os
3
+ from io import BytesIO, StringIO
4
  from threading import Lock
5
 
6
  import numpy as np
 
33
  model, tag_list = load_model("models/jtp-3-hydra.safetensors", device=device)
34
  model.requires_grad_(False)
35
 
36
+ def rewrite_tag(tag: str) -> str:
37
+ return tag.replace("_", " ").replace("vulva", "pussy")
38
+
39
  tags = {
40
+ rewrite_tag(tag): idx
41
  for idx, tag in enumerate(tag_list)
42
  }
43
  tag_list = list(tags.keys())
 
174
 
175
  return image
176
 
177
+ def filter_tags(predictions: dict[str, float], threshold: float, calibration: dict[str, float] | None):
178
+ if calibration is None:
179
+ predictions = {
180
+ key: value
181
+ for key, value in predictions.items()
182
+ if value >= threshold
183
+ }
184
+ else:
185
+ predictions = {
186
+ key: value
187
+ for key, value in predictions.items()
188
+ if value >= calibration.get(key, float("inf"))
189
+ }
190
 
191
  tag_str = ", ".join(predictions.keys())
192
  return tag_str, predictions
 
236
  processed_image,
237
  )
238
 
239
+ def image_changed(image: Image.Image, threshold: float, calibration: dict[str, float] | None, cam_depth: int):
240
  features, predictions = run_classifier(image, cam_depth)
241
+ return *filter_tags(predictions, threshold, calibration), features, predictions
242
 
243
  def image_clear():
244
  return (
 
247
  None, None, {},
248
  )
249
 
250
+ def threshold_input(predictions: dict[str, float], threshold: float):
251
+ return (
252
+ *filter_tags(predictions, threshold, None), None,
253
+ gr.Slider(label="Tag Threshold", elem_classes=[]),
254
+ gr.Textbox(label="Upload Calibration")
255
+ )
256
+
257
+ def parse_calibration(data) -> dict[str, float]:
258
+ return {
259
+ rewrite_tag(row["tag"]): float(row["threshold"])
260
+ for row in csv.DictReader(data)
261
+ }
262
+
263
+ def calibration_load(predictions: dict[str, float]):
264
+ try:
265
+ with open("calibration.csv", "r", encoding="utf-8", newline="") as csv:
266
+ calibration = parse_calibration(csv)
267
+ except Exception:
268
+ return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File")
269
+
270
+ return (
271
+ *filter_tags(predictions, 0.0, calibration), calibration,
272
+ gr.Slider(label="Using Default Calibration", elem_classes=["inactive-slider"]),
273
+ gr.Textbox(label="Change Calibration")
274
+ )
275
+
276
+ def calibration_changed(predictions: dict[str, float], calibration_data: bytes):
277
+ try:
278
+ calibration = parse_calibration(StringIO(calibration_data.decode("utf-8")))
279
+ except Exception:
280
+ return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.Textbox(label="Invalid Calibration File")
281
+
282
+ return (
283
+ *filter_tags(predictions, 0.0, calibration), calibration,
284
+ gr.Slider(label="Using Uploaded Calibration", elem_classes=["inactive-slider"]),
285
+ gr.Textbox(label="Change Calibration")
286
+ )
287
+
288
  def cam_changed(
289
  display_image: Image.Image,
290
  image: Image.Image, features: dict[str, Tensor],
 
308
  ) !important;
309
  background-size: 100% 100% !important;
310
  }
311
+
312
+ .inactive-slider input[type=range] {
313
+ --slider-color: grey !important;
314
+ }
315
+
316
  #image_container-image {
317
  width: 100%;
318
  aspect-ratio: 1 / 1;
 
335
  image_state = gr.State()
336
  features_state = gr.State()
337
  predictions_state = gr.State(value={})
338
+ calibration_state = gr.State()
339
 
340
  gr.HTML(
341
  "<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
 
377
  )
378
 
379
  with gr.Column():
380
+ with gr.Row(variant="panel"):
381
+ threshold_slider = gr.Slider(
382
+ minimum=0.00, maximum=1.00, step=0.01, value=0.30,
383
+ label="Tag Threshold", scale=4
384
+ )
385
+
386
+ with gr.Column(), gr.Group():
387
+ calibration_default = gr.Button(
388
+ interactive=os.path.exists("calibration.csv"),
389
+ value="Default Calibration", size="lg",
390
+ )
391
+
392
+ calibration_upload = gr.UploadButton(
393
+ file_count="single", file_types=["text"], type="binary",
394
+ label="Upload Calibration", size="md", variant="secondary",
395
+ )
396
+
397
+ tag_string = gr.Textbox(lines=3, label="Tags", show_copy_button=True)
398
  tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
399
 
400
  image.upload(
 
409
  show_progress_on=[image]
410
  ).then(
411
  fn=image_changed,
412
+ inputs=[image_state, threshold_slider, calibration_state, cam_depth],
413
  outputs=[
414
  tag_string, tag_box,
415
  features_state, predictions_state,
 
430
  show_progress_on=[url]
431
  ).then(
432
  fn=image_changed,
433
+ inputs=[image_state, threshold_slider, calibration_state, cam_depth],
434
  outputs=[
435
  tag_string, tag_box,
436
  features_state, predictions_state,
 
451
  )
452
 
453
  threshold_slider.input(
454
+ fn=threshold_input,
455
  inputs=[predictions_state, threshold_slider],
456
+ outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
457
  trigger_mode='always_last',
458
  show_progress='hidden'
459
  )
460
 
461
+ calibration_default.click(
462
+ fn=calibration_load,
463
+ inputs=[predictions_state],
464
+ outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
465
+ show_progress='hidden'
466
+ )
467
+
468
+ calibration_upload.upload(
469
+ fn=calibration_changed,
470
+ inputs=[predictions_state, calibration_upload],
471
+ outputs=[tag_string, tag_box, calibration_state, threshold_slider, calibration_upload],
472
+ trigger_mode='always_last',
473
+ show_progress='minimal',
474
+ show_progress_on=[calibration_upload],
475
+ )
476
+
477
  cam_tag.input(
478
  fn=cam_changed,
479
  inputs=[
 
518
  show_progress_on=[cam_tag]
519
  )
520
 
521
+ scan_timer = gr.Timer()
522
+ scan_timer.tick(
523
+ fn=lambda: gr.Button(interactive=os.path.exists("calibration.csv")),
524
+ outputs=[calibration_default],
525
+ show_progress='hidden'
526
+ )
527
+
528
  if __name__ == "__main__":
529
  demo.launch()