Commit ·
2c02f46
1
Parent(s): 162684a
JTP-3 Hydra Release
Browse files- .gitattributes +2 -0
- .gitignore +2 -0
- README.md +101 -3
- app.bat +4 -0
- app.py +428 -0
- data/hydra.jpg +3 -0
- data/jtp-3-hydra-tags.csv +0 -0
- data/jtp-3-hydra-val.csv +3 -0
- glu.py +40 -0
- hydra_pool.py +581 -0
- image.py +271 -0
- inference.bat +4 -0
- inference.py +318 -0
- install.bat +4 -0
- loader.py +150 -0
- model.py +192 -0
- models/jtp-3-hydra.safetensors +3 -0
- requirements.txt +8 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/jtp-3-hydra-val.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/hydra.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
venv
|
README.md
CHANGED
|
@@ -1,3 +1,101 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- furry
|
| 4 |
+
- e621
|
| 5 |
+
- not-for-all-audiences
|
| 6 |
+
pipeline_tag: image-classification
|
| 7 |
+
base_model: google/siglip2-so400m-patch16-naflex
|
| 8 |
+
library_name: timm
|
| 9 |
+
language:
|
| 10 |
+
- en
|
| 11 |
+
license: apache-2.0
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
<div style="text-align: center;">
|
| 15 |
+
<img style="width: 60%; display: inline-block;" src="https://huggingface.co/RedRocket/JTP-3/resolve/main/data/hydra.jpg">
|
| 16 |
+
|
| 17 |
+
<h1 style="text-align: center; margin-bottom: 0;">JTP-3 Hydra</h1>
|
| 18 |
+
<span style="font-size: large;">e621 Image Classifier by <a href="https://huggingface.co/RedRocket/" style="font-size: large;">Project RedRocket</a></span>
|
| 19 |
+
</div>
|
| 20 |
+
|
| 21 |
+
JTP-3 Hydra is a finetune of the SigLIP2 image classifier with a custom classifier head, designed to predict 7,504 popular tags from [e621](https://e621.net).
|
| 22 |
+
|
| 23 |
+
## Downloading
|
| 24 |
+
Follow Hugging Face instructions to check-out the respository using git.
|
| 25 |
+
If you are unable to do this, manually download all the `.py` files, as well as `model/jtp-3-hydra.safetensors` and `requirements.txt`.
|
| 26 |
+
If you are on Windows, also download the `.bat` files and follow the instructions below for easy installation.
|
| 27 |
+
|
| 28 |
+
## Windows Installation and Usage
|
| 29 |
+
For Windows, ensure you have at least Python 3.12 [installed](https://www.python.org/downloads/windows/) and available on your path.
|
| 30 |
+
Then, double-click ``install.bat`` to run installation, which will create a virtual environment for all the requirements and install them.
|
| 31 |
+
|
| 32 |
+
You can run the WebUI by double clicking ``app.bat`` and navigating your browser to the URL it shows. The link is not shared publicly.
|
| 33 |
+
|
| 34 |
+
On the command line, you can use ``inference.bat`` to do bulk operations such as tagging entire directories. Run ``inference.bat --help`` for help using the command line.
|
| 35 |
+
If you provide a path to a file or directory, it will write ``.txt`` caption files beside each image using the default threshold of ``0.5``.
|
| 36 |
+
|
| 37 |
+
### Linux Installation and Usage
|
| 38 |
+
```sh
|
| 39 |
+
python -m venv venv
|
| 40 |
+
source venv/bin/activate
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
```sh
|
| 45 |
+
source venv/bin/activate
|
| 46 |
+
python app.py
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
```sh
|
| 50 |
+
source venv/bin/activate
|
| 51 |
+
python inference.py --help
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Usage Notes
|
| 55 |
+
The model predicts 7,501 e621 tags, as well as the added rating meta-tags ``safe``, ``questionable``, and ``explicit``.
|
| 56 |
+
The model is trained with implications, but it's predictions are not constrained.
|
| 57 |
+
So, for example, it's possible it can say ``tyrannosaurus rex`` is more likely than ``dinosaur``.
|
| 58 |
+
|
| 59 |
+
The model is trained on images on e621 only, and not on photographs of people or real animals.
|
| 60 |
+
While it has retained some ability to classify photos, this is not in any way supported.
|
| 61 |
+
|
| 62 |
+
The interactive interfaces use a threshold convention of -100% to 100%.
|
| 63 |
+
This is different from other classifier models that generally range from 0% to 100%.
|
| 64 |
+
|
| 65 |
+
The model sees all transparency as a black background.
|
| 66 |
+
|
| 67 |
+
## Technical Notes
|
| 68 |
+
The model consists of [SigLIP2 So400m Patch16 NAFlex](https://huggingface.co/google/siglip2-so400m-patch16-naflex) followed by a custom cross-attention transformer block with learned per-tag queries, SwiGLU feedforward, and per-tag SwiGLU output heads. The per-tag cross attention mechanism is the origin of the moniker "hydra".
|
| 69 |
+
|
| 70 |
+
Subject to the preprocessing mentioned below, the initial set of training tags was all <span style="color:#2e76b4">general</span> tags with at least 1,200 examples, all <span style="color:#ed5d1f">species</span> and <span style="color:#00aa00">character</span> tags with at least 500 examples, a semi-automated selection of <span style="color:#dd00dd">copyright</span> and <span style="color:#666666">meta</span> tags, and a handful of manually-selected <span style="color:#228822">lore</span> tags which are sometimes discernible from the image.
|
| 71 |
+
This resulted in 8,067 tags. After training, tags with very poor validation performance were pruned, resulting in the final set of 7,504 tags.
|
| 72 |
+
|
| 73 |
+
Extensive semi-manual dataset curation was used to improve the quality of the training data.
|
| 74 |
+
The dataset preprocessing code consists of over 12,000 lines of code and data files.
|
| 75 |
+
In addition to correcting implications, manually-defined rules are used to detect common scenarios of missing, incomplete, or contradictory tagging and to selectively mask individual tags on a per-dataset-item basis.
|
| 76 |
+
This is responsible for JTP-3's excellent performance in detecting colors and "combo tags" such as `male_feral`.
|
| 77 |
+
|
| 78 |
+
Margin-focal cross entropy loss based on ASL was used to mitigate the effects of inconsistent labeling on e621 and the extreme class imbalance.
|
| 79 |
+
The dataset was sampled in mini-epochs according to a self-entropy metric.
|
| 80 |
+
Loss weight for negative labels was logarithmically redistributed from images with few tags to those with many tags.
|
| 81 |
+
|
| 82 |
+
Raw validation performance metrics and tag lists are available in the ``data`` folder.
|
| 83 |
+
These can be used to create P/R curves, compute CTI or F<sub>1</sub> scores, or select automated thresholds for each tag.
|
| 84 |
+
The list of supported tags is also embedded in the safetensors metadata as ``classifier.labels``.
|
| 85 |
+
|
| 86 |
+
Internally, the model operates on logits as normal and classification thresholds are expressed in the interval from 0.0 to 1.0.
|
| 87 |
+
This is reflected in the ``data`` files and csv output of ``inference.py``.
|
| 88 |
+
|
| 89 |
+
## Credits
|
| 90 |
+
|
| 91 |
+
RedHotTensors — Architecture design, dataset curation, infrastructure and training, testing, and release.<br>
|
| 92 |
+
DrHead — WebUI, multi-layer CAM, testing, and additional code.<br>
|
| 93 |
+
Thessalo — Advice and testing.<br>
|
| 94 |
+
Google Gemini — Hero image.<br>
|
| 95 |
+
[Furry Diffusion Community](https://discord.com/channels/1019133813105905664/1254974507819733017) — Beta feedback and WebUI testing.
|
| 96 |
+
|
| 97 |
+
### Citations
|
| 98 |
+
|
| 99 |
+
Michael Tschannen, et al. [SigLIP 2.](https://arxiv.org/abs/2502.14786)<br>
|
| 100 |
+
Emanuel Ben-Baruch, et al. [Asymmetric Loss For Multi-Label Classification.](https://arxiv.org/abs/2009.14119)<br>
|
| 101 |
+
Noam Shazeer. [GLU Variants Improve Transformer.](https://arxiv.org/abs/2002.05202)
|
app.bat
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
IF NOT EXIST venv call .\install.bat
|
| 2 |
+
|
| 3 |
+
call venv\Scripts\activate.bat
|
| 4 |
+
python app.py
|
app.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from threading import Lock
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.nn import Parameter
|
| 9 |
+
from torch.nn.functional import sigmoid
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 14 |
+
|
| 15 |
+
import requests
|
| 16 |
+
|
| 17 |
+
from model import load_model, process_image, patchify_image
|
| 18 |
+
from image import unpatchify
|
| 19 |
+
|
| 20 |
+
device = "cuda"
|
| 21 |
+
PATCH_SIZE = 16
|
| 22 |
+
MAX_SEQ_LEN = 1024
|
| 23 |
+
|
| 24 |
+
model_lock = Lock()
|
| 25 |
+
model, tag_list = load_model("models/jtp-3-hydra.safetensors", device=device)
|
| 26 |
+
model.requires_grad_(False)
|
| 27 |
+
|
| 28 |
+
tags = {
|
| 29 |
+
tag.replace("_", " ").replace("vulva", "pussy"): idx
|
| 30 |
+
for idx, tag in enumerate(tag_list)
|
| 31 |
+
}
|
| 32 |
+
tag_list = list(tags.keys())
|
| 33 |
+
|
| 34 |
+
FONT = ImageFont.load_default(24)
|
| 35 |
+
|
| 36 |
+
@torch.no_grad()
|
| 37 |
+
def run_classifier(image: Image.Image, cam_depth: int):
|
| 38 |
+
patches, patch_coords, patch_valid = patchify_image(image, PATCH_SIZE, MAX_SEQ_LEN)
|
| 39 |
+
patches = patches.unsqueeze(0).to(device=device, non_blocking=True)
|
| 40 |
+
patch_coords = patch_coords.unsqueeze(0).to(device=device, non_blocking=True)
|
| 41 |
+
patch_valid = patch_valid.unsqueeze(0).to(device=device, non_blocking=True)
|
| 42 |
+
|
| 43 |
+
patches = patches.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
|
| 44 |
+
patch_coords = patch_coords.to(dtype=torch.int32)
|
| 45 |
+
|
| 46 |
+
with model_lock:
|
| 47 |
+
features = model.forward_intermediates(
|
| 48 |
+
patches,
|
| 49 |
+
patch_coord=patch_coords,
|
| 50 |
+
patch_valid=patch_valid,
|
| 51 |
+
indices=cam_depth,
|
| 52 |
+
output_dict=True,
|
| 53 |
+
output_fmt='NLC'
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
logits = model.forward_head(features["image_features"], patch_valid=patch_valid)
|
| 57 |
+
del features["image_features"]
|
| 58 |
+
|
| 59 |
+
features["patch_coords"] = patch_coords
|
| 60 |
+
features["patch_valid"] = patch_valid
|
| 61 |
+
del patches, patch_coords, patch_valid
|
| 62 |
+
|
| 63 |
+
probits = sigmoid(logits[0].to(dtype=torch.float32))
|
| 64 |
+
probits.mul_(2.0).sub_(1.0) # scale to -1 to 1
|
| 65 |
+
|
| 66 |
+
values, indices = probits.cpu().topk(250)
|
| 67 |
+
predictions = {
|
| 68 |
+
tag_list[idx.item()]: val.item()
|
| 69 |
+
for idx, val in sorted(
|
| 70 |
+
zip(indices, values),
|
| 71 |
+
key=lambda item: item[1].item(),
|
| 72 |
+
reverse=True
|
| 73 |
+
)
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return features, predictions
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def run_cam(
|
| 80 |
+
display_image: Image.Image,
|
| 81 |
+
image: Image.Image, features: dict[str, Tensor],
|
| 82 |
+
tag_idx: int, cam_depth: int
|
| 83 |
+
):
|
| 84 |
+
intermediates = features["image_intermediates"]
|
| 85 |
+
if len(intermediates) < cam_depth:
|
| 86 |
+
features, _ = run_classifier(image, cam_depth)
|
| 87 |
+
intermediates = features["image_intermediates"]
|
| 88 |
+
elif len(intermediates) > cam_depth:
|
| 89 |
+
intermediates = intermediates[-cam_depth:]
|
| 90 |
+
|
| 91 |
+
patch_coords = features["patch_coords"]
|
| 92 |
+
patch_valid = features["patch_valid"]
|
| 93 |
+
|
| 94 |
+
with model_lock:
|
| 95 |
+
saved_q = model.attn_pool.q
|
| 96 |
+
saved_p = model.attn_pool.out_proj.weight
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
model.attn_pool.q = Parameter(saved_q[:, [tag_idx], :], requires_grad=False)
|
| 100 |
+
model.attn_pool.out_proj.weight = Parameter(saved_p[[tag_idx], :, :], requires_grad=False)
|
| 101 |
+
|
| 102 |
+
with torch.enable_grad():
|
| 103 |
+
for intermediate in intermediates:
|
| 104 |
+
intermediate.requires_grad_(True).retain_grad()
|
| 105 |
+
model.forward_head(intermediate, patch_valid=patch_valid)[0, 0].backward()
|
| 106 |
+
finally:
|
| 107 |
+
model.attn_pool.q = saved_q
|
| 108 |
+
model.attn_pool.out_proj.weight = saved_p
|
| 109 |
+
|
| 110 |
+
cam_1d: Tensor | None = None
|
| 111 |
+
for intermediate in intermediates:
|
| 112 |
+
patch_grad = (intermediate.grad.float() * intermediate.sign()).sum(dim=(0, 2))
|
| 113 |
+
intermediate.grad = None
|
| 114 |
+
|
| 115 |
+
if cam_1d is None:
|
| 116 |
+
cam_1d = patch_grad
|
| 117 |
+
else:
|
| 118 |
+
cam_1d.add_(patch_grad)
|
| 119 |
+
|
| 120 |
+
assert cam_1d is not None
|
| 121 |
+
|
| 122 |
+
cam_2d = unpatchify(cam_1d, patch_coords, patch_valid).cpu().numpy()
|
| 123 |
+
return cam_composite(display_image, cam_2d), features
|
| 124 |
+
|
| 125 |
+
def cam_composite(image: Image.Image, cam: np.ndarray):
|
| 126 |
+
"""
|
| 127 |
+
Overlays CAM on image and returns a PIL image.
|
| 128 |
+
Args:
|
| 129 |
+
image_pil: PIL Image (RGB)
|
| 130 |
+
cam: 2D numpy array (activation map)
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
PIL.Image.Image with overlay
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
cam_abs = np.abs(cam)
|
| 137 |
+
cam_scale = cam_abs.max()
|
| 138 |
+
|
| 139 |
+
cam_rgba = np.dstack((
|
| 140 |
+
(cam < 0).astype(np.float32),
|
| 141 |
+
(cam > 0).astype(np.float32),
|
| 142 |
+
np.zeros_like(cam, dtype=np.float32),
|
| 143 |
+
cam_abs * (0.5 / cam_scale),
|
| 144 |
+
)) # Shape: (H, W, 4)
|
| 145 |
+
|
| 146 |
+
cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8))
|
| 147 |
+
cam_pil = cam_pil.resize(image.size, resample=Image.Resampling.NEAREST)
|
| 148 |
+
|
| 149 |
+
image = Image.blend(
|
| 150 |
+
image.convert('RGBA'),
|
| 151 |
+
image.convert('L').convert('RGBA'),
|
| 152 |
+
0.33
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
image = Image.alpha_composite(image, cam_pil)
|
| 156 |
+
|
| 157 |
+
draw = ImageDraw.Draw(image)
|
| 158 |
+
draw.text(
|
| 159 |
+
(image.width - 7, image.height - 7),
|
| 160 |
+
f"{cam_scale.item():.4g}",
|
| 161 |
+
anchor="rd", font=FONT, fill=(32, 32, 255, 255)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return image
|
| 165 |
+
|
| 166 |
+
def filter_tags(predictions: dict[str, float], threshold: float):
|
| 167 |
+
predictions = {
|
| 168 |
+
key: value
|
| 169 |
+
for key, value in predictions.items()
|
| 170 |
+
if value >= threshold
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
tag_str = ", ".join(predictions.keys())
|
| 174 |
+
return tag_str, predictions
|
| 175 |
+
|
| 176 |
+
def resize_image(image: Image.Image) -> Image.Image:
|
| 177 |
+
longest_side = max(image.height, image.width)
|
| 178 |
+
if longest_side < 1080:
|
| 179 |
+
return image
|
| 180 |
+
|
| 181 |
+
scale = 1080 / longest_side
|
| 182 |
+
return image.resize(
|
| 183 |
+
(
|
| 184 |
+
int(round(image.width * scale)),
|
| 185 |
+
int(round(image.height * scale)),
|
| 186 |
+
),
|
| 187 |
+
resample=Image.Resampling.LANCZOS,
|
| 188 |
+
reducing_gap=3.0
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def image_upload(image: Image.Image):
|
| 192 |
+
display_image = resize_image(image)
|
| 193 |
+
processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
|
| 194 |
+
|
| 195 |
+
if display_image is not image and processed_image is not image:
|
| 196 |
+
image.close()
|
| 197 |
+
|
| 198 |
+
return (
|
| 199 |
+
"", {}, "None", "",
|
| 200 |
+
gr.skip() if display_image is image else display_image, display_image,
|
| 201 |
+
processed_image,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def url_submit(url: str):
|
| 205 |
+
resp = requests.get(url, timeout=10)
|
| 206 |
+
resp.raise_for_status()
|
| 207 |
+
|
| 208 |
+
image = Image.open(BytesIO(resp.content))
|
| 209 |
+
display_image = resize_image(image)
|
| 210 |
+
processed_image = process_image(image, PATCH_SIZE, MAX_SEQ_LEN)
|
| 211 |
+
|
| 212 |
+
if display_image is not image and processed_image is not image:
|
| 213 |
+
image.close()
|
| 214 |
+
|
| 215 |
+
return (
|
| 216 |
+
"", {}, "None",
|
| 217 |
+
display_image, display_image,
|
| 218 |
+
processed_image,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def image_changed(image: Image.Image, threshold: float, cam_depth: int):
|
| 222 |
+
features, predictions = run_classifier(image, cam_depth)
|
| 223 |
+
return *filter_tags(predictions, threshold), features, predictions
|
| 224 |
+
|
| 225 |
+
def image_clear():
|
| 226 |
+
return (
|
| 227 |
+
"", {}, "None", "",
|
| 228 |
+
None, None,
|
| 229 |
+
None, None, {},
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def cam_changed(
|
| 233 |
+
display_image: Image.Image,
|
| 234 |
+
image: Image.Image, features: dict[str, Tensor],
|
| 235 |
+
tag: str, cam_depth: int
|
| 236 |
+
):
|
| 237 |
+
if tag == "None":
|
| 238 |
+
return display_image, features
|
| 239 |
+
|
| 240 |
+
return run_cam(display_image, image, features, tags[tag], cam_depth)
|
| 241 |
+
|
| 242 |
+
def tag_box_select(evt: gr.SelectData):
|
| 243 |
+
return evt.value
|
| 244 |
+
|
| 245 |
+
custom_css = """
|
| 246 |
+
.output-class { display: none; }
|
| 247 |
+
.inferno-slider input[type=range] {
|
| 248 |
+
background: linear-gradient(to right,
|
| 249 |
+
#000004, #1b0c41, #4a0c6b, #781c6d,
|
| 250 |
+
#a52c60, #cf4446, #ed6925, #fb9b06,
|
| 251 |
+
#f7d13d, #fcffa4
|
| 252 |
+
) !important;
|
| 253 |
+
background-size: 100% 100% !important;
|
| 254 |
+
}
|
| 255 |
+
#image_container-image {
|
| 256 |
+
width: 100%;
|
| 257 |
+
aspect-ratio: 1 / 1;
|
| 258 |
+
max-height: 100%;
|
| 259 |
+
}
|
| 260 |
+
#image_container img {
|
| 261 |
+
object-fit: contain !important;
|
| 262 |
+
}
|
| 263 |
+
.show-api, .show-api-divider {
|
| 264 |
+
display: none !important;
|
| 265 |
+
}
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
with gr.Blocks(
|
| 269 |
+
title="RedRocket JTP-3 Hydra",
|
| 270 |
+
css=custom_css,
|
| 271 |
+
analytics_enabled=False,
|
| 272 |
+
) as demo:
|
| 273 |
+
display_image_state = gr.State()
|
| 274 |
+
image_state = gr.State()
|
| 275 |
+
features_state = gr.State()
|
| 276 |
+
predictions_state = gr.State(value={})
|
| 277 |
+
|
| 278 |
+
gr.HTML(
|
| 279 |
+
"<h1 style='display:flex; flex-flow: row nowrap; align-items: center;'>"
|
| 280 |
+
"<a href='https://huggingface.co/RedRocket' target='_blank'>"
|
| 281 |
+
"<img src='https://huggingface.co/spaces/RedRocket/README/resolve/main/RedRocket.png' style='width: 2em; margin-right: 0.5em;'>"
|
| 282 |
+
"</a>"
|
| 283 |
+
"<span>"
|
| 284 |
+
"<a href='https://huggingface.co/RedRocket' target='_blank'>RedRocket</a> – JTP-3 Hydra"
|
| 285 |
+
"</span>"
|
| 286 |
+
"</h1>"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
with gr.Row():
|
| 290 |
+
with gr.Column():
|
| 291 |
+
with gr.Column():
|
| 292 |
+
image = gr.Image(
|
| 293 |
+
sources=['upload', 'clipboard'], type='pil',
|
| 294 |
+
show_label=False,
|
| 295 |
+
show_download_button=False,
|
| 296 |
+
show_share_button=False,
|
| 297 |
+
elem_id="image_container"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
url = gr.Textbox(
|
| 301 |
+
label="Upload Image via Url:",
|
| 302 |
+
placeholder="https://example.com/image.jpg",
|
| 303 |
+
max_lines=1,
|
| 304 |
+
submit_btn="⮝",
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
with gr.Column():
|
| 308 |
+
cam_tag = gr.Dropdown(
|
| 309 |
+
value="None", choices=["None"] + tag_list,
|
| 310 |
+
label="CAM Attention Overlay (You can also click a tag on the right.)", show_label=True
|
| 311 |
+
)
|
| 312 |
+
cam_depth = gr.Slider(
|
| 313 |
+
minimum=1, maximum=27, step=1, value=1,
|
| 314 |
+
label="CAM Depth (1=fastest, more precise; 27=slowest, more general)"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
with gr.Column():
|
| 318 |
+
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Tag Threshold")
|
| 319 |
+
tag_string = gr.Textbox(lines=3, label="Tags", show_label=True, show_copy_button=True)
|
| 320 |
+
tag_box = gr.Label(num_top_classes=250, show_label=False, show_heading=False)
|
| 321 |
+
|
| 322 |
+
image.upload(
|
| 323 |
+
fn=image_upload,
|
| 324 |
+
inputs=[image],
|
| 325 |
+
outputs=[
|
| 326 |
+
tag_string, tag_box, cam_tag, url,
|
| 327 |
+
image, display_image_state,
|
| 328 |
+
image_state,
|
| 329 |
+
],
|
| 330 |
+
show_progress='minimal',
|
| 331 |
+
show_progress_on=[image]
|
| 332 |
+
).then(
|
| 333 |
+
fn=image_changed,
|
| 334 |
+
inputs=[image_state, threshold_slider, cam_depth],
|
| 335 |
+
outputs=[
|
| 336 |
+
tag_string, tag_box,
|
| 337 |
+
features_state, predictions_state,
|
| 338 |
+
],
|
| 339 |
+
show_progress='minimal',
|
| 340 |
+
show_progress_on=[tag_box]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
url.submit(
|
| 344 |
+
fn=url_submit,
|
| 345 |
+
inputs=[url],
|
| 346 |
+
outputs=[
|
| 347 |
+
tag_string, tag_box, cam_tag,
|
| 348 |
+
image, display_image_state,
|
| 349 |
+
image_state,
|
| 350 |
+
],
|
| 351 |
+
show_progress='minimal',
|
| 352 |
+
show_progress_on=[url]
|
| 353 |
+
).then(
|
| 354 |
+
fn=image_changed,
|
| 355 |
+
inputs=[image_state, threshold_slider, cam_depth],
|
| 356 |
+
outputs=[
|
| 357 |
+
tag_string, tag_box,
|
| 358 |
+
features_state, predictions_state,
|
| 359 |
+
],
|
| 360 |
+
show_progress='minimal',
|
| 361 |
+
show_progress_on=[tag_box]
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
image.clear(
|
| 365 |
+
fn=image_clear,
|
| 366 |
+
inputs=[],
|
| 367 |
+
outputs=[
|
| 368 |
+
tag_string, tag_box, cam_tag, url,
|
| 369 |
+
image, display_image_state,
|
| 370 |
+
image_state, features_state, predictions_state,
|
| 371 |
+
],
|
| 372 |
+
show_progress='hidden'
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
threshold_slider.input(
|
| 376 |
+
fn=filter_tags,
|
| 377 |
+
inputs=[predictions_state, threshold_slider],
|
| 378 |
+
outputs=[tag_string, tag_box],
|
| 379 |
+
trigger_mode='always_last',
|
| 380 |
+
show_progress='hidden'
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
cam_tag.input(
|
| 384 |
+
fn=cam_changed,
|
| 385 |
+
inputs=[
|
| 386 |
+
display_image_state,
|
| 387 |
+
image_state, features_state,
|
| 388 |
+
cam_tag, cam_depth,
|
| 389 |
+
],
|
| 390 |
+
outputs=[image, features_state],
|
| 391 |
+
trigger_mode='always_last',
|
| 392 |
+
show_progress='minimal',
|
| 393 |
+
show_progress_on=[cam_tag]
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
cam_depth.input(
|
| 397 |
+
fn=cam_changed,
|
| 398 |
+
inputs=[
|
| 399 |
+
display_image_state,
|
| 400 |
+
image_state, features_state,
|
| 401 |
+
cam_tag, cam_depth,
|
| 402 |
+
],
|
| 403 |
+
outputs=[image, features_state],
|
| 404 |
+
trigger_mode='always_last',
|
| 405 |
+
show_progress='minimal',
|
| 406 |
+
show_progress_on=[cam_depth]
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
tag_box.select(
|
| 410 |
+
fn=tag_box_select,
|
| 411 |
+
inputs=[],
|
| 412 |
+
outputs=[cam_tag],
|
| 413 |
+
trigger_mode='always_last',
|
| 414 |
+
show_progress='hidden',
|
| 415 |
+
).then(
|
| 416 |
+
fn=cam_changed,
|
| 417 |
+
inputs=[
|
| 418 |
+
display_image_state,
|
| 419 |
+
image_state, features_state,
|
| 420 |
+
cam_tag, cam_depth,
|
| 421 |
+
],
|
| 422 |
+
outputs=[image, features_state],
|
| 423 |
+
show_progress='minimal',
|
| 424 |
+
show_progress_on=[cam_tag]
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
demo.launch()
|
data/hydra.jpg
ADDED
|
Git LFS Details
|
data/jtp-3-hydra-tags.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/jtp-3-hydra-val.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7e9901ae3dec04942ed5ecb7dda0fbbf01afc4d27b1b5f80d509b664952ff77
|
| 3 |
+
size 42149079
|
glu.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.nn import Module
|
| 6 |
+
from torch.nn.functional import silu, gelu
|
| 7 |
+
|
| 8 |
+
class GatedUnit(Module):
|
| 9 |
+
def __init__(self, dim: int = -1) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.dim = dim
|
| 13 |
+
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def _activation(self, x: Tensor) -> Tensor:
|
| 16 |
+
...
|
| 17 |
+
|
| 18 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 19 |
+
f, g = x.chunk(2, dim=self.dim)
|
| 20 |
+
return self._activation(f) * g
|
| 21 |
+
|
| 22 |
+
class SwiGLU(GatedUnit):
|
| 23 |
+
def __init__(self, dim: int = -1) -> None:
|
| 24 |
+
super().__init__(dim)
|
| 25 |
+
|
| 26 |
+
def _activation(self, x: Tensor) -> Tensor:
|
| 27 |
+
return silu(x)
|
| 28 |
+
|
| 29 |
+
class GeGLU(GatedUnit):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim: int = -1,
|
| 33 |
+
approximate: Literal["tanh", "none"] = "tanh"
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__(dim)
|
| 36 |
+
|
| 37 |
+
self.approximate = approximate
|
| 38 |
+
|
| 39 |
+
def _activation(self, x: Tensor) -> Tensor:
|
| 40 |
+
return gelu(x, self.approximate)
|
hydra_pool.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from math import sqrt
|
| 4 |
+
from typing import Any, Iterable, Self, cast
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.nn import (
|
| 9 |
+
Module, ModuleList, Parameter, Buffer,
|
| 10 |
+
Linear, LayerNorm, RMSNorm, Dropout, Flatten,
|
| 11 |
+
init
|
| 12 |
+
)
|
| 13 |
+
from torch.nn.functional import pad, scaled_dot_product_attention
|
| 14 |
+
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
from glu import SwiGLU
|
| 18 |
+
|
| 19 |
+
class IndexedAdd(Module):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
n_indices: int,
|
| 23 |
+
dim: int,
|
| 24 |
+
weight_shape: tuple[int, ...] | None = None,
|
| 25 |
+
*,
|
| 26 |
+
inplace: bool = False,
|
| 27 |
+
device: torch.device | str | None = None,
|
| 28 |
+
dtype: torch.dtype | None = None,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.dim = dim
|
| 33 |
+
self.inplace = inplace
|
| 34 |
+
|
| 35 |
+
self.index = Buffer(torch.empty(
|
| 36 |
+
2, n_indices,
|
| 37 |
+
device=device, dtype=torch.int32
|
| 38 |
+
))
|
| 39 |
+
|
| 40 |
+
self.weight = Parameter(torch.ones(
|
| 41 |
+
*(sz if sz != -1 else n_indices for sz in weight_shape),
|
| 42 |
+
device=device, dtype=dtype
|
| 43 |
+
)) if weight_shape is not None else None
|
| 44 |
+
|
| 45 |
+
def _save_to_state_dict(
|
| 46 |
+
self,
|
| 47 |
+
destination: dict[str, Any],
|
| 48 |
+
prefix: str,
|
| 49 |
+
keep_vars: bool
|
| 50 |
+
) -> None:
|
| 51 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 52 |
+
|
| 53 |
+
if keep_vars:
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
index_key = f"{prefix}index"
|
| 58 |
+
index = destination[index_key]
|
| 59 |
+
|
| 60 |
+
min_index = index.amin(None).item()
|
| 61 |
+
if min_index >= 0:
|
| 62 |
+
max_index = index.amax(None).item()
|
| 63 |
+
if max_index < (1 << 8):
|
| 64 |
+
destination[index_key] = index.to(dtype=torch.uint8)
|
| 65 |
+
elif max_index < (1 << 16):
|
| 66 |
+
destination[index_key] = index.to(dtype=torch.uint16)
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
def load_indices(self, indices: Iterable[tuple[int, int]], *, mean: bool = False) -> None:
|
| 70 |
+
if mean:
|
| 71 |
+
if self.weight is None:
|
| 72 |
+
raise ValueError("No weights to initialize with means.")
|
| 73 |
+
|
| 74 |
+
groups: dict[int, list[int]] = defaultdict(list)
|
| 75 |
+
|
| 76 |
+
idx = -1
|
| 77 |
+
for idx, (src, dst) in enumerate(indices):
|
| 78 |
+
self.index[0, idx] = src
|
| 79 |
+
self.index[1, idx] = dst
|
| 80 |
+
|
| 81 |
+
if mean:
|
| 82 |
+
groups[dst].append(idx)
|
| 83 |
+
|
| 84 |
+
if (idx + 1) != self.index.size(1):
|
| 85 |
+
raise IndexError(f"Expected {self.index.size(1)} indices, but got {idx + 1}.")
|
| 86 |
+
|
| 87 |
+
if not mean:
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
assert self.weight is not None
|
| 91 |
+
|
| 92 |
+
for idxs in groups.values():
|
| 93 |
+
if len(idxs) < 2:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
self.weight.index_fill_(
|
| 97 |
+
self.dim,
|
| 98 |
+
torch.tensor(idxs, device=self.weight.device, dtype=torch.int64),
|
| 99 |
+
1.0 / len(idxs)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, dst: Tensor, src: Tensor) -> Tensor:
|
| 103 |
+
src = src.index_select(self.dim, self.index[0])
|
| 104 |
+
|
| 105 |
+
if self.weight is not None:
|
| 106 |
+
src.mul_(self.weight)
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
dst.index_add_(self.dim, self.index[1], src)
|
| 110 |
+
if self.inplace else
|
| 111 |
+
dst.index_add(self.dim, self.index[1], src)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
class BatchLinear(Module):
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
batch_shape: tuple[int, ...] | int,
|
| 118 |
+
in_features: int,
|
| 119 |
+
out_features: int,
|
| 120 |
+
*,
|
| 121 |
+
bias: bool = False,
|
| 122 |
+
flatten: bool = False,
|
| 123 |
+
bias_inplace: bool = True,
|
| 124 |
+
device: torch.device | str | None = None,
|
| 125 |
+
dtype: torch.dtype | None = None,
|
| 126 |
+
) -> None:
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
if isinstance(batch_shape, int):
|
| 130 |
+
batch_shape = (batch_shape,)
|
| 131 |
+
elif not batch_shape:
|
| 132 |
+
raise ValueError("At least one batch dimension is required.")
|
| 133 |
+
|
| 134 |
+
self.flatten = -(len(batch_shape) + 1) if flatten else 0
|
| 135 |
+
|
| 136 |
+
self.weight = Parameter(torch.empty(
|
| 137 |
+
*batch_shape, in_features, out_features,
|
| 138 |
+
device=device, dtype=dtype
|
| 139 |
+
))
|
| 140 |
+
|
| 141 |
+
bt = self.weight.flatten(end_dim=-3).mT
|
| 142 |
+
for idx in range(bt.size(0)):
|
| 143 |
+
init.kaiming_uniform_(bt[idx], a=sqrt(5))
|
| 144 |
+
|
| 145 |
+
self.bias = Parameter(torch.zeros(
|
| 146 |
+
*batch_shape, out_features,
|
| 147 |
+
device=device, dtype=dtype
|
| 148 |
+
)) if bias else None
|
| 149 |
+
|
| 150 |
+
self.bias_inplace = bias_inplace
|
| 151 |
+
|
| 152 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 153 |
+
# ... B... 1 I @ B... I O -> ... B... O
|
| 154 |
+
x = torch.matmul(x.unsqueeze(-2), self.weight).squeeze(-2)
|
| 155 |
+
|
| 156 |
+
if self.bias is not None:
|
| 157 |
+
if self.bias_inplace:
|
| 158 |
+
x.add_(self.bias)
|
| 159 |
+
else:
|
| 160 |
+
x = x + self.bias
|
| 161 |
+
|
| 162 |
+
if self.flatten:
|
| 163 |
+
x = x.flatten(self.flatten)
|
| 164 |
+
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
class Mean(Module):
|
| 168 |
+
def __init__(self, dim: tuple[int, ...] | int = -1, *, keepdim: bool = False) -> None:
|
| 169 |
+
super().__init__()
|
| 170 |
+
|
| 171 |
+
self.dim = dim
|
| 172 |
+
self.keepdim = keepdim
|
| 173 |
+
|
| 174 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 175 |
+
return x.mean(self.dim, self.keepdim)
|
| 176 |
+
|
| 177 |
+
class _MidBlock(Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
attn_dim: int,
|
| 181 |
+
head_dim: int,
|
| 182 |
+
n_classes: int,
|
| 183 |
+
*,
|
| 184 |
+
ff_ratio: float,
|
| 185 |
+
ff_dropout: float,
|
| 186 |
+
q_cls_inplace: bool = True,
|
| 187 |
+
device: torch.device | str | None,
|
| 188 |
+
dtype: torch.dtype | None,
|
| 189 |
+
) -> None:
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
self.head_dim = head_dim
|
| 193 |
+
self.q_cls_inplace = q_cls_inplace
|
| 194 |
+
|
| 195 |
+
hidden_dim = int(attn_dim * ff_ratio)
|
| 196 |
+
|
| 197 |
+
self.q_proj = Linear(
|
| 198 |
+
attn_dim, attn_dim, bias=False,
|
| 199 |
+
device=device, dtype=dtype
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
self.q_cls = Parameter(torch.zeros(
|
| 203 |
+
n_classes, attn_dim,
|
| 204 |
+
device=device, dtype=dtype
|
| 205 |
+
))
|
| 206 |
+
|
| 207 |
+
self.q_norm = RMSNorm(head_dim, eps=1e-5, elementwise_affine=False)
|
| 208 |
+
|
| 209 |
+
self.attn_out = Linear(
|
| 210 |
+
attn_dim, attn_dim, bias=False,
|
| 211 |
+
device=device, dtype=dtype
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.ff_norm = LayerNorm(
|
| 215 |
+
attn_dim,
|
| 216 |
+
device=device, dtype=dtype
|
| 217 |
+
)
|
| 218 |
+
self.ff_in = Linear(
|
| 219 |
+
attn_dim, hidden_dim * 2, bias=False,
|
| 220 |
+
device=device, dtype=dtype
|
| 221 |
+
)
|
| 222 |
+
self.ff_act = SwiGLU()
|
| 223 |
+
self.ff_drop = Dropout(ff_dropout)
|
| 224 |
+
self.ff_out = Linear(
|
| 225 |
+
hidden_dim, attn_dim, bias=False,
|
| 226 |
+
device=device, dtype=dtype
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def _forward_q(self, x: Tensor) -> Tensor:
|
| 230 |
+
x = self.q_proj(x)
|
| 231 |
+
|
| 232 |
+
if self.q_cls_inplace:
|
| 233 |
+
x.add_(self.q_cls)
|
| 234 |
+
else:
|
| 235 |
+
x = x + self.q_cls
|
| 236 |
+
|
| 237 |
+
x = self.q_norm(x)
|
| 238 |
+
x = rearrange(x, "... s (h e) -> ... h s e", e=self.head_dim)
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
def _forward_attn(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None) -> Tensor:
|
| 242 |
+
a = scaled_dot_product_attention(
|
| 243 |
+
self._forward_q(x), k, v,
|
| 244 |
+
attn_mask=attn_mask
|
| 245 |
+
)
|
| 246 |
+
a = rearrange(a, "... h s e -> ... s (h e)")
|
| 247 |
+
a = self.attn_out(a)
|
| 248 |
+
return x + a
|
| 249 |
+
|
| 250 |
+
def _forward_ff(self, x: Tensor) -> Tensor:
|
| 251 |
+
f = self.ff_norm(x)
|
| 252 |
+
f = self.ff_in(f)
|
| 253 |
+
f = self.ff_act(f)
|
| 254 |
+
f = self.ff_drop(f)
|
| 255 |
+
f = self.ff_out(f)
|
| 256 |
+
return x + f
|
| 257 |
+
|
| 258 |
+
def forward(self, x: Tensor, k: Tensor, v: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
| 259 |
+
x = self._forward_attn(x, k, v, attn_mask)
|
| 260 |
+
x = self._forward_ff(x)
|
| 261 |
+
return x
|
| 262 |
+
|
| 263 |
+
class HydraPool(Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
attn_dim: int,
|
| 267 |
+
head_dim: int,
|
| 268 |
+
n_classes: int,
|
| 269 |
+
*,
|
| 270 |
+
mid_blocks: int = 0,
|
| 271 |
+
roots: tuple[int, int, int] = (0, 0, 0),
|
| 272 |
+
ff_ratio: float = 3.0,
|
| 273 |
+
ff_dropout: float = 0.0,
|
| 274 |
+
input_dim: int = -1,
|
| 275 |
+
output_dim: int = 1,
|
| 276 |
+
device: torch.device | str | None = None,
|
| 277 |
+
dtype: torch.dtype | None = None,
|
| 278 |
+
) -> None:
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
if input_dim < 0:
|
| 282 |
+
input_dim = attn_dim
|
| 283 |
+
|
| 284 |
+
assert attn_dim % head_dim == 0
|
| 285 |
+
n_heads = attn_dim // head_dim
|
| 286 |
+
|
| 287 |
+
self.n_classes = n_classes
|
| 288 |
+
self.head_dim = head_dim
|
| 289 |
+
self.output_dim = output_dim
|
| 290 |
+
|
| 291 |
+
self._has_roots = False
|
| 292 |
+
self._has_ff = False
|
| 293 |
+
|
| 294 |
+
self.q: Parameter | Buffer
|
| 295 |
+
self._q_normed: bool | None
|
| 296 |
+
|
| 297 |
+
if roots != (0, 0, 0):
|
| 298 |
+
self._has_roots = True
|
| 299 |
+
n_roots, n_classroots, n_subclasses = roots
|
| 300 |
+
|
| 301 |
+
if n_classroots < n_roots:
|
| 302 |
+
raise ValueError("Number of classroots cannot be less than the number of roots.")
|
| 303 |
+
|
| 304 |
+
self.cls = Parameter(torch.randn(
|
| 305 |
+
n_heads, n_classes, head_dim,
|
| 306 |
+
device=device, dtype=dtype
|
| 307 |
+
))
|
| 308 |
+
|
| 309 |
+
self.roots = Parameter(torch.randn(
|
| 310 |
+
n_heads, n_roots, head_dim,
|
| 311 |
+
device=device, dtype=dtype
|
| 312 |
+
)) if n_roots > 0 else None
|
| 313 |
+
|
| 314 |
+
self.clsroots = IndexedAdd(
|
| 315 |
+
n_classroots, dim=-2, weight_shape=(n_heads, -1, 1),
|
| 316 |
+
device=device, dtype=dtype
|
| 317 |
+
) if n_classroots > 0 else None
|
| 318 |
+
|
| 319 |
+
self.clscls = IndexedAdd(
|
| 320 |
+
n_subclasses, dim=-2, weight_shape=(n_heads, -1, 1),
|
| 321 |
+
inplace=True, device=device, dtype=dtype
|
| 322 |
+
) if n_subclasses > 0 else None
|
| 323 |
+
|
| 324 |
+
self.q = Buffer(torch.empty(
|
| 325 |
+
n_heads, n_classes, head_dim,
|
| 326 |
+
device=device, dtype=dtype
|
| 327 |
+
))
|
| 328 |
+
self._q_normed = None
|
| 329 |
+
else:
|
| 330 |
+
self.q = Parameter(torch.randn(
|
| 331 |
+
n_heads, n_classes, head_dim,
|
| 332 |
+
device=device, dtype=dtype
|
| 333 |
+
))
|
| 334 |
+
self._q_normed = False
|
| 335 |
+
|
| 336 |
+
self.kv = Linear(
|
| 337 |
+
input_dim, attn_dim * 2, bias=False,
|
| 338 |
+
device=device, dtype=dtype
|
| 339 |
+
)
|
| 340 |
+
self.qk_norm = RMSNorm(
|
| 341 |
+
head_dim, eps=1e-5, elementwise_affine=False
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if ff_ratio > 0.0:
|
| 345 |
+
self._has_ff = True
|
| 346 |
+
hidden_dim = int(attn_dim * ff_ratio)
|
| 347 |
+
|
| 348 |
+
self.ff_norm = LayerNorm(
|
| 349 |
+
attn_dim,
|
| 350 |
+
device=device, dtype=dtype
|
| 351 |
+
)
|
| 352 |
+
self.ff_in = Linear(
|
| 353 |
+
attn_dim, hidden_dim * 2, bias=False,
|
| 354 |
+
device=device, dtype=dtype
|
| 355 |
+
)
|
| 356 |
+
self.ff_act = SwiGLU()
|
| 357 |
+
self.ff_drop = Dropout(ff_dropout)
|
| 358 |
+
self.ff_out = Linear(
|
| 359 |
+
hidden_dim, attn_dim, bias=False,
|
| 360 |
+
device=device, dtype=dtype
|
| 361 |
+
)
|
| 362 |
+
elif mid_blocks > 0:
|
| 363 |
+
raise ValueError("Feedforward required with mid blocks.")
|
| 364 |
+
|
| 365 |
+
self.mid_blocks = ModuleList(
|
| 366 |
+
_MidBlock(
|
| 367 |
+
attn_dim, head_dim, n_classes,
|
| 368 |
+
ff_ratio=ff_ratio, ff_dropout=ff_dropout,
|
| 369 |
+
device=device, dtype=dtype
|
| 370 |
+
) for _ in range(mid_blocks)
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
self.out_proj = BatchLinear(
|
| 374 |
+
n_classes, attn_dim, output_dim * 2,
|
| 375 |
+
device=device, dtype=dtype
|
| 376 |
+
)
|
| 377 |
+
self.out_act = SwiGLU()
|
| 378 |
+
|
| 379 |
+
@property
|
| 380 |
+
def has_roots(self) -> bool:
|
| 381 |
+
return self._has_roots
|
| 382 |
+
|
| 383 |
+
def get_extra_state(self) -> dict[str, Any]:
|
| 384 |
+
return { "q_normed": self._q_normed }
|
| 385 |
+
|
| 386 |
+
def set_extra_state(self, state: dict[str, Any]) -> None:
|
| 387 |
+
self._q_normed = state["q_normed"]
|
| 388 |
+
|
| 389 |
+
def create_head(self) -> Module:
|
| 390 |
+
if self.output_dim == 1:
|
| 391 |
+
return Flatten(-2)
|
| 392 |
+
|
| 393 |
+
return Mean(-1)
|
| 394 |
+
|
| 395 |
+
def train(self, mode: bool = True) -> Self:
|
| 396 |
+
super().train(mode)
|
| 397 |
+
|
| 398 |
+
if mode:
|
| 399 |
+
if self._has_roots:
|
| 400 |
+
self._q_normed = None
|
| 401 |
+
else:
|
| 402 |
+
self._q_normed = False
|
| 403 |
+
else:
|
| 404 |
+
if self._has_roots:
|
| 405 |
+
self._cache_query()
|
| 406 |
+
|
| 407 |
+
return self
|
| 408 |
+
|
| 409 |
+
def inference(self) -> Self:
|
| 410 |
+
super().train(False)
|
| 411 |
+
self._cache_query()
|
| 412 |
+
|
| 413 |
+
if self._has_roots:
|
| 414 |
+
self._has_roots = False
|
| 415 |
+
self.q = Parameter(self.q)
|
| 416 |
+
|
| 417 |
+
del self.cls, self.roots, self.clsroots, self.clscls
|
| 418 |
+
|
| 419 |
+
return self
|
| 420 |
+
|
| 421 |
+
def _cache_query(self) -> None:
|
| 422 |
+
assert not self.training
|
| 423 |
+
|
| 424 |
+
if self._q_normed:
|
| 425 |
+
return
|
| 426 |
+
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
self.q.to(device=self.kv.weight.device)
|
| 429 |
+
self.q.copy_(self._forward_q())
|
| 430 |
+
self._q_normed = True
|
| 431 |
+
|
| 432 |
+
def _forward_q(self) -> Tensor:
|
| 433 |
+
match self._q_normed:
|
| 434 |
+
case None:
|
| 435 |
+
assert self._has_roots
|
| 436 |
+
|
| 437 |
+
if self.roots is not None:
|
| 438 |
+
q = self.qk_norm(self.roots)
|
| 439 |
+
q = self.clsroots(self.cls, q)
|
| 440 |
+
else:
|
| 441 |
+
q = self.cls
|
| 442 |
+
|
| 443 |
+
if self.clscls is not None:
|
| 444 |
+
q = self.clscls(q, q.detach())
|
| 445 |
+
|
| 446 |
+
q = self.qk_norm(q)
|
| 447 |
+
return q
|
| 448 |
+
|
| 449 |
+
case False:
|
| 450 |
+
assert not self._has_roots
|
| 451 |
+
return self.qk_norm(self.q)
|
| 452 |
+
|
| 453 |
+
case True:
|
| 454 |
+
return self.q
|
| 455 |
+
|
| 456 |
+
def _forward_attn(self, x: Tensor, attn_mask: Tensor | None) -> tuple[Tensor, Tensor, Tensor]:
|
| 457 |
+
q = self._forward_q().expand(*x.shape[:-2], -1, -1, -1)
|
| 458 |
+
|
| 459 |
+
x = self.kv(x)
|
| 460 |
+
k, v = rearrange(x, "... s (n h e) -> n ... h s e", n=2, e=self.head_dim).unbind(0)
|
| 461 |
+
k = self.qk_norm(k)
|
| 462 |
+
|
| 463 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 464 |
+
return rearrange(x, "... h s e -> ... s (h e)"), k, v
|
| 465 |
+
|
| 466 |
+
def _forward_ff(self, x: Tensor) -> Tensor:
|
| 467 |
+
if not self._has_ff:
|
| 468 |
+
return x
|
| 469 |
+
|
| 470 |
+
f = self.ff_norm(x)
|
| 471 |
+
f = self.ff_in(f)
|
| 472 |
+
f = self.ff_act(f)
|
| 473 |
+
f = self.ff_drop(f)
|
| 474 |
+
f = self.ff_out(f)
|
| 475 |
+
return x + f
|
| 476 |
+
|
| 477 |
+
def _forward_out(self, x: Tensor) -> Tensor:
|
| 478 |
+
x = self.out_proj(x)
|
| 479 |
+
x = self.out_act(x)
|
| 480 |
+
return x
|
| 481 |
+
|
| 482 |
+
def forward(self, x: Tensor, attn_mask: Tensor | None = None) -> Tensor:
|
| 483 |
+
x, k, v = self._forward_attn(x, attn_mask)
|
| 484 |
+
x = self._forward_ff(x)
|
| 485 |
+
|
| 486 |
+
for block in self.mid_blocks:
|
| 487 |
+
x = block(x, k, v, attn_mask)
|
| 488 |
+
|
| 489 |
+
x = self._forward_out(x)
|
| 490 |
+
return x
|
| 491 |
+
|
| 492 |
+
def prune_roots(self, retain_classes: set[int]) -> tuple[list[int], list[int]]:
|
| 493 |
+
if not self._has_roots or self.roots is None:
|
| 494 |
+
raise TypeError("No roots to prune.")
|
| 495 |
+
|
| 496 |
+
if self.clscls is not None:
|
| 497 |
+
raise TypeError("Subclass roots cannot be pruned.")
|
| 498 |
+
|
| 499 |
+
used_roots: set[int] = set()
|
| 500 |
+
used_clsroots: list[int] = []
|
| 501 |
+
|
| 502 |
+
assert self.clsroots is not None
|
| 503 |
+
clsroots = [
|
| 504 |
+
cast(list[int], clsroot.tolist())
|
| 505 |
+
for clsroot in self.clsroots.index.cpu().unbind(1)
|
| 506 |
+
]
|
| 507 |
+
|
| 508 |
+
for idx, (src, dest) in enumerate(clsroots):
|
| 509 |
+
if dest in retain_classes:
|
| 510 |
+
used_roots.add(src)
|
| 511 |
+
used_clsroots.append(idx)
|
| 512 |
+
|
| 513 |
+
sorted_roots = sorted(used_roots)
|
| 514 |
+
del used_roots
|
| 515 |
+
|
| 516 |
+
rootmap = {
|
| 517 |
+
root: idx
|
| 518 |
+
for idx, root in enumerate(sorted_roots)
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
clsmap = {
|
| 522 |
+
cls: idx
|
| 523 |
+
for idx, cls in enumerate(sorted(retain_classes))
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
for idx in used_clsroots:
|
| 527 |
+
src, dest = clsroots[idx]
|
| 528 |
+
self.clsroots.index[0, idx] = rootmap[src]
|
| 529 |
+
self.clsroots.index[1, idx] = clsmap[dest]
|
| 530 |
+
|
| 531 |
+
return sorted_roots, used_clsroots
|
| 532 |
+
|
| 533 |
+
@staticmethod
|
| 534 |
+
def for_state(
|
| 535 |
+
state_dict: dict[str, Any],
|
| 536 |
+
prefix: str = "",
|
| 537 |
+
*,
|
| 538 |
+
ff_dropout: float = 0.0,
|
| 539 |
+
device: torch.device | str | None = None,
|
| 540 |
+
dtype: torch.dtype | None = None,
|
| 541 |
+
) -> "HydraPool":
|
| 542 |
+
n_heads, n_classes, head_dim = state_dict[f"{prefix}q"].shape
|
| 543 |
+
attn_dim = n_heads * head_dim
|
| 544 |
+
|
| 545 |
+
roots_t = state_dict.get(f"{prefix}roots")
|
| 546 |
+
clsroots_t = state_dict.get(f"{prefix}clsroots.index")
|
| 547 |
+
clscls_t = state_dict.get(f"{prefix}clscls.index")
|
| 548 |
+
roots = (
|
| 549 |
+
roots_t.size(1) if roots_t is not None else 0,
|
| 550 |
+
clsroots_t.size(1) if clsroots_t is not None else 0,
|
| 551 |
+
clscls_t.size(1) if clscls_t is not None else 0
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
input_dim = state_dict[f"{prefix}kv.weight"].size(1)
|
| 555 |
+
output_dim = state_dict[f"{prefix}out_proj.weight"].size(2) // 2
|
| 556 |
+
|
| 557 |
+
# avoid off-by-one issue due to truncation
|
| 558 |
+
ffout_t = state_dict.get(f"{prefix}ff_out.weight")
|
| 559 |
+
hidden_dim = ffout_t.size(1) + 0.5 if ffout_t is not None else 0
|
| 560 |
+
ff_ratio = hidden_dim / attn_dim
|
| 561 |
+
|
| 562 |
+
pattern = re.compile(rf"^{re.escape(prefix)}mid_blocks\.([0-9]+)\.")
|
| 563 |
+
mid_blocks = max([-1, *(
|
| 564 |
+
int(match[1])
|
| 565 |
+
for key in state_dict
|
| 566 |
+
if (match := pattern.match(key)) is not None
|
| 567 |
+
)]) + 1
|
| 568 |
+
|
| 569 |
+
return HydraPool(
|
| 570 |
+
attn_dim,
|
| 571 |
+
head_dim,
|
| 572 |
+
n_classes,
|
| 573 |
+
mid_blocks=mid_blocks,
|
| 574 |
+
roots=roots,
|
| 575 |
+
ff_ratio=ff_ratio,
|
| 576 |
+
ff_dropout=ff_dropout,
|
| 577 |
+
input_dim=input_dim,
|
| 578 |
+
output_dim=output_dim,
|
| 579 |
+
device=device,
|
| 580 |
+
dtype=dtype
|
| 581 |
+
)
|
image.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from typing import Any, Callable, cast
|
| 3 |
+
from warnings import warn, catch_warnings, filterwarnings
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
import PIL.Image as image
|
| 11 |
+
import PIL.ImageCms as image_cms
|
| 12 |
+
|
| 13 |
+
from PIL.Image import Image, Resampling
|
| 14 |
+
from PIL.ImageCms import (
|
| 15 |
+
Direction, Intent, ImageCmsProfile, PyCMSError,
|
| 16 |
+
createProfile, getDefaultIntent, isIntentSupported, profileToProfile
|
| 17 |
+
)
|
| 18 |
+
from PIL.ImageOps import exif_transpose
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import pillow_jxl
|
| 22 |
+
except ImportError:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
image.MAX_IMAGE_PIXELS = None
|
| 26 |
+
|
| 27 |
+
_SRGB = createProfile(colorSpace='sRGB')
|
| 28 |
+
|
| 29 |
+
_INTENT_FLAGS = {
|
| 30 |
+
Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
|
| 31 |
+
Intent.RELATIVE_COLORIMETRIC: (
|
| 32 |
+
image_cms.FLAGS["HIGHRESPRECALC"] |
|
| 33 |
+
image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
|
| 34 |
+
),
|
| 35 |
+
Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
class CMSWarning(UserWarning):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
message: str,
|
| 42 |
+
*,
|
| 43 |
+
path: str | None = None,
|
| 44 |
+
cms_info: dict[str, Any] | None = None,
|
| 45 |
+
cause: Exception | None = None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__(message)
|
| 48 |
+
self.__cause__ = cause
|
| 49 |
+
|
| 50 |
+
self.path = path
|
| 51 |
+
self.cms_info = cms_info
|
| 52 |
+
|
| 53 |
+
self.add_note(f"path: {path}")
|
| 54 |
+
self.add_note(f"info: {cms_info}")
|
| 55 |
+
|
| 56 |
+
def _coalesce_intent(intent: Intent | int) -> Intent:
|
| 57 |
+
if isinstance(intent, Intent):
|
| 58 |
+
return intent
|
| 59 |
+
|
| 60 |
+
match intent:
|
| 61 |
+
case 0:
|
| 62 |
+
return Intent.PERCEPTUAL
|
| 63 |
+
case 1:
|
| 64 |
+
return Intent.RELATIVE_COLORIMETRIC
|
| 65 |
+
case 2:
|
| 66 |
+
return Intent.SATURATION
|
| 67 |
+
case 3:
|
| 68 |
+
return Intent.ABSOLUTE_COLORIMETRIC
|
| 69 |
+
case _:
|
| 70 |
+
raise ValueError("invalid intent")
|
| 71 |
+
|
| 72 |
+
def _add_info(info: dict[str, Any], source: object, key: str) -> None:
|
| 73 |
+
try:
|
| 74 |
+
if (value := getattr(source, key, None)) is not None:
|
| 75 |
+
info[key] = value
|
| 76 |
+
except Exception:
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def open_srgb(
|
| 80 |
+
path: str,
|
| 81 |
+
*,
|
| 82 |
+
resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
|
| 83 |
+
crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
|
| 84 |
+
expect: tuple[int, int] | None = None,
|
| 85 |
+
) -> Image:
|
| 86 |
+
with open(path, "rb", buffering=(1024 * 1024)) as file:
|
| 87 |
+
img: Image = image.open(file)
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
out = process_srgb(img, resize=resize, crop=crop, expect=expect)
|
| 91 |
+
except:
|
| 92 |
+
img.close()
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
if img is not out:
|
| 96 |
+
img.close()
|
| 97 |
+
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
def process_srgb(
|
| 101 |
+
img: Image,
|
| 102 |
+
*,
|
| 103 |
+
resize: Callable[[tuple[int, int]], tuple[int, int] | None] | tuple[int, int] | None = None,
|
| 104 |
+
crop: Callable[[tuple[int, int]], tuple[int, int, int, int] | None] | tuple[int, int, int, int] | None = None,
|
| 105 |
+
expect: tuple[int, int] | None = None,
|
| 106 |
+
) -> Image:
|
| 107 |
+
img.load()
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
exif_transpose(img, in_place=True)
|
| 111 |
+
except Exception:
|
| 112 |
+
pass # corrupt EXIF metadata is fine
|
| 113 |
+
|
| 114 |
+
size = (img.width, img.height)
|
| 115 |
+
|
| 116 |
+
if expect is not None and size != expect:
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
f"Image is {size[0]}x{size[1]}, "
|
| 119 |
+
f"but expected {expect[0]}x{expect[1]}."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if (icc_raw := img.info.get("icc_profile")) is not None:
|
| 123 |
+
cms_info: dict[str, Any] = {
|
| 124 |
+
"native_mode": img.mode,
|
| 125 |
+
"transparency": img.has_transparency_data,
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
profile = ImageCmsProfile(BytesIO(icc_raw))
|
| 130 |
+
_add_info(cms_info, profile.profile, "profile_description")
|
| 131 |
+
_add_info(cms_info, profile.profile, "target")
|
| 132 |
+
_add_info(cms_info, profile.profile, "xcolor_space")
|
| 133 |
+
_add_info(cms_info, profile.profile, "connection_space")
|
| 134 |
+
_add_info(cms_info, profile.profile, "colorimetric_intent")
|
| 135 |
+
_add_info(cms_info, profile.profile, "rendering_intent")
|
| 136 |
+
|
| 137 |
+
working_mode = img.mode
|
| 138 |
+
if img.mode.startswith(("RGB", "BGR", "P")):
|
| 139 |
+
working_mode = "RGBA" if img.has_transparency_data else "RGB"
|
| 140 |
+
elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
|
| 141 |
+
working_mode = "LA" if img.has_transparency_data else "L"
|
| 142 |
+
|
| 143 |
+
if img.mode != working_mode:
|
| 144 |
+
cms_info["working_mode"] = working_mode
|
| 145 |
+
img = img.convert(working_mode)
|
| 146 |
+
|
| 147 |
+
mode = "RGBA" if img.has_transparency_data else "RGB"
|
| 148 |
+
|
| 149 |
+
intent = Intent.RELATIVE_COLORIMETRIC
|
| 150 |
+
if isIntentSupported(profile, intent, Direction.INPUT) != 1:
|
| 151 |
+
intent = _coalesce_intent(getDefaultIntent(profile))
|
| 152 |
+
|
| 153 |
+
cms_info["conversion_intent"] = intent
|
| 154 |
+
|
| 155 |
+
if (flags := _INTENT_FLAGS.get(intent)) is None:
|
| 156 |
+
raise RuntimeError("Unsupported intent")
|
| 157 |
+
|
| 158 |
+
if img.mode == mode:
|
| 159 |
+
profileToProfile(
|
| 160 |
+
img,
|
| 161 |
+
profile,
|
| 162 |
+
_SRGB,
|
| 163 |
+
renderingIntent=intent,
|
| 164 |
+
inPlace=True,
|
| 165 |
+
flags=flags
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
img = cast(Image, profileToProfile(
|
| 169 |
+
img,
|
| 170 |
+
profile,
|
| 171 |
+
_SRGB,
|
| 172 |
+
renderingIntent=intent,
|
| 173 |
+
outputMode=mode,
|
| 174 |
+
flags=flags
|
| 175 |
+
))
|
| 176 |
+
except Exception as ex:
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
if img.has_transparency_data:
|
| 180 |
+
if img.mode != "RGBa":
|
| 181 |
+
try:
|
| 182 |
+
img = img.convert("RGBa")
|
| 183 |
+
except ValueError:
|
| 184 |
+
img = img.convert("RGBA").convert("RGBa")
|
| 185 |
+
elif img.mode != "RGB":
|
| 186 |
+
img = img.convert("RGB")
|
| 187 |
+
|
| 188 |
+
if crop is not None and not isinstance(crop, tuple):
|
| 189 |
+
crop = crop(size)
|
| 190 |
+
|
| 191 |
+
if crop is not None:
|
| 192 |
+
left, top, right, bottom = crop
|
| 193 |
+
size = (right - left, top - bottom)
|
| 194 |
+
|
| 195 |
+
if resize is not None and not isinstance(resize, tuple):
|
| 196 |
+
resize = resize(size)
|
| 197 |
+
|
| 198 |
+
if resize is not None and size != resize:
|
| 199 |
+
img = img.resize(
|
| 200 |
+
resize,
|
| 201 |
+
Resampling.LANCZOS,
|
| 202 |
+
box=crop,
|
| 203 |
+
reducing_gap=3.0
|
| 204 |
+
)
|
| 205 |
+
crop = None
|
| 206 |
+
|
| 207 |
+
if crop is not None:
|
| 208 |
+
img = img.crop(crop)
|
| 209 |
+
|
| 210 |
+
return img
|
| 211 |
+
|
| 212 |
+
def put_srgb(img: Image, tensor: Tensor) -> None:
|
| 213 |
+
if img.mode not in ("RGB", "RGBA", "RGBa"):
|
| 214 |
+
raise ValueError(f"Image has non-RGB mode {img.mode}.")
|
| 215 |
+
|
| 216 |
+
np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")
|
| 217 |
+
|
| 218 |
+
def put_srgb_patch(
|
| 219 |
+
img: Image,
|
| 220 |
+
patch_data: Tensor,
|
| 221 |
+
patch_coord: Tensor,
|
| 222 |
+
patch_valid: Tensor,
|
| 223 |
+
patch_size: int
|
| 224 |
+
) -> None:
|
| 225 |
+
if img.mode not in ("RGB", "RGBA", "RGBa"):
|
| 226 |
+
raise ValueError(f"Image has non-RGB mode {img.mode}.")
|
| 227 |
+
|
| 228 |
+
patches = rearrange(
|
| 229 |
+
np.asarray(img)[:, :, :3],
|
| 230 |
+
"(h p1) (w p2) c -> h w (p1 p2 c)",
|
| 231 |
+
p1=patch_size, p2=patch_size
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
coords = np.stack(np.meshgrid(
|
| 235 |
+
np.arange(patches.shape[0], dtype=np.int16),
|
| 236 |
+
np.arange(patches.shape[1], dtype=np.int16),
|
| 237 |
+
indexing="ij"
|
| 238 |
+
), axis=-1)
|
| 239 |
+
|
| 240 |
+
coords = rearrange(coords, "h w c -> (h w) c")
|
| 241 |
+
patches = rearrange(patches, "h w p -> (h w) p")
|
| 242 |
+
n = patches.shape[0]
|
| 243 |
+
|
| 244 |
+
np.copyto(patch_data[:n].numpy(), patches, casting="no")
|
| 245 |
+
np.copyto(patch_coord[:n].numpy(), coords, casting="no")
|
| 246 |
+
patch_valid[:n] = True
|
| 247 |
+
|
| 248 |
+
def unpatchify(input: Tensor, coords: Tensor, valid: Tensor) -> Tensor:
|
| 249 |
+
"""
|
| 250 |
+
Scatter valid patches from (seqlen, ...) to (H, W, ...), using coords and valid mask.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
input: Tensor of shape (seqlen, ...), patch data.
|
| 254 |
+
coords: Tensor of shape (seqlen, 2), spatial coordinates [y, x] for each patch.
|
| 255 |
+
valid: Tensor of shape (seqlen,), boolean mask for valid patches.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
Tensor of shape (H, W, ...), with valid patches scattered to their spatial locations.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
valid_coords = coords[0, valid[0]] # (n_valid, 2)
|
| 262 |
+
valid_patches = input[valid[0]] # (n_valid, ...)
|
| 263 |
+
|
| 264 |
+
h = int(valid_coords[:, 0].max().item()) + 1
|
| 265 |
+
w = int(valid_coords[:, 1].max().item()) + 1
|
| 266 |
+
|
| 267 |
+
output_shape = (h, w) + input.shape[1:]
|
| 268 |
+
output = input.new_zeros(output_shape)
|
| 269 |
+
|
| 270 |
+
output[valid_coords[:, 0], valid_coords[:, 1]] = valid_patches
|
| 271 |
+
return output
|
inference.bat
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
IF NOT EXIST venv call .\install.bat
|
| 2 |
+
|
| 3 |
+
call venv\Scripts\activate.bat
|
| 4 |
+
python inference.py %*
|
inference.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import csv
|
| 3 |
+
import itertools
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from typing import Any, Iterable
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
from timm.models import NaFlexVit
|
| 14 |
+
|
| 15 |
+
from loader import Loader
|
| 16 |
+
from model import load_model, load_image
|
| 17 |
+
|
| 18 |
+
PATCH_SIZE = 16
|
| 19 |
+
|
| 20 |
+
def from_symmetric(threshold: float) -> float:
|
| 21 |
+
return (threshold + 1.0) / 2.0
|
| 22 |
+
|
| 23 |
+
def to_symmetric(threshold: float) -> float:
|
| 24 |
+
return (threshold - 0.5) * 2.0
|
| 25 |
+
|
| 26 |
+
def classify_output(output: Tensor, tags: list[str], threshold: float = 0.0) -> dict[str, float]:
|
| 27 |
+
return {
|
| 28 |
+
tag: prob
|
| 29 |
+
for tag, prob in zip(tags, output.tolist())
|
| 30 |
+
if prob >= threshold
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def _run_interactive(
|
| 34 |
+
*,
|
| 35 |
+
model: NaFlexVit, tags: list[str],
|
| 36 |
+
seqlen: int, threshold: float,
|
| 37 |
+
device: str
|
| 38 |
+
) -> None:
|
| 39 |
+
print(
|
| 40 |
+
"\n"
|
| 41 |
+
"JTP-3 Hydra Interactive Classifier\n"
|
| 42 |
+
" Type 'q' to quit, or 'h' for help.\n"
|
| 43 |
+
" For bulk operations, quit and run again with a path, or '-h' for help.\n"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
while True:
|
| 47 |
+
print("> ", end="")
|
| 48 |
+
line = input().strip()
|
| 49 |
+
|
| 50 |
+
if line in ("q", "quit", "exit"):
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
if line in ("", "h", "help", "?"):
|
| 54 |
+
print(
|
| 55 |
+
"Provide a file path to classify, or one of the following commands:\n"
|
| 56 |
+
f" threshold T (-1.0 to 1.0, currently {threshold}, 0.2 to 0.8 recommended)\n"
|
| 57 |
+
f" seqlen N (64 to 2048, currently {seqlen}, 1024 recommended)\n"
|
| 58 |
+
" quit (or 'q', 'exit')"
|
| 59 |
+
)
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
if line.startswith("threshold "):
|
| 63 |
+
try:
|
| 64 |
+
parsed = float(line[10:])
|
| 65 |
+
except Exception as ex:
|
| 66 |
+
print(ex)
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
if -1.0 <= parsed <= 1.0:
|
| 70 |
+
threshold = parsed
|
| 71 |
+
else:
|
| 72 |
+
print("Threshold must be between -1.0 and 1.0.")
|
| 73 |
+
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
if line.startswith("seqlen "):
|
| 77 |
+
try:
|
| 78 |
+
parsed = int(line[7:])
|
| 79 |
+
except Exception as ex:
|
| 80 |
+
print(ex)
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
if 64 <= parsed <= 2048:
|
| 84 |
+
seqlen = parsed
|
| 85 |
+
else:
|
| 86 |
+
print("Sequence length must be between 64 and 2048.")
|
| 87 |
+
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
p_t, pc_t, pv_t = load_image(line, PATCH_SIZE, seqlen, False)
|
| 92 |
+
except Exception as ex:
|
| 93 |
+
print(ex)
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
p_d = p_t.unsqueeze(0).to(device=device, non_blocking=True)
|
| 97 |
+
pc_d = pc_t.unsqueeze(0).to(device=device, non_blocking=True)
|
| 98 |
+
pv_d = pv_t.unsqueeze(0).to(device=device, non_blocking=True)
|
| 99 |
+
|
| 100 |
+
p_d = p_d.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
|
| 101 |
+
pc_d = pc_d.to(dtype=torch.int32)
|
| 102 |
+
|
| 103 |
+
o_d = model(p_d, pc_d, pv_d).float().sigmoid()
|
| 104 |
+
del p_d, pc_d, pv_d
|
| 105 |
+
|
| 106 |
+
classes = classify_output(o_d[0], tags, from_symmetric(threshold))
|
| 107 |
+
for cls, prob in sorted(classes.items(), key=lambda item: (-item[1], item[0])):
|
| 108 |
+
print(f" {to_symmetric(prob)*100:6.1f}% {cls}")
|
| 109 |
+
|
| 110 |
+
del classes
|
| 111 |
+
del o_d
|
| 112 |
+
del p_t, pc_t, pv_t
|
| 113 |
+
|
| 114 |
+
def _run_batched(
|
| 115 |
+
*,
|
| 116 |
+
model: NaFlexVit, tags: list[str],
|
| 117 |
+
paths: list[str], recursive: bool,
|
| 118 |
+
threshold: float, writer: Any, prefix: str,
|
| 119 |
+
batch_size: int, seqlen: int,
|
| 120 |
+
n_workers: int, share_memory: bool,
|
| 121 |
+
device: str,
|
| 122 |
+
) -> None:
|
| 123 |
+
loader = Loader(
|
| 124 |
+
n_workers,
|
| 125 |
+
patch_size=PATCH_SIZE, max_seqlen=seqlen,
|
| 126 |
+
share_memory=share_memory
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def dir_iter(path: str) -> Iterable[str]:
|
| 130 |
+
for entry in os.scandir(path):
|
| 131 |
+
if (
|
| 132 |
+
not entry.name.startswith(".")
|
| 133 |
+
and entry.name != "__pycache__"
|
| 134 |
+
):
|
| 135 |
+
if entry.is_file():
|
| 136 |
+
if not entry.name.endswith((
|
| 137 |
+
".txt", ".csv", ".json",
|
| 138 |
+
".py", ".safetensors",
|
| 139 |
+
)):
|
| 140 |
+
yield entry.path
|
| 141 |
+
elif recursive and entry.is_dir():
|
| 142 |
+
yield from dir_iter(entry.path)
|
| 143 |
+
|
| 144 |
+
def paths_iter() -> Iterable[str]:
|
| 145 |
+
for path in paths:
|
| 146 |
+
if os.path.isdir(path):
|
| 147 |
+
yield from dir_iter(path)
|
| 148 |
+
else:
|
| 149 |
+
yield path
|
| 150 |
+
|
| 151 |
+
for batch in itertools.batched(paths_iter(), batch_size):
|
| 152 |
+
patches: list[Tensor] = []
|
| 153 |
+
patch_coords: list[Tensor] = []
|
| 154 |
+
patch_valid: list[Tensor] = []
|
| 155 |
+
batch_paths: list[str] = []
|
| 156 |
+
|
| 157 |
+
for path, result in loader.load(batch).items():
|
| 158 |
+
if isinstance(result, Exception):
|
| 159 |
+
print(f"{repr(path)}: {result}", file=sys.stderr)
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
batch_paths.append(path)
|
| 163 |
+
patches.append(result[0])
|
| 164 |
+
patch_coords.append(result[1])
|
| 165 |
+
patch_valid.append(result[2])
|
| 166 |
+
|
| 167 |
+
if not patches:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
p_d = torch.stack(patches).to(device=device, non_blocking=True)
|
| 171 |
+
pc_d = torch.stack(patch_coords).to(device=device, non_blocking=True)
|
| 172 |
+
pv_d = torch.stack(patch_valid).to(device=device, non_blocking=True)
|
| 173 |
+
|
| 174 |
+
p_d = p_d.to(dtype=torch.bfloat16).div_(127.5).sub_(1.0)
|
| 175 |
+
pc_d = pc_d.to(dtype=torch.int32)
|
| 176 |
+
|
| 177 |
+
o_d = model(p_d, pc_d, pv_d).float().sigmoid()
|
| 178 |
+
del p_d, pc_d, pv_d
|
| 179 |
+
|
| 180 |
+
for path, output in zip(batch_paths, o_d.cpu()):
|
| 181 |
+
if writer is None:
|
| 182 |
+
with open(
|
| 183 |
+
f"{os.path.splitext(path)[0]}.txt", "w",
|
| 184 |
+
encoding="utf-8"
|
| 185 |
+
) as file:
|
| 186 |
+
classes = list(classify_output(output, tags, threshold).keys())
|
| 187 |
+
random.shuffle(classes)
|
| 188 |
+
|
| 189 |
+
if prefix:
|
| 190 |
+
try:
|
| 191 |
+
classes.remove(prefix)
|
| 192 |
+
except ValueError:
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
classes.insert(0, prefix)
|
| 196 |
+
|
| 197 |
+
file.write(', '.join(classes))
|
| 198 |
+
else:
|
| 199 |
+
writer.writerow((path, *(f"{prob.item():.4f}" for prob in output)))
|
| 200 |
+
|
| 201 |
+
del o_d
|
| 202 |
+
|
| 203 |
+
loader.shutdown()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@torch.inference_mode()
|
| 207 |
+
def main() -> None:
|
| 208 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 209 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 210 |
+
torch.backends.cudnn.benchmark = True
|
| 211 |
+
|
| 212 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 213 |
+
|
| 214 |
+
parser = argparse.ArgumentParser(
|
| 215 |
+
description="JTP-3 Hydra",
|
| 216 |
+
epilog="By Project RedRocket. Visit https://huggingface.co/spaces/RedRocket/JTP-3 for more information."
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument("--model", type=str, default="models/jtp-3-hydra.safetensors",
|
| 219 |
+
help="Path to model file.")
|
| 220 |
+
parser.add_argument("-b", "--batch", type=int, default=1,
|
| 221 |
+
help="Batch size.")
|
| 222 |
+
parser.add_argument("-w", "--workers", type=int, default=-1,
|
| 223 |
+
help="Number of dataloader workers. (Default: number of cores)")
|
| 224 |
+
parser.add_argument("--seqlen", type=int, default=1024,
|
| 225 |
+
help="NaFlex sequence length. (Default: 1024)")
|
| 226 |
+
parser.add_argument("-t", "--threshold", type=float, default=0.5,
|
| 227 |
+
help="Classification threshold. (-1.0 to 1.0)")
|
| 228 |
+
parser.add_argument("--no-shm", dest="shm", action="store_false",
|
| 229 |
+
help="Disable shared memory between workers.")
|
| 230 |
+
parser.add_argument("-d", "--device", type=str, default=default_device,
|
| 231 |
+
help=f"Torch device. (Default: {default_device})")
|
| 232 |
+
parser.add_argument("-r", "--recursive", action="store_true",
|
| 233 |
+
help="Classify directories recursively. (Dotfiles will be ignored.)")
|
| 234 |
+
parser.add_argument("-O", "--original-tags", action="store_true",
|
| 235 |
+
help="Do not rewrite tags for compatibility with diffusion models.")
|
| 236 |
+
parser.add_argument("-o", "--output", type=str,
|
| 237 |
+
help="Path for CSV output, or '-' for standard output. If not specified, individual .txt caption files are written.")
|
| 238 |
+
parser.add_argument("-p", "--prefix", type=str, default="",
|
| 239 |
+
help="Prefix all .txt caption files with the specified text. If the prefix matches a tag, the tag will not be repeated.")
|
| 240 |
+
parser.add_argument("paths", nargs="*",
|
| 241 |
+
help="Path to files and directories to classify. If none are specified, run interactively."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
args = parser.parse_args()
|
| 245 |
+
|
| 246 |
+
if args.batch < 1:
|
| 247 |
+
parser.error("--batch must be at least 1")
|
| 248 |
+
if not 64 <= args.seqlen <= 2048:
|
| 249 |
+
parser.error("--seqlen must be between 64 and 2048 (1024 strongly recommended)")
|
| 250 |
+
if not -1.0 <= args.threshold <= 1.0:
|
| 251 |
+
parser.error("--threshold must be between -1.0 and 1.0")
|
| 252 |
+
|
| 253 |
+
print(f"Loading {repr(args.model)} ...", end="", file=sys.stderr)
|
| 254 |
+
model, tags = load_model(args.model, device=args.device)
|
| 255 |
+
print(f" {len(tags)} tags", file=sys.stderr)
|
| 256 |
+
|
| 257 |
+
def rewrite_tag(tag: str) -> str:
|
| 258 |
+
if not args.original_tags:
|
| 259 |
+
tag = tag.replace("vulva", "pussy")
|
| 260 |
+
|
| 261 |
+
if args.output is None: # caption files
|
| 262 |
+
tag = tag.replace("_", " ")
|
| 263 |
+
tag = tag.replace("(", r"\(")
|
| 264 |
+
tag = tag.replace(")", r"\)")
|
| 265 |
+
|
| 266 |
+
return tag
|
| 267 |
+
|
| 268 |
+
for idx in range(len(tags)):
|
| 269 |
+
tags[idx] = rewrite_tag(tags[idx])
|
| 270 |
+
|
| 271 |
+
if args.paths:
|
| 272 |
+
file: Any = None
|
| 273 |
+
writer: Any = None
|
| 274 |
+
|
| 275 |
+
match args.output:
|
| 276 |
+
case None:
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
case "-":
|
| 280 |
+
writer = csv.writer(sys.stdout)
|
| 281 |
+
|
| 282 |
+
case _:
|
| 283 |
+
file = open(
|
| 284 |
+
args.file, "w",
|
| 285 |
+
buffering=(1024 * 1024),
|
| 286 |
+
encoding="utf-8",
|
| 287 |
+
newline="",
|
| 288 |
+
)
|
| 289 |
+
writer = csv.writer(file)
|
| 290 |
+
writer.writerow(("filename", *tags))
|
| 291 |
+
try:
|
| 292 |
+
_run_batched(
|
| 293 |
+
model=model,
|
| 294 |
+
tags=tags,
|
| 295 |
+
paths=args.paths,
|
| 296 |
+
recursive=args.recursive,
|
| 297 |
+
threshold=from_symmetric(args.threshold),
|
| 298 |
+
writer=writer, prefix=args.prefix,
|
| 299 |
+
batch_size=args.batch,
|
| 300 |
+
seqlen=args.seqlen,
|
| 301 |
+
n_workers=args.workers,
|
| 302 |
+
share_memory=args.shm,
|
| 303 |
+
device=args.device,
|
| 304 |
+
)
|
| 305 |
+
finally:
|
| 306 |
+
if file is not None:
|
| 307 |
+
file.close()
|
| 308 |
+
else:
|
| 309 |
+
_run_interactive(
|
| 310 |
+
model=model,
|
| 311 |
+
tags=tags,
|
| 312 |
+
seqlen=args.seqlen,
|
| 313 |
+
threshold=args.threshold,
|
| 314 |
+
device=args.device,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
install.bat
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m venv venv
|
| 2 |
+
|
| 3 |
+
call venv\Scripts\activate.bat
|
| 4 |
+
pip install -r requirements.txt
|
loader.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import environ, process_cpu_count
|
| 2 |
+
from typing import Iterable, Self
|
| 3 |
+
|
| 4 |
+
from threading import Thread
|
| 5 |
+
|
| 6 |
+
import multiprocessing
|
| 7 |
+
from multiprocessing.queues import SimpleQueue
|
| 8 |
+
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.multiprocessing.queue import SimpleQueue as TorchQueue
|
| 11 |
+
|
| 12 |
+
from model import load_image
|
| 13 |
+
|
| 14 |
+
class EnvScope:
|
| 15 |
+
__slots__ = ("env", "saved")
|
| 16 |
+
|
| 17 |
+
def __init__(self, env: dict[str, str | int | float | None]) -> None:
|
| 18 |
+
self.env = {
|
| 19 |
+
env: None if value is None else str(value)
|
| 20 |
+
for env, value in env.items()
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
self.saved: dict[str, str | None]
|
| 24 |
+
|
| 25 |
+
def __enter__(self) -> Self:
|
| 26 |
+
if hasattr(self, "saved"):
|
| 27 |
+
raise RuntimeError("EnvScope is already in use.")
|
| 28 |
+
|
| 29 |
+
self.saved = {}
|
| 30 |
+
for env, value in self.env.items():
|
| 31 |
+
self.saved[env] = environ.get(env, None)
|
| 32 |
+
|
| 33 |
+
if value is None:
|
| 34 |
+
del environ[env]
|
| 35 |
+
else:
|
| 36 |
+
environ[env] = value
|
| 37 |
+
|
| 38 |
+
return self
|
| 39 |
+
|
| 40 |
+
def __exit__(self, exc_type, exc_value, tb) -> None:
|
| 41 |
+
for env, value in self.saved.items():
|
| 42 |
+
if value is None:
|
| 43 |
+
del environ[env]
|
| 44 |
+
else:
|
| 45 |
+
environ[env] = value
|
| 46 |
+
|
| 47 |
+
del self.saved
|
| 48 |
+
|
| 49 |
+
class Loader:
|
| 50 |
+
def __init__(
|
| 51 |
+
self, n_workers: int = -1, *,
|
| 52 |
+
patch_size: int = 16, max_seqlen: int = 1024,
|
| 53 |
+
share_memory: bool = True
|
| 54 |
+
) -> None:
|
| 55 |
+
ctx = multiprocessing.get_context("spawn")
|
| 56 |
+
|
| 57 |
+
self.patch_size = patch_size
|
| 58 |
+
self.max_seqlen = max_seqlen
|
| 59 |
+
|
| 60 |
+
if n_workers < 0:
|
| 61 |
+
n_workers = process_cpu_count() or 1
|
| 62 |
+
|
| 63 |
+
if n_workers == 0:
|
| 64 |
+
self._workers = []
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
self._submission_queue: SimpleQueue[str | None] = SimpleQueue(ctx=ctx)
|
| 68 |
+
self._completion_queue: SimpleQueue[tuple[str, tuple[Tensor, Tensor, Tensor] | Exception] | None] = TorchQueue(ctx=ctx)
|
| 69 |
+
self._workers = [
|
| 70 |
+
ctx.Process(
|
| 71 |
+
target=_worker_fn,
|
| 72 |
+
args=(
|
| 73 |
+
self._submission_queue,
|
| 74 |
+
self._completion_queue,
|
| 75 |
+
patch_size,
|
| 76 |
+
max_seqlen,
|
| 77 |
+
share_memory,
|
| 78 |
+
),
|
| 79 |
+
name=f"loader-{idx}",
|
| 80 |
+
daemon=True
|
| 81 |
+
)
|
| 82 |
+
for idx in range(n_workers)
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
threads = [
|
| 87 |
+
Thread(
|
| 88 |
+
target=proc.start,
|
| 89 |
+
name=f"pstart-{proc.name}",
|
| 90 |
+
daemon=True,
|
| 91 |
+
) for proc in self._workers
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
with EnvScope({
|
| 95 |
+
"OMP_NUM_THREADS": 1,
|
| 96 |
+
"OPENBLAS_NUM_THREADS": 1,
|
| 97 |
+
"CUDA_VISIBLE_DEVICES": "",
|
| 98 |
+
}):
|
| 99 |
+
for thread in threads:
|
| 100 |
+
thread.start()
|
| 101 |
+
|
| 102 |
+
for thread in threads:
|
| 103 |
+
thread.join()
|
| 104 |
+
|
| 105 |
+
def load(self, paths: Iterable[str]) -> dict[str, tuple[Tensor, Tensor, Tensor] | Exception]:
|
| 106 |
+
loaded: dict[str, tuple[Tensor, Tensor, Tensor] | Exception] = {}
|
| 107 |
+
|
| 108 |
+
if self._workers:
|
| 109 |
+
count = 0
|
| 110 |
+
for path in paths:
|
| 111 |
+
self._submission_queue.put(path)
|
| 112 |
+
count += 1
|
| 113 |
+
|
| 114 |
+
for _ in range(count):
|
| 115 |
+
result = self._completion_queue.get()
|
| 116 |
+
assert result is not None
|
| 117 |
+
loaded[result[0]] = result[1]
|
| 118 |
+
else:
|
| 119 |
+
for path in paths:
|
| 120 |
+
try:
|
| 121 |
+
loaded[path] = load_image(path, self.patch_size, self.max_seqlen, False)
|
| 122 |
+
except Exception as ex:
|
| 123 |
+
loaded[path] = ex
|
| 124 |
+
|
| 125 |
+
return loaded
|
| 126 |
+
|
| 127 |
+
def shutdown(self, wait: bool = True) -> None:
|
| 128 |
+
for _ in range(len(self._workers)):
|
| 129 |
+
self._submission_queue.put(None)
|
| 130 |
+
|
| 131 |
+
if wait:
|
| 132 |
+
for _ in range(len(self._workers)):
|
| 133 |
+
assert self._completion_queue.get() is None
|
| 134 |
+
|
| 135 |
+
self._workers.clear()
|
| 136 |
+
|
| 137 |
+
def _worker_fn(
|
| 138 |
+
submission_queue: SimpleQueue[str | None],
|
| 139 |
+
completion_queue: SimpleQueue[tuple[str, tuple[Tensor, Tensor, Tensor] | Exception] | None],
|
| 140 |
+
patch_size: int,
|
| 141 |
+
max_seqlen: int,
|
| 142 |
+
share_memory: bool,
|
| 143 |
+
):
|
| 144 |
+
while (path := submission_queue.get()) is not None:
|
| 145 |
+
try:
|
| 146 |
+
completion_queue.put((path, load_image(path, patch_size, max_seqlen, share_memory)))
|
| 147 |
+
except Exception as ex:
|
| 148 |
+
completion_queue.put((path, ex))
|
| 149 |
+
|
| 150 |
+
completion_queue.put(None)
|
model.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import ceil
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.nn import Identity
|
| 6 |
+
|
| 7 |
+
import timm
|
| 8 |
+
from timm.models import NaFlexVit
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
|
| 14 |
+
from image import process_srgb, put_srgb_patch
|
| 15 |
+
|
| 16 |
+
def sdpa_attn_mask(
|
| 17 |
+
patch_valid: Tensor,
|
| 18 |
+
num_prefix_tokens: int = 0,
|
| 19 |
+
symmetric: bool = True,
|
| 20 |
+
q_len: int | None = None,
|
| 21 |
+
dtype: torch.dtype | None = None,
|
| 22 |
+
) -> Tensor:
|
| 23 |
+
mask = patch_valid.unflatten(-1, (1, 1, -1))
|
| 24 |
+
|
| 25 |
+
if num_prefix_tokens:
|
| 26 |
+
mask = torch.cat((
|
| 27 |
+
torch.ones(
|
| 28 |
+
*mask.shape[:-1], num_prefix_tokens,
|
| 29 |
+
device=patch_valid.device, dtype=torch.bool
|
| 30 |
+
), mask
|
| 31 |
+
), dim=-1)
|
| 32 |
+
|
| 33 |
+
return mask
|
| 34 |
+
|
| 35 |
+
timm.models.naflexvit.create_attention_mask = sdpa_attn_mask
|
| 36 |
+
|
| 37 |
+
def get_image_size_for_seq(
|
| 38 |
+
image_hw: tuple[int, int],
|
| 39 |
+
patch_size: int = 16,
|
| 40 |
+
max_seq_len: int = 1024,
|
| 41 |
+
max_ratio: float = 1.0,
|
| 42 |
+
eps: float = 1e-5,
|
| 43 |
+
) -> tuple[int, int]:
|
| 44 |
+
"""Determine image size for sequence length constraint."""
|
| 45 |
+
|
| 46 |
+
assert max_ratio >= 1.0
|
| 47 |
+
assert eps * 2 < max_ratio
|
| 48 |
+
|
| 49 |
+
h, w = image_hw
|
| 50 |
+
max_py = int(max((h * max_ratio) // patch_size, 1))
|
| 51 |
+
max_px = int(max((w * max_ratio) // patch_size, 1))
|
| 52 |
+
|
| 53 |
+
if (max_py * max_px) <= max_seq_len:
|
| 54 |
+
return max_py * patch_size, max_px * patch_size
|
| 55 |
+
|
| 56 |
+
def patchify(ratio: float) -> tuple[int, int]:
|
| 57 |
+
return (
|
| 58 |
+
min(int(ceil((h * ratio) / patch_size)), max_py),
|
| 59 |
+
min(int(ceil((w * ratio) / patch_size)), max_px)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
py, px = patchify(eps)
|
| 63 |
+
if (py * px) > max_seq_len:
|
| 64 |
+
raise ValueError(f"Image of size {w}x{h} is too large.")
|
| 65 |
+
|
| 66 |
+
ratio = eps
|
| 67 |
+
while (max_ratio - ratio) >= eps:
|
| 68 |
+
mid = (ratio + max_ratio) / 2.0
|
| 69 |
+
|
| 70 |
+
mpy, mpx = patchify(mid)
|
| 71 |
+
seq_len = mpy * mpx
|
| 72 |
+
|
| 73 |
+
if seq_len > max_seq_len:
|
| 74 |
+
max_ratio = mid
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
ratio = mid
|
| 78 |
+
py = mpy
|
| 79 |
+
px = mpx
|
| 80 |
+
|
| 81 |
+
if seq_len == max_seq_len:
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
assert py >= 1 and px >= 1
|
| 85 |
+
return py * patch_size, px * patch_size
|
| 86 |
+
|
| 87 |
+
def process_image(img: Image.Image, patch_size: int, max_seq_len: int) -> Image.Image:
|
| 88 |
+
def compute_resize(wh: tuple[int, int]) -> tuple[int, int]:
|
| 89 |
+
h, w = get_image_size_for_seq((wh[1], wh[0]), patch_size, max_seq_len)
|
| 90 |
+
return w, h
|
| 91 |
+
|
| 92 |
+
return process_srgb(img, resize=compute_resize)
|
| 93 |
+
|
| 94 |
+
def patchify_image(img: Image.Image, patch_size: int, max_seq_len: int, share_memory: bool = False) -> tuple[Tensor, Tensor, Tensor]:
|
| 95 |
+
patches = torch.zeros(max_seq_len, patch_size * patch_size * 3, device="cpu", dtype=torch.uint8)
|
| 96 |
+
patch_coords = torch.zeros(max_seq_len, 2, device="cpu", dtype=torch.int16)
|
| 97 |
+
patch_valid = torch.zeros(max_seq_len, device="cpu", dtype=torch.bool)
|
| 98 |
+
|
| 99 |
+
if share_memory:
|
| 100 |
+
patches.share_memory_()
|
| 101 |
+
patch_coords.share_memory_()
|
| 102 |
+
patch_valid.share_memory_()
|
| 103 |
+
|
| 104 |
+
put_srgb_patch(img, patches, patch_coords, patch_valid, patch_size)
|
| 105 |
+
return patches, patch_coords, patch_valid
|
| 106 |
+
|
| 107 |
+
def load_image(
|
| 108 |
+
path: str,
|
| 109 |
+
patch_size: int = 16,
|
| 110 |
+
max_seq_len: int = 1024,
|
| 111 |
+
share_memory: bool = False
|
| 112 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 113 |
+
with open(path, "rb", buffering=(1024 * 1024)) as file:
|
| 114 |
+
img: Image.Image = Image.open(file)
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
processed = process_image(img, patch_size, max_seq_len)
|
| 118 |
+
except:
|
| 119 |
+
img.close()
|
| 120 |
+
raise
|
| 121 |
+
|
| 122 |
+
if img is not processed:
|
| 123 |
+
img.close()
|
| 124 |
+
|
| 125 |
+
return patchify_image(processed, patch_size, max_seq_len, share_memory)
|
| 126 |
+
|
| 127 |
+
def load_model(path: str, device: torch.device | str | None = None) -> tuple[NaFlexVit, list[str]]:
|
| 128 |
+
with safe_open(path, framework="pt", device="cpu") as file:
|
| 129 |
+
metadata = file.metadata()
|
| 130 |
+
|
| 131 |
+
state_dict = {
|
| 132 |
+
key: file.get_tensor(key)
|
| 133 |
+
for key in file.keys()
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
arch = metadata["modelspec.architecture"]
|
| 137 |
+
if not arch.startswith("naflexvit_so400m_patch16_siglip"):
|
| 138 |
+
raise ValueError(f"Unrecognized model architecture: {arch}")
|
| 139 |
+
|
| 140 |
+
tags = metadata["classifier.labels"].split("\n")
|
| 141 |
+
|
| 142 |
+
model = timm.create_model(
|
| 143 |
+
'naflexvit_so400m_patch16_siglip',
|
| 144 |
+
pretrained=False, num_classes=0,
|
| 145 |
+
pos_embed_interp_mode="bilinear",
|
| 146 |
+
weight_init="skip", fix_init=False,
|
| 147 |
+
device="cpu", dtype=torch.bfloat16,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
match arch[31:]:
|
| 151 |
+
case "": # vanilla
|
| 152 |
+
model.reset_classifier(len(tags))
|
| 153 |
+
|
| 154 |
+
case "+rr_slim":
|
| 155 |
+
model.reset_classifier(len(tags))
|
| 156 |
+
|
| 157 |
+
if "attn_pool.q.weight" not in state_dict:
|
| 158 |
+
model.attn_pool.q = Identity()
|
| 159 |
+
|
| 160 |
+
if "head.bias" not in state_dict:
|
| 161 |
+
model.head.bias = None
|
| 162 |
+
|
| 163 |
+
case "+rr_chonker":
|
| 164 |
+
from chonker_pool import ChonkerPool
|
| 165 |
+
|
| 166 |
+
model.attn_pool = ChonkerPool(
|
| 167 |
+
2, 1152, 72,
|
| 168 |
+
device=device, dtype=torch.bfloat16
|
| 169 |
+
)
|
| 170 |
+
model.head = model.attn_pool.create_head(len(tags))
|
| 171 |
+
model.num_classes = len(tags)
|
| 172 |
+
|
| 173 |
+
case "+rr_hydra":
|
| 174 |
+
from hydra_pool import HydraPool
|
| 175 |
+
|
| 176 |
+
model.attn_pool = HydraPool.for_state(
|
| 177 |
+
state_dict, "attn_pool.",
|
| 178 |
+
device=device, dtype=torch.bfloat16
|
| 179 |
+
)
|
| 180 |
+
model.head = model.attn_pool.create_head()
|
| 181 |
+
model.num_classes = len(tags)
|
| 182 |
+
|
| 183 |
+
state_dict["attn_pool._extra_state"] = { "q_normed": True }
|
| 184 |
+
|
| 185 |
+
case _:
|
| 186 |
+
raise ValueError(f"Unrecognized model architecture: {arch}")
|
| 187 |
+
|
| 188 |
+
model.eval().to(dtype=torch.bfloat16)
|
| 189 |
+
model.load_state_dict(state_dict, strict=True)
|
| 190 |
+
model.to(device=device)
|
| 191 |
+
|
| 192 |
+
return model, tags
|
models/jtp-3-hydra.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b91691be739ba07d5d7cfb74296ab19ab016d7eabad03542a0a34c90a3d9969
|
| 3 |
+
size 1002587984
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
timm
|
| 3 |
+
numpy
|
| 4 |
+
pillow
|
| 5 |
+
einops
|
| 6 |
+
safetensors
|
| 7 |
+
gradio
|
| 8 |
+
requests
|