Mitchins's picture
Upload folder using huggingface_hub
eb9bf47 verified
#!/usr/bin/env python3
"""Quick inference helper for the anime style classifier."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from PIL import Image
from torchvision import models, transforms
STYLE_DEFAULTS = ['flat', 'grim', 'modern', 'moe', 'painterly', 'retro']
def load_config(path: Path) -> dict:
with path.open('r') as fh:
return json.load(fh)
def build_model(config: dict, weights_path: Path) -> torch.nn.Module:
model = models.efficientnet_b0(weights=None)
head = torch.nn.Linear(model.classifier[1].in_features, config['num_labels'])
model.classifier[1] = head
checkpoint = torch.load(weights_path, map_location='cpu')
model.load_state_dict(checkpoint)
model.eval()
return model
def build_transform(config: dict) -> transforms.Compose:
return transforms.Compose([
transforms.Resize((config['image_size'], config['image_size'])),
transforms.CenterCrop(config['image_size']),
transforms.ToTensor(),
transforms.Normalize(config['mean'], config['std'])
])
def classify_image(model: torch.nn.Module, tf: transforms.Compose, image_path: Path, labels: list[str]) -> list[tuple[str, float]]:
img = Image.open(image_path).convert('RGB')
x = tf(img).unsqueeze(0)
with torch.no_grad():
probs = torch.softmax(model(x), dim=1)[0]
return list(zip(labels, probs.tolist()))
def main() -> None:
parser = argparse.ArgumentParser(description='Anime style classifier inference helper')
parser.add_argument('image', type=Path, help='Path to an image file')
parser.add_argument('--model', type=Path, default=Path('pytorch_model.bin'))
parser.add_argument('--config', type=Path, default=Path('config.json'))
parser.add_argument('--top-k', type=int, default=6)
args = parser.parse_args()
config = load_config(args.config)
labels = [config['id2label'].get(str(i), STYLE_DEFAULTS[i]) for i in range(config['num_labels'])]
transform = build_transform(config)
model = build_model(config, args.model)
preds = classify_image(model, transform, args.image, labels)
preds.sort(key=lambda item: item[1], reverse=True)
print(f"Predictions for {args.image}:")
for style, prob in preds[:args.top_k]:
print(f" {style:<10s} {prob:.4f}")
if __name__ == '__main__':
main()