Airliner Latent Classifier (Stable Diffusion v1.4)
Latent-space binary classifier trained on Stable Diffusion v1.4 VAE latents (shape 4ร64ร64) with a simple MLP head and a timestep embedding (from the DDIM scheduler).
Intended for concept probing and classifier guidance in diffusion workflows.
- Concept:
airliner - Input: latent tensor
z โ โ^{4ร64ร64}and a diffusion timestept - Output: logit/probability that
zcontains the concept at timestept - Author/Org: DiffusionConceptErasure
- Date: 2025-11-05
Usage (PyTorch)
import torch
from diffusers import DDIMScheduler
# ---- model definition (must match training) ----
import torch.nn as nn
class FixedTimestepEncoding(nn.Module):
def __init__(self, scheduler):
super().__init__()
self.register_buffer("alphas_cumprod", scheduler.alphas_cumprod)
def forward(self, t):
alpha_bar = self.alphas_cumprod[t]
return torch.stack([alpha_bar.sqrt(), (1 - alpha_bar).sqrt()], dim=-1)
class LatentClassifierT(nn.Module):
def __init__(self, latent_shape=(4, 64, 64), scheduler=None):
super().__init__()
c, h, w = latent_shape
flat_dim = c * h * w
self.t_embed = FixedTimestepEncoding(scheduler)
self.fc_t = nn.Linear(2, 1024)
self.fc_x = nn.Linear(flat_dim, 1024)
self.net = nn.Sequential(
nn.SiLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.SiLU(),
nn.Dropout(0.3),
nn.Linear(512, 1)
)
def forward(self, z, t):
z_flat = z.flatten(start_dim=1)
return self.net(self.fc_x(z_flat) + self.fc_t(self.t_embed(t)))
# ---- load weights ----
repo_id = "DiffusionConceptErasure/latent-classifier-airliner"
ckpt_name = "airliner.pt"
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
model = LatentClassifierT(scheduler=scheduler)
state = torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{repo_id}/resolve/main/{ckpt_name}",
map_location="cpu"
)
model.load_state_dict(state["model_state_dict"] if "model_state_dict" in state else state)
model.eval()
# Example inference:
z = torch.randn(1, 4, 64, 64) # latent
t = torch.randint(0, scheduler.config.num_train_timesteps, (1,)) # timestep
with torch.no_grad():
logit = model(z, t) # shape [1, 1]
prob = torch.sigmoid(logit)
print(prob.item())
Notes
- Trained with DDIM power-law timestep sampling biased to noisier latents.
- For classifier guidance, average logits across a few noisy t samples if desired.
- Expectation: highest discriminability at moderate noise; extreme noise reduces signal.
Citation
If you use this, please cite:
@inproceedings{lu2025concepts,
title={When Are Concepts Erased From Diffusion Models?},
author={Kevin Lu and Nicky Kriplani and Rohit Gandikota and Minh Pham and David Bau and Chinmay Hegde and Niv Cohen},
booktitle={NeurIPS},
year={2025}
}