Coupling Layers 32
Coupling Types 4:1 channel/spatial
Training Data Downsampled ImageNet
Training Duration 50 epochs
Performance 3.79 bits per dim (validation)

Installation

The model in this repo requires the following pip package. See github.com/btrude/jet-pytorch for more details.

pip install jet-pytorch

Usage

from jet_pytorch import Jet

jet_config = dict(
    patch_size=4,
    patch_dim=48,
    n_patches=256,
    coupling_layers=32,
    block_depth=2,
    block_width=512,
    num_heads=8,
    scale_factor=2.0,
    coupling_types=(
        "channels", "channels",
        "channels", "channels",
        "spatial",
    ),
    spatial_coupling_projs=(
        "checkerboard", "checkerboard-inv",
        "vstripes", "vstripes-inv",
        "hstripes", "hstripes-inv",
    )
)
model = Jet(**jet_config)

This is the default Jet configuration and that which matches the pretrained weights available on the huggingface hub

Download and/or load pretrained imagenet 64x64 weights

from jet_pytorch.util import get_pretrained

weights = get_pretrained()
model.load_state_dict(weights)

Sample from a Jet

from torch.distributions import Normal

batch_size = 16
n_patches = 256
patch_dim = 48
pdf = Normal(0, 1)
z = pdf.sample((batch_size, n_patches, patch_dim))
img, logdet = model.inverse(z)

Training a Jet

from jet_pytorch.train import train

jet_config = dict(...)
train(
    jet_config=jet_config,
    batch_size=64,
    accumulate_steps=16,
    device="cuda:0",
    epochs=50,
    warmup_percentage=0.1,
    max_grad_norm=1.0,
    learning_rate=3e-4,
    weight_decay=1e-5,
    adam_betas=(0.9, 0.95),
    images_path_train="/path/to/train/images",
    images_path_valid="/path/to/validation/images",
    num_workers=8,
    checkpoint_path="jet.pt",
)

The training code favors gradient accumulation over alternatives like gradient checkpointing, allowing this script to run on GPUs with less than 8GB of VRAM. The true batch size is thus equal to batch_size * accumulate_steps. Note that the default configuration assumes at least 24GB of VRAM.

Create visualizations

from jet_pytorch.sample import sample

# Creates visualization using the default Jet config/pretrained weights
sample("path/to/your/images")

# Creates visualization using default Jet config/a local checkpoint
sample(
    "path/to/your/images",
    checkpoint_path="path/to/your/checkpoint.pt",
)

# Creates visualization using custom Jet config/a local checkpoint
jet_config = dict(...)
sample(
    "path/to/your/images",
    jet_config=jet_config,
    checkpoint_path="path/to/your/checkpoint.pt",
)

Visualization results are stored at output/jet.png

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support