Image Classification
English
furry
e621
Not-For-All-Audiences
RedHotTensors commited on
Commit
2c02f46
·
1 Parent(s): 162684a

JTP-3 Hydra Release

Browse files
Files changed (18) hide show
  1. .gitattributes +2 -0
  2. .gitignore +2 -0
  3. README.md +101 -3
  4. app.bat +4 -0
  5. app.py +428 -0
  6. data/hydra.jpg +3 -0
  7. data/jtp-3-hydra-tags.csv +0 -0
  8. data/jtp-3-hydra-val.csv +3 -0
  9. glu.py +40 -0
  10. hydra_pool.py +581 -0
  11. image.py +271 -0
  12. inference.bat +4 -0
  13. inference.py +318 -0
  14. install.bat +4 -0
  15. loader.py +150 -0
  16. model.py +192 -0
  17. models/jtp-3-hydra.safetensors +3 -0
  18. requirements.txt +8 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/jtp-3-hydra-val.csv filter=lfs diff=lfs merge=lfs -text
37
+ data/hydra.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ venv
README.md CHANGED
@@ -1,3 +1,101 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - furry
4
+ - e621
5
+ - not-for-all-audiences
6
+ pipeline_tag: image-classification
7
+ base_model: google/siglip2-so400m-patch16-naflex
8
+ library_name: timm
9
+ language:
10
+ - en
11
+ license: apache-2.0
12
+ ---
13
+
14
+ <div style="text-align: center;">
15
+ <img style="width: 60%; display: inline-block;" src="https://huggingface.co/RedRocket/JTP-3/resolve/main/data/hydra.jpg">
16
+
17
+ <h1 style="text-align: center; margin-bottom: 0;">JTP-3 Hydra</h1>
18
+ <span style="font-size: large;">e621 Image Classifier by <a href="https://huggingface.co/RedRocket/" style="font-size: large;">Project RedRocket</a></span>
19
+ </div>
20
+
21
+ JTP-3 Hydra is a finetune of the SigLIP2 image classifier with a custom classifier head, designed to predict 7,504 popular tags from [e621](https://e621.net).
22
+
23
+ ## Downloading
24
+ Follow Hugging Face instructions to check-out the respository using git.
25
+ If you are unable to do this, manually download all the `.py` files, as well as `model/jtp-3-hydra.safetensors` and `requirements.txt`.
26
+ If you are on Windows, also download the `.bat` files and follow the instructions below for easy installation.
27
+
28
+ ## Windows Installation and Usage
29
+ For Windows, ensure you have at least Python 3.12 [installed](https://www.python.org/downloads/windows/) and available on your path.
30
+ Then, double-click ``install.bat`` to run installation, which will create a virtual environment for all the requirements and install them.
31
+
32
+ You can run the WebUI by double clicking ``app.bat`` and navigating your browser to the URL it shows. The link is not shared publicly.
33
+
34
+ On the command line, you can use ``inference.bat`` to do bulk operations such as tagging entire directories. Run ``inference.bat --help`` for help using the command line.
35
+ If you provide a path to a file or directory, it will write ``.txt`` caption files beside each image using the default threshold of ``0.5``.
36
+
37
+ ### Linux Installation and Usage
38
+ ```sh
39
+ python -m venv venv
40
+ source venv/bin/activate
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ ```sh
45
+ source venv/bin/activate
46
+ python app.py
47
+ ```
48
+
49
+ ```sh
50
+ source venv/bin/activate
51
+ python inference.py --help
52
+ ```
53
+
54
+ ## Usage Notes
55
+ The model predicts 7,501 e621 tags, as well as the added rating meta-tags ``safe``, ``questionable``, and ``explicit``.
56
+ The model is trained with implications, but it's predictions are not constrained.
57
+ So, for example, it's possible it can say ``tyrannosaurus rex`` is more likely than ``dinosaur``.
58
+
59
+ The model is trained on images on e621 only, and not on photographs of people or real animals.
60
+ While it has retained some ability to classify photos, this is not in any way supported.
61
+
62
+ The interactive interfaces use a threshold convention of -100% to 100%.
63
+ This is different from other classifier models that generally range from 0% to 100%.
64
+
65
+ The model sees all transparency as a black background.
66
+
67
+ ## Technical Notes
68
+ The model consists of [SigLIP2 So400m Patch16 NAFlex](https://huggingface.co/google/siglip2-so400m-patch16-naflex) followed by a custom cross-attention transformer block with learned per-tag queries, SwiGLU feedforward, and per-tag SwiGLU output heads. The per-tag cross attention mechanism is the origin of the moniker "hydra".
69
+
70
+ Subject to the preprocessing mentioned below, the initial set of training tags was all <span style="color:#2e76b4">general</span> tags with at least 1,200 examples, all <span style="color:#ed5d1f">species</span> and <span style="color:#00aa00">character</span> tags with at least 500 examples, a semi-automated selection of <span style="color:#dd00dd">copyright</span> and <span style="color:#666666">meta</span> tags, and a handful of manually-selected <span style="color:#228822">lore</span> tags which are sometimes discernible from the image.
71
+ This resulted in 8,067 tags. After training, tags with very poor validation performance were pruned, resulting in the final set of 7,504 tags.
72
+
73
+ Extensive semi-manual dataset curation was used to improve the quality of the training data.
74
+ The dataset preprocessing code consists of over 12,000 lines of code and data files.
75
+ In addition to correcting implications, manually-defined rules are used to detect common scenarios of missing, incomplete, or contradictory tagging and to selectively mask individual tags on a per-dataset-item basis.
76
+ This is responsible for JTP-3's excellent performance in detecting colors and "combo tags" such as `male_feral`.
77
+
78
+ Margin-focal cross entropy loss based on ASL was used to mitigate the effects of inconsistent labeling on e621 and the extreme class imbalance.
79
+ The dataset was sampled in mini-epochs according to a self-entropy metric.
80
+ Loss weight for negative labels was logarithmically redistributed from images with few tags to those with many tags.
81
+
82
+ Raw validation performance metrics and tag lists are available in the ``data`` folder.
83
+ These can be used to create P/R curves, compute CTI or F<sub>1</sub> scores, or select automated thresholds for each tag.
84
+ The list of supported tags is also embedded in the safetensors metadata as ``classifier.labels``.
85
+
86
+ Internally, the model operates on logits as normal and classification thresholds are expressed in the interval from 0.0 to 1.0.
87
+ This is reflected in the ``data`` files and csv output of ``inference.py``.
88
+
89
+ ## Credits
90
+
91
+ RedHotTensors — Architecture design, dataset curation, infrastructure and training, testing, and release.<br>
92
+ DrHead — WebUI, multi-layer CAM, testing, and additional code.<br>
93
+ Thessalo — Advice and testing.<br>
94
+ Google Gemini — Hero image.<br>
95
+ [Furry Diffusion Community](https://discord.com/channels/1019133813105905664/1254974507819733017) — Beta feedback and WebUI testing.
96
+
97
+ ### Citations
98
+
99
+ Michael Tschannen, et al. [SigLIP 2.](https://arxiv.org/abs/2502.14786)<br>
100
+ Emanuel Ben-Baruch, et al. [Asymmetric Loss For Multi-Label Classification.](https://arxiv.org/abs/2009.14119)<br>
101
+ Noam Shazeer. [GLU Variants Improve Transformer.](https://arxiv.org/abs/2002.05202)
app.bat ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ IF NOT EXIST venv call .\install.bat
2
+
3
+ call venv\Scripts\activate.bat
4
+ python app.py
app.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from threading import Lock
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn import Parameter
9
+ from torch.nn.functional import sigmoid
10
+
11
+ import gradio as gr
12
+
13
+ from PIL import Image, ImageDraw, ImageFont
14
+
15
+ import requests
16
+
17
+ from model import load_model, process_image, patchify_image
18
+ from image import unpatchify
19
+
20
+ device = "cuda"
21
+ PATCH_SIZE = 16
22
+ MAX_SEQ_LEN = 1024
23
+
24
+ model_lock = Lock()
25
+ model, tag_list = load_model("models/jtp-3-hydra.safetensors", device=device)
26
+ model.requires_grad_(False)
27
+
28
+ tags = {
29
+ tag.replace("_", " ").replace("vulva", "pussy"): idx
30
+ for idx, tag in enumerate(tag_list)
31
+ }
32
+ tag_list = list(tags.keys())
33
+
34
+ FONT = ImageFont.load_default(24)
35
+
36
+ @torch.no_grad()
37
+ def run_classifier(image: Image.Image, cam_depth: int):
38
+ patches, patch_coords, patch_valid = patchify_image(image, PATCH_SIZE, MAX_SEQ_LEN)
39
+ patches = patches.unsqueeze(0).to(device=device, non_blocking=True)
40
+ patch_coords = patch_coords.unsqueeze(0).to(device=device, non_blocking=True)
41
+ patch_valid = patch_valid.unsqueeze(0).to(device=device, non_blocking=True)
42
+
43
+ patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
44
+ patch_coords = patch_coords.to(dtype=torch.int32)
45
+
46
+ with model_lock:
47
+ features = model.forward_intermediates(
48
+ patches,
49
+ patch_coord=patch_coords,
50
+ patch_valid=patch_valid,
51
+ indices=cam_depth,
52
+ output_dict=True,
53
+ output_fmt='NLC'
54
+ )
55
+
56
+ logits = model.forward_head(features["image_features"], patch_valid=patch_valid)
57
+ del features["image_features"]
58
+
59
+ features["patch_coords"] = patch_coords
60
+ features["patch_valid"] = patch_valid
61
+ del patches, patch_coords, patch_valid
62
+
63
+ probits = sigmoid(logits[0].to(dtype=torch.float32))
64
+ probits.mul_(2.0).sub_(1.0) # scale to -1 to 1
65
+
66
+ values, indices = probits.cpu().topk(250)
67
+ predictions = {
68
+ tag_list[idx.item()]: val.item()
69
+ for idx, val in sorted(
70
+ zip(indices, values),
71
+ key=lambda item: item[1].item(),
72
+ reverse=True
73
+ )
74
+ }
75
+
76
+ return features, predictions
77
+
78
+ @torch.no_grad()
79
+ def run_cam(
80
+ display_image: Image.Image,
81
+ image: Image.Image, features: dict[str, Tensor],
82
+ tag_idx: int, cam_depth: int
83
+ ):
84
+ intermediates = features["image_intermediates"]
85
+ if len(intermediates) < cam_depth:
86
+ features, _ = run_classifier(image, cam_depth)
87
+ intermediates = features["image_intermediates"]
88
+ elif len(intermediates) > cam_depth:
89
+ intermediates = intermediates[-cam_depth:]
90
+
91
+ patch_coords = features["patch_coords"]
92
+ patch_valid = features["patch_valid"]
93
+
94
+ with model_lock:
95
+ saved_q = model.attn_pool.q
96
+ saved_p = model.attn_pool.out_proj.weight
97
+
98
+ try:
99
+ model.attn_pool.q = Parameter(saved_q[:, [tag_idx], :], requires_grad=False)
100
+ model.attn_pool.out_proj.weight = Parameter(saved_p[[tag_idx], :, :], requires_grad=False)
101
+
102
+ with torch.enable_grad():
103
+ for intermediate in intermediates:
104
+ intermediate.requires_grad_(True).retain_grad()
105
+ model.forward_head(intermediate, patch_valid=patch_valid)[0, 0].backward()
106
+ finally:
107
+ model.attn_pool.q = saved_q
108
+ model.attn_pool.out_proj.weight = saved_p
109
+
110
+ cam_1d: Tensor | None = None
111
+ for intermediate in intermediates:
112
+ patch_grad = (intermediate.grad.float() * intermediate.sign()).sum(dim=(0, 2))
113
+ intermediate.grad = None
114
+
115
+ if cam_1d is None:
116
+ cam_1d = patch_grad
117
+ else:
118
+ cam_1d.add_(patch_grad)
119
+
120
+ assert cam_1d is not None
121
+
122
+ cam_2d = unpatchify(cam_1d, patch_coords, patch_valid).cpu().numpy()
123
+ return cam_composite(display_image, cam_2d), features
124
+
125
+ def cam_composite(image: Image.Image, cam: np.ndarray):
126
+ """
127
+ Overlays CAM on image and returns a PIL image.
128
+ Args:
129
+ image_pil: PIL Image (RGB)
130
+ cam: 2D numpy array (activation map)
131
+
132
+ Returns:
133
+ PIL.Image.Image with overlay
134
+ """
135
+
136
+ cam_abs = np.abs(cam)
137
+ cam_scale = cam_abs.max()
138
+
139
+ cam_rgba = np.dstack((
140
+ (cam < 0).astype(np.float32),
141
+ (cam > 0).astype(np.float32),
142
+ np.zeros_like(cam, dtype=np.float32),
143
+ cam_abs * (0.5 / cam_scale),
144
+ )) # Shape: (H, W, 4)
145
+
146
+ cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8))
147
+ cam_pil = cam_pil.resize(image.size, resample=Image.Resampling.NEAREST)
148
+
149
+ image = Image.blend(
150
+ image.convert('RGBA'),
151
+ image.convert('L').convert('RGBA'),
152
+ 0.33
153
+ )
154
+
155
+ image = Image.alpha_composite(image, cam_pil)
156
+
157
+ draw = ImageDraw.Draw(image)
158
+ draw.text(
159
+ (image.width - 7, image.height - 7),
160
+ f"{cam_scale.item():.4g}",
161
+ anchor="rd", font=FONT, fill=(32, 32, 255, 255)
162
+ )
163
+
164
+ return image
165
+
166
+ def filter_tags(predictions: dict[str, float], threshold: float):
167
+ predictions = {
168
+ key: value
169
+ for key, value in predictions.items()
170
+ if value >= threshold
171
+ }
172
+
173
+ tag_str = ", ".join(predictions.keys())
174
+ return tag_str, predictions
175
+
176
+ def resize_image(image: Image.Image) -> Image.Image:
177
+ longest_side = max(image.height, image.width)
178
+ if longest_side < 1080:
179
+ return image
180
+
181
+ scale = 1080 / longest_side
182
+ return image.resize(
183
+ (
184
+ int(round(image.width * scale)),
185
+ int(round(image.height * scale)),
186
+ ),
187
+ resample=Image.Resampling.LANCZOS,
188
+ reducing_gap=3.0
189
+ )
190
+
191
+ def image_upload(image: Image.Image):
192
+ display_image = resize_image(image)
193
+ processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
194
+
195
+ if display_image is not image and processed_image is not image:
196
+ image.close()
197
+
198
+ return (
199
+ "", {}, "None", "",
200
+ gr.skip() if display_image is image else display_image, display_image,
201
+ processed_image,
202
+ )
203
+
204
+ def url_submit(url: str):
205
+ resp = requests.get(url, timeout=10)
206
+ resp.raise_for_status()
207
+
208
+ image = Image.open(BytesIO(resp.content))
209
+ display_image = resize_image(image)
210
+ processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
211
+
212
+ if display_image is not image and processed_image is not image:
213
+ image.close()
214
+
215
+ return (
216
+ "", {}, "None",
217
+ display_image, display_image,
218
+ processed_image,
219
+ )
220
+
221
+ def image_changed(image: Image.Image, threshold: float, cam_depth: int):
222
+ features, predictions = run_classifier(image, cam_depth)
223
+ return *filter_tags(predictions, threshold), features, predictions
224
+
225
+ def image_clear():
226
+ return (
227
+ "", {}, "None", "",
228
+ None, None,
229
+ None, None, {},
230
+ )
231
+
232
+ def cam_changed(
233
+ display_image: Image.Image,
234
+ image: Image.Image, features: dict[str, Tensor],
235
+ tag: str, cam_depth: int
236
+ ):
237
+ if tag == "None":
238
+ return display_image, features
239
+
240
+ return run_cam(display_image, image, features, tags[tag], cam_depth)
241
+
242
+ def tag_box_select(evt: gr.SelectData):
243
+ return evt.value
244
+
245
+ custom_css = """
246
+ .output-class { display: none; }
247
+ .inferno-slider input[type=range] {
248
+ background: linear-gradient(to right,
249
+ #000004, #1b0c41, #4a0c6b, #781c6d,
250
+ #a52c60, #cf4446, #ed6925, #fb9b06,
251
+ #f7d13d, #fcffa4
252
+ ) !important;
253
+ background-size: 100% 100% !important;
254
+ }
255
+ #image_container-image {
256
+ width: 100%;
257
+ aspect-ratio: 1 / 1;
258
+ max-height: 100%;
259
+ }
260
+ #image_container img {
261
+ object-fit: contain !important;
262
+ }
263
+ .show-api, .show-api-divider {
264
+ display: none !important;
265
+ }
266
+ """
267
+
268
+ with gr.Blocks(
269
+ title="RedRocket JTP-3 Hydra",
270
+ css=custom_css,
271
+ analytics_enabled=False,
272
+ ) as demo:
273
+ display_image_state = gr.State()
274
+ image_state = gr.State()
275
+ features_state = gr.State()
276
+ predictions_state = gr.State(value={})
277
+
278
+ gr.HTML(
279
+ "<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
280
+ "<a href='https://huggingface.co/RedRocket' target='_blank'>"
281
+ "<img src='https://huggingface.co/spaces/RedRocket/README/resolve/main/RedRocket.png' style='width: 2em; margin-right: 0.5em;'>"
282
+ "</a>"
283
+ "<span>"
284
+ "<a href='https://huggingface.co/RedRocket' target='_blank'>RedRocket</a> &ndash; JTP-3 Hydra"
285
+ "</span>"
286
+ "</h1>"
287
+ )
288
+
289
+ with gr.Row():
290
+ with gr.Column():
291
+ with gr.Column():
292
+ image = gr.Image(
293
+ sources=['upload', 'clipboard'], type='pil',
294
+ show_label=False,
295
+ show_download_button=False,
296
+ show_share_button=False,
297
+ elem_id="image_container"
298
+ )
299
+
300
+ url = gr.Textbox(
301
+ label="Upload Image via Url:",
302
+ placeholder="https://example.com/image.jpg",
303
+ max_lines=1,
304
+ submit_btn="⮝",
305
+ )
306
+
307
+ with gr.Column():
308
+ cam_tag = gr.Dropdown(
309
+ value="None", choices=["None"] + tag_list,
310
+ label="CAM Attention Overlay (You can also click a tag on the right.)", show_label=True
311
+ )
312
+ cam_depth = gr.Slider(
313
+ minimum=1, maximum=27, step=1, value=1,
314
+ label="CAM Depth (1=fastest, more precise; 27=slowest, more general)"
315
+ )
316
+
317
+ with gr.Column():
318
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Tag Threshold")
319
+ tag_string = gr.Textbox(lines=3, label="Tags", show_label=True, show_copy_button=True)
320
+ tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
321
+
322
+ image.upload(
323
+ fn=image_upload,
324
+ inputs=[image],
325
+ outputs=[
326
+ tag_string, tag_box, cam_tag, url,
327
+ image, display_image_state,
328
+ image_state,
329
+ ],
330
+ show_progress='minimal',
331
+ show_progress_on=[image]
332
+ ).then(
333
+ fn=image_changed,
334
+ inputs=[image_state, threshold_slider, cam_depth],
335
+ outputs=[
336
+ tag_string, tag_box,
337
+ features_state, predictions_state,
338
+ ],
339
+ show_progress='minimal',
340
+ show_progress_on=[tag_box]
341
+ )
342
+
343
+ url.submit(
344
+ fn=url_submit,
345
+ inputs=[url],
346
+ outputs=[
347
+ tag_string, tag_box, cam_tag,
348
+ image, display_image_state,
349
+ image_state,
350
+ ],
351
+ show_progress='minimal',
352
+ show_progress_on=[url]
353
+ ).then(
354
+ fn=image_changed,
355
+ inputs=[image_state, threshold_slider, cam_depth],
356
+ outputs=[
357
+ tag_string, tag_box,
358
+ features_state, predictions_state,
359
+ ],
360
+ show_progress='minimal',
361
+ show_progress_on=[tag_box]
362
+ )
363
+
364
+ image.clear(
365
+ fn=image_clear,
366
+ inputs=[],
367
+ outputs=[
368
+ tag_string, tag_box, cam_tag, url,
369
+ image, display_image_state,
370
+ image_state, features_state, predictions_state,
371
+ ],
372
+ show_progress='hidden'
373
+ )
374
+
375
+ threshold_slider.input(
376
+ fn=filter_tags,
377
+ inputs=[predictions_state, threshold_slider],
378
+ outputs=[tag_string, tag_box],
379
+ trigger_mode='always_last',
380
+ show_progress='hidden'
381
+ )
382
+
383
+ cam_tag.input(
384
+ fn=cam_changed,
385
+ inputs=[
386
+ display_image_state,
387
+ image_state, features_state,
388
+ cam_tag, cam_depth,
389
+ ],
390
+ outputs=[image, features_state],
391
+ trigger_mode='always_last',
392
+ show_progress='minimal',
393
+ show_progress_on=[cam_tag]
394
+ )
395
+
396
+ cam_depth.input(
397
+ fn=cam_changed,
398
+ inputs=[
399
+ display_image_state,
400
+ image_state, features_state,
401
+ cam_tag, cam_depth,
402
+ ],
403
+ outputs=[image, features_state],
404
+ trigger_mode='always_last',
405
+ show_progress='minimal',
406
+ show_progress_on=[cam_depth]
407
+ )
408
+
409
+ tag_box.select(
410
+ fn=tag_box_select,
411
+ inputs=[],
412
+ outputs=[cam_tag],
413
+ trigger_mode='always_last',
414
+ show_progress='hidden',
415
+ ).then(
416
+ fn=cam_changed,
417
+ inputs=[
418
+ display_image_state,
419
+ image_state, features_state,
420
+ cam_tag, cam_depth,
421
+ ],
422
+ outputs=[image, features_state],
423
+ show_progress='minimal',
424
+ show_progress_on=[cam_tag]
425
+ )
426
+
427
+ if __name__ == "__main__":
428
+ demo.launch()
data/hydra.jpg ADDED

Git LFS Details

  • SHA256: 358c8a585af297ce40474916514d32041367ae5f428db0789c590d5d03800ac0
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
data/jtp-3-hydra-tags.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/jtp-3-hydra-val.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7e9901ae3dec04942ed5ecb7dda0fbbf01afc4d27b1b5f80d509b664952ff77
3
+ size 42149079
glu.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Literal
3
+
4
+ from torch import Tensor
5
+ from torch.nn import Module
6
+ from torch.nn.functional import silu, gelu
7
+
8
+ class GatedUnit(Module):
9
+ def __init__(self, dim: int = -1) -> None:
10
+ super().__init__()
11
+
12
+ self.dim = dim
13
+
14
+ @abstractmethod
15
+ def _activation(self, x: Tensor) -> Tensor:
16
+ ...
17
+
18
+ def forward(self, x: Tensor) -> Tensor:
19
+ f, g = x.chunk(2, dim=self.dim)
20
+ return self._activation(f) * g
21
+
22
+ class SwiGLU(GatedUnit):
23
+ def __init__(self, dim: int = -1) -> None:
24
+ super().__init__(dim)
25
+
26
+ def _activation(self, x: Tensor) -> Tensor:
27
+ return silu(x)
28
+
29
+ class GeGLU(GatedUnit):
30
+ def __init__(
31
+ self,
32
+ dim: int = -1,
33
+ approximate: Literal["tanh", "none"] = "tanh"
34
+ ) -> None:
35
+ super().__init__(dim)
36
+
37
+ self.approximate = approximate
38
+
39
+ def _activation(self, x: Tensor) -> Tensor:
40
+ return gelu(x, self.approximate)
hydra_pool.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import defaultdict
3
+ from math import sqrt
4
+ from typing import Any, Iterable, Self, cast
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn import (
9
+ Module, ModuleList, Parameter, Buffer,
10
+ Linear, LayerNorm, RMSNorm, Dropout, Flatten,
11
+ init
12
+ )
13
+ from torch.nn.functional import pad, scaled_dot_product_attention
14
+
15
+ from einops import rearrange
16
+
17
+ from glu import SwiGLU
18
+
19
+ class IndexedAdd(Module):
20
+ def __init__(
21
+ self,
22
+ n_indices: int,
23
+ dim: int,
24
+ weight_shape: tuple[int, ...] | None = None,
25
+ *,
26
+ inplace: bool = False,
27
+ device: torch.device | str | None = None,
28
+ dtype: torch.dtype | None = None,
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ self.dim = dim
33
+ self.inplace = inplace
34
+
35
+ self.index = Buffer(torch.empty(
36
+ 2, n_indices,
37
+ device=device, dtype=torch.int32
38
+ ))
39
+
40
+ self.weight = Parameter(torch.ones(
41
+ *(sz if sz != -1 else n_indices for sz in weight_shape),
42
+ device=device, dtype=dtype
43
+ )) if weight_shape is not None else None
44
+
45
+ def _save_to_state_dict(
46
+ self,
47
+ destination: dict[str, Any],
48
+ prefix: str,
49
+ keep_vars: bool
50
+ ) -> None:
51
+ super()._save_to_state_dict(destination, prefix, keep_vars)
52
+
53
+ if keep_vars:
54
+ return
55
+
56
+ with torch.no_grad():
57
+ index_key = f"{prefix}index"
58
+ index = destination[index_key]
59
+
60
+ min_index = index.amin(None).item()
61
+ if min_index >= 0:
62
+ max_index = index.amax(None).item()
63
+ if max_index < (1 << 8):
64
+ destination[index_key] = index.to(dtype=torch.uint8)
65
+ elif max_index < (1 << 16):
66
+ destination[index_key] = index.to(dtype=torch.uint16)
67
+
68
+ @torch.no_grad()
69
+ def load_indices(self, indices: Iterable[tuple[int, int]], *, mean: bool = False) -> None:
70
+ if mean:
71
+ if self.weight is None:
72
+ raise ValueError("No weights to initialize with means.")
73
+
74
+ groups: dict[int, list[int]] = defaultdict(list)
75
+
76
+ idx = -1
77
+ for idx, (src, dst) in enumerate(indices):
78
+ self.index[0, idx] = src
79
+ self.index[1, idx] = dst
80
+
81
+ if mean:
82
+ groups[dst].append(idx)
83
+
84
+ if (idx + 1) != self.index.size(1):
85
+ raise IndexError(f"Expected {self.index.size(1)} indices, but got {idx + 1}.")
86
+
87
+ if not mean:
88
+ return
89
+
90
+ assert self.weight is not None
91
+
92
+ for idxs in groups.values():
93
+ if len(idxs) < 2:
94
+ continue
95
+
96
+ self.weight.index_fill_(
97
+ self.dim,
98
+ torch.tensor(idxs, device=self.weight.device, dtype=torch.int64),
99
+ 1.0 / len(idxs)
100
+ )
101
+
102
+ def forward(self, dst: Tensor, src: Tensor) -> Tensor:
103
+ src = src.index_select(self.dim, self.index[0])
104
+
105
+ if self.weight is not None:
106
+ src.mul_(self.weight)
107
+
108
+ return (
109
+ dst.index_add_(self.dim, self.index[1], src)
110
+ if self.inplace else
111
+ dst.index_add(self.dim, self.index[1], src)
112
+ )
113
+
114
+ class BatchLinear(Module):
115
+ def __init__(
116
+ self,
117
+ batch_shape: tuple[int, ...] | int,
118
+ in_features: int,
119
+ out_features: int,
120
+ *,
121
+ bias: bool = False,
122
+ flatten: bool = False,
123
+ bias_inplace: bool = True,
124
+ device: torch.device | str | None = None,
125
+ dtype: torch.dtype | None = None,
126
+ ) -> None:
127
+ super().__init__()
128
+
129
+ if isinstance(batch_shape, int):
130
+ batch_shape = (batch_shape,)
131
+ elif not batch_shape:
132
+ raise ValueError("At least one batch dimension is required.")
133
+
134
+ self.flatten = -(len(batch_shape) + 1) if flatten else 0
135
+
136
+ self.weight = Parameter(torch.empty(
137
+ *batch_shape, in_features, out_features,
138
+ device=device, dtype=dtype
139
+ ))
140
+
141
+ bt = self.weight.flatten(end_dim=-3).mT
142
+ for idx in range(bt.size(0)):
143
+ init.kaiming_uniform_(bt[idx], a=sqrt(5))
144
+
145
+ self.bias = Parameter(torch.zeros(
146
+ *batch_shape, out_features,
147
+ device=device, dtype=dtype
148
+ )) if bias else None
149
+
150
+ self.bias_inplace = bias_inplace
151
+
152
+ def forward(self, x: Tensor) -> Tensor:
153
+ # ... B... 1 I @ B... I O -> ... B... O
154
+ x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2)
155
+
156
+ if self.bias is not None:
157
+ if self.bias_inplace:
158
+ x.add_(self.bias)
159
+ else:
160
+ x = x + self.bias
161
+
162
+ if self.flatten:
163
+ x = x.flatten(self.flatten)
164
+
165
+ return x
166
+
167
+ class Mean(Module):
168
+ def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None:
169
+ super().__init__()
170
+
171
+ self.dim = dim
172
+ self.keepdim = keepdim
173
+
174
+ def forward(self, x: Tensor) -> Tensor:
175
+ return x.mean(self.dim, self.keepdim)
176
+
177
+ class _MidBlock(Module):
178
+ def __init__(
179
+ self,
180
+ attn_dim: int,
181
+ head_dim: int,
182
+ n_classes: int,
183
+ *,
184
+ ff_ratio: float,
185
+ ff_dropout: float,
186
+ q_cls_inplace: bool = True,
187
+ device: torch.device | str | None,
188
+ dtype: torch.dtype | None,
189
+ ) -> None:
190
+ super().__init__()
191
+
192
+ self.head_dim = head_dim
193
+ self.q_cls_inplace = q_cls_inplace
194
+
195
+ hidden_dim = int(attn_dim * ff_ratio)
196
+
197
+ self.q_proj = Linear(
198
+ attn_dim, attn_dim, bias=False,
199
+ device=device, dtype=dtype
200
+ )
201
+
202
+ self.q_cls = Parameter(torch.zeros(
203
+ n_classes, attn_dim,
204
+ device=device, dtype=dtype
205
+ ))
206
+
207
+ self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)
208
+
209
+ self.attn_out = Linear(
210
+ attn_dim, attn_dim, bias=False,
211
+ device=device, dtype=dtype
212
+ )
213
+
214
+ self.ff_norm = LayerNorm(
215
+ attn_dim,
216
+ device=device, dtype=dtype
217
+ )
218
+ self.ff_in = Linear(
219
+ attn_dim, hidden_dim * 2, bias=False,
220
+ device=device, dtype=dtype
221
+ )
222
+ self.ff_act = SwiGLU()
223
+ self.ff_drop = Dropout(ff_dropout)
224
+ self.ff_out = Linear(
225
+ hidden_dim, attn_dim, bias=False,
226
+ device=device, dtype=dtype
227
+ )
228
+
229
+ def _forward_q(self, x: Tensor) -> Tensor:
230
+ x = self.q_proj(x)
231
+
232
+ if self.q_cls_inplace:
233
+ x.add_(self.q_cls)
234
+ else:
235
+ x = x + self.q_cls
236
+
237
+ x = self.q_norm(x)
238
+ x = rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim)
239
+ return x
240
+
241
+ def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor:
242
+ a = scaled_dot_product_attention(
243
+ self._forward_q(x), k, v,
244
+ attn_mask=attn_mask
245
+ )
246
+ a = rearrange(a, "... h s e -> ... s (h e)")
247
+ a = self.attn_out(a)
248
+ return x + a
249
+
250
+ def _forward_ff(self, x: Tensor) -> Tensor:
251
+ f = self.ff_norm(x)
252
+ f = self.ff_in(f)
253
+ f = self.ff_act(f)
254
+ f = self.ff_drop(f)
255
+ f = self.ff_out(f)
256
+ return x + f
257
+
258
+ def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor:
259
+ x = self._forward_attn(x, k, v, attn_mask)
260
+ x = self._forward_ff(x)
261
+ return x
262
+
263
+ class HydraPool(Module):
264
+ def __init__(
265
+ self,
266
+ attn_dim: int,
267
+ head_dim: int,
268
+ n_classes: int,
269
+ *,
270
+ mid_blocks: int = 0,
271
+ roots: tuple[int, int, int] = (0, 0, 0),
272
+ ff_ratio: float = 3.0,
273
+ ff_dropout: float = 0.0,
274
+ input_dim: int = -1,
275
+ output_dim: int = 1,
276
+ device: torch.device | str | None = None,
277
+ dtype: torch.dtype | None = None,
278
+ ) -> None:
279
+ super().__init__()
280
+
281
+ if input_dim < 0:
282
+ input_dim = attn_dim
283
+
284
+ assert attn_dim % head_dim == 0
285
+ n_heads = attn_dim // head_dim
286
+
287
+ self.n_classes = n_classes
288
+ self.head_dim = head_dim
289
+ self.output_dim = output_dim
290
+
291
+ self._has_roots = False
292
+ self._has_ff = False
293
+
294
+ self.q: Parameter | Buffer
295
+ self._q_normed: bool | None
296
+
297
+ if roots != (0, 0, 0):
298
+ self._has_roots = True
299
+ n_roots, n_classroots, n_subclasses = roots
300
+
301
+ if n_classroots < n_roots:
302
+ raise ValueError("Number of classroots cannot be less than the number of roots.")
303
+
304
+ self.cls = Parameter(torch.randn(
305
+ n_heads, n_classes, head_dim,
306
+ device=device, dtype=dtype
307
+ ))
308
+
309
+ self.roots = Parameter(torch.randn(
310
+ n_heads, n_roots, head_dim,
311
+ device=device, dtype=dtype
312
+ )) if n_roots > 0 else None
313
+
314
+ self.clsroots = IndexedAdd(
315
+ n_classroots, dim=-2, weight_shape=(n_heads, -1, 1),
316
+ device=device, dtype=dtype
317
+ ) if n_classroots > 0 else None
318
+
319
+ self.clscls = IndexedAdd(
320
+ n_subclasses, dim=-2, weight_shape=(n_heads, -1, 1),
321
+ inplace=True, device=device, dtype=dtype
322
+ ) if n_subclasses > 0 else None
323
+
324
+ self.q = Buffer(torch.empty(
325
+ n_heads, n_classes, head_dim,
326
+ device=device, dtype=dtype
327
+ ))
328
+ self._q_normed = None
329
+ else:
330
+ self.q = Parameter(torch.randn(
331
+ n_heads, n_classes, head_dim,
332
+ device=device, dtype=dtype
333
+ ))
334
+ self._q_normed = False
335
+
336
+ self.kv = Linear(
337
+ input_dim, attn_dim * 2, bias=False,
338
+ device=device, dtype=dtype
339
+ )
340
+ self.qk_norm = RMSNorm(
341
+ head_dim, eps=1e-5, elementwise_affine=False
342
+ )
343
+
344
+ if ff_ratio > 0.0:
345
+ self._has_ff = True
346
+ hidden_dim = int(attn_dim * ff_ratio)
347
+
348
+ self.ff_norm = LayerNorm(
349
+ attn_dim,
350
+ device=device, dtype=dtype
351
+ )
352
+ self.ff_in = Linear(
353
+ attn_dim, hidden_dim * 2, bias=False,
354
+ device=device, dtype=dtype
355
+ )
356
+ self.ff_act = SwiGLU()
357
+ self.ff_drop = Dropout(ff_dropout)
358
+ self.ff_out = Linear(
359
+ hidden_dim, attn_dim, bias=False,
360
+ device=device, dtype=dtype
361
+ )
362
+ elif mid_blocks > 0:
363
+ raise ValueError("Feedforward required with mid blocks.")
364
+
365
+ self.mid_blocks = ModuleList(
366
+ _MidBlock(
367
+ attn_dim, head_dim, n_classes,
368
+ ff_ratio=ff_ratio, ff_dropout=ff_dropout,
369
+ device=device, dtype=dtype
370
+ ) for _ in range(mid_blocks)
371
+ )
372
+
373
+ self.out_proj = BatchLinear(
374
+ n_classes, attn_dim, output_dim * 2,
375
+ device=device, dtype=dtype
376
+ )
377
+ self.out_act = SwiGLU()
378
+
379
+ @property
380
+ def has_roots(self) -> bool:
381
+ return self._has_roots
382
+
383
+ def get_extra_state(self) -> dict[str, Any]:
384
+ return { "q_normed": self._q_normed }
385
+
386
+ def set_extra_state(self, state: dict[str, Any]) -> None:
387
+ self._q_normed = state["q_normed"]
388
+
389
+ def create_head(self) -> Module:
390
+ if self.output_dim == 1:
391
+ return Flatten(-2)
392
+
393
+ return Mean(-1)
394
+
395
+ def train(self, mode: bool = True) -> Self:
396
+ super().train(mode)
397
+
398
+ if mode:
399
+ if self._has_roots:
400
+ self._q_normed = None
401
+ else:
402
+ self._q_normed = False
403
+ else:
404
+ if self._has_roots:
405
+ self._cache_query()
406
+
407
+ return self
408
+
409
+ def inference(self) -> Self:
410
+ super().train(False)
411
+ self._cache_query()
412
+
413
+ if self._has_roots:
414
+ self._has_roots = False
415
+ self.q = Parameter(self.q)
416
+
417
+ del self.cls, self.roots, self.clsroots, self.clscls
418
+
419
+ return self
420
+
421
+ def _cache_query(self) -> None:
422
+ assert not self.training
423
+
424
+ if self._q_normed:
425
+ return
426
+
427
+ with torch.no_grad():
428
+ self.q.to(device=self.kv.weight.device)
429
+ self.q.copy_(self._forward_q())
430
+ self._q_normed = True
431
+
432
+ def _forward_q(self) -> Tensor:
433
+ match self._q_normed:
434
+ case None:
435
+ assert self._has_roots
436
+
437
+ if self.roots is not None:
438
+ q = self.qk_norm(self.roots)
439
+ q = self.clsroots(self.cls, q)
440
+ else:
441
+ q = self.cls
442
+
443
+ if self.clscls is not None:
444
+ q = self.clscls(q, q.detach())
445
+
446
+ q = self.qk_norm(q)
447
+ return q
448
+
449
+ case False:
450
+ assert not self._has_roots
451
+ return self.qk_norm(self.q)
452
+
453
+ case True:
454
+ return self.q
455
+
456
+ def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]:
457
+ q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1)
458
+
459
+ x = self.kv(x)
460
+ k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0)
461
+ k = self.qk_norm(k)
462
+
463
+ x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
464
+ return rearrange(x, "... h s e -> ... s (h e)"), k, v
465
+
466
+ def _forward_ff(self, x: Tensor) -> Tensor:
467
+ if not self._has_ff:
468
+ return x
469
+
470
+ f = self.ff_norm(x)
471
+ f = self.ff_in(f)
472
+ f = self.ff_act(f)
473
+ f = self.ff_drop(f)
474
+ f = self.ff_out(f)
475
+ return x + f
476
+
477
+ def _forward_out(self, x: Tensor) -> Tensor:
478
+ x = self.out_proj(x)
479
+ x = self.out_act(x)
480
+ return x
481
+
482
+ def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
483
+ x, k, v = self._forward_attn(x, attn_mask)
484
+ x = self._forward_ff(x)
485
+
486
+ for block in self.mid_blocks:
487
+ x = block(x, k, v, attn_mask)
488
+
489
+ x = self._forward_out(x)
490
+ return x
491
+
492
+ def prune_roots(self, retain_classes: set[int]) -> tuple[list[int], list[int]]:
493
+ if not self._has_roots or self.roots is None:
494
+ raise TypeError("No roots to prune.")
495
+
496
+ if self.clscls is not None:
497
+ raise TypeError("Subclass roots cannot be pruned.")
498
+
499
+ used_roots: set[int] = set()
500
+ used_clsroots: list[int] = []
501
+
502
+ assert self.clsroots is not None
503
+ clsroots = [
504
+ cast(list[int], clsroot.tolist())
505
+ for clsroot in self.clsroots.index.cpu().unbind(1)
506
+ ]
507
+
508
+ for idx, (src, dest) in enumerate(clsroots):
509
+ if dest in retain_classes:
510
+ used_roots.add(src)
511
+ used_clsroots.append(idx)
512
+
513
+ sorted_roots = sorted(used_roots)
514
+ del used_roots
515
+
516
+ rootmap = {
517
+ root: idx
518
+ for idx, root in enumerate(sorted_roots)
519
+ }
520
+
521
+ clsmap = {
522
+ cls: idx
523
+ for idx, cls in enumerate(sorted(retain_classes))
524
+ }
525
+
526
+ for idx in used_clsroots:
527
+ src, dest = clsroots[idx]
528
+ self.clsroots.index[0, idx] = rootmap[src]
529
+ self.clsroots.index[1, idx] = clsmap[dest]
530
+
531
+ return sorted_roots, used_clsroots
532
+
533
+ @staticmethod
534
+ def for_state(
535
+ state_dict: dict[str, Any],
536
+ prefix: str = "",
537
+ *,
538
+ ff_dropout: float = 0.0,
539
+ device: torch.device | str | None = None,
540
+ dtype: torch.dtype | None = None,
541
+ ) -> "HydraPool":
542
+ n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape
543
+ attn_dim = n_heads * head_dim
544
+
545
+ roots_t = state_dict.get(f"{prefix}roots")
546
+ clsroots_t = state_dict.get(f"{prefix}clsroots.index")
547
+ clscls_t = state_dict.get(f"{prefix}clscls.index")
548
+ roots = (
549
+ roots_t.size(1) if roots_t is not None else 0,
550
+ clsroots_t.size(1) if clsroots_t is not None else 0,
551
+ clscls_t.size(1) if clscls_t is not None else 0
552
+ )
553
+
554
+ input_dim = state_dict[f"{prefix}kv.weight"].size(1)
555
+ output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2
556
+
557
+ # avoid off-by-one issue due to truncation
558
+ ffout_t = state_dict.get(f"{prefix}ff_out.weight")
559
+ hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0
560
+ ff_ratio = hidden_dim / attn_dim
561
+
562
+ pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.")
563
+ mid_blocks = max([-1, *(
564
+ int(match[1])
565
+ for key in state_dict
566
+ if (match := pattern.match(key)) is not None
567
+ )]) + 1
568
+
569
+ return HydraPool(
570
+ attn_dim,
571
+ head_dim,
572
+ n_classes,
573
+ mid_blocks=mid_blocks,
574
+ roots=roots,
575
+ ff_ratio=ff_ratio,
576
+ ff_dropout=ff_dropout,
577
+ input_dim=input_dim,
578
+ output_dim=output_dim,
579
+ device=device,
580
+ dtype=dtype
581
+ )
image.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Any, Callable, cast
3
+ from warnings import warn, catch_warnings, filterwarnings
4
+
5
+ import numpy as np
6
+ from torch import Tensor
7
+
8
+ from einops import rearrange
9
+
10
+ import PIL.Image as image
11
+ import PIL.ImageCms as image_cms
12
+
13
+ from PIL.Image import Image, Resampling
14
+ from PIL.ImageCms import (
15
+ Direction, Intent, ImageCmsProfile, PyCMSError,
16
+ createProfile, getDefaultIntent, isIntentSupported, profileToProfile
17
+ )
18
+ from PIL.ImageOps import exif_transpose
19
+
20
+ try:
21
+ import pillow_jxl
22
+ except ImportError:
23
+ pass
24
+
25
+ image.MAX_IMAGE_PIXELS = None
26
+
27
+ _SRGB = createProfile(colorSpace='sRGB')
28
+
29
+ _INTENT_FLAGS = {
30
+ Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
31
+ Intent.RELATIVE_COLORIMETRIC: (
32
+ image_cms.FLAGS["HIGHRESPRECALC"] |
33
+ image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
34
+ ),
35
+ Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
36
+ }
37
+
38
+ class CMSWarning(UserWarning):
39
+ def __init__(
40
+ self,
41
+ message: str,
42
+ *,
43
+ path: str | None = None,
44
+ cms_info: dict[str, Any] | None = None,
45
+ cause: Exception | None = None,
46
+ ):
47
+ super().__init__(message)
48
+ self.__cause__ = cause
49
+
50
+ self.path = path
51
+ self.cms_info = cms_info
52
+
53
+ self.add_note(f"path: {path}")
54
+ self.add_note(f"info: {cms_info}")
55
+
56
+ def _coalesce_intent(intent: Intent | int) -> Intent:
57
+ if isinstance(intent, Intent):
58
+ return intent
59
+
60
+ match intent:
61
+ case 0:
62
+ return Intent.PERCEPTUAL
63
+ case 1:
64
+ return Intent.RELATIVE_COLORIMETRIC
65
+ case 2:
66
+ return Intent.SATURATION
67
+ case 3:
68
+ return Intent.ABSOLUTE_COLORIMETRIC
69
+ case _:
70
+ raise ValueError("invalid intent")
71
+
72
+ def _add_info(info: dict[str, Any], source: object, key: str) -> None:
73
+ try:
74
+ if (value := getattr(source, key, None)) is not None:
75
+ info[key] = value
76
+ except Exception:
77
+ pass
78
+
79
+ def open_srgb(
80
+ path: str,
81
+ *,
82
+ resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
83
+ crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
84
+ expect: tuple[int, int] | None = None,
85
+ ) -> Image:
86
+ with open(path, "rb", buffering=(1024 * 1024)) as file:
87
+ img: Image = image.open(file)
88
+
89
+ try:
90
+ out = process_srgb(img, resize=resize, crop=crop, expect=expect)
91
+ except:
92
+ img.close()
93
+ raise
94
+
95
+ if img is not out:
96
+ img.close()
97
+
98
+ return out
99
+
100
+ def process_srgb(
101
+ img: Image,
102
+ *,
103
+ resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
104
+ crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
105
+ expect: tuple[int, int] | None = None,
106
+ ) -> Image:
107
+ img.load()
108
+
109
+ try:
110
+ exif_transpose(img, in_place=True)
111
+ except Exception:
112
+ pass # corrupt EXIF metadata is fine
113
+
114
+ size = (img.width, img.height)
115
+
116
+ if expect is not None and size != expect:
117
+ raise RuntimeError(
118
+ f"Image is {size[0]}x{size[1]}, "
119
+ f"but expected {expect[0]}x{expect[1]}."
120
+ )
121
+
122
+ if (icc_raw := img.info.get("icc_profile")) is not None:
123
+ cms_info: dict[str, Any] = {
124
+ "native_mode": img.mode,
125
+ "transparency": img.has_transparency_data,
126
+ }
127
+
128
+ try:
129
+ profile = ImageCmsProfile(BytesIO(icc_raw))
130
+ _add_info(cms_info, profile.profile, "profile_description")
131
+ _add_info(cms_info, profile.profile, "target")
132
+ _add_info(cms_info, profile.profile, "xcolor_space")
133
+ _add_info(cms_info, profile.profile, "connection_space")
134
+ _add_info(cms_info, profile.profile, "colorimetric_intent")
135
+ _add_info(cms_info, profile.profile, "rendering_intent")
136
+
137
+ working_mode = img.mode
138
+ if img.mode.startswith(("RGB", "BGR", "P")):
139
+ working_mode = "RGBA" if img.has_transparency_data else "RGB"
140
+ elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
141
+ working_mode = "LA" if img.has_transparency_data else "L"
142
+
143
+ if img.mode != working_mode:
144
+ cms_info["working_mode"] = working_mode
145
+ img = img.convert(working_mode)
146
+
147
+ mode = "RGBA" if img.has_transparency_data else "RGB"
148
+
149
+ intent = Intent.RELATIVE_COLORIMETRIC
150
+ if isIntentSupported(profile, intent, Direction.INPUT) != 1:
151
+ intent = _coalesce_intent(getDefaultIntent(profile))
152
+
153
+ cms_info["conversion_intent"] = intent
154
+
155
+ if (flags := _INTENT_FLAGS.get(intent)) is None:
156
+ raise RuntimeError("Unsupported intent")
157
+
158
+ if img.mode == mode:
159
+ profileToProfile(
160
+ img,
161
+ profile,
162
+ _SRGB,
163
+ renderingIntent=intent,
164
+ inPlace=True,
165
+ flags=flags
166
+ )
167
+ else:
168
+ img = cast(Image, profileToProfile(
169
+ img,
170
+ profile,
171
+ _SRGB,
172
+ renderingIntent=intent,
173
+ outputMode=mode,
174
+ flags=flags
175
+ ))
176
+ except Exception as ex:
177
+ pass
178
+
179
+ if img.has_transparency_data:
180
+ if img.mode != "RGBa":
181
+ try:
182
+ img = img.convert("RGBa")
183
+ except ValueError:
184
+ img = img.convert("RGBA").convert("RGBa")
185
+ elif img.mode != "RGB":
186
+ img = img.convert("RGB")
187
+
188
+ if crop is not None and not isinstance(crop, tuple):
189
+ crop = crop(size)
190
+
191
+ if crop is not None:
192
+ left, top, right, bottom = crop
193
+ size = (right - left, top - bottom)
194
+
195
+ if resize is not None and not isinstance(resize, tuple):
196
+ resize = resize(size)
197
+
198
+ if resize is not None and size != resize:
199
+ img = img.resize(
200
+ resize,
201
+ Resampling.LANCZOS,
202
+ box=crop,
203
+ reducing_gap=3.0
204
+ )
205
+ crop = None
206
+
207
+ if crop is not None:
208
+ img = img.crop(crop)
209
+
210
+ return img
211
+
212
+ def put_srgb(img: Image, tensor: Tensor) -> None:
213
+ if img.mode not in ("RGB", "RGBA", "RGBa"):
214
+ raise ValueError(f"Image has non-RGB mode {img.mode}.")
215
+
216
+ np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")
217
+
218
+ def put_srgb_patch(
219
+ img: Image,
220
+ patch_data: Tensor,
221
+ patch_coord: Tensor,
222
+ patch_valid: Tensor,
223
+ patch_size: int
224
+ ) -> None:
225
+ if img.mode not in ("RGB", "RGBA", "RGBa"):
226
+ raise ValueError(f"Image has non-RGB mode {img.mode}.")
227
+
228
+ patches = rearrange(
229
+ np.asarray(img)[:, :, :3],
230
+ "(h p1) (w p2) c -> h w (p1 p2 c)",
231
+ p1=patch_size, p2=patch_size
232
+ )
233
+
234
+ coords = np.stack(np.meshgrid(
235
+ np.arange(patches.shape[0], dtype=np.int16),
236
+ np.arange(patches.shape[1], dtype=np.int16),
237
+ indexing="ij"
238
+ ), axis=-1)
239
+
240
+ coords = rearrange(coords, "h w c -> (h w) c")
241
+ patches = rearrange(patches, "h w p -> (h w) p")
242
+ n = patches.shape[0]
243
+
244
+ np.copyto(patch_data[:n].numpy(), patches, casting="no")
245
+ np.copyto(patch_coord[:n].numpy(), coords, casting="no")
246
+ patch_valid[:n] = True
247
+
248
+ def unpatchify(input: Tensor, coords: Tensor, valid: Tensor) -> Tensor:
249
+ """
250
+ Scatter valid patches from (seqlen, ...) to (H, W, ...), using coords and valid mask.
251
+
252
+ Args:
253
+ input: Tensor of shape (seqlen, ...), patch data.
254
+ coords: Tensor of shape (seqlen, 2), spatial coordinates [y, x] for each patch.
255
+ valid: Tensor of shape (seqlen,), boolean mask for valid patches.
256
+
257
+ Returns:
258
+ Tensor of shape (H, W, ...), with valid patches scattered to their spatial locations.
259
+ """
260
+
261
+ valid_coords = coords[0, valid[0]] # (n_valid, 2)
262
+ valid_patches = input[valid[0]] # (n_valid, ...)
263
+
264
+ h = int(valid_coords[:, 0].max().item()) + 1
265
+ w = int(valid_coords[:, 1].max().item()) + 1
266
+
267
+ output_shape = (h, w) + input.shape[1:]
268
+ output = input.new_zeros(output_shape)
269
+
270
+ output[valid_coords[:, 0], valid_coords[:, 1]] = valid_patches
271
+ return output
inference.bat ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ IF NOT EXIST venv call .\install.bat
2
+
3
+ call venv\Scripts\activate.bat
4
+ python inference.py %*
inference.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import itertools
4
+ import os
5
+ import random
6
+ import sys
7
+
8
+ from typing import Any, Iterable
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from timm.models import NaFlexVit
14
+
15
+ from loader import Loader
16
+ from model import load_model, load_image
17
+
18
+ PATCH_SIZE = 16
19
+
20
+ def from_symmetric(threshold: float) -> float:
21
+ return (threshold + 1.0) / 2.0
22
+
23
+ def to_symmetric(threshold: float) -> float:
24
+ return (threshold - 0.5) * 2.0
25
+
26
+ def classify_output(output: Tensor, tags: list[str], threshold: float = 0.0) -> dict[str, float]:
27
+ return {
28
+ tag: prob
29
+ for tag, prob in zip(tags, output.tolist())
30
+ if prob >= threshold
31
+ }
32
+
33
+ def _run_interactive(
34
+ *,
35
+ model: NaFlexVit, tags: list[str],
36
+ seqlen: int, threshold: float,
37
+ device: str
38
+ ) -> None:
39
+ print(
40
+ "\n"
41
+ "JTP-3 Hydra Interactive Classifier\n"
42
+ " Type 'q' to quit, or 'h' for help.\n"
43
+ " For bulk operations, quit and run again with a path, or '-h' for help.\n"
44
+ )
45
+
46
+ while True:
47
+ print("> ", end="")
48
+ line = input().strip()
49
+
50
+ if line in ("q", "quit", "exit"):
51
+ break
52
+
53
+ if line in ("", "h", "help", "?"):
54
+ print(
55
+ "Provide a file path to classify, or one of the following commands:\n"
56
+ f" threshold T (-1.0 to 1.0, currently {threshold}, 0.2 to 0.8 recommended)\n"
57
+ f" seqlen N (64 to 2048, currently {seqlen}, 1024 recommended)\n"
58
+ " quit (or 'q', 'exit')"
59
+ )
60
+ continue
61
+
62
+ if line.startswith("threshold "):
63
+ try:
64
+ parsed = float(line[10:])
65
+ except Exception as ex:
66
+ print(ex)
67
+ continue
68
+
69
+ if -1.0 <= parsed <= 1.0:
70
+ threshold = parsed
71
+ else:
72
+ print("Threshold must be between -1.0 and 1.0.")
73
+
74
+ continue
75
+
76
+ if line.startswith("seqlen "):
77
+ try:
78
+ parsed = int(line[7:])
79
+ except Exception as ex:
80
+ print(ex)
81
+ continue
82
+
83
+ if 64 <= parsed <= 2048:
84
+ seqlen = parsed
85
+ else:
86
+ print("Sequence length must be between 64 and 2048.")
87
+
88
+ continue
89
+
90
+ try:
91
+ p_t, pc_t, pv_t = load_image(line, PATCH_SIZE, seqlen, False)
92
+ except Exception as ex:
93
+ print(ex)
94
+ continue
95
+
96
+ p_d = p_t.unsqueeze(0).to(device=device, non_blocking=True)
97
+ pc_d = pc_t.unsqueeze(0).to(device=device, non_blocking=True)
98
+ pv_d = pv_t.unsqueeze(0).to(device=device, non_blocking=True)
99
+
100
+ p_d = p_d.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
101
+ pc_d = pc_d.to(dtype=torch.int32)
102
+
103
+ o_d = model(p_d, pc_d, pv_d).float().sigmoid()
104
+ del p_d, pc_d, pv_d
105
+
106
+ classes = classify_output(o_d[0], tags, from_symmetric(threshold))
107
+ for cls, prob in sorted(classes.items(), key=lambda item: (-item[1], item[0])):
108
+ print(f" {to_symmetric(prob)*100:6.1f}% {cls}")
109
+
110
+ del classes
111
+ del o_d
112
+ del p_t, pc_t, pv_t
113
+
114
+ def _run_batched(
115
+ *,
116
+ model: NaFlexVit, tags: list[str],
117
+ paths: list[str], recursive: bool,
118
+ threshold: float, writer: Any, prefix: str,
119
+ batch_size: int, seqlen: int,
120
+ n_workers: int, share_memory: bool,
121
+ device: str,
122
+ ) -> None:
123
+ loader = Loader(
124
+ n_workers,
125
+ patch_size=PATCH_SIZE, max_seqlen=seqlen,
126
+ share_memory=share_memory
127
+ )
128
+
129
+ def dir_iter(path: str) -> Iterable[str]:
130
+ for entry in os.scandir(path):
131
+ if (
132
+ not entry.name.startswith(".")
133
+ and entry.name != "__pycache__"
134
+ ):
135
+ if entry.is_file():
136
+ if not entry.name.endswith((
137
+ ".txt", ".csv", ".json",
138
+ ".py", ".safetensors",
139
+ )):
140
+ yield entry.path
141
+ elif recursive and entry.is_dir():
142
+ yield from dir_iter(entry.path)
143
+
144
+ def paths_iter() -> Iterable[str]:
145
+ for path in paths:
146
+ if os.path.isdir(path):
147
+ yield from dir_iter(path)
148
+ else:
149
+ yield path
150
+
151
+ for batch in itertools.batched(paths_iter(), batch_size):
152
+ patches: list[Tensor] = []
153
+ patch_coords: list[Tensor] = []
154
+ patch_valid: list[Tensor] = []
155
+ batch_paths: list[str] = []
156
+
157
+ for path, result in loader.load(batch).items():
158
+ if isinstance(result, Exception):
159
+ print(f"{repr(path)}: {result}", file=sys.stderr)
160
+ continue
161
+
162
+ batch_paths.append(path)
163
+ patches.append(result[0])
164
+ patch_coords.append(result[1])
165
+ patch_valid.append(result[2])
166
+
167
+ if not patches:
168
+ continue
169
+
170
+ p_d = torch.stack(patches).to(device=device, non_blocking=True)
171
+ pc_d = torch.stack(patch_coords).to(device=device, non_blocking=True)
172
+ pv_d = torch.stack(patch_valid).to(device=device, non_blocking=True)
173
+
174
+ p_d = p_d.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
175
+ pc_d = pc_d.to(dtype=torch.int32)
176
+
177
+ o_d = model(p_d, pc_d, pv_d).float().sigmoid()
178
+ del p_d, pc_d, pv_d
179
+
180
+ for path, output in zip(batch_paths, o_d.cpu()):
181
+ if writer is None:
182
+ with open(
183
+ f"{os.path.splitext(path)[0]}.txt", "w",
184
+ encoding="utf-8"
185
+ ) as file:
186
+ classes = list(classify_output(output, tags, threshold).keys())
187
+ random.shuffle(classes)
188
+
189
+ if prefix:
190
+ try:
191
+ classes.remove(prefix)
192
+ except ValueError:
193
+ pass
194
+
195
+ classes.insert(0, prefix)
196
+
197
+ file.write(', '.join(classes))
198
+ else:
199
+ writer.writerow((path, *(f"{prob.item():.4f}" for prob in output)))
200
+
201
+ del o_d
202
+
203
+ loader.shutdown()
204
+
205
+
206
+ @torch.inference_mode()
207
+ def main() -> None:
208
+ torch.backends.cuda.matmul.allow_tf32 = True
209
+ torch.backends.cudnn.allow_tf32 = True
210
+ torch.backends.cudnn.benchmark = True
211
+
212
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
213
+
214
+ parser = argparse.ArgumentParser(
215
+ description="JTP-3 Hydra",
216
+ epilog="By Project RedRocket. Visit https://huggingface.co/spaces/RedRocket/JTP-3 for more information."
217
+ )
218
+ parser.add_argument("--model", type=str, default="models/jtp-3-hydra.safetensors",
219
+ help="Path to model file.")
220
+ parser.add_argument("-b", "--batch", type=int, default=1,
221
+ help="Batch size.")
222
+ parser.add_argument("-w", "--workers", type=int, default=-1,
223
+ help="Number of dataloader workers. (Default: number of cores)")
224
+ parser.add_argument("--seqlen", type=int, default=1024,
225
+ help="NaFlex sequence length. (Default: 1024)")
226
+ parser.add_argument("-t", "--threshold", type=float, default=0.5,
227
+ help="Classification threshold. (-1.0 to 1.0)")
228
+ parser.add_argument("--no-shm", dest="shm", action="store_false",
229
+ help="Disable shared memory between workers.")
230
+ parser.add_argument("-d", "--device", type=str, default=default_device,
231
+ help=f"Torch device. (Default: {default_device})")
232
+ parser.add_argument("-r", "--recursive", action="store_true",
233
+ help="Classify directories recursively. (Dotfiles will be ignored.)")
234
+ parser.add_argument("-O", "--original-tags", action="store_true",
235
+ help="Do not rewrite tags for compatibility with diffusion models.")
236
+ parser.add_argument("-o", "--output", type=str,
237
+ help="Path for CSV output, or '-' for standard output. If not specified, individual .txt caption files are written.")
238
+ parser.add_argument("-p", "--prefix", type=str, default="",
239
+ help="Prefix all .txt caption files with the specified text. If the prefix matches a tag, the tag will not be repeated.")
240
+ parser.add_argument("paths", nargs="*",
241
+ help="Path to files and directories to classify. If none are specified, run interactively."
242
+ )
243
+
244
+ args = parser.parse_args()
245
+
246
+ if args.batch < 1:
247
+ parser.error("--batch must be at least 1")
248
+ if not 64 <= args.seqlen <= 2048:
249
+ parser.error("--seqlen must be between 64 and 2048 (1024 strongly recommended)")
250
+ if not -1.0 <= args.threshold <= 1.0:
251
+ parser.error("--threshold must be between -1.0 and 1.0")
252
+
253
+ print(f"Loading {repr(args.model)} ...", end="", file=sys.stderr)
254
+ model, tags = load_model(args.model, device=args.device)
255
+ print(f" {len(tags)} tags", file=sys.stderr)
256
+
257
+ def rewrite_tag(tag: str) -> str:
258
+ if not args.original_tags:
259
+ tag = tag.replace("vulva", "pussy")
260
+
261
+ if args.output is None: # caption files
262
+ tag = tag.replace("_", " ")
263
+ tag = tag.replace("(", r"\(")
264
+ tag = tag.replace(")", r"\)")
265
+
266
+ return tag
267
+
268
+ for idx in range(len(tags)):
269
+ tags[idx] = rewrite_tag(tags[idx])
270
+
271
+ if args.paths:
272
+ file: Any = None
273
+ writer: Any = None
274
+
275
+ match args.output:
276
+ case None:
277
+ pass
278
+
279
+ case "-":
280
+ writer = csv.writer(sys.stdout)
281
+
282
+ case _:
283
+ file = open(
284
+ args.file, "w",
285
+ buffering=(1024 * 1024),
286
+ encoding="utf-8",
287
+ newline="",
288
+ )
289
+ writer = csv.writer(file)
290
+ writer.writerow(("filename", *tags))
291
+ try:
292
+ _run_batched(
293
+ model=model,
294
+ tags=tags,
295
+ paths=args.paths,
296
+ recursive=args.recursive,
297
+ threshold=from_symmetric(args.threshold),
298
+ writer=writer, prefix=args.prefix,
299
+ batch_size=args.batch,
300
+ seqlen=args.seqlen,
301
+ n_workers=args.workers,
302
+ share_memory=args.shm,
303
+ device=args.device,
304
+ )
305
+ finally:
306
+ if file is not None:
307
+ file.close()
308
+ else:
309
+ _run_interactive(
310
+ model=model,
311
+ tags=tags,
312
+ seqlen=args.seqlen,
313
+ threshold=args.threshold,
314
+ device=args.device,
315
+ )
316
+
317
+ if __name__ == "__main__":
318
+ main()
install.bat ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python -m venv venv
2
+
3
+ call venv\Scripts\activate.bat
4
+ pip install -r requirements.txt
loader.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import environ, process_cpu_count
2
+ from typing import Iterable, Self
3
+
4
+ from threading import Thread
5
+
6
+ import multiprocessing
7
+ from multiprocessing.queues import SimpleQueue
8
+
9
+ from torch import Tensor
10
+ from torch.multiprocessing.queue import SimpleQueue as TorchQueue
11
+
12
+ from model import load_image
13
+
14
+ class EnvScope:
15
+ __slots__ = ("env", "saved")
16
+
17
+ def __init__(self, env: dict[str, str | int | float | None]) -> None:
18
+ self.env = {
19
+ env: None if value is None else str(value)
20
+ for env, value in env.items()
21
+ }
22
+
23
+ self.saved: dict[str, str | None]
24
+
25
+ def __enter__(self) -> Self:
26
+ if hasattr(self, "saved"):
27
+ raise RuntimeError("EnvScope is already in use.")
28
+
29
+ self.saved = {}
30
+ for env, value in self.env.items():
31
+ self.saved[env] = environ.get(env, None)
32
+
33
+ if value is None:
34
+ del environ[env]
35
+ else:
36
+ environ[env] = value
37
+
38
+ return self
39
+
40
+ def __exit__(self, exc_type, exc_value, tb) -> None:
41
+ for env, value in self.saved.items():
42
+ if value is None:
43
+ del environ[env]
44
+ else:
45
+ environ[env] = value
46
+
47
+ del self.saved
48
+
49
+ class Loader:
50
+ def __init__(
51
+ self, n_workers: int = -1, *,
52
+ patch_size: int = 16, max_seqlen: int = 1024,
53
+ share_memory: bool = True
54
+ ) -> None:
55
+ ctx = multiprocessing.get_context("spawn")
56
+
57
+ self.patch_size = patch_size
58
+ self.max_seqlen = max_seqlen
59
+
60
+ if n_workers < 0:
61
+ n_workers = process_cpu_count() or 1
62
+
63
+ if n_workers == 0:
64
+ self._workers = []
65
+ return
66
+
67
+ self._submission_queue: SimpleQueue[str | None] = SimpleQueue(ctx=ctx)
68
+ self._completion_queue: SimpleQueue[tuple[str, tuple[Tensor, Tensor, Tensor] | Exception] | None] = TorchQueue(ctx=ctx)
69
+ self._workers = [
70
+ ctx.Process(
71
+ target=_worker_fn,
72
+ args=(
73
+ self._submission_queue,
74
+ self._completion_queue,
75
+ patch_size,
76
+ max_seqlen,
77
+ share_memory,
78
+ ),
79
+ name=f"loader-{idx}",
80
+ daemon=True
81
+ )
82
+ for idx in range(n_workers)
83
+ ]
84
+
85
+
86
+ threads = [
87
+ Thread(
88
+ target=proc.start,
89
+ name=f"pstart-{proc.name}",
90
+ daemon=True,
91
+ ) for proc in self._workers
92
+ ]
93
+
94
+ with EnvScope({
95
+ "OMP_NUM_THREADS": 1,
96
+ "OPENBLAS_NUM_THREADS": 1,
97
+ "CUDA_VISIBLE_DEVICES": "",
98
+ }):
99
+ for thread in threads:
100
+ thread.start()
101
+
102
+ for thread in threads:
103
+ thread.join()
104
+
105
+ def load(self, paths: Iterable[str]) -> dict[str, tuple[Tensor, Tensor, Tensor] | Exception]:
106
+ loaded: dict[str, tuple[Tensor, Tensor, Tensor] | Exception] = {}
107
+
108
+ if self._workers:
109
+ count = 0
110
+ for path in paths:
111
+ self._submission_queue.put(path)
112
+ count += 1
113
+
114
+ for _ in range(count):
115
+ result = self._completion_queue.get()
116
+ assert result is not None
117
+ loaded[result[0]] = result[1]
118
+ else:
119
+ for path in paths:
120
+ try:
121
+ loaded[path] = load_image(path, self.patch_size, self.max_seqlen, False)
122
+ except Exception as ex:
123
+ loaded[path] = ex
124
+
125
+ return loaded
126
+
127
+ def shutdown(self, wait: bool = True) -> None:
128
+ for _ in range(len(self._workers)):
129
+ self._submission_queue.put(None)
130
+
131
+ if wait:
132
+ for _ in range(len(self._workers)):
133
+ assert self._completion_queue.get() is None
134
+
135
+ self._workers.clear()
136
+
137
+ def _worker_fn(
138
+ submission_queue: SimpleQueue[str | None],
139
+ completion_queue: SimpleQueue[tuple[str, tuple[Tensor, Tensor, Tensor] | Exception] | None],
140
+ patch_size: int,
141
+ max_seqlen: int,
142
+ share_memory: bool,
143
+ ):
144
+ while (path := submission_queue.get()) is not None:
145
+ try:
146
+ completion_queue.put((path, load_image(path, patch_size, max_seqlen, share_memory)))
147
+ except Exception as ex:
148
+ completion_queue.put((path, ex))
149
+
150
+ completion_queue.put(None)
model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn import Identity
6
+
7
+ import timm
8
+ from timm.models import NaFlexVit
9
+
10
+ from PIL import Image
11
+
12
+ from safetensors import safe_open
13
+
14
+ from image import process_srgb, put_srgb_patch
15
+
16
+ def sdpa_attn_mask(
17
+ patch_valid: Tensor,
18
+ num_prefix_tokens: int = 0,
19
+ symmetric: bool = True,
20
+ q_len: int | None = None,
21
+ dtype: torch.dtype | None = None,
22
+ ) -> Tensor:
23
+ mask = patch_valid.unflatten(-1, (1, 1, -1))
24
+
25
+ if num_prefix_tokens:
26
+ mask = torch.cat((
27
+ torch.ones(
28
+ *mask.shape[:-1], num_prefix_tokens,
29
+ device=patch_valid.device, dtype=torch.bool
30
+ ), mask
31
+ ), dim=-1)
32
+
33
+ return mask
34
+
35
+ timm.models.naflexvit.create_attention_mask = sdpa_attn_mask
36
+
37
+ def get_image_size_for_seq(
38
+ image_hw: tuple[int, int],
39
+ patch_size: int = 16,
40
+ max_seq_len: int = 1024,
41
+ max_ratio: float = 1.0,
42
+ eps: float = 1e-5,
43
+ ) -> tuple[int, int]:
44
+ """Determine image size for sequence length constraint."""
45
+
46
+ assert max_ratio >= 1.0
47
+ assert eps * 2 < max_ratio
48
+
49
+ h, w = image_hw
50
+ max_py = int(max((h * max_ratio) // patch_size, 1))
51
+ max_px = int(max((w * max_ratio) // patch_size, 1))
52
+
53
+ if (max_py * max_px) <= max_seq_len:
54
+ return max_py * patch_size, max_px * patch_size
55
+
56
+ def patchify(ratio: float) -> tuple[int, int]:
57
+ return (
58
+ min(int(ceil((h * ratio) / patch_size)), max_py),
59
+ min(int(ceil((w * ratio) / patch_size)), max_px)
60
+ )
61
+
62
+ py, px = patchify(eps)
63
+ if (py * px) > max_seq_len:
64
+ raise ValueError(f"Image of size {w}x{h} is too large.")
65
+
66
+ ratio = eps
67
+ while (max_ratio - ratio) >= eps:
68
+ mid = (ratio + max_ratio) / 2.0
69
+
70
+ mpy, mpx = patchify(mid)
71
+ seq_len = mpy * mpx
72
+
73
+ if seq_len > max_seq_len:
74
+ max_ratio = mid
75
+ continue
76
+
77
+ ratio = mid
78
+ py = mpy
79
+ px = mpx
80
+
81
+ if seq_len == max_seq_len:
82
+ break
83
+
84
+ assert py >= 1 and px >= 1
85
+ return py * patch_size, px * patch_size
86
+
87
+ def process_image(img: Image.Image, patch_size: int, max_seq_len: int) -> Image.Image:
88
+ def compute_resize(wh: tuple[int, int]) -> tuple[int, int]:
89
+ h, w = get_image_size_for_seq((wh[1], wh[0]), patch_size, max_seq_len)
90
+ return w, h
91
+
92
+ return process_srgb(img, resize=compute_resize)
93
+
94
+ def patchify_image(img: Image.Image, patch_size: int, max_seq_len: int, share_memory: bool = False) -> tuple[Tensor, Tensor, Tensor]:
95
+ patches = torch.zeros(max_seq_len, patch_size * patch_size * 3, device="cpu", dtype=torch.uint8)
96
+ patch_coords = torch.zeros(max_seq_len, 2, device="cpu", dtype=torch.int16)
97
+ patch_valid = torch.zeros(max_seq_len, device="cpu", dtype=torch.bool)
98
+
99
+ if share_memory:
100
+ patches.share_memory_()
101
+ patch_coords.share_memory_()
102
+ patch_valid.share_memory_()
103
+
104
+ put_srgb_patch(img, patches, patch_coords, patch_valid, patch_size)
105
+ return patches, patch_coords, patch_valid
106
+
107
+ def load_image(
108
+ path: str,
109
+ patch_size: int = 16,
110
+ max_seq_len: int = 1024,
111
+ share_memory: bool = False
112
+ ) -> tuple[Tensor, Tensor, Tensor]:
113
+ with open(path, "rb", buffering=(1024 * 1024)) as file:
114
+ img: Image.Image = Image.open(file)
115
+
116
+ try:
117
+ processed = process_image(img, patch_size, max_seq_len)
118
+ except:
119
+ img.close()
120
+ raise
121
+
122
+ if img is not processed:
123
+ img.close()
124
+
125
+ return patchify_image(processed, patch_size, max_seq_len, share_memory)
126
+
127
+ def load_model(path: str, device: torch.device | str | None = None) -> tuple[NaFlexVit, list[str]]:
128
+ with safe_open(path, framework="pt", device="cpu") as file:
129
+ metadata = file.metadata()
130
+
131
+ state_dict = {
132
+ key: file.get_tensor(key)
133
+ for key in file.keys()
134
+ }
135
+
136
+ arch = metadata["modelspec.architecture"]
137
+ if not arch.startswith("naflexvit_so400m_patch16_siglip"):
138
+ raise ValueError(f"Unrecognized model architecture: {arch}")
139
+
140
+ tags = metadata["classifier.labels"].split("\n")
141
+
142
+ model = timm.create_model(
143
+ 'naflexvit_so400m_patch16_siglip',
144
+ pretrained=False, num_classes=0,
145
+ pos_embed_interp_mode="bilinear",
146
+ weight_init="skip", fix_init=False,
147
+ device="cpu", dtype=torch.bfloat16,
148
+ )
149
+
150
+ match arch[31:]:
151
+ case "": # vanilla
152
+ model.reset_classifier(len(tags))
153
+
154
+ case "+rr_slim":
155
+ model.reset_classifier(len(tags))
156
+
157
+ if "attn_pool.q.weight" not in state_dict:
158
+ model.attn_pool.q = Identity()
159
+
160
+ if "head.bias" not in state_dict:
161
+ model.head.bias = None
162
+
163
+ case "+rr_chonker":
164
+ from chonker_pool import ChonkerPool
165
+
166
+ model.attn_pool = ChonkerPool(
167
+ 2, 1152, 72,
168
+ device=device, dtype=torch.bfloat16
169
+ )
170
+ model.head = model.attn_pool.create_head(len(tags))
171
+ model.num_classes = len(tags)
172
+
173
+ case "+rr_hydra":
174
+ from hydra_pool import HydraPool
175
+
176
+ model.attn_pool = HydraPool.for_state(
177
+ state_dict, "attn_pool.",
178
+ device=device, dtype=torch.bfloat16
179
+ )
180
+ model.head = model.attn_pool.create_head()
181
+ model.num_classes = len(tags)
182
+
183
+ state_dict["attn_pool._extra_state"] = { "q_normed": True }
184
+
185
+ case _:
186
+ raise ValueError(f"Unrecognized model architecture: {arch}")
187
+
188
+ model.eval().to(dtype=torch.bfloat16)
189
+ model.load_state_dict(state_dict, strict=True)
190
+ model.to(device=device)
191
+
192
+ return model, tags
models/jtp-3-hydra.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b91691be739ba07d5d7cfb74296ab19ab016d7eabad03542a0a34c90a3d9969
3
+ size 1002587984
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ timm
3
+ numpy
4
+ pillow
5
+ einops
6
+ safetensors
7
+ gradio
8
+ requests