|
|
|
|
|
"""Quick inference helper for the anime style classifier.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import models, transforms |
|
|
|
|
|
|
|
|
STYLE_DEFAULTS = ['flat', 'grim', 'modern', 'moe', 'painterly', 'retro'] |
|
|
|
|
|
|
|
|
def load_config(path: Path) -> dict: |
|
|
with path.open('r') as fh: |
|
|
return json.load(fh) |
|
|
|
|
|
|
|
|
def build_model(config: dict, weights_path: Path) -> torch.nn.Module: |
|
|
model = models.efficientnet_b0(weights=None) |
|
|
head = torch.nn.Linear(model.classifier[1].in_features, config['num_labels']) |
|
|
model.classifier[1] = head |
|
|
|
|
|
checkpoint = torch.load(weights_path, map_location='cpu') |
|
|
model.load_state_dict(checkpoint) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
def build_transform(config: dict) -> transforms.Compose: |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((config['image_size'], config['image_size'])), |
|
|
transforms.CenterCrop(config['image_size']), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(config['mean'], config['std']) |
|
|
]) |
|
|
|
|
|
|
|
|
def classify_image(model: torch.nn.Module, tf: transforms.Compose, image_path: Path, labels: list[str]) -> list[tuple[str, float]]: |
|
|
img = Image.open(image_path).convert('RGB') |
|
|
x = tf(img).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
probs = torch.softmax(model(x), dim=1)[0] |
|
|
return list(zip(labels, probs.tolist())) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser(description='Anime style classifier inference helper') |
|
|
parser.add_argument('image', type=Path, help='Path to an image file') |
|
|
parser.add_argument('--model', type=Path, default=Path('pytorch_model.bin')) |
|
|
parser.add_argument('--config', type=Path, default=Path('config.json')) |
|
|
parser.add_argument('--top-k', type=int, default=6) |
|
|
args = parser.parse_args() |
|
|
|
|
|
config = load_config(args.config) |
|
|
labels = [config['id2label'].get(str(i), STYLE_DEFAULTS[i]) for i in range(config['num_labels'])] |
|
|
transform = build_transform(config) |
|
|
model = build_model(config, args.model) |
|
|
|
|
|
preds = classify_image(model, transform, args.image, labels) |
|
|
preds.sort(key=lambda item: item[1], reverse=True) |
|
|
|
|
|
print(f"Predictions for {args.image}:") |
|
|
for style, prob in preds[:args.top_k]: |
|
|
print(f" {style:<10s} {prob:.4f}") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|