#!/usr/bin/env python3 """Quick inference helper for the anime style classifier.""" from __future__ import annotations import argparse import json from pathlib import Path import torch from PIL import Image from torchvision import models, transforms STYLE_DEFAULTS = ['flat', 'grim', 'modern', 'moe', 'painterly', 'retro'] def load_config(path: Path) -> dict: with path.open('r') as fh: return json.load(fh) def build_model(config: dict, weights_path: Path) -> torch.nn.Module: model = models.efficientnet_b0(weights=None) head = torch.nn.Linear(model.classifier[1].in_features, config['num_labels']) model.classifier[1] = head checkpoint = torch.load(weights_path, map_location='cpu') model.load_state_dict(checkpoint) model.eval() return model def build_transform(config: dict) -> transforms.Compose: return transforms.Compose([ transforms.Resize((config['image_size'], config['image_size'])), transforms.CenterCrop(config['image_size']), transforms.ToTensor(), transforms.Normalize(config['mean'], config['std']) ]) def classify_image(model: torch.nn.Module, tf: transforms.Compose, image_path: Path, labels: list[str]) -> list[tuple[str, float]]: img = Image.open(image_path).convert('RGB') x = tf(img).unsqueeze(0) with torch.no_grad(): probs = torch.softmax(model(x), dim=1)[0] return list(zip(labels, probs.tolist())) def main() -> None: parser = argparse.ArgumentParser(description='Anime style classifier inference helper') parser.add_argument('image', type=Path, help='Path to an image file') parser.add_argument('--model', type=Path, default=Path('pytorch_model.bin')) parser.add_argument('--config', type=Path, default=Path('config.json')) parser.add_argument('--top-k', type=int, default=6) args = parser.parse_args() config = load_config(args.config) labels = [config['id2label'].get(str(i), STYLE_DEFAULTS[i]) for i in range(config['num_labels'])] transform = build_transform(config) model = build_model(config, args.model) preds = classify_image(model, transform, args.image, labels) preds.sort(key=lambda item: item[1], reverse=True) print(f"Predictions for {args.image}:") for style, prob in preds[:args.top_k]: print(f" {style:<10s} {prob:.4f}") if __name__ == '__main__': main()