| Coupling Layers | 32 |
| Coupling Types | 4:1 channel/spatial |
| Training Data | Downsampled ImageNet |
| Training Duration | 50 epochs |
| Performance | 3.79 bits per dim (validation) |
Installation
The model in this repo requires the following pip package. See github.com/btrude/jet-pytorch for more details.
pip install jet-pytorch
Usage
from jet_pytorch import Jet
jet_config = dict(
patch_size=4,
patch_dim=48,
n_patches=256,
coupling_layers=32,
block_depth=2,
block_width=512,
num_heads=8,
scale_factor=2.0,
coupling_types=(
"channels", "channels",
"channels", "channels",
"spatial",
),
spatial_coupling_projs=(
"checkerboard", "checkerboard-inv",
"vstripes", "vstripes-inv",
"hstripes", "hstripes-inv",
)
)
model = Jet(**jet_config)
This is the default Jet configuration and that which matches the pretrained weights available on the huggingface hub
Download and/or load pretrained imagenet 64x64 weights
from jet_pytorch.util import get_pretrained
weights = get_pretrained()
model.load_state_dict(weights)
Sample from a Jet
from torch.distributions import Normal
batch_size = 16
n_patches = 256
patch_dim = 48
pdf = Normal(0, 1)
z = pdf.sample((batch_size, n_patches, patch_dim))
img, logdet = model.inverse(z)
Training a Jet
from jet_pytorch.train import train
jet_config = dict(...)
train(
jet_config=jet_config,
batch_size=64,
accumulate_steps=16,
device="cuda:0",
epochs=50,
warmup_percentage=0.1,
max_grad_norm=1.0,
learning_rate=3e-4,
weight_decay=1e-5,
adam_betas=(0.9, 0.95),
images_path_train="/path/to/train/images",
images_path_valid="/path/to/validation/images",
num_workers=8,
checkpoint_path="jet.pt",
)
The training code favors gradient accumulation over alternatives like gradient checkpointing, allowing this script to run on GPUs with less than 8GB of VRAM. The true batch size is thus equal to batch_size * accumulate_steps. Note that the default configuration assumes at least 24GB of VRAM.
Create visualizations
from jet_pytorch.sample import sample
# Creates visualization using the default Jet config/pretrained weights
sample("path/to/your/images")
# Creates visualization using default Jet config/a local checkpoint
sample(
"path/to/your/images",
checkpoint_path="path/to/your/checkpoint.pt",
)
# Creates visualization using custom Jet config/a local checkpoint
jet_config = dict(...)
sample(
"path/to/your/images",
jet_config=jet_config,
checkpoint_path="path/to/your/checkpoint.pt",
)
Visualization results are stored at output/jet.png
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support