Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- data/extracted_features/subj01/nsd_cliptext_train.npy +3 -0
- data/extracted_features/subj01/nsd_vdvae_features_31l.npz +3 -0
- data/predicted_features/subj01/nsd_clipvision_predtest_nsdgeneral.npy +3 -0
- data/predicted_features/subj01/nsd_vdvae_nsdgeneral_pred_sub1_31l_alpha50k.npy +3 -0
- data/processed_data/subj01/nsd_train_fmriavg_nsdgeneral_sub1.npy +3 -0
- vdvae/header-image.png +3 -0
- vdvae/model/imagenet64-iter-1600000-model-ema.th +3 -0
- vdvae/model/imagenet64-iter-1600000-model.th +3 -0
- vdvae/model/imagenet64-iter-1600000-opt.th +3 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/modeling_utils.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_bert.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_gpt2.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_gpt2.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_bert.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_bert.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_gpt2.cpython-38.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_utils.cpython-310.pyc +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/vocab/bert-base-cased-vocab.txt +0 -0
- versatile_diffusion/lib/model_zoo/optimus_models/vocab/bert_vocab_download_info.json +15 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/sd.yaml +68 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/vd.yaml +61 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/__init__.py +0 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/cfg_helper.py +664 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/cfg_holder.py +28 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/ds_estimator.py +85 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/ds_formatter.py +39 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/__init__.py +1 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/eva_base.py +293 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/eva_null.py +26 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/__init__.py +0 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/sd_default.py +441 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/vd_default.py +549 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/__init__.py +4 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/attention.py +435 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/autoencoder.py +428 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/bert.py +142 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip.py +226 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/__init__.py +1 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/clip.py +237 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/model.py +436 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/simple_tokenizer.py +132 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/get_model.py +120 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/get_optimizer.py +47 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/get_scheduler.py +262 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/utils.py +292 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim.py +341 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim_dualcontext.py +144 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim_dualmodel.py +244 -0
- versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim_vd.py +293 -0
.gitattributes
CHANGED
|
@@ -2985,3 +2985,14 @@ results/versatile_diffusion/subj01/roi/12.png filter=lfs diff=lfs merge=lfs -tex
|
|
| 2985 |
results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
|
| 2986 |
results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
|
| 2987 |
results/versatile_diffusion/subj01/roi/4.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2985 |
results/versatile_diffusion/subj01/roi/2.png filter=lfs diff=lfs merge=lfs -text
|
| 2986 |
results/versatile_diffusion/subj01/roi/3.png filter=lfs diff=lfs merge=lfs -text
|
| 2987 |
results/versatile_diffusion/subj01/roi/4.png filter=lfs diff=lfs merge=lfs -text
|
| 2988 |
+
vdvae/header-image.png filter=lfs diff=lfs merge=lfs -text
|
| 2989 |
+
data/predicted_features/subj01/nsd_vdvae_nsdgeneral_pred_sub1_31l_alpha50k.npy filter=lfs diff=lfs merge=lfs -text
|
| 2990 |
+
versatile_diffusion/pretrained/kl-f8.pth filter=lfs diff=lfs merge=lfs -text
|
| 2991 |
+
data/processed_data/subj01/nsd_train_fmriavg_nsdgeneral_sub1.npy filter=lfs diff=lfs merge=lfs -text
|
| 2992 |
+
vdvae/model/imagenet64-iter-1600000-model-ema.th filter=lfs diff=lfs merge=lfs -text
|
| 2993 |
+
vdvae/model/imagenet64-iter-1600000-model.th filter=lfs diff=lfs merge=lfs -text
|
| 2994 |
+
data/predicted_features/subj01/nsd_clipvision_predtest_nsdgeneral.npy filter=lfs diff=lfs merge=lfs -text
|
| 2995 |
+
versatile_diffusion/pretrained/optimus-vae.pth filter=lfs diff=lfs merge=lfs -text
|
| 2996 |
+
vdvae/model/imagenet64-iter-1600000-opt.th filter=lfs diff=lfs merge=lfs -text
|
| 2997 |
+
data/extracted_features/subj01/nsd_vdvae_features_31l.npz filter=lfs diff=lfs merge=lfs -text
|
| 2998 |
+
data/extracted_features/subj01/nsd_cliptext_train.npy filter=lfs diff=lfs merge=lfs -text
|
data/extracted_features/subj01/nsd_cliptext_train.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe33228e2c7334c50c6b09d790c17c188787c46ed89addd37efe2c8f0dfc9e44
|
| 3 |
+
size 4191086720
|
data/extracted_features/subj01/nsd_vdvae_features_31l.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a919a9bd590e71398e45fff63895051f50230151547f398b584e83aa434fc38f
|
| 3 |
+
size 3588737796
|
data/predicted_features/subj01/nsd_clipvision_predtest_nsdgeneral.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba7dc9d374d06dc67d8b3b0fe88e2961bce2e44b549081c702023d94183d53a0
|
| 3 |
+
size 1550585984
|
data/predicted_features/subj01/nsd_vdvae_nsdgeneral_pred_sub1_31l_alpha50k.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5db240ec4b4ad976b2f5ef847ae6277342565de10dadb581ae7748ba0b4a91a6
|
| 3 |
+
size 716215936
|
data/processed_data/subj01/nsd_train_fmriavg_nsdgeneral_sub1.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26b874870b202e820f290617dcb15fe88af5908cdbf23b25c5a743bcf404aa12
|
| 3 |
+
size 1114391456
|
vdvae/header-image.png
ADDED
|
Git LFS Details
|
vdvae/model/imagenet64-iter-1600000-model-ema.th
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d3b9e07177f5c5082e31685cc47f4e53f8557e4a5eefb72bca514445d3bc48e
|
| 3 |
+
size 500977841
|
vdvae/model/imagenet64-iter-1600000-model.th
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1701a7e831b370ceaa5f3b57a9d8e7b610b9d10cffd3083dae11986eea95bbb
|
| 3 |
+
size 501006513
|
vdvae/model/imagenet64-iter-1600000-opt.th
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:855bc24ebac829a7ae3cdb6fd8c40a91bde7faab4f5ed857a94a1480d03de6f8
|
| 3 |
+
size 1001616051
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/modeling_utils.cpython-310.pyc
ADDED
|
Binary file (31.6 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_bert.cpython-310.pyc
ADDED
|
Binary file (52.5 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_gpt2.cpython-310.pyc
ADDED
|
Binary file (32.8 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/optimus_gpt2.cpython-38.pyc
ADDED
|
Binary file (34.4 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_bert.cpython-310.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_bert.cpython-38.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_gpt2.cpython-38.pyc
ADDED
|
Binary file (9.04 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/__pycache__/tokenization_utils.cpython-310.pyc
ADDED
|
Binary file (33 kB). View file
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/vocab/bert-base-cased-vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
versatile_diffusion/lib/model_zoo/optimus_models/vocab/bert_vocab_download_info.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
| 3 |
+
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
| 4 |
+
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
| 5 |
+
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
| 6 |
+
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
| 7 |
+
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
| 8 |
+
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
| 9 |
+
"bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
| 10 |
+
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
| 11 |
+
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
| 12 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
| 13 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
| 14 |
+
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt"
|
| 15 |
+
}
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/sd.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sd_base:
|
| 2 |
+
symbol: sd
|
| 3 |
+
find_unused_parameters: true
|
| 4 |
+
|
| 5 |
+
sd_autoencoder:
|
| 6 |
+
type: autoencoderkl
|
| 7 |
+
args:
|
| 8 |
+
embed_dim: 4
|
| 9 |
+
monitor: val/rec_loss
|
| 10 |
+
ddconfig:
|
| 11 |
+
double_z: true
|
| 12 |
+
z_channels: 4
|
| 13 |
+
resolution: 256
|
| 14 |
+
in_channels: 3
|
| 15 |
+
out_ch: 3
|
| 16 |
+
ch: 128
|
| 17 |
+
ch_mult: [1, 2, 4, 4]
|
| 18 |
+
num_res_blocks: 2
|
| 19 |
+
attn_resolutions: []
|
| 20 |
+
dropout: 0.0
|
| 21 |
+
lossconfig:
|
| 22 |
+
target: torch.nn.Identity
|
| 23 |
+
pth: pretrained/kl-f8.pth
|
| 24 |
+
|
| 25 |
+
sd_t2i:
|
| 26 |
+
super_cfg: sd_base
|
| 27 |
+
type: sd_t2i
|
| 28 |
+
args:
|
| 29 |
+
first_stage_config: MODEL(sd_autoencoder)
|
| 30 |
+
cond_stage_config: MODEL(clip_text_frozen)
|
| 31 |
+
unet_config: MODEL(openai_unet_sd)
|
| 32 |
+
beta_linear_start: 0.00085
|
| 33 |
+
beta_linear_end: 0.012
|
| 34 |
+
num_timesteps_cond: 1
|
| 35 |
+
timesteps: 1000
|
| 36 |
+
scale_factor: 0.18215
|
| 37 |
+
use_ema: true
|
| 38 |
+
|
| 39 |
+
sd_t2i_noema:
|
| 40 |
+
super_cfg: sd
|
| 41 |
+
args:
|
| 42 |
+
use_ema: false
|
| 43 |
+
|
| 44 |
+
#####################
|
| 45 |
+
# sd with full clip #
|
| 46 |
+
#####################
|
| 47 |
+
|
| 48 |
+
sd_t2i_fullclip_backward_compatible:
|
| 49 |
+
super_cfg: sd_t2i
|
| 50 |
+
args:
|
| 51 |
+
cond_stage_config: MODEL(clip_frozen_encode_text_noproj)
|
| 52 |
+
|
| 53 |
+
sd_t2i_fullclip_backward_compatible_noema:
|
| 54 |
+
super_cfg: sd_t2i_noema
|
| 55 |
+
args:
|
| 56 |
+
cond_stage_config: MODEL(clip_frozen_encode_text_noproj)
|
| 57 |
+
|
| 58 |
+
sd_t2i_fullclip:
|
| 59 |
+
super_cfg: sd_t2i
|
| 60 |
+
args:
|
| 61 |
+
cond_stage_config: MODEL(clip_frozen_encode_text)
|
| 62 |
+
|
| 63 |
+
sd_variation:
|
| 64 |
+
super_cfg: sd_t2i
|
| 65 |
+
type: sd_variation
|
| 66 |
+
args:
|
| 67 |
+
cond_stage_config: MODEL(clip_vision_frozen_justin)
|
| 68 |
+
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/configs/model/vd.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# vd_base:
|
| 2 |
+
# symbol: vd
|
| 3 |
+
# find_unused_parameters: true
|
| 4 |
+
|
| 5 |
+
############
|
| 6 |
+
# vd basic #
|
| 7 |
+
############
|
| 8 |
+
|
| 9 |
+
vd_basic:
|
| 10 |
+
super_cfg: sd_t2i
|
| 11 |
+
type: vd_basic
|
| 12 |
+
symbol: vd
|
| 13 |
+
find_unused_parameters: true
|
| 14 |
+
args:
|
| 15 |
+
cond_stage_config: MODEL(clip_frozen_encode_vision)
|
| 16 |
+
|
| 17 |
+
vd_basic_noema:
|
| 18 |
+
super_cfg: vd_basic
|
| 19 |
+
args:
|
| 20 |
+
use_ema: false
|
| 21 |
+
|
| 22 |
+
###################
|
| 23 |
+
# vd dual-context #
|
| 24 |
+
###################
|
| 25 |
+
|
| 26 |
+
vd_dc:
|
| 27 |
+
super_cfg: sd_t2i_fullclip
|
| 28 |
+
type: vd_dc
|
| 29 |
+
symbol: vd
|
| 30 |
+
find_unused_parameters: true
|
| 31 |
+
args:
|
| 32 |
+
unet_config: MODEL(openai_unet_dual_context)
|
| 33 |
+
|
| 34 |
+
vd_dc_noema:
|
| 35 |
+
super_cfg: vd_dc
|
| 36 |
+
args:
|
| 37 |
+
use_ema: false
|
| 38 |
+
|
| 39 |
+
######
|
| 40 |
+
# vd #
|
| 41 |
+
######
|
| 42 |
+
|
| 43 |
+
vd:
|
| 44 |
+
type: vd
|
| 45 |
+
symbol: vd
|
| 46 |
+
find_unused_parameters: true
|
| 47 |
+
args:
|
| 48 |
+
autokl_cfg: MODEL(sd_autoencoder)
|
| 49 |
+
optimus_cfg: MODEL(optimus_vae)
|
| 50 |
+
clip_cfg: MODEL(clip_frozen)
|
| 51 |
+
unet_config: MODEL(openai_unet_vd)
|
| 52 |
+
beta_linear_start: 0.00085
|
| 53 |
+
beta_linear_end: 0.012
|
| 54 |
+
timesteps: 1000
|
| 55 |
+
scale_factor: 0.18215
|
| 56 |
+
use_ema: true
|
| 57 |
+
|
| 58 |
+
vd_noema:
|
| 59 |
+
super_cfg: vd
|
| 60 |
+
args:
|
| 61 |
+
use_ema: false
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/__init__.py
ADDED
|
File without changes
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/cfg_helper.py
ADDED
|
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import shutil
|
| 4 |
+
import copy
|
| 5 |
+
import time
|
| 6 |
+
import pprint
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import matplotlib
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import yaml
|
| 13 |
+
from easydict import EasyDict as edict
|
| 14 |
+
|
| 15 |
+
from .model_zoo import get_model
|
| 16 |
+
|
| 17 |
+
############
|
| 18 |
+
# cfg_bank #
|
| 19 |
+
############
|
| 20 |
+
|
| 21 |
+
def cfg_solvef(cmd, root):
|
| 22 |
+
if not isinstance(cmd, str):
|
| 23 |
+
return cmd
|
| 24 |
+
|
| 25 |
+
if cmd.find('SAME')==0:
|
| 26 |
+
zoom = root
|
| 27 |
+
p = cmd[len('SAME'):].strip('()').split('.')
|
| 28 |
+
p = [pi.strip() for pi in p]
|
| 29 |
+
for pi in p:
|
| 30 |
+
try:
|
| 31 |
+
pi = int(pi)
|
| 32 |
+
except:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
zoom = zoom[pi]
|
| 37 |
+
except:
|
| 38 |
+
return cmd
|
| 39 |
+
return cfg_solvef(zoom, root)
|
| 40 |
+
|
| 41 |
+
if cmd.find('SEARCH')==0:
|
| 42 |
+
zoom = root
|
| 43 |
+
p = cmd[len('SEARCH'):].strip('()').split('.')
|
| 44 |
+
p = [pi.strip() for pi in p]
|
| 45 |
+
find = True
|
| 46 |
+
# Depth first search
|
| 47 |
+
for pi in p:
|
| 48 |
+
try:
|
| 49 |
+
pi = int(pi)
|
| 50 |
+
except:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
zoom = zoom[pi]
|
| 55 |
+
except:
|
| 56 |
+
find = False
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
if find:
|
| 60 |
+
return cfg_solvef(zoom, root)
|
| 61 |
+
else:
|
| 62 |
+
if isinstance(root, dict):
|
| 63 |
+
for ri in root:
|
| 64 |
+
rv = cfg_solvef(cmd, root[ri])
|
| 65 |
+
if rv != cmd:
|
| 66 |
+
return rv
|
| 67 |
+
if isinstance(root, list):
|
| 68 |
+
for ri in root:
|
| 69 |
+
rv = cfg_solvef(cmd, ri)
|
| 70 |
+
if rv != cmd:
|
| 71 |
+
return rv
|
| 72 |
+
return cmd
|
| 73 |
+
|
| 74 |
+
if cmd.find('MODEL')==0:
|
| 75 |
+
goto = cmd[len('MODEL'):].strip('()')
|
| 76 |
+
return model_cfg_bank()(goto)
|
| 77 |
+
|
| 78 |
+
if cmd.find('DATASET')==0:
|
| 79 |
+
goto = cmd[len('DATASET'):].strip('()')
|
| 80 |
+
return dataset_cfg_bank()(goto)
|
| 81 |
+
|
| 82 |
+
return cmd
|
| 83 |
+
|
| 84 |
+
def cfg_solve(cfg, cfg_root):
|
| 85 |
+
# The function solve cfg element such that
|
| 86 |
+
# all sorrogate input are settled.
|
| 87 |
+
# (i.e. SAME(***) )
|
| 88 |
+
if isinstance(cfg, list):
|
| 89 |
+
for i in range(len(cfg)):
|
| 90 |
+
if isinstance(cfg[i], (list, dict)):
|
| 91 |
+
cfg[i] = cfg_solve(cfg[i], cfg_root)
|
| 92 |
+
else:
|
| 93 |
+
cfg[i] = cfg_solvef(cfg[i], cfg_root)
|
| 94 |
+
if isinstance(cfg, dict):
|
| 95 |
+
for k in cfg:
|
| 96 |
+
if isinstance(cfg[k], (list, dict)):
|
| 97 |
+
cfg[k] = cfg_solve(cfg[k], cfg_root)
|
| 98 |
+
else:
|
| 99 |
+
cfg[k] = cfg_solvef(cfg[k], cfg_root)
|
| 100 |
+
return cfg
|
| 101 |
+
|
| 102 |
+
class model_cfg_bank(object):
|
| 103 |
+
def __init__(self):
|
| 104 |
+
self.cfg_dir = osp.join('configs', 'model')
|
| 105 |
+
self.cfg_bank = edict()
|
| 106 |
+
|
| 107 |
+
def __call__(self, name):
|
| 108 |
+
if name not in self.cfg_bank:
|
| 109 |
+
cfg_path = self.get_yaml_path(name)
|
| 110 |
+
with open(cfg_path, 'r') as f:
|
| 111 |
+
cfg_new = yaml.load(
|
| 112 |
+
f, Loader=yaml.FullLoader)
|
| 113 |
+
cfg_new = edict(cfg_new)
|
| 114 |
+
self.cfg_bank.update(cfg_new)
|
| 115 |
+
|
| 116 |
+
cfg = self.cfg_bank[name]
|
| 117 |
+
cfg.name = name
|
| 118 |
+
if 'super_cfg' not in cfg:
|
| 119 |
+
cfg = cfg_solve(cfg, cfg)
|
| 120 |
+
self.cfg_bank[name] = cfg
|
| 121 |
+
return copy.deepcopy(cfg)
|
| 122 |
+
|
| 123 |
+
super_cfg = self.__call__(cfg.super_cfg)
|
| 124 |
+
# unlike other field,
|
| 125 |
+
# args will not be replaced but update.
|
| 126 |
+
if 'args' in cfg:
|
| 127 |
+
if 'args' in super_cfg:
|
| 128 |
+
super_cfg.args.update(cfg.args)
|
| 129 |
+
else:
|
| 130 |
+
super_cfg.args = cfg.args
|
| 131 |
+
cfg.pop('args')
|
| 132 |
+
|
| 133 |
+
super_cfg.update(cfg)
|
| 134 |
+
super_cfg.pop('super_cfg')
|
| 135 |
+
cfg = super_cfg
|
| 136 |
+
try:
|
| 137 |
+
delete_args = cfg.pop('delete_args')
|
| 138 |
+
except:
|
| 139 |
+
delete_args = []
|
| 140 |
+
|
| 141 |
+
for dargs in delete_args:
|
| 142 |
+
cfg.args.pop(dargs)
|
| 143 |
+
|
| 144 |
+
cfg = cfg_solve(cfg, cfg)
|
| 145 |
+
self.cfg_bank[name] = cfg
|
| 146 |
+
return copy.deepcopy(cfg)
|
| 147 |
+
|
| 148 |
+
def get_yaml_path(self, name):
|
| 149 |
+
if name.find('ldm')==0:
|
| 150 |
+
return osp.join(
|
| 151 |
+
self.cfg_dir, 'ldm.yaml')
|
| 152 |
+
elif name.find('comodgan')==0:
|
| 153 |
+
return osp.join(
|
| 154 |
+
self.cfg_dir, 'comodgan.yaml')
|
| 155 |
+
elif name.find('stylegan')==0:
|
| 156 |
+
return osp.join(
|
| 157 |
+
self.cfg_dir, 'stylegan.yaml')
|
| 158 |
+
elif name.find('absgan')==0:
|
| 159 |
+
return osp.join(
|
| 160 |
+
self.cfg_dir, 'absgan.yaml')
|
| 161 |
+
elif name.find('ashgan')==0:
|
| 162 |
+
return osp.join(
|
| 163 |
+
self.cfg_dir, 'ashgan.yaml')
|
| 164 |
+
elif name.find('sr3')==0:
|
| 165 |
+
return osp.join(
|
| 166 |
+
self.cfg_dir, 'sr3.yaml')
|
| 167 |
+
elif name.find('specdiffsr')==0:
|
| 168 |
+
return osp.join(
|
| 169 |
+
self.cfg_dir, 'specdiffsr.yaml')
|
| 170 |
+
elif name.find('openai_unet')==0:
|
| 171 |
+
return osp.join(
|
| 172 |
+
self.cfg_dir, 'openai_unet.yaml')
|
| 173 |
+
elif name.find('clip')==0:
|
| 174 |
+
return osp.join(
|
| 175 |
+
self.cfg_dir, 'clip.yaml')
|
| 176 |
+
elif name.find('sd')==0:
|
| 177 |
+
return osp.join(
|
| 178 |
+
self.cfg_dir, 'sd.yaml')
|
| 179 |
+
elif name.find('vd')==0:
|
| 180 |
+
return osp.join(
|
| 181 |
+
self.cfg_dir, 'vd.yaml')
|
| 182 |
+
elif name.find('optimus')==0:
|
| 183 |
+
return osp.join(
|
| 184 |
+
self.cfg_dir, 'optimus.yaml')
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError
|
| 187 |
+
|
| 188 |
+
class dataset_cfg_bank(object):
|
| 189 |
+
def __init__(self):
|
| 190 |
+
self.cfg_dir = osp.join('configs', 'dataset')
|
| 191 |
+
self.cfg_bank = edict()
|
| 192 |
+
|
| 193 |
+
def __call__(self, name):
|
| 194 |
+
if name not in self.cfg_bank:
|
| 195 |
+
cfg_path = self.get_yaml_path(name)
|
| 196 |
+
with open(cfg_path, 'r') as f:
|
| 197 |
+
cfg_new = yaml.load(
|
| 198 |
+
f, Loader=yaml.FullLoader)
|
| 199 |
+
cfg_new = edict(cfg_new)
|
| 200 |
+
self.cfg_bank.update(cfg_new)
|
| 201 |
+
|
| 202 |
+
cfg = self.cfg_bank[name]
|
| 203 |
+
cfg.name = name
|
| 204 |
+
if cfg.get('super_cfg', None) is None:
|
| 205 |
+
cfg = cfg_solve(cfg, cfg)
|
| 206 |
+
self.cfg_bank[name] = cfg
|
| 207 |
+
return copy.deepcopy(cfg)
|
| 208 |
+
|
| 209 |
+
super_cfg = self.__call__(cfg.super_cfg)
|
| 210 |
+
super_cfg.update(cfg)
|
| 211 |
+
cfg = super_cfg
|
| 212 |
+
cfg.super_cfg = None
|
| 213 |
+
try:
|
| 214 |
+
delete = cfg.pop('delete')
|
| 215 |
+
except:
|
| 216 |
+
delete = []
|
| 217 |
+
|
| 218 |
+
for dargs in delete:
|
| 219 |
+
cfg.pop(dargs)
|
| 220 |
+
|
| 221 |
+
cfg = cfg_solve(cfg, cfg)
|
| 222 |
+
self.cfg_bank[name] = cfg
|
| 223 |
+
return copy.deepcopy(cfg)
|
| 224 |
+
|
| 225 |
+
def get_yaml_path(self, name):
|
| 226 |
+
if name.find('cityscapes')==0:
|
| 227 |
+
return osp.join(
|
| 228 |
+
self.cfg_dir, 'cityscapes.yaml')
|
| 229 |
+
elif name.find('div2k')==0:
|
| 230 |
+
return osp.join(
|
| 231 |
+
self.cfg_dir, 'div2k.yaml')
|
| 232 |
+
elif name.find('gandiv2k')==0:
|
| 233 |
+
return osp.join(
|
| 234 |
+
self.cfg_dir, 'gandiv2k.yaml')
|
| 235 |
+
elif name.find('srbenchmark')==0:
|
| 236 |
+
return osp.join(
|
| 237 |
+
self.cfg_dir, 'srbenchmark.yaml')
|
| 238 |
+
elif name.find('imagedir')==0:
|
| 239 |
+
return osp.join(
|
| 240 |
+
self.cfg_dir, 'imagedir.yaml')
|
| 241 |
+
elif name.find('places2')==0:
|
| 242 |
+
return osp.join(
|
| 243 |
+
self.cfg_dir, 'places2.yaml')
|
| 244 |
+
elif name.find('ffhq')==0:
|
| 245 |
+
return osp.join(
|
| 246 |
+
self.cfg_dir, 'ffhq.yaml')
|
| 247 |
+
elif name.find('imcpt')==0:
|
| 248 |
+
return osp.join(
|
| 249 |
+
self.cfg_dir, 'imcpt.yaml')
|
| 250 |
+
elif name.find('texture')==0:
|
| 251 |
+
return osp.join(
|
| 252 |
+
self.cfg_dir, 'texture.yaml')
|
| 253 |
+
elif name.find('openimages')==0:
|
| 254 |
+
return osp.join(
|
| 255 |
+
self.cfg_dir, 'openimages.yaml')
|
| 256 |
+
elif name.find('laion2b')==0:
|
| 257 |
+
return osp.join(
|
| 258 |
+
self.cfg_dir, 'laion2b.yaml')
|
| 259 |
+
elif name.find('laionart')==0:
|
| 260 |
+
return osp.join(
|
| 261 |
+
self.cfg_dir, 'laionart.yaml')
|
| 262 |
+
elif name.find('celeba')==0:
|
| 263 |
+
return osp.join(
|
| 264 |
+
self.cfg_dir, 'celeba.yaml')
|
| 265 |
+
elif name.find('coyo')==0:
|
| 266 |
+
return osp.join(
|
| 267 |
+
self.cfg_dir, 'coyo.yaml')
|
| 268 |
+
elif name.find('pafc')==0:
|
| 269 |
+
return osp.join(
|
| 270 |
+
self.cfg_dir, 'pafc.yaml')
|
| 271 |
+
elif name.find('coco')==0:
|
| 272 |
+
return osp.join(
|
| 273 |
+
self.cfg_dir, 'coco.yaml')
|
| 274 |
+
else:
|
| 275 |
+
raise ValueError
|
| 276 |
+
|
| 277 |
+
class experiment_cfg_bank(object):
|
| 278 |
+
def __init__(self):
|
| 279 |
+
self.cfg_dir = osp.join('configs', 'experiment')
|
| 280 |
+
self.cfg_bank = edict()
|
| 281 |
+
|
| 282 |
+
def __call__(self, name):
|
| 283 |
+
if name not in self.cfg_bank:
|
| 284 |
+
cfg_path = self.get_yaml_path(name)
|
| 285 |
+
with open(cfg_path, 'r') as f:
|
| 286 |
+
cfg = yaml.load(
|
| 287 |
+
f, Loader=yaml.FullLoader)
|
| 288 |
+
cfg = edict(cfg)
|
| 289 |
+
|
| 290 |
+
cfg = cfg_solve(cfg, cfg)
|
| 291 |
+
cfg = cfg_solve(cfg, cfg)
|
| 292 |
+
# twice for SEARCH
|
| 293 |
+
self.cfg_bank[name] = cfg
|
| 294 |
+
return copy.deepcopy(cfg)
|
| 295 |
+
|
| 296 |
+
def get_yaml_path(self, name):
|
| 297 |
+
return osp.join(
|
| 298 |
+
self.cfg_dir, name+'.yaml')
|
| 299 |
+
|
| 300 |
+
def load_cfg_yaml(path):
|
| 301 |
+
if osp.isfile(path):
|
| 302 |
+
cfg_path = path
|
| 303 |
+
elif osp.isfile(osp.join('configs', 'experiment', path)):
|
| 304 |
+
cfg_path = osp.join('configs', 'experiment', path)
|
| 305 |
+
elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
|
| 306 |
+
cfg_path = osp.join('configs', 'experiment', path+'.yaml')
|
| 307 |
+
else:
|
| 308 |
+
assert False, 'No such config!'
|
| 309 |
+
|
| 310 |
+
with open(cfg_path, 'r') as f:
|
| 311 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 312 |
+
cfg = edict(cfg)
|
| 313 |
+
cfg = cfg_solve(cfg, cfg)
|
| 314 |
+
cfg = cfg_solve(cfg, cfg)
|
| 315 |
+
return cfg
|
| 316 |
+
|
| 317 |
+
##############
|
| 318 |
+
# cfg_helper #
|
| 319 |
+
##############
|
| 320 |
+
|
| 321 |
+
def get_experiment_id(ref=None):
|
| 322 |
+
if ref is None:
|
| 323 |
+
time.sleep(0.5)
|
| 324 |
+
return int(time.time()*100)
|
| 325 |
+
else:
|
| 326 |
+
try:
|
| 327 |
+
return int(ref)
|
| 328 |
+
except:
|
| 329 |
+
pass
|
| 330 |
+
|
| 331 |
+
_, ref = osp.split(ref)
|
| 332 |
+
ref = ref.split('_')[0]
|
| 333 |
+
try:
|
| 334 |
+
return int(ref)
|
| 335 |
+
except:
|
| 336 |
+
assert False, 'Invalid experiment ID!'
|
| 337 |
+
|
| 338 |
+
def record_resume_cfg(path):
|
| 339 |
+
cnt = 0
|
| 340 |
+
while True:
|
| 341 |
+
if osp.exists(path+'.{:04d}'.format(cnt)):
|
| 342 |
+
cnt += 1
|
| 343 |
+
continue
|
| 344 |
+
shutil.copyfile(path, path+'.{:04d}'.format(cnt))
|
| 345 |
+
break
|
| 346 |
+
|
| 347 |
+
def get_command_line_args():
|
| 348 |
+
parser = argparse.ArgumentParser()
|
| 349 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
| 350 |
+
parser.add_argument('--config', type=str)
|
| 351 |
+
parser.add_argument('--gpu', nargs='+', type=int)
|
| 352 |
+
|
| 353 |
+
parser.add_argument('--node_rank', type=int, default=0)
|
| 354 |
+
parser.add_argument('--nodes', type=int, default=1)
|
| 355 |
+
parser.add_argument('--addr', type=str, default='127.0.0.1')
|
| 356 |
+
parser.add_argument('--port', type=int, default=11233)
|
| 357 |
+
|
| 358 |
+
parser.add_argument('--signature', nargs='+', type=str)
|
| 359 |
+
parser.add_argument('--seed', type=int)
|
| 360 |
+
|
| 361 |
+
parser.add_argument('--eval', type=str)
|
| 362 |
+
parser.add_argument('--eval_subdir', type=str)
|
| 363 |
+
parser.add_argument('--pretrained', type=str)
|
| 364 |
+
|
| 365 |
+
parser.add_argument('--resume_dir', type=str)
|
| 366 |
+
parser.add_argument('--resume_step', type=int)
|
| 367 |
+
parser.add_argument('--resume_weight', type=str)
|
| 368 |
+
|
| 369 |
+
args = parser.parse_args()
|
| 370 |
+
|
| 371 |
+
# Special handling the resume
|
| 372 |
+
if args.resume_dir is not None:
|
| 373 |
+
cfg = edict()
|
| 374 |
+
cfg.env = edict()
|
| 375 |
+
cfg.env.debug = args.debug
|
| 376 |
+
cfg.env.resume = edict()
|
| 377 |
+
cfg.env.resume.dir = args.resume_dir
|
| 378 |
+
cfg.env.resume.step = args.resume_step
|
| 379 |
+
cfg.env.resume.weight = args.resume_weight
|
| 380 |
+
return cfg
|
| 381 |
+
|
| 382 |
+
cfg = load_cfg_yaml(args.config)
|
| 383 |
+
cfg.env.debug = args.debug
|
| 384 |
+
cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
|
| 385 |
+
cfg.env.master_addr = args.addr
|
| 386 |
+
cfg.env.master_port = args.port
|
| 387 |
+
cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
|
| 388 |
+
cfg.env.node_rank = args.node_rank
|
| 389 |
+
cfg.env.nodes = args.nodes
|
| 390 |
+
|
| 391 |
+
istrain = False if args.eval is not None else True
|
| 392 |
+
isdebug = cfg.env.debug
|
| 393 |
+
|
| 394 |
+
if istrain:
|
| 395 |
+
if isdebug:
|
| 396 |
+
cfg.env.experiment_id = 999999999999
|
| 397 |
+
cfg.train.signature = ['debug']
|
| 398 |
+
else:
|
| 399 |
+
cfg.env.experiment_id = get_experiment_id()
|
| 400 |
+
if args.signature is not None:
|
| 401 |
+
cfg.train.signature = args.signature
|
| 402 |
+
else:
|
| 403 |
+
if 'train' in cfg:
|
| 404 |
+
cfg.pop('train')
|
| 405 |
+
cfg.env.experiment_id = get_experiment_id(args.eval)
|
| 406 |
+
if args.signature is not None:
|
| 407 |
+
cfg.eval.signature = args.signature
|
| 408 |
+
|
| 409 |
+
if isdebug and (args.eval is None):
|
| 410 |
+
cfg.env.experiment_id = 999999999999
|
| 411 |
+
cfg.eval.signature = ['debug']
|
| 412 |
+
|
| 413 |
+
if args.eval_subdir is not None:
|
| 414 |
+
if isdebug:
|
| 415 |
+
cfg.eval.eval_subdir = 'debug'
|
| 416 |
+
else:
|
| 417 |
+
cfg.eval.eval_subdir = args.eval_subdir
|
| 418 |
+
if args.pretrained is not None:
|
| 419 |
+
cfg.eval.pretrained = args.pretrained
|
| 420 |
+
# The override pretrained over the setting in cfg.model
|
| 421 |
+
|
| 422 |
+
if args.seed is not None:
|
| 423 |
+
cfg.env.rnd_seed = args.seed
|
| 424 |
+
|
| 425 |
+
return cfg
|
| 426 |
+
|
| 427 |
+
def cfg_initiates(cfg):
|
| 428 |
+
cfge = cfg.env
|
| 429 |
+
isdebug = cfge.debug
|
| 430 |
+
isresume = 'resume' in cfge
|
| 431 |
+
istrain = 'train' in cfg
|
| 432 |
+
haseval = 'eval' in cfg
|
| 433 |
+
cfgt = cfg.train if istrain else None
|
| 434 |
+
cfgv = cfg.eval if haseval else None
|
| 435 |
+
|
| 436 |
+
###############################
|
| 437 |
+
# get some environment params #
|
| 438 |
+
###############################
|
| 439 |
+
|
| 440 |
+
cfge.computer = os.uname()
|
| 441 |
+
cfge.torch_version = str(torch.__version__)
|
| 442 |
+
|
| 443 |
+
##########
|
| 444 |
+
# resume #
|
| 445 |
+
##########
|
| 446 |
+
|
| 447 |
+
if isresume:
|
| 448 |
+
resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
|
| 449 |
+
record_resume_cfg(resume_cfg_path)
|
| 450 |
+
with open(resume_cfg_path, 'r') as f:
|
| 451 |
+
cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
|
| 452 |
+
cfg_resume = edict(cfg_resume)
|
| 453 |
+
cfg_resume.env.update(cfge)
|
| 454 |
+
cfg = cfg_resume
|
| 455 |
+
cfge = cfg.env
|
| 456 |
+
log_file = cfg.train.log_file
|
| 457 |
+
|
| 458 |
+
print('')
|
| 459 |
+
print('##########')
|
| 460 |
+
print('# resume #')
|
| 461 |
+
print('##########')
|
| 462 |
+
print('')
|
| 463 |
+
with open(log_file, 'a') as f:
|
| 464 |
+
print('', file=f)
|
| 465 |
+
print('##########', file=f)
|
| 466 |
+
print('# resume #', file=f)
|
| 467 |
+
print('##########', file=f)
|
| 468 |
+
print('', file=f)
|
| 469 |
+
|
| 470 |
+
pprint.pprint(cfg)
|
| 471 |
+
with open(log_file, 'a') as f:
|
| 472 |
+
pprint.pprint(cfg, f)
|
| 473 |
+
|
| 474 |
+
####################
|
| 475 |
+
# node distributed #
|
| 476 |
+
####################
|
| 477 |
+
|
| 478 |
+
if cfg.env.master_addr!='127.0.0.1':
|
| 479 |
+
os.environ['MASTER_ADDR'] = cfge.master_addr
|
| 480 |
+
os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
|
| 481 |
+
if cfg.env.dist_backend=='nccl':
|
| 482 |
+
os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
|
| 483 |
+
if cfg.env.dist_backend=='gloo':
|
| 484 |
+
os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
|
| 485 |
+
|
| 486 |
+
#######################
|
| 487 |
+
# cuda visible device #
|
| 488 |
+
#######################
|
| 489 |
+
|
| 490 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
|
| 491 |
+
[str(gid) for gid in cfge.gpu_device])
|
| 492 |
+
|
| 493 |
+
#####################
|
| 494 |
+
# return resume cfg #
|
| 495 |
+
#####################
|
| 496 |
+
|
| 497 |
+
if isresume:
|
| 498 |
+
return cfg
|
| 499 |
+
|
| 500 |
+
#############################################
|
| 501 |
+
# some misc setting that not need in resume #
|
| 502 |
+
#############################################
|
| 503 |
+
|
| 504 |
+
cfgm = cfg.model
|
| 505 |
+
cfge.gpu_count = len(cfge.gpu_device)
|
| 506 |
+
|
| 507 |
+
##########################################
|
| 508 |
+
# align batch size and num worker config #
|
| 509 |
+
##########################################
|
| 510 |
+
|
| 511 |
+
gpu_n = cfge.gpu_count * cfge.nodes
|
| 512 |
+
def align_batch_size(bs, bs_per_gpu):
|
| 513 |
+
assert (bs is not None) or (bs_per_gpu is not None)
|
| 514 |
+
bs = bs_per_gpu * gpu_n if bs is None else bs
|
| 515 |
+
bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
|
| 516 |
+
assert (bs == bs_per_gpu * gpu_n)
|
| 517 |
+
return bs, bs_per_gpu
|
| 518 |
+
|
| 519 |
+
if istrain:
|
| 520 |
+
cfgt.batch_size, cfgt.batch_size_per_gpu = \
|
| 521 |
+
align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
|
| 522 |
+
cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
|
| 523 |
+
align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
|
| 524 |
+
if haseval:
|
| 525 |
+
cfgv.batch_size, cfgv.batch_size_per_gpu = \
|
| 526 |
+
align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
|
| 527 |
+
cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
|
| 528 |
+
align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
|
| 529 |
+
|
| 530 |
+
##################
|
| 531 |
+
# create log dir #
|
| 532 |
+
##################
|
| 533 |
+
|
| 534 |
+
if istrain:
|
| 535 |
+
if not isdebug:
|
| 536 |
+
sig = cfgt.get('signature', [])
|
| 537 |
+
version = get_model().get_version(cfgm.type)
|
| 538 |
+
sig = sig + ['v{}'.format(version), 's{}'.format(cfge.rnd_seed)]
|
| 539 |
+
else:
|
| 540 |
+
sig = ['debug']
|
| 541 |
+
|
| 542 |
+
log_dir = [
|
| 543 |
+
cfge.log_root_dir,
|
| 544 |
+
'{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
|
| 545 |
+
'_'.join([str(cfge.experiment_id)] + sig)
|
| 546 |
+
]
|
| 547 |
+
log_dir = osp.join(*log_dir)
|
| 548 |
+
log_file = osp.join(log_dir, 'train.log')
|
| 549 |
+
if not osp.exists(log_file):
|
| 550 |
+
os.makedirs(osp.dirname(log_file))
|
| 551 |
+
cfgt.log_dir = log_dir
|
| 552 |
+
cfgt.log_file = log_file
|
| 553 |
+
|
| 554 |
+
if haseval:
|
| 555 |
+
cfgv.log_dir = log_dir
|
| 556 |
+
cfgv.log_file = log_file
|
| 557 |
+
else:
|
| 558 |
+
model_symbol = cfgm.symbol
|
| 559 |
+
if cfgv.get('dataset', None) is None:
|
| 560 |
+
dataset_symbol = 'nodataset'
|
| 561 |
+
else:
|
| 562 |
+
dataset_symbol = cfgv.dataset.symbol
|
| 563 |
+
|
| 564 |
+
log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
|
| 565 |
+
exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
|
| 566 |
+
if exp_dir is None:
|
| 567 |
+
if not isdebug:
|
| 568 |
+
sig = cfgv.get('signature', []) + ['evalonly']
|
| 569 |
+
else:
|
| 570 |
+
sig = ['debug']
|
| 571 |
+
exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
|
| 572 |
+
|
| 573 |
+
eval_subdir = cfgv.get('eval_subdir', None)
|
| 574 |
+
# override subdir in debug mode (if eval_subdir is set)
|
| 575 |
+
eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
|
| 576 |
+
|
| 577 |
+
if eval_subdir is not None:
|
| 578 |
+
log_dir = osp.join(log_dir, exp_dir, eval_subdir)
|
| 579 |
+
else:
|
| 580 |
+
log_dir = osp.join(log_dir, exp_dir)
|
| 581 |
+
|
| 582 |
+
disable_log_override = cfgv.get('disable_log_override', False)
|
| 583 |
+
if osp.isdir(log_dir):
|
| 584 |
+
if disable_log_override:
|
| 585 |
+
assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
|
| 586 |
+
else:
|
| 587 |
+
os.makedirs(log_dir)
|
| 588 |
+
|
| 589 |
+
log_file = osp.join(log_dir, 'eval.log')
|
| 590 |
+
cfgv.log_dir = log_dir
|
| 591 |
+
cfgv.log_file = log_file
|
| 592 |
+
|
| 593 |
+
######################
|
| 594 |
+
# print and save cfg #
|
| 595 |
+
######################
|
| 596 |
+
|
| 597 |
+
pprint.pprint(cfg)
|
| 598 |
+
with open(log_file, 'w') as f:
|
| 599 |
+
pprint.pprint(cfg, f)
|
| 600 |
+
with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
|
| 601 |
+
yaml.dump(edict_2_dict(cfg), f)
|
| 602 |
+
|
| 603 |
+
#############
|
| 604 |
+
# save code #
|
| 605 |
+
#############
|
| 606 |
+
|
| 607 |
+
save_code = False
|
| 608 |
+
if istrain:
|
| 609 |
+
save_code = cfgt.get('save_code', False)
|
| 610 |
+
elif haseval:
|
| 611 |
+
save_code = cfgv.get('save_code', False)
|
| 612 |
+
|
| 613 |
+
if save_code:
|
| 614 |
+
codedir = osp.join(log_dir, 'code')
|
| 615 |
+
if osp.exists(codedir):
|
| 616 |
+
shutil.rmtree(codedir)
|
| 617 |
+
for d in ['configs', 'lib']:
|
| 618 |
+
fromcodedir = d
|
| 619 |
+
tocodedir = osp.join(codedir, d)
|
| 620 |
+
shutil.copytree(
|
| 621 |
+
fromcodedir, tocodedir,
|
| 622 |
+
ignore=shutil.ignore_patterns(
|
| 623 |
+
'*__pycache__*', '*build*'))
|
| 624 |
+
for codei in os.listdir('.'):
|
| 625 |
+
if osp.splitext(codei)[1] == 'py':
|
| 626 |
+
shutil.copy(codei, codedir)
|
| 627 |
+
|
| 628 |
+
#######################
|
| 629 |
+
# set matplotlib mode #
|
| 630 |
+
#######################
|
| 631 |
+
|
| 632 |
+
if 'matplotlib_mode' in cfge:
|
| 633 |
+
try:
|
| 634 |
+
matplotlib.use(cfge.matplotlib_mode)
|
| 635 |
+
except:
|
| 636 |
+
print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
|
| 637 |
+
|
| 638 |
+
return cfg
|
| 639 |
+
|
| 640 |
+
def edict_2_dict(x):
|
| 641 |
+
if isinstance(x, dict):
|
| 642 |
+
xnew = {}
|
| 643 |
+
for k in x:
|
| 644 |
+
xnew[k] = edict_2_dict(x[k])
|
| 645 |
+
return xnew
|
| 646 |
+
elif isinstance(x, list):
|
| 647 |
+
xnew = []
|
| 648 |
+
for i in range(len(x)):
|
| 649 |
+
xnew.append( edict_2_dict(x[i]) )
|
| 650 |
+
return xnew
|
| 651 |
+
else:
|
| 652 |
+
return x
|
| 653 |
+
|
| 654 |
+
def search_experiment_folder(root, exid):
|
| 655 |
+
target = None
|
| 656 |
+
for fi in os.listdir(root):
|
| 657 |
+
if not osp.isdir(osp.join(root, fi)):
|
| 658 |
+
continue
|
| 659 |
+
if int(fi.split('_')[0]) == exid:
|
| 660 |
+
if target is not None:
|
| 661 |
+
return None # duplicated
|
| 662 |
+
elif target is None:
|
| 663 |
+
target = fi
|
| 664 |
+
return target
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/cfg_holder.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
def singleton(class_):
|
| 4 |
+
instances = {}
|
| 5 |
+
def getinstance(*args, **kwargs):
|
| 6 |
+
if class_ not in instances:
|
| 7 |
+
instances[class_] = class_(*args, **kwargs)
|
| 8 |
+
return instances[class_]
|
| 9 |
+
return getinstance
|
| 10 |
+
|
| 11 |
+
##############
|
| 12 |
+
# cfg_holder #
|
| 13 |
+
##############
|
| 14 |
+
|
| 15 |
+
@singleton
|
| 16 |
+
class cfg_unique_holder(object):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.cfg = None
|
| 19 |
+
# this is use to track the main codes.
|
| 20 |
+
self.code = set()
|
| 21 |
+
def save_cfg(self, cfg):
|
| 22 |
+
self.cfg = copy.deepcopy(cfg)
|
| 23 |
+
def add_code(self, code):
|
| 24 |
+
"""
|
| 25 |
+
A new main code is reached and
|
| 26 |
+
its name is added.
|
| 27 |
+
"""
|
| 28 |
+
self.code.add(code)
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/ds_estimator.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import numpy.random as npr
|
| 4 |
+
import PIL
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
import xml.etree.ElementTree as ET
|
| 10 |
+
import json
|
| 11 |
+
import copy
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
def singleton(class_):
|
| 15 |
+
instances = {}
|
| 16 |
+
def getinstance(*args, **kwargs):
|
| 17 |
+
if class_ not in instances:
|
| 18 |
+
instances[class_] = class_(*args, **kwargs)
|
| 19 |
+
return instances[class_]
|
| 20 |
+
return getinstance
|
| 21 |
+
|
| 22 |
+
@singleton
|
| 23 |
+
class get_estimator(object):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.estimator = {}
|
| 26 |
+
|
| 27 |
+
def register(self, estimf):
|
| 28 |
+
self.estimator[estimf.__name__] = estimf
|
| 29 |
+
|
| 30 |
+
def __call__(self, cfg):
|
| 31 |
+
if cfg is None:
|
| 32 |
+
return None
|
| 33 |
+
t = cfg.type
|
| 34 |
+
return self.estimator[t](**cfg.args)
|
| 35 |
+
|
| 36 |
+
def register():
|
| 37 |
+
def wrapper(class_):
|
| 38 |
+
get_estimator().register(class_)
|
| 39 |
+
return class_
|
| 40 |
+
return wrapper
|
| 41 |
+
|
| 42 |
+
@register()
|
| 43 |
+
class PickFileEstimator(object):
|
| 44 |
+
"""
|
| 45 |
+
This is an estimator that filter load_info
|
| 46 |
+
using the provided filelist
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self,
|
| 49 |
+
filelist = None,
|
| 50 |
+
repeat_n = 1):
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
filelist: a list of string gives the name of images
|
| 54 |
+
we would like to visualize, evaluate or train.
|
| 55 |
+
repeat_n: int, times these images will be repeated
|
| 56 |
+
"""
|
| 57 |
+
self.filelist = filelist
|
| 58 |
+
self.repeat_n = repeat_n
|
| 59 |
+
|
| 60 |
+
def __call__(self, load_info):
|
| 61 |
+
load_info_new = []
|
| 62 |
+
for info in load_info:
|
| 63 |
+
if os.path.basename(info['image_path']).split('.')[0] in self.filelist:
|
| 64 |
+
load_info_new.append(info)
|
| 65 |
+
return load_info_new * self.repeat_n
|
| 66 |
+
|
| 67 |
+
@register()
|
| 68 |
+
class PickIndexEstimator(object):
|
| 69 |
+
"""
|
| 70 |
+
This is an estimator that filter load_info
|
| 71 |
+
using the provided indices
|
| 72 |
+
"""
|
| 73 |
+
def __init__(self,
|
| 74 |
+
indexlist = None,
|
| 75 |
+
**kwargs):
|
| 76 |
+
"""
|
| 77 |
+
Args:
|
| 78 |
+
indexlist: [] of int.
|
| 79 |
+
the indices to be filtered out.
|
| 80 |
+
"""
|
| 81 |
+
self.indexlist = indexlist
|
| 82 |
+
|
| 83 |
+
def __call__(self, load_info):
|
| 84 |
+
load_info_new = [load_info[i] for i in self.indexlist]
|
| 85 |
+
return load_info_new
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/data_factory/common/ds_formatter.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import numpy as np
|
| 4 |
+
import numpy.random as npr
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
import scipy.ndimage
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import copy
|
| 10 |
+
import gc
|
| 11 |
+
import itertools
|
| 12 |
+
|
| 13 |
+
def singleton(class_):
|
| 14 |
+
instances = {}
|
| 15 |
+
def getinstance(*args, **kwargs):
|
| 16 |
+
if class_ not in instances:
|
| 17 |
+
instances[class_] = class_(*args, **kwargs)
|
| 18 |
+
return instances[class_]
|
| 19 |
+
return getinstance
|
| 20 |
+
|
| 21 |
+
@singleton
|
| 22 |
+
class get_formatter(object):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.formatter = {}
|
| 25 |
+
|
| 26 |
+
def register(self, formatf):
|
| 27 |
+
self.formatter[formatf.__name__] = formatf
|
| 28 |
+
|
| 29 |
+
def __call__(self, cfg):
|
| 30 |
+
if cfg is None:
|
| 31 |
+
return None
|
| 32 |
+
t = cfg.type
|
| 33 |
+
return self.formatter[t](**cfg.args)
|
| 34 |
+
|
| 35 |
+
def register():
|
| 36 |
+
def wrapper(class_):
|
| 37 |
+
get_formatter().register(class_)
|
| 38 |
+
return class_
|
| 39 |
+
return wrapper
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .eva_base import get_evaluator
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/eva_base.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import copy
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from ..log_service import print_log
|
| 12 |
+
|
| 13 |
+
def singleton(class_):
|
| 14 |
+
instances = {}
|
| 15 |
+
def getinstance(*args, **kwargs):
|
| 16 |
+
if class_ not in instances:
|
| 17 |
+
instances[class_] = class_(*args, **kwargs)
|
| 18 |
+
return instances[class_]
|
| 19 |
+
return getinstance
|
| 20 |
+
|
| 21 |
+
@singleton
|
| 22 |
+
class get_evaluator(object):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.evaluator = {}
|
| 25 |
+
|
| 26 |
+
def register(self, evaf, name):
|
| 27 |
+
self.evaluator[name] = evaf
|
| 28 |
+
|
| 29 |
+
def __call__(self, pipeline_cfg=None):
|
| 30 |
+
if pipeline_cfg is None:
|
| 31 |
+
from . import eva_null
|
| 32 |
+
return self.evaluator['null']()
|
| 33 |
+
|
| 34 |
+
if not isinstance(pipeline_cfg, list):
|
| 35 |
+
t = pipeline_cfg.type
|
| 36 |
+
if t == 'miou':
|
| 37 |
+
from . import eva_miou
|
| 38 |
+
if t == 'psnr':
|
| 39 |
+
from . import eva_psnr
|
| 40 |
+
if t == 'ssim':
|
| 41 |
+
from . import eva_ssim
|
| 42 |
+
if t == 'lpips':
|
| 43 |
+
from . import eva_lpips
|
| 44 |
+
if t == 'fid':
|
| 45 |
+
from . import eva_fid
|
| 46 |
+
return self.evaluator[t](**pipeline_cfg.args)
|
| 47 |
+
|
| 48 |
+
evaluator = []
|
| 49 |
+
for ci in pipeline_cfg:
|
| 50 |
+
t = ci.type
|
| 51 |
+
if t == 'miou':
|
| 52 |
+
from . import eva_miou
|
| 53 |
+
if t == 'psnr':
|
| 54 |
+
from . import eva_psnr
|
| 55 |
+
if t == 'ssim':
|
| 56 |
+
from . import eva_ssim
|
| 57 |
+
if t == 'lpips':
|
| 58 |
+
from . import eva_lpips
|
| 59 |
+
if t == 'fid':
|
| 60 |
+
from . import eva_fid
|
| 61 |
+
evaluator.append(
|
| 62 |
+
self.evaluator[t](**ci.args))
|
| 63 |
+
if len(evaluator) == 0:
|
| 64 |
+
return None
|
| 65 |
+
else:
|
| 66 |
+
return compose(evaluator)
|
| 67 |
+
|
| 68 |
+
def register(name):
|
| 69 |
+
def wrapper(class_):
|
| 70 |
+
get_evaluator().register(class_, name)
|
| 71 |
+
return class_
|
| 72 |
+
return wrapper
|
| 73 |
+
|
| 74 |
+
class base_evaluator(object):
|
| 75 |
+
def __init__(self,
|
| 76 |
+
**args):
|
| 77 |
+
'''
|
| 78 |
+
Args:
|
| 79 |
+
sample_n, int,
|
| 80 |
+
the total number of sample. used in
|
| 81 |
+
distributed sync
|
| 82 |
+
'''
|
| 83 |
+
if not dist.is_available():
|
| 84 |
+
raise ValueError
|
| 85 |
+
self.world_size = dist.get_world_size()
|
| 86 |
+
self.rank = dist.get_rank()
|
| 87 |
+
self.sample_n = None
|
| 88 |
+
self.final = {}
|
| 89 |
+
|
| 90 |
+
def sync(self, data):
|
| 91 |
+
"""
|
| 92 |
+
Args:
|
| 93 |
+
data: any,
|
| 94 |
+
the data needs to be broadcasted
|
| 95 |
+
"""
|
| 96 |
+
if data is None:
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
if isinstance(data, tuple):
|
| 100 |
+
data = list(data)
|
| 101 |
+
|
| 102 |
+
if isinstance(data, list):
|
| 103 |
+
data_list = []
|
| 104 |
+
for datai in data:
|
| 105 |
+
data_list.append(self.sync(datai))
|
| 106 |
+
data = [[*i] for i in zip(*data_list)]
|
| 107 |
+
return data
|
| 108 |
+
|
| 109 |
+
data = [
|
| 110 |
+
self.sync_(data, ranki)
|
| 111 |
+
for ranki in range(self.world_size)
|
| 112 |
+
]
|
| 113 |
+
return data
|
| 114 |
+
|
| 115 |
+
def sync_(self, data, rank):
|
| 116 |
+
|
| 117 |
+
t = type(data)
|
| 118 |
+
is_broadcast = rank == self.rank
|
| 119 |
+
|
| 120 |
+
if t is np.ndarray:
|
| 121 |
+
dtrans = data
|
| 122 |
+
dt = data.dtype
|
| 123 |
+
if dt in [
|
| 124 |
+
int,
|
| 125 |
+
np.bool,
|
| 126 |
+
np.uint8,
|
| 127 |
+
np.int8,
|
| 128 |
+
np.int16,
|
| 129 |
+
np.int32,
|
| 130 |
+
np.int64,]:
|
| 131 |
+
dtt = torch.int64
|
| 132 |
+
elif dt in [
|
| 133 |
+
float,
|
| 134 |
+
np.float16,
|
| 135 |
+
np.float32,
|
| 136 |
+
np.float64,]:
|
| 137 |
+
dtt = torch.float64
|
| 138 |
+
|
| 139 |
+
elif t is str:
|
| 140 |
+
dtrans = np.array(
|
| 141 |
+
[ord(c) for c in data],
|
| 142 |
+
dtype = np.int64
|
| 143 |
+
)
|
| 144 |
+
dt = np.int64
|
| 145 |
+
dtt = torch.int64
|
| 146 |
+
else:
|
| 147 |
+
raise ValueError
|
| 148 |
+
|
| 149 |
+
if is_broadcast:
|
| 150 |
+
n = len(dtrans.shape)
|
| 151 |
+
n = torch.tensor(n).long()
|
| 152 |
+
|
| 153 |
+
n = n.to(self.rank)
|
| 154 |
+
dist.broadcast(n, src=rank)
|
| 155 |
+
|
| 156 |
+
n = list(dtrans.shape)
|
| 157 |
+
n = torch.tensor(n).long()
|
| 158 |
+
n = n.to(self.rank)
|
| 159 |
+
dist.broadcast(n, src=rank)
|
| 160 |
+
|
| 161 |
+
n = torch.tensor(dtrans, dtype=dtt)
|
| 162 |
+
n = n.to(self.rank)
|
| 163 |
+
dist.broadcast(n, src=rank)
|
| 164 |
+
return data
|
| 165 |
+
|
| 166 |
+
n = torch.tensor(0).long()
|
| 167 |
+
n = n.to(self.rank)
|
| 168 |
+
dist.broadcast(n, src=rank)
|
| 169 |
+
n = n.item()
|
| 170 |
+
|
| 171 |
+
n = torch.zeros(n).long()
|
| 172 |
+
n = n.to(self.rank)
|
| 173 |
+
dist.broadcast(n, src=rank)
|
| 174 |
+
n = list(n.to('cpu').numpy())
|
| 175 |
+
|
| 176 |
+
n = torch.zeros(n, dtype=dtt)
|
| 177 |
+
n = n.to(self.rank)
|
| 178 |
+
dist.broadcast(n, src=rank)
|
| 179 |
+
n = n.to('cpu').numpy().astype(dt)
|
| 180 |
+
|
| 181 |
+
if t is np.ndarray:
|
| 182 |
+
return n
|
| 183 |
+
elif t is str:
|
| 184 |
+
n = ''.join([chr(c) for c in n])
|
| 185 |
+
return n
|
| 186 |
+
|
| 187 |
+
def zipzap_arrange(self, data):
|
| 188 |
+
'''
|
| 189 |
+
Order the data so it range like this:
|
| 190 |
+
input [[0, 2, 4, 6], [1, 3, 5, 7]] -> output [0, 1, 2, 3, 4, 5, ...]
|
| 191 |
+
'''
|
| 192 |
+
if isinstance(data[0], list):
|
| 193 |
+
data_new = []
|
| 194 |
+
maxlen = max([len(i) for i in data])
|
| 195 |
+
totlen = sum([len(i) for i in data])
|
| 196 |
+
cnt = 0
|
| 197 |
+
for idx in range(maxlen):
|
| 198 |
+
for datai in data:
|
| 199 |
+
data_new += [datai[idx]]
|
| 200 |
+
cnt += 1
|
| 201 |
+
if cnt >= totlen:
|
| 202 |
+
break
|
| 203 |
+
return data_new
|
| 204 |
+
|
| 205 |
+
elif isinstance(data[0], np.ndarray):
|
| 206 |
+
maxlen = max([i.shape[0] for i in data])
|
| 207 |
+
totlen = sum([i.shape[0] for i in data])
|
| 208 |
+
datai_shape = data[0].shape[1:]
|
| 209 |
+
data = [
|
| 210 |
+
np.concatenate(datai, np.zeros(maxlen-datai.shape[0], *datai_shape), axis=0)
|
| 211 |
+
if datai.shape[0] < maxlen else datai
|
| 212 |
+
for datai in data
|
| 213 |
+
] # even the array
|
| 214 |
+
data = np.stack(data, axis=1).reshape(-1, *datai_shape)
|
| 215 |
+
data = data[:totlen]
|
| 216 |
+
return data
|
| 217 |
+
|
| 218 |
+
else:
|
| 219 |
+
raise NotImplementedError
|
| 220 |
+
|
| 221 |
+
def add_batch(self, **args):
|
| 222 |
+
raise NotImplementedError
|
| 223 |
+
|
| 224 |
+
def set_sample_n(self, sample_n):
|
| 225 |
+
self.sample_n = sample_n
|
| 226 |
+
|
| 227 |
+
def compute(self):
|
| 228 |
+
raise NotImplementedError
|
| 229 |
+
|
| 230 |
+
# Function needed in training to judge which
|
| 231 |
+
# evaluated number is better
|
| 232 |
+
def isbetter(self, old, new):
|
| 233 |
+
return new>old
|
| 234 |
+
|
| 235 |
+
def one_line_summary(self):
|
| 236 |
+
print_log('Evaluator display')
|
| 237 |
+
|
| 238 |
+
def save(self, path):
|
| 239 |
+
if not osp.exists(path):
|
| 240 |
+
os.makedirs(path)
|
| 241 |
+
ofile = osp.join(path, 'result.json')
|
| 242 |
+
with open(ofile, 'w') as f:
|
| 243 |
+
json.dump(self.final, f, indent=4)
|
| 244 |
+
|
| 245 |
+
def clear_data(self):
|
| 246 |
+
raise NotImplementedError
|
| 247 |
+
|
| 248 |
+
class compose(object):
|
| 249 |
+
def __init__(self, pipeline):
|
| 250 |
+
self.pipeline = pipeline
|
| 251 |
+
self.sample_n = None
|
| 252 |
+
self.final = {}
|
| 253 |
+
|
| 254 |
+
def add_batch(self, *args, **kwargs):
|
| 255 |
+
for pi in self.pipeline:
|
| 256 |
+
pi.add_batch(*args, **kwargs)
|
| 257 |
+
|
| 258 |
+
def set_sample_n(self, sample_n):
|
| 259 |
+
self.sample_n = sample_n
|
| 260 |
+
for pi in self.pipeline:
|
| 261 |
+
pi.set_sample_n(sample_n)
|
| 262 |
+
|
| 263 |
+
def compute(self):
|
| 264 |
+
rv = {}
|
| 265 |
+
for pi in self.pipeline:
|
| 266 |
+
rv[pi.symbol] = pi.compute()
|
| 267 |
+
self.final[pi.symbol] = pi.final
|
| 268 |
+
return rv
|
| 269 |
+
|
| 270 |
+
def isbetter(self, old, new):
|
| 271 |
+
check = 0
|
| 272 |
+
for pi in self.pipeline:
|
| 273 |
+
if pi.isbetter(old, new):
|
| 274 |
+
check+=1
|
| 275 |
+
if check/len(self.pipeline)>0.5:
|
| 276 |
+
return True
|
| 277 |
+
else:
|
| 278 |
+
return False
|
| 279 |
+
|
| 280 |
+
def one_line_summary(self):
|
| 281 |
+
for pi in self.pipeline:
|
| 282 |
+
pi.one_line_summary()
|
| 283 |
+
|
| 284 |
+
def save(self, path):
|
| 285 |
+
if not osp.exists(path):
|
| 286 |
+
os.makedirs(path)
|
| 287 |
+
ofile = osp.join(path, 'result.json')
|
| 288 |
+
with open(ofile, 'w') as f:
|
| 289 |
+
json.dump(self.final, f, indent=4)
|
| 290 |
+
|
| 291 |
+
def clear_data(self):
|
| 292 |
+
for pi in self.pipeline:
|
| 293 |
+
pi.clear_data()
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/evaluator/eva_null.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import lpips
|
| 4 |
+
|
| 5 |
+
from .. import nputils
|
| 6 |
+
from ..log_service import print_log
|
| 7 |
+
|
| 8 |
+
from .eva_base import base_evaluator, register
|
| 9 |
+
|
| 10 |
+
@register('null')
|
| 11 |
+
class null_evaluator(base_evaluator):
|
| 12 |
+
def __init__(self, **dummy):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
def add_batch(self,
|
| 16 |
+
**dummy):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def compute(self):
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
def one_line_summary(self):
|
| 23 |
+
print_log('Evaluator null')
|
| 24 |
+
|
| 25 |
+
def clear_data(self):
|
| 26 |
+
pass
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/__init__.py
ADDED
|
File without changes
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/sd_default.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torchvision import transforms as tvtrans
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import time
|
| 7 |
+
import timeit
|
| 8 |
+
import copy
|
| 9 |
+
import json
|
| 10 |
+
import pickle
|
| 11 |
+
import PIL.Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from easydict import EasyDict as edict
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
|
| 17 |
+
from lib.cfg_holder import cfg_unique_holder as cfguh
|
| 18 |
+
from lib.data_factory import get_dataset, get_sampler, collate
|
| 19 |
+
from lib.model_zoo import \
|
| 20 |
+
get_model, get_optimizer, get_scheduler
|
| 21 |
+
from lib.log_service import print_log
|
| 22 |
+
|
| 23 |
+
from ..utils import train as train_base
|
| 24 |
+
from ..utils import eval as eval_base
|
| 25 |
+
from ..utils import train_stage as tsbase
|
| 26 |
+
from ..utils import eval_stage as esbase
|
| 27 |
+
from .. import sync
|
| 28 |
+
|
| 29 |
+
###############
|
| 30 |
+
# some helper #
|
| 31 |
+
###############
|
| 32 |
+
|
| 33 |
+
def atomic_save(cfg, net, opt, step, path):
|
| 34 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 35 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 36 |
+
netm = net.module
|
| 37 |
+
else:
|
| 38 |
+
netm = net
|
| 39 |
+
sd = netm.state_dict()
|
| 40 |
+
slimmed_sd = [(ki, vi) for ki, vi in sd.items()
|
| 41 |
+
if ki.find('first_stage_model')!=0 and ki.find('cond_stage_model')!=0]
|
| 42 |
+
|
| 43 |
+
checkpoint = {
|
| 44 |
+
"config" : cfg,
|
| 45 |
+
"state_dict" : OrderedDict(slimmed_sd),
|
| 46 |
+
"step" : step}
|
| 47 |
+
if opt is not None:
|
| 48 |
+
checkpoint['optimizer_states'] = opt.state_dict()
|
| 49 |
+
import io
|
| 50 |
+
import fsspec
|
| 51 |
+
bytesbuffer = io.BytesIO()
|
| 52 |
+
torch.save(checkpoint, bytesbuffer)
|
| 53 |
+
with fsspec.open(path, "wb") as f:
|
| 54 |
+
f.write(bytesbuffer.getvalue())
|
| 55 |
+
|
| 56 |
+
def load_state_dict(net, cfg):
|
| 57 |
+
pretrained_pth_full = cfg.get('pretrained_pth_full' , None)
|
| 58 |
+
pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
|
| 59 |
+
pretrained_pth = cfg.get('pretrained_pth' , None)
|
| 60 |
+
pretrained_ckpt = cfg.get('pretrained_ckpt' , None)
|
| 61 |
+
pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None)
|
| 62 |
+
pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None)
|
| 63 |
+
strict_sd = cfg.get('strict_sd', False)
|
| 64 |
+
errmsg = "Overlapped model state_dict! This is undesired behavior!"
|
| 65 |
+
|
| 66 |
+
if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
|
| 67 |
+
assert (pretrained_pth is None) and \
|
| 68 |
+
(pretrained_ckpt is None) and \
|
| 69 |
+
(pretrained_pth_dm is None) and \
|
| 70 |
+
(pretrained_pth_ema is None), errmsg
|
| 71 |
+
if pretrained_pth_full is not None:
|
| 72 |
+
target_file = pretrained_pth_full
|
| 73 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 74 |
+
assert pretrained_ckpt is None, errmsg
|
| 75 |
+
else:
|
| 76 |
+
target_file = pretrained_ckpt_full
|
| 77 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 78 |
+
print_log('Load full model from [{}] strict [{}].'.format(
|
| 79 |
+
target_file, strict_sd))
|
| 80 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 81 |
+
|
| 82 |
+
if pretrained_pth is not None or pretrained_ckpt is not None:
|
| 83 |
+
assert (pretrained_ckpt_full is None) and \
|
| 84 |
+
(pretrained_pth_full is None) and \
|
| 85 |
+
(pretrained_pth_dm is None) and \
|
| 86 |
+
(pretrained_pth_ema is None), errmsg
|
| 87 |
+
if pretrained_pth is not None:
|
| 88 |
+
target_file = pretrained_pth
|
| 89 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 90 |
+
assert pretrained_ckpt is None, errmsg
|
| 91 |
+
else:
|
| 92 |
+
target_file = pretrained_ckpt
|
| 93 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 94 |
+
print_log('Load model from [{}] strict [{}].'.format(
|
| 95 |
+
target_file, strict_sd))
|
| 96 |
+
sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
|
| 97 |
+
if ki.find('first_stage_model')==0 or ki.find('cond_stage_model')==0]
|
| 98 |
+
sd.update(OrderedDict(sd_extra))
|
| 99 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 100 |
+
|
| 101 |
+
if pretrained_pth_dm is not None:
|
| 102 |
+
assert (pretrained_ckpt_full is None) and \
|
| 103 |
+
(pretrained_pth_full is None) and \
|
| 104 |
+
(pretrained_pth is None) and \
|
| 105 |
+
(pretrained_ckpt is None), errmsg
|
| 106 |
+
print_log('Load diffusion model from [{}] strict [{}].'.format(
|
| 107 |
+
pretrained_pth_dm, strict_sd))
|
| 108 |
+
sd = torch.load(pretrained_pth_dm, map_location='cpu')
|
| 109 |
+
net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
|
| 110 |
+
|
| 111 |
+
if pretrained_pth_ema is not None:
|
| 112 |
+
assert (pretrained_ckpt_full is None) and \
|
| 113 |
+
(pretrained_pth_full is None) and \
|
| 114 |
+
(pretrained_pth is None) and \
|
| 115 |
+
(pretrained_ckpt is None), errmsg
|
| 116 |
+
print_log('Load unet ema model from [{}] strict [{}].'.format(
|
| 117 |
+
pretrained_pth_ema, strict_sd))
|
| 118 |
+
sd = torch.load(pretrained_pth_ema, map_location='cpu')
|
| 119 |
+
net.model_ema.load_state_dict(sd, strict=strict_sd)
|
| 120 |
+
|
| 121 |
+
def auto_merge_imlist(imlist, max=64):
|
| 122 |
+
imlist = imlist[0:max]
|
| 123 |
+
h, w = imlist[0].shape[0:2]
|
| 124 |
+
num_images = len(imlist)
|
| 125 |
+
num_row = int(np.sqrt(num_images))
|
| 126 |
+
num_col = num_images//num_row + 1 if num_images%num_row!=0 else num_images//num_row
|
| 127 |
+
canvas = np.zeros([num_row*h, num_col*w, 3], dtype=np.uint8)
|
| 128 |
+
for idx, im in enumerate(imlist):
|
| 129 |
+
hi = (idx // num_col) * h
|
| 130 |
+
wi = (idx % num_col) * w
|
| 131 |
+
canvas[hi:hi+h, wi:wi+w, :] = im
|
| 132 |
+
return canvas
|
| 133 |
+
|
| 134 |
+
def latent2im(net, latent):
|
| 135 |
+
single_input = len(latent.shape) == 3
|
| 136 |
+
if single_input:
|
| 137 |
+
latent = latent[None]
|
| 138 |
+
im = net.decode_image(latent.to(net.device))
|
| 139 |
+
im = torch.clamp((im+1.0)/2.0, min=0.0, max=1.0)
|
| 140 |
+
im = [tvtrans.ToPILImage()(i) for i in im]
|
| 141 |
+
if single_input:
|
| 142 |
+
im = im[0]
|
| 143 |
+
return im
|
| 144 |
+
|
| 145 |
+
def im2latent(net, im):
|
| 146 |
+
single_input = not isinstance(im, list)
|
| 147 |
+
if single_input:
|
| 148 |
+
im = [im]
|
| 149 |
+
im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0)
|
| 150 |
+
im = (im*2-1).to(net.device)
|
| 151 |
+
z = net.encode_image(im)
|
| 152 |
+
if single_input:
|
| 153 |
+
z = z[0]
|
| 154 |
+
return z
|
| 155 |
+
|
| 156 |
+
class color_adjust(object):
|
| 157 |
+
def __init__(self, ref_from, ref_to):
|
| 158 |
+
x0, m0, std0 = self.get_data_and_stat(ref_from)
|
| 159 |
+
x1, m1, std1 = self.get_data_and_stat(ref_to)
|
| 160 |
+
self.ref_from_stat = (m0, std0)
|
| 161 |
+
self.ref_to_stat = (m1, std1)
|
| 162 |
+
self.ref_from = self.preprocess(x0).reshape(-1, 3)
|
| 163 |
+
self.ref_to = x1.reshape(-1, 3)
|
| 164 |
+
|
| 165 |
+
def get_data_and_stat(self, x):
|
| 166 |
+
if isinstance(x, str):
|
| 167 |
+
x = np.array(PIL.Image.open(x))
|
| 168 |
+
elif isinstance(x, PIL.Image.Image):
|
| 169 |
+
x = np.array(x)
|
| 170 |
+
elif isinstance(x, torch.Tensor):
|
| 171 |
+
x = torch.clamp(x, min=0.0, max=1.0)
|
| 172 |
+
x = np.array(tvtrans.ToPILImage()(x))
|
| 173 |
+
elif isinstance(x, np.ndarray):
|
| 174 |
+
pass
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError
|
| 177 |
+
x = x.astype(float)
|
| 178 |
+
m = np.reshape(x, (-1, 3)).mean(0)
|
| 179 |
+
s = np.reshape(x, (-1, 3)).std(0)
|
| 180 |
+
return x, m, s
|
| 181 |
+
|
| 182 |
+
def preprocess(self, x):
|
| 183 |
+
m0, s0 = self.ref_from_stat
|
| 184 |
+
m1, s1 = self.ref_to_stat
|
| 185 |
+
y = ((x-m0)/s0)*s1 + m1
|
| 186 |
+
return y
|
| 187 |
+
|
| 188 |
+
def __call__(self, xin, keep=0, simple=False):
|
| 189 |
+
xin, _, _ = self.get_data_and_stat(xin)
|
| 190 |
+
x = self.preprocess(xin)
|
| 191 |
+
if simple:
|
| 192 |
+
y = (x*(1-keep) + xin*keep)
|
| 193 |
+
y = np.clip(y, 0, 255).astype(np.uint8)
|
| 194 |
+
return y
|
| 195 |
+
|
| 196 |
+
h, w = x.shape[:2]
|
| 197 |
+
x = x.reshape(-1, 3)
|
| 198 |
+
y = []
|
| 199 |
+
for chi in range(3):
|
| 200 |
+
yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
|
| 201 |
+
y.append(yi)
|
| 202 |
+
|
| 203 |
+
y = np.stack(y, axis=1)
|
| 204 |
+
y = y.reshape(h, w, 3)
|
| 205 |
+
y = (y.astype(float)*(1-keep) + xin.astype(float)*keep)
|
| 206 |
+
y = np.clip(y, 0, 255).astype(np.uint8)
|
| 207 |
+
return y
|
| 208 |
+
|
| 209 |
+
def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
|
| 210 |
+
arr = np.concatenate((arr_fo, arr_to))
|
| 211 |
+
min_v = arr.min() - 1e-6
|
| 212 |
+
max_v = arr.max() + 1e-6
|
| 213 |
+
min_vto = arr_to.min() - 1e-6
|
| 214 |
+
max_vto = arr_to.max() + 1e-6
|
| 215 |
+
xs = np.array(
|
| 216 |
+
[min_v + (max_v - min_v) * i / n for i in range(n + 1)])
|
| 217 |
+
hist_fo, _ = np.histogram(arr_fo, xs)
|
| 218 |
+
hist_to, _ = np.histogram(arr_to, xs)
|
| 219 |
+
xs = xs[:-1]
|
| 220 |
+
# compute probability distribution
|
| 221 |
+
cum_fo = np.cumsum(hist_fo)
|
| 222 |
+
cum_to = np.cumsum(hist_to)
|
| 223 |
+
d_fo = cum_fo / cum_fo[-1]
|
| 224 |
+
d_to = cum_to / cum_to[-1]
|
| 225 |
+
# transfer
|
| 226 |
+
t_d = np.interp(d_fo, d_to, xs)
|
| 227 |
+
t_d[d_fo <= d_to[ 0]] = min_vto
|
| 228 |
+
t_d[d_fo >= d_to[-1]] = max_vto
|
| 229 |
+
arr_out = np.interp(arr_in, xs, t_d)
|
| 230 |
+
return arr_out
|
| 231 |
+
|
| 232 |
+
########
|
| 233 |
+
# main #
|
| 234 |
+
########
|
| 235 |
+
|
| 236 |
+
class eval(eval_base):
|
| 237 |
+
def prepare_model(self):
|
| 238 |
+
cfg = cfguh().cfg
|
| 239 |
+
net = get_model()(cfg.model)
|
| 240 |
+
if cfg.env.cuda:
|
| 241 |
+
net.to(self.local_rank)
|
| 242 |
+
load_state_dict(net, cfg.eval) #<--- added
|
| 243 |
+
net = torch.nn.parallel.DistributedDataParallel(
|
| 244 |
+
net, device_ids=[self.local_rank],
|
| 245 |
+
find_unused_parameters=True)
|
| 246 |
+
net.eval()
|
| 247 |
+
return {'net' : net,}
|
| 248 |
+
|
| 249 |
+
class eval_stage(esbase):
|
| 250 |
+
"""
|
| 251 |
+
This is eval stage that can check comprehensive results
|
| 252 |
+
"""
|
| 253 |
+
def __init__(self):
|
| 254 |
+
from ..model_zoo.ddim import DDIMSampler
|
| 255 |
+
self.sampler = DDIMSampler
|
| 256 |
+
|
| 257 |
+
def get_net(self, paras):
|
| 258 |
+
return paras['net']
|
| 259 |
+
|
| 260 |
+
def get_image_path(self):
|
| 261 |
+
if 'train' in cfguh().cfg:
|
| 262 |
+
log_dir = cfguh().cfg.train.log_dir
|
| 263 |
+
else:
|
| 264 |
+
log_dir = cfguh().cfg.eval.log_dir
|
| 265 |
+
return os.path.join(log_dir, "udemo")
|
| 266 |
+
|
| 267 |
+
@torch.no_grad()
|
| 268 |
+
def sample(self, net, sampler, prompt, output_dim, scale, n_samples, ddim_steps, ddim_eta):
|
| 269 |
+
h, w = output_dim
|
| 270 |
+
uc = None
|
| 271 |
+
if scale != 1.0:
|
| 272 |
+
uc = net.get_learned_conditioning(n_samples * [""])
|
| 273 |
+
c = net.get_learned_conditioning(n_samples * [prompt])
|
| 274 |
+
shape = [4, h//8, w//8]
|
| 275 |
+
rv = sampler.sample(
|
| 276 |
+
S=ddim_steps,
|
| 277 |
+
conditioning=c,
|
| 278 |
+
batch_size=n_samples,
|
| 279 |
+
shape=shape,
|
| 280 |
+
verbose=False,
|
| 281 |
+
unconditional_guidance_scale=scale,
|
| 282 |
+
unconditional_conditioning=uc,
|
| 283 |
+
eta=ddim_eta)
|
| 284 |
+
return rv
|
| 285 |
+
|
| 286 |
+
def save_images(self, pil_list, name, path, suffix=''):
|
| 287 |
+
canvas = auto_merge_imlist([np.array(i) for i in pil_list])
|
| 288 |
+
image_name = '{}{}.png'.format(name, suffix)
|
| 289 |
+
PIL.Image.fromarray(canvas).save(osp.join(path, image_name))
|
| 290 |
+
|
| 291 |
+
def __call__(self, **paras):
|
| 292 |
+
cfg = cfguh().cfg
|
| 293 |
+
cfgv = cfg.eval
|
| 294 |
+
|
| 295 |
+
net = paras['net']
|
| 296 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 297 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 298 |
+
|
| 299 |
+
LRANK = sync.get_rank('local')
|
| 300 |
+
LWSIZE = sync.get_world_size('local')
|
| 301 |
+
|
| 302 |
+
image_path = self.get_image_path()
|
| 303 |
+
self.create_dir(image_path)
|
| 304 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 305 |
+
suffix='' if eval_cnt is None else '_itern'+str(eval_cnt)
|
| 306 |
+
|
| 307 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 308 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 309 |
+
netm = net.module
|
| 310 |
+
else:
|
| 311 |
+
netm = net
|
| 312 |
+
|
| 313 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 314 |
+
sampler = self.sampler(netm)
|
| 315 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 316 |
+
|
| 317 |
+
replicate = cfgv.get('replicate', 1)
|
| 318 |
+
conditioning = cfgv.conditioning * replicate
|
| 319 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 320 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 321 |
+
|
| 322 |
+
for prompti, seedi in zip(conditioning_local, seed_increment):
|
| 323 |
+
if prompti == 'SKIP':
|
| 324 |
+
continue
|
| 325 |
+
draw_filename = prompti.strip().replace(' ', '-')
|
| 326 |
+
if fix_seed:
|
| 327 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 328 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 329 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 330 |
+
else:
|
| 331 |
+
suffixi = suffix
|
| 332 |
+
|
| 333 |
+
if with_ema:
|
| 334 |
+
with netm.ema_scope():
|
| 335 |
+
x, _ = self.sample(netm, sampler, prompti, **cfgv.sample)
|
| 336 |
+
else:
|
| 337 |
+
x, _ = self.sample(netm, sampler, prompti, **cfgv.sample)
|
| 338 |
+
|
| 339 |
+
demo_image = latent2im(netm, x)
|
| 340 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 341 |
+
|
| 342 |
+
if eval_cnt is not None:
|
| 343 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 344 |
+
return {}
|
| 345 |
+
|
| 346 |
+
##################
|
| 347 |
+
# eval variation #
|
| 348 |
+
##################
|
| 349 |
+
|
| 350 |
+
class eval_stage_variation(eval_stage):
|
| 351 |
+
@torch.no_grad()
|
| 352 |
+
def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta):
|
| 353 |
+
h, w = output_dim
|
| 354 |
+
vh = tvtrans.ToTensor()(PIL.Image.open(visual_hint))[None].to(net.device)
|
| 355 |
+
c = net.get_learned_conditioning(vh)
|
| 356 |
+
c = c.repeat(n_samples, 1, 1)
|
| 357 |
+
uc = None
|
| 358 |
+
if scale != 1.0:
|
| 359 |
+
dummy = torch.zeros_like(vh)
|
| 360 |
+
uc = net.get_learned_conditioning(dummy)
|
| 361 |
+
uc = uc.repeat(n_samples, 1, 1)
|
| 362 |
+
|
| 363 |
+
shape = [4, h//8, w//8]
|
| 364 |
+
rv = sampler.sample(
|
| 365 |
+
S=ddim_steps,
|
| 366 |
+
conditioning=c,
|
| 367 |
+
batch_size=n_samples,
|
| 368 |
+
shape=shape,
|
| 369 |
+
verbose=False,
|
| 370 |
+
unconditional_guidance_scale=scale,
|
| 371 |
+
unconditional_conditioning=uc,
|
| 372 |
+
eta=ddim_eta)
|
| 373 |
+
return rv
|
| 374 |
+
|
| 375 |
+
def __call__(self, **paras):
|
| 376 |
+
cfg = cfguh().cfg
|
| 377 |
+
cfgv = cfg.eval
|
| 378 |
+
|
| 379 |
+
net = paras['net']
|
| 380 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 381 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 382 |
+
|
| 383 |
+
LRANK = sync.get_rank('local')
|
| 384 |
+
LWSIZE = sync.get_world_size('local')
|
| 385 |
+
|
| 386 |
+
image_path = self.get_image_path()
|
| 387 |
+
self.create_dir(image_path)
|
| 388 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 389 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 390 |
+
|
| 391 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 392 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 393 |
+
netm = net.module
|
| 394 |
+
else:
|
| 395 |
+
netm = net
|
| 396 |
+
|
| 397 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 398 |
+
sampler = self.sampler(netm)
|
| 399 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 400 |
+
|
| 401 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 402 |
+
color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 403 |
+
color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 404 |
+
|
| 405 |
+
replicate = cfgv.get('replicate', 1)
|
| 406 |
+
conditioning = cfgv.conditioning * replicate
|
| 407 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 408 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 409 |
+
|
| 410 |
+
for ci, seedi in zip(conditioning_local, seed_increment):
|
| 411 |
+
if ci == 'SKIP':
|
| 412 |
+
continue
|
| 413 |
+
|
| 414 |
+
draw_filename = osp.splitext(osp.basename(ci))[0]
|
| 415 |
+
|
| 416 |
+
if fix_seed:
|
| 417 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 418 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 419 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 420 |
+
else:
|
| 421 |
+
suffixi = suffix
|
| 422 |
+
|
| 423 |
+
if with_ema:
|
| 424 |
+
with netm.ema_scope():
|
| 425 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 426 |
+
else:
|
| 427 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 428 |
+
|
| 429 |
+
demo_image = latent2im(netm, x)
|
| 430 |
+
if color_adj:
|
| 431 |
+
x_adj = []
|
| 432 |
+
for demoi in demo_image:
|
| 433 |
+
color_adj_f = color_adjust(ref_from=demoi, ref_to=ci)
|
| 434 |
+
xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
|
| 435 |
+
x_adj.append(xi_adj)
|
| 436 |
+
demo_image = x_adj
|
| 437 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 438 |
+
|
| 439 |
+
if eval_cnt is not None:
|
| 440 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 441 |
+
return {}
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/experiments/vd_default.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torchvision import transforms as tvtrans
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import time
|
| 7 |
+
import timeit
|
| 8 |
+
import copy
|
| 9 |
+
import json
|
| 10 |
+
import pickle
|
| 11 |
+
import PIL.Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from easydict import EasyDict as edict
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
|
| 17 |
+
from lib.cfg_holder import cfg_unique_holder as cfguh
|
| 18 |
+
from lib.data_factory import get_dataset, get_sampler, collate
|
| 19 |
+
from lib.model_zoo import \
|
| 20 |
+
get_model, get_optimizer, get_scheduler
|
| 21 |
+
from lib.log_service import print_log
|
| 22 |
+
|
| 23 |
+
from ..utils import train as train_base
|
| 24 |
+
from ..utils import eval as eval_base
|
| 25 |
+
from ..utils import train_stage as tsbase
|
| 26 |
+
from ..utils import eval_stage as esbase
|
| 27 |
+
from .. import sync
|
| 28 |
+
|
| 29 |
+
from .sd_default import auto_merge_imlist, latent2im, color_adjust
|
| 30 |
+
from .sd_default import eval as eval_base
|
| 31 |
+
from .sd_default import eval_stage as eval_stage_base
|
| 32 |
+
|
| 33 |
+
###############
|
| 34 |
+
# some helper #
|
| 35 |
+
###############
|
| 36 |
+
|
| 37 |
+
def atomic_save(cfg, net, opt, step, path):
|
| 38 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 39 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 40 |
+
netm = net.module
|
| 41 |
+
else:
|
| 42 |
+
netm = net
|
| 43 |
+
sd = netm.state_dict()
|
| 44 |
+
slimmed_sd = [(ki, vi) for ki, vi in sd.items()
|
| 45 |
+
if ki.find('autokl')!=0 and ki.find('optimus')!=0 and ki.find('clip')!=0]
|
| 46 |
+
|
| 47 |
+
checkpoint = {
|
| 48 |
+
"config" : cfg,
|
| 49 |
+
"state_dict" : OrderedDict(slimmed_sd),
|
| 50 |
+
"step" : step}
|
| 51 |
+
if opt is not None:
|
| 52 |
+
checkpoint['optimizer_states'] = opt.state_dict()
|
| 53 |
+
import io
|
| 54 |
+
import fsspec
|
| 55 |
+
bytesbuffer = io.BytesIO()
|
| 56 |
+
torch.save(checkpoint, bytesbuffer)
|
| 57 |
+
with fsspec.open(path, "wb") as f:
|
| 58 |
+
f.write(bytesbuffer.getvalue())
|
| 59 |
+
|
| 60 |
+
def load_state_dict(net, cfg):
|
| 61 |
+
pretrained_pth_full = cfg.get('pretrained_pth_full' , None)
|
| 62 |
+
pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
|
| 63 |
+
pretrained_pth = cfg.get('pretrained_pth' , None)
|
| 64 |
+
pretrained_ckpt = cfg.get('pretrained_ckpt' , None)
|
| 65 |
+
pretrained_pth_dm = cfg.get('pretrained_pth_dm' , None)
|
| 66 |
+
pretrained_pth_ema = cfg.get('pretrained_pth_ema' , None)
|
| 67 |
+
strict_sd = cfg.get('strict_sd', False)
|
| 68 |
+
errmsg = "Overlapped model state_dict! This is undesired behavior!"
|
| 69 |
+
|
| 70 |
+
if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
|
| 71 |
+
assert (pretrained_pth is None) and \
|
| 72 |
+
(pretrained_ckpt is None) and \
|
| 73 |
+
(pretrained_pth_dm is None) and \
|
| 74 |
+
(pretrained_pth_ema is None), errmsg
|
| 75 |
+
if pretrained_pth_full is not None:
|
| 76 |
+
target_file = pretrained_pth_full
|
| 77 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 78 |
+
assert pretrained_ckpt is None, errmsg
|
| 79 |
+
else:
|
| 80 |
+
target_file = pretrained_ckpt_full
|
| 81 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 82 |
+
print_log('Load full model from [{}] strict [{}].'.format(
|
| 83 |
+
target_file, strict_sd))
|
| 84 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 85 |
+
|
| 86 |
+
if pretrained_pth is not None or pretrained_ckpt is not None:
|
| 87 |
+
assert (pretrained_ckpt_full is None) and \
|
| 88 |
+
(pretrained_pth_full is None) and \
|
| 89 |
+
(pretrained_pth_dm is None) and \
|
| 90 |
+
(pretrained_pth_ema is None), errmsg
|
| 91 |
+
if pretrained_pth is not None:
|
| 92 |
+
target_file = pretrained_pth
|
| 93 |
+
sd = torch.load(target_file, map_location='cpu')
|
| 94 |
+
assert pretrained_ckpt is None, errmsg
|
| 95 |
+
else:
|
| 96 |
+
target_file = pretrained_ckpt
|
| 97 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
| 98 |
+
print_log('Load model from [{}] strict [{}].'.format(
|
| 99 |
+
target_file, strict_sd))
|
| 100 |
+
sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
|
| 101 |
+
if ki.find('autokl')==0 or ki.find('optimus')==0 or ki.find('clip')==0]
|
| 102 |
+
sd.update(OrderedDict(sd_extra))
|
| 103 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 104 |
+
|
| 105 |
+
if pretrained_pth_dm is not None:
|
| 106 |
+
assert (pretrained_ckpt_full is None) and \
|
| 107 |
+
(pretrained_pth_full is None) and \
|
| 108 |
+
(pretrained_pth is None) and \
|
| 109 |
+
(pretrained_ckpt is None), errmsg
|
| 110 |
+
print_log('Load diffusion model from [{}] strict [{}].'.format(
|
| 111 |
+
pretrained_pth_dm, strict_sd))
|
| 112 |
+
sd = torch.load(pretrained_pth_dm, map_location='cpu')
|
| 113 |
+
net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
|
| 114 |
+
|
| 115 |
+
if pretrained_pth_ema is not None:
|
| 116 |
+
assert (pretrained_ckpt_full is None) and \
|
| 117 |
+
(pretrained_pth_full is None) and \
|
| 118 |
+
(pretrained_pth is None) and \
|
| 119 |
+
(pretrained_ckpt is None), errmsg
|
| 120 |
+
print_log('Load unet ema model from [{}] strict [{}].'.format(
|
| 121 |
+
pretrained_pth_ema, strict_sd))
|
| 122 |
+
sd = torch.load(pretrained_pth_ema, map_location='cpu')
|
| 123 |
+
net.model_ema.load_state_dict(sd, strict=strict_sd)
|
| 124 |
+
|
| 125 |
+
###################
|
| 126 |
+
# official stages #
|
| 127 |
+
###################
|
| 128 |
+
|
| 129 |
+
class eval(eval_base):
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
class eval_stage(eval_stage_base):
|
| 133 |
+
"""
|
| 134 |
+
Evaluation of both prompt and vision
|
| 135 |
+
"""
|
| 136 |
+
def __init__(self):
|
| 137 |
+
from ..model_zoo.ddim_vd import DDIMSampler_VD
|
| 138 |
+
self.sampler = DDIMSampler_VD
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def sample(
|
| 142 |
+
self, net, sampler, context, otype, ctype, image_output_dim, text_latent_dim,
|
| 143 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 144 |
+
if ctype == 'prompt':
|
| 145 |
+
c = net.clip_encode_text(n_samples * [context])
|
| 146 |
+
uc = None
|
| 147 |
+
if scale != 1.0:
|
| 148 |
+
uc = net.clip_encode_text(n_samples * [""])
|
| 149 |
+
elif ctype == 'vision':
|
| 150 |
+
context = context[None].repeat(n_samples, 1, 1, 1)
|
| 151 |
+
c = net.clip_encode_vision(context)
|
| 152 |
+
uc = None
|
| 153 |
+
if scale != 1.0:
|
| 154 |
+
dummy = torch.zeros_like(context)
|
| 155 |
+
uc = net.clip_encode_vision(dummy)
|
| 156 |
+
|
| 157 |
+
if otype == 'image':
|
| 158 |
+
h, w = image_output_dim
|
| 159 |
+
shape = [n_samples, 4, h//8, w//8]
|
| 160 |
+
rv = sampler.sample(
|
| 161 |
+
steps=ddim_steps,
|
| 162 |
+
shape=shape,
|
| 163 |
+
conditioning=c,
|
| 164 |
+
unconditional_guidance_scale=scale,
|
| 165 |
+
unconditional_conditioning=uc,
|
| 166 |
+
xtype=otype, ctype=ctype,
|
| 167 |
+
eta=ddim_eta,
|
| 168 |
+
verbose=False,)
|
| 169 |
+
elif otype == 'text':
|
| 170 |
+
n = text_latent_dim
|
| 171 |
+
shape = [n_samples, n]
|
| 172 |
+
rv = sampler.sample(
|
| 173 |
+
steps=ddim_steps,
|
| 174 |
+
shape=shape,
|
| 175 |
+
conditioning=c,
|
| 176 |
+
unconditional_guidance_scale=scale,
|
| 177 |
+
unconditional_conditioning=uc,
|
| 178 |
+
xtype=otype, ctype=ctype,
|
| 179 |
+
eta=ddim_eta,
|
| 180 |
+
verbose=False,)
|
| 181 |
+
|
| 182 |
+
return rv
|
| 183 |
+
|
| 184 |
+
def decode_and_save(
|
| 185 |
+
self, netm, z, xtype, ctype, path, name, suffix,
|
| 186 |
+
color_adj=False, color_adj_to=None):
|
| 187 |
+
if xtype == 'image':
|
| 188 |
+
x = netm.autokl_decode(z)
|
| 189 |
+
name = 't2i_'+name if ctype == 'prompt' else 'v2i_'+name
|
| 190 |
+
if color_adj and (ctype=='vision'):
|
| 191 |
+
keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 192 |
+
simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 193 |
+
x_adj = []
|
| 194 |
+
for xi in x:
|
| 195 |
+
color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
|
| 196 |
+
xi_adj = color_adj_f((xi+1)/2, keep=keep_ratio, simple=simple)
|
| 197 |
+
x_adj.append(xi_adj)
|
| 198 |
+
x = x_adj
|
| 199 |
+
self.save_images(x, name, path, suffix=suffix)
|
| 200 |
+
elif xtype == 'text':
|
| 201 |
+
prompt_temperature = cfguh().cfg.eval.get('prompt_temperature', 1.0)
|
| 202 |
+
x = netm.optimus_decode(z, temperature=prompt_temperature)
|
| 203 |
+
name = 't2t_'+name if ctype == 'prompt' else 'v2t_'+name
|
| 204 |
+
prompt_merge_same_adj_word = cfguh().cfg.eval.get('prompt_merge_same_adj_word', False)
|
| 205 |
+
if prompt_merge_same_adj_word:
|
| 206 |
+
xnew = []
|
| 207 |
+
for xi in x:
|
| 208 |
+
xi_split = xi.split()
|
| 209 |
+
xinew = []
|
| 210 |
+
for idxi, wi in enumerate(xi_split):
|
| 211 |
+
if idxi!=0 and wi==xi_split[idxi-1]:
|
| 212 |
+
continue
|
| 213 |
+
xinew.append(wi)
|
| 214 |
+
xnew.append(' '.join(xinew))
|
| 215 |
+
x = xnew
|
| 216 |
+
self.save_text(x, name, path, suffix=suffix)
|
| 217 |
+
|
| 218 |
+
def save_images(self, x, name, path, suffix=''):
|
| 219 |
+
if isinstance(x, torch.Tensor):
|
| 220 |
+
single_input = len(x.shape) == 3
|
| 221 |
+
if single_input:
|
| 222 |
+
x = x[None]
|
| 223 |
+
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
|
| 224 |
+
x = [tvtrans.ToPILImage()(xi) for xi in x]
|
| 225 |
+
xlist = [np.array(xi) for xi in x]
|
| 226 |
+
elif isinstance(x, list):
|
| 227 |
+
xlist = x
|
| 228 |
+
canvas = auto_merge_imlist(xlist)
|
| 229 |
+
image_name = '{}{}.png'.format(name, suffix)
|
| 230 |
+
PIL.Image.fromarray(canvas).save(osp.join(path, image_name))
|
| 231 |
+
|
| 232 |
+
def save_text(self, x, name, path, suffix=''):
|
| 233 |
+
file_name = '{}{}.txt'.format(name, suffix)
|
| 234 |
+
with open(osp.join(path, file_name) ,'w') as f:
|
| 235 |
+
for xi in x:
|
| 236 |
+
f.write(xi+'\n')
|
| 237 |
+
|
| 238 |
+
def __call__(self, **paras):
|
| 239 |
+
cfg = cfguh().cfg
|
| 240 |
+
cfgv = cfg.eval
|
| 241 |
+
|
| 242 |
+
net = self.get_net(paras)
|
| 243 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 244 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 245 |
+
|
| 246 |
+
LRANK = sync.get_rank('local')
|
| 247 |
+
LWSIZE = sync.get_world_size('local')
|
| 248 |
+
|
| 249 |
+
output_path = self.get_image_path()
|
| 250 |
+
self.create_dir(output_path)
|
| 251 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 252 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 253 |
+
|
| 254 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 255 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 256 |
+
netm = net.module
|
| 257 |
+
else:
|
| 258 |
+
netm = net
|
| 259 |
+
|
| 260 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 261 |
+
sampler = self.sampler(netm)
|
| 262 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 263 |
+
|
| 264 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 265 |
+
|
| 266 |
+
replicate = cfgv.get('replicate', 1)
|
| 267 |
+
conditioning = cfgv.conditioning * replicate
|
| 268 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 269 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 270 |
+
|
| 271 |
+
for conditioningi, seedi in zip(conditioning_local, seed_increment):
|
| 272 |
+
if conditioningi == 'SKIP':
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
ci, otypei = conditioningi
|
| 276 |
+
|
| 277 |
+
if osp.isfile(ci):
|
| 278 |
+
# is vision
|
| 279 |
+
output_name = osp.splitext(osp.basename(ci))[0]
|
| 280 |
+
ci = tvtrans.ToTensor()(PIL.Image.open(ci))
|
| 281 |
+
ci = ci*2 - 1
|
| 282 |
+
ctypei = 'vision'
|
| 283 |
+
else:
|
| 284 |
+
# is prompt
|
| 285 |
+
output_name = ci.strip().replace(' ', '-')
|
| 286 |
+
ctypei = 'prompt'
|
| 287 |
+
|
| 288 |
+
if fix_seed:
|
| 289 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 290 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 291 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 292 |
+
else:
|
| 293 |
+
suffixi = suffix
|
| 294 |
+
|
| 295 |
+
if with_ema:
|
| 296 |
+
with netm.ema_scope():
|
| 297 |
+
z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample)
|
| 298 |
+
else:
|
| 299 |
+
z, _ = self.sample(netm, sampler, ci, otypei, ctypei, **cfgv.sample)
|
| 300 |
+
|
| 301 |
+
self.decode_and_save(
|
| 302 |
+
netm, z, otypei, ctypei, output_path, output_name, suffixi,
|
| 303 |
+
color_adj=color_adj, color_adj_to=conditioningi[0],)
|
| 304 |
+
|
| 305 |
+
if eval_cnt is not None:
|
| 306 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 307 |
+
return {}
|
| 308 |
+
|
| 309 |
+
################
|
| 310 |
+
# basic stages #
|
| 311 |
+
################
|
| 312 |
+
|
| 313 |
+
class eval_stage_basic(eval_stage_base):
|
| 314 |
+
@torch.no_grad()
|
| 315 |
+
def sample(self, net, sampler, visual_hint, output_dim, scale, n_samples, ddim_steps, ddim_eta):
|
| 316 |
+
h, w = output_dim
|
| 317 |
+
vh = PIL.Image.open(visual_hint)
|
| 318 |
+
c = net.clip_encode_vision(n_samples * [vh])
|
| 319 |
+
uc = None
|
| 320 |
+
if scale != 1.0:
|
| 321 |
+
dummy = torch.zeros_like(tvtrans.ToTensor()(vh))
|
| 322 |
+
uc = net.clip_encode_vision(n_samples * [dummy])
|
| 323 |
+
|
| 324 |
+
shape = [4, h//8, w//8]
|
| 325 |
+
rv = sampler.sample(
|
| 326 |
+
S=ddim_steps,
|
| 327 |
+
conditioning=c,
|
| 328 |
+
batch_size=n_samples,
|
| 329 |
+
shape=shape,
|
| 330 |
+
verbose=False,
|
| 331 |
+
unconditional_guidance_scale=scale,
|
| 332 |
+
unconditional_conditioning=uc,
|
| 333 |
+
eta=ddim_eta)
|
| 334 |
+
return rv
|
| 335 |
+
|
| 336 |
+
def __call__(self, **paras):
|
| 337 |
+
cfg = cfguh().cfg
|
| 338 |
+
cfgv = cfg.eval
|
| 339 |
+
|
| 340 |
+
net = paras['net']
|
| 341 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 342 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 343 |
+
|
| 344 |
+
LRANK = sync.get_rank('local')
|
| 345 |
+
LWSIZE = sync.get_world_size('local')
|
| 346 |
+
|
| 347 |
+
image_path = self.get_image_path()
|
| 348 |
+
self.create_dir(image_path)
|
| 349 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 350 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 351 |
+
|
| 352 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 353 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 354 |
+
netm = net.module
|
| 355 |
+
else:
|
| 356 |
+
netm = net
|
| 357 |
+
|
| 358 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 359 |
+
sampler = self.sampler(netm)
|
| 360 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 361 |
+
|
| 362 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 363 |
+
color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 364 |
+
color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 365 |
+
|
| 366 |
+
replicate = cfgv.get('replicate', 1)
|
| 367 |
+
conditioning = cfgv.conditioning * replicate
|
| 368 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 369 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 370 |
+
|
| 371 |
+
for ci, seedi in zip(conditioning_local, seed_increment):
|
| 372 |
+
if ci == 'SKIP':
|
| 373 |
+
continue
|
| 374 |
+
draw_filename = osp.splitext(osp.basename(ci))[0]
|
| 375 |
+
if fix_seed:
|
| 376 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 377 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 378 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 379 |
+
else:
|
| 380 |
+
suffixi = suffix
|
| 381 |
+
|
| 382 |
+
if with_ema:
|
| 383 |
+
with netm.ema_scope():
|
| 384 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 385 |
+
else:
|
| 386 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 387 |
+
|
| 388 |
+
demo_image = latent2im(netm, x)
|
| 389 |
+
if color_adj:
|
| 390 |
+
x_adj = []
|
| 391 |
+
for demoi in demo_image:
|
| 392 |
+
color_adj_f = color_adjust(ref_from=demoi, ref_to=ci)
|
| 393 |
+
xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
|
| 394 |
+
x_adj.append(xi_adj)
|
| 395 |
+
demo_image = x_adj
|
| 396 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 397 |
+
|
| 398 |
+
if eval_cnt is not None:
|
| 399 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 400 |
+
return {}
|
| 401 |
+
|
| 402 |
+
#######################
|
| 403 |
+
# dual context stages #
|
| 404 |
+
#######################
|
| 405 |
+
|
| 406 |
+
class eval_stage_dc(eval_stage_base):
|
| 407 |
+
def __init__(self):
|
| 408 |
+
from ..model_zoo.ddim_dualcontext import DDIMSampler_DualContext
|
| 409 |
+
self.sampler = DDIMSampler_DualContext
|
| 410 |
+
|
| 411 |
+
@torch.no_grad()
|
| 412 |
+
def sample(
|
| 413 |
+
self, net, sampler, conditioning, output_dim,
|
| 414 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 415 |
+
ctype, cvalue =conditioning
|
| 416 |
+
if ctype == 'prompt':
|
| 417 |
+
return self.sample_text(
|
| 418 |
+
net, sampler, cvalue, output_dim,
|
| 419 |
+
scale, n_samples, ddim_steps, ddim_eta)
|
| 420 |
+
elif ctype == 'vision':
|
| 421 |
+
return self.sample_vision(
|
| 422 |
+
net, sampler, cvalue, output_dim,
|
| 423 |
+
scale, n_samples, ddim_steps, ddim_eta)
|
| 424 |
+
else:
|
| 425 |
+
raise ValueError
|
| 426 |
+
|
| 427 |
+
@torch.no_grad()
|
| 428 |
+
def sample_text(
|
| 429 |
+
self, net, sampler, prompt, output_dim,
|
| 430 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 431 |
+
h, w = output_dim
|
| 432 |
+
uc = None
|
| 433 |
+
if scale != 1.0:
|
| 434 |
+
uc = net.clip_encode_text(n_samples * [""])
|
| 435 |
+
c = net.clip_encode_text(n_samples * [prompt])
|
| 436 |
+
shape = [n_samples, 4, h//8, w//8]
|
| 437 |
+
rv = sampler.sample_text(
|
| 438 |
+
steps=ddim_steps,
|
| 439 |
+
shape=shape,
|
| 440 |
+
conditioning=c,
|
| 441 |
+
unconditional_guidance_scale=scale,
|
| 442 |
+
unconditional_conditioning=uc,
|
| 443 |
+
eta=ddim_eta,
|
| 444 |
+
verbose=False,)
|
| 445 |
+
return rv
|
| 446 |
+
|
| 447 |
+
@torch.no_grad()
|
| 448 |
+
def sample_vision(
|
| 449 |
+
self, net, sampler, visual_hint, output_dim,
|
| 450 |
+
scale, n_samples, ddim_steps, ddim_eta):
|
| 451 |
+
h, w = output_dim
|
| 452 |
+
if len(visual_hint.shape) == 3:
|
| 453 |
+
visual_hint=visual_hint[None].repeat(n_samples, 1, 1, 1)
|
| 454 |
+
else:
|
| 455 |
+
raise ValueError
|
| 456 |
+
|
| 457 |
+
c = net.clip_encode_vision(visual_hint)
|
| 458 |
+
uc = None
|
| 459 |
+
if scale != 1.0:
|
| 460 |
+
visual_hint_blank = torch.zeros_like(visual_hint)
|
| 461 |
+
uc = net.clip_encode_vision(visual_hint_blank)
|
| 462 |
+
|
| 463 |
+
shape = [n_samples, 4, h//8, w//8]
|
| 464 |
+
rv = sampler.sample_vision(
|
| 465 |
+
steps=ddim_steps,
|
| 466 |
+
shape=shape,
|
| 467 |
+
conditioning=c,
|
| 468 |
+
unconditional_guidance_scale=scale,
|
| 469 |
+
unconditional_conditioning=uc,
|
| 470 |
+
eta=ddim_eta,
|
| 471 |
+
verbose=False,)
|
| 472 |
+
return rv
|
| 473 |
+
|
| 474 |
+
def __call__(self, **paras):
|
| 475 |
+
cfg = cfguh().cfg
|
| 476 |
+
cfgv = cfg.eval
|
| 477 |
+
|
| 478 |
+
net = self.get_net(paras)
|
| 479 |
+
eval_cnt = paras.get('eval_cnt', None)
|
| 480 |
+
fix_seed = cfgv.get('fix_seed', False)
|
| 481 |
+
|
| 482 |
+
LRANK = sync.get_rank('local')
|
| 483 |
+
LWSIZE = sync.get_world_size('local')
|
| 484 |
+
|
| 485 |
+
image_path = self.get_image_path()
|
| 486 |
+
self.create_dir(image_path)
|
| 487 |
+
suffix='' if eval_cnt is None else '_'+str(eval_cnt)
|
| 488 |
+
|
| 489 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 490 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 491 |
+
netm = net.module
|
| 492 |
+
else:
|
| 493 |
+
netm = net
|
| 494 |
+
|
| 495 |
+
with_ema = getattr(netm, 'model_ema', None) is not None
|
| 496 |
+
sampler = self.sampler(netm)
|
| 497 |
+
setattr(netm, 'device', LRANK) # Trick
|
| 498 |
+
|
| 499 |
+
color_adj = cfguh().cfg.eval.get('color_adj', False)
|
| 500 |
+
color_adj_keep_ratio = cfguh().cfg.eval.get('color_adj_keep_ratio', 0.5)
|
| 501 |
+
color_adj_simple = cfguh().cfg.eval.get('color_adj_simple', True)
|
| 502 |
+
|
| 503 |
+
replicate = cfgv.get('replicate', 1)
|
| 504 |
+
conditioning = cfgv.conditioning * replicate
|
| 505 |
+
conditioning_local = conditioning[LRANK : len(conditioning) : LWSIZE]
|
| 506 |
+
seed_increment = [i for i in range(len(conditioning))][LRANK : len(conditioning) : LWSIZE]
|
| 507 |
+
|
| 508 |
+
for ci, seedi in zip(conditioning_local, seed_increment):
|
| 509 |
+
if ci == 'SKIP':
|
| 510 |
+
continue
|
| 511 |
+
|
| 512 |
+
if osp.isfile(ci):
|
| 513 |
+
# is vision
|
| 514 |
+
draw_filename = 'v2i_' + osp.splitext(osp.basename(ci))[0]
|
| 515 |
+
ci = tvtrans.ToTensor()(PIL.Image.open(ci))
|
| 516 |
+
ci = ci*2 - 1
|
| 517 |
+
ci = ('vision', ci)
|
| 518 |
+
else:
|
| 519 |
+
# is prompt
|
| 520 |
+
draw_filename = 't2i_' + ci.strip().replace(' ', '-')
|
| 521 |
+
ci = ('prompt', ci)
|
| 522 |
+
|
| 523 |
+
if fix_seed:
|
| 524 |
+
np.random.seed(cfg.env.rnd_seed + seedi)
|
| 525 |
+
torch.manual_seed(cfg.env.rnd_seed + seedi + 100)
|
| 526 |
+
suffixi = suffix + "_seed{}".format(cfg.env.rnd_seed + seedi + 100)
|
| 527 |
+
else:
|
| 528 |
+
suffixi = suffix
|
| 529 |
+
|
| 530 |
+
if with_ema:
|
| 531 |
+
with netm.ema_scope():
|
| 532 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 533 |
+
else:
|
| 534 |
+
x, _ = self.sample(netm, sampler, ci, **cfgv.sample)
|
| 535 |
+
|
| 536 |
+
demo_image = latent2im(netm, x)
|
| 537 |
+
if color_adj and ci[0] == 'vision':
|
| 538 |
+
x_adj = []
|
| 539 |
+
for demoi in demo_image:
|
| 540 |
+
color_adj_f = color_adjust(ref_from=demoi, ref_to=ci[1])
|
| 541 |
+
xi_adj = color_adj_f(demoi, keep=color_adj_keep_ratio, simple=color_adj_simple)
|
| 542 |
+
x_adj.append(xi_adj)
|
| 543 |
+
demo_image = x_adj
|
| 544 |
+
self.save_images(demo_image, draw_filename, image_path, suffix=suffixi)
|
| 545 |
+
|
| 546 |
+
if eval_cnt is not None:
|
| 547 |
+
print_log('Demo printed for {}'.format(eval_cnt))
|
| 548 |
+
return {}
|
| 549 |
+
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .common.get_model import get_model
|
| 2 |
+
from .common.get_optimizer import get_optimizer
|
| 3 |
+
from .common.get_scheduler import get_scheduler
|
| 4 |
+
from .common.utils import get_unit
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/attention.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inspect import isfunction
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn, einsum
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
|
| 8 |
+
from .diffusion_utils import checkpoint
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def exists(val):
|
| 12 |
+
return val is not None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def uniq(arr):
|
| 16 |
+
return{el: True for el in arr}.keys()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def default(val, d):
|
| 20 |
+
if exists(val):
|
| 21 |
+
return val
|
| 22 |
+
return d() if isfunction(d) else d
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def max_neg_value(t):
|
| 26 |
+
return -torch.finfo(t.dtype).max
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def init_(tensor):
|
| 30 |
+
dim = tensor.shape[-1]
|
| 31 |
+
std = 1 / math.sqrt(dim)
|
| 32 |
+
tensor.uniform_(-std, std)
|
| 33 |
+
return tensor
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# feedforward
|
| 37 |
+
class GEGLU(nn.Module):
|
| 38 |
+
def __init__(self, dim_in, dim_out):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 44 |
+
return x * F.gelu(gate)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class FeedForward(nn.Module):
|
| 48 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
| 49 |
+
super().__init__()
|
| 50 |
+
inner_dim = int(dim * mult)
|
| 51 |
+
dim_out = default(dim_out, dim)
|
| 52 |
+
project_in = nn.Sequential(
|
| 53 |
+
nn.Linear(dim, inner_dim),
|
| 54 |
+
nn.GELU()
|
| 55 |
+
) if not glu else GEGLU(dim, inner_dim)
|
| 56 |
+
|
| 57 |
+
self.net = nn.Sequential(
|
| 58 |
+
project_in,
|
| 59 |
+
nn.Dropout(dropout),
|
| 60 |
+
nn.Linear(inner_dim, dim_out)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return self.net(x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def zero_module(module):
|
| 68 |
+
"""
|
| 69 |
+
Zero out the parameters of a module and return it.
|
| 70 |
+
"""
|
| 71 |
+
for p in module.parameters():
|
| 72 |
+
p.detach().zero_()
|
| 73 |
+
return module
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def Normalize(in_channels):
|
| 77 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class LinearAttention(nn.Module):
|
| 81 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.heads = heads
|
| 84 |
+
hidden_dim = dim_head * heads
|
| 85 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
| 86 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
b, c, h, w = x.shape
|
| 90 |
+
qkv = self.to_qkv(x)
|
| 91 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
| 92 |
+
k = k.softmax(dim=-1)
|
| 93 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
| 94 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
| 95 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
| 96 |
+
return self.to_out(out)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class SpatialSelfAttention(nn.Module):
|
| 100 |
+
def __init__(self, in_channels):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.in_channels = in_channels
|
| 103 |
+
|
| 104 |
+
self.norm = Normalize(in_channels)
|
| 105 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 106 |
+
in_channels,
|
| 107 |
+
kernel_size=1,
|
| 108 |
+
stride=1,
|
| 109 |
+
padding=0)
|
| 110 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 111 |
+
in_channels,
|
| 112 |
+
kernel_size=1,
|
| 113 |
+
stride=1,
|
| 114 |
+
padding=0)
|
| 115 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 116 |
+
in_channels,
|
| 117 |
+
kernel_size=1,
|
| 118 |
+
stride=1,
|
| 119 |
+
padding=0)
|
| 120 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 121 |
+
in_channels,
|
| 122 |
+
kernel_size=1,
|
| 123 |
+
stride=1,
|
| 124 |
+
padding=0)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
h_ = x
|
| 128 |
+
h_ = self.norm(h_)
|
| 129 |
+
q = self.q(h_)
|
| 130 |
+
k = self.k(h_)
|
| 131 |
+
v = self.v(h_)
|
| 132 |
+
|
| 133 |
+
# compute attention
|
| 134 |
+
b,c,h,w = q.shape
|
| 135 |
+
q = rearrange(q, 'b c h w -> b (h w) c')
|
| 136 |
+
k = rearrange(k, 'b c h w -> b c (h w)')
|
| 137 |
+
w_ = torch.einsum('bij,bjk->bik', q, k)
|
| 138 |
+
|
| 139 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 140 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 141 |
+
|
| 142 |
+
# attend to values
|
| 143 |
+
v = rearrange(v, 'b c h w -> b c (h w)')
|
| 144 |
+
w_ = rearrange(w_, 'b i j -> b j i')
|
| 145 |
+
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
| 146 |
+
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
| 147 |
+
h_ = self.proj_out(h_)
|
| 148 |
+
|
| 149 |
+
return x+h_
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class CrossAttention(nn.Module):
|
| 153 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
| 154 |
+
super().__init__()
|
| 155 |
+
inner_dim = dim_head * heads
|
| 156 |
+
context_dim = default(context_dim, query_dim)
|
| 157 |
+
|
| 158 |
+
self.scale = dim_head ** -0.5
|
| 159 |
+
self.heads = heads
|
| 160 |
+
|
| 161 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 162 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
| 163 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
| 164 |
+
|
| 165 |
+
self.to_out = nn.Sequential(
|
| 166 |
+
nn.Linear(inner_dim, query_dim),
|
| 167 |
+
nn.Dropout(dropout)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def forward(self, x, context=None, mask=None):
|
| 171 |
+
h = self.heads
|
| 172 |
+
|
| 173 |
+
q = self.to_q(x)
|
| 174 |
+
context = default(context, x)
|
| 175 |
+
k = self.to_k(context)
|
| 176 |
+
v = self.to_v(context)
|
| 177 |
+
|
| 178 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 179 |
+
|
| 180 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 181 |
+
|
| 182 |
+
if exists(mask):
|
| 183 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
| 184 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 185 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 186 |
+
sim.masked_fill_(~mask, max_neg_value)
|
| 187 |
+
|
| 188 |
+
# attention, what we cannot get enough of
|
| 189 |
+
attn = sim.softmax(dim=-1)
|
| 190 |
+
|
| 191 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
| 192 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
| 193 |
+
return self.to_out(out)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class BasicTransformerBlock(nn.Module):
|
| 197 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
| 198 |
+
disable_self_attn=False):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.disable_self_attn = disable_self_attn
|
| 201 |
+
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
| 202 |
+
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
| 203 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 204 |
+
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
| 205 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
| 206 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 207 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 208 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 209 |
+
self.checkpoint = checkpoint
|
| 210 |
+
|
| 211 |
+
def forward(self, x, context=None):
|
| 212 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
| 213 |
+
|
| 214 |
+
def _forward(self, x, context=None):
|
| 215 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
| 216 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
| 217 |
+
x = self.ff(self.norm3(x)) + x
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class SpatialTransformer(nn.Module):
|
| 222 |
+
"""
|
| 223 |
+
Transformer block for image-like data.
|
| 224 |
+
First, project the input (aka embedding)
|
| 225 |
+
and reshape to b, t, d.
|
| 226 |
+
Then apply standard transformer action.
|
| 227 |
+
Finally, reshape to image
|
| 228 |
+
"""
|
| 229 |
+
def __init__(self, in_channels, n_heads, d_head,
|
| 230 |
+
depth=1, dropout=0., context_dim=None,
|
| 231 |
+
disable_self_attn=False):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.in_channels = in_channels
|
| 234 |
+
inner_dim = n_heads * d_head
|
| 235 |
+
self.norm = Normalize(in_channels)
|
| 236 |
+
|
| 237 |
+
self.proj_in = nn.Conv2d(in_channels,
|
| 238 |
+
inner_dim,
|
| 239 |
+
kernel_size=1,
|
| 240 |
+
stride=1,
|
| 241 |
+
padding=0)
|
| 242 |
+
|
| 243 |
+
self.transformer_blocks = nn.ModuleList(
|
| 244 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
|
| 245 |
+
disable_self_attn=disable_self_attn)
|
| 246 |
+
for d in range(depth)]
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
| 250 |
+
in_channels,
|
| 251 |
+
kernel_size=1,
|
| 252 |
+
stride=1,
|
| 253 |
+
padding=0))
|
| 254 |
+
|
| 255 |
+
def forward(self, x, context=None):
|
| 256 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
| 257 |
+
b, c, h, w = x.shape
|
| 258 |
+
x_in = x
|
| 259 |
+
x = self.norm(x)
|
| 260 |
+
x = self.proj_in(x)
|
| 261 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
| 262 |
+
for block in self.transformer_blocks:
|
| 263 |
+
x = block(x, context=context)
|
| 264 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 265 |
+
x = self.proj_out(x)
|
| 266 |
+
return x + x_in
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
##########################
|
| 270 |
+
# transformer no context #
|
| 271 |
+
##########################
|
| 272 |
+
|
| 273 |
+
class BasicTransformerBlockNoContext(nn.Module):
|
| 274 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
|
| 277 |
+
dropout=dropout, context_dim=None)
|
| 278 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 279 |
+
self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
|
| 280 |
+
dropout=dropout, context_dim=None)
|
| 281 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 282 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 283 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 284 |
+
self.checkpoint = checkpoint
|
| 285 |
+
|
| 286 |
+
def forward(self, x):
|
| 287 |
+
return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
|
| 288 |
+
|
| 289 |
+
def _forward(self, x):
|
| 290 |
+
x = self.attn1(self.norm1(x)) + x
|
| 291 |
+
x = self.attn2(self.norm2(x)) + x
|
| 292 |
+
x = self.ff(self.norm3(x)) + x
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
class SpatialTransformerNoContext(nn.Module):
|
| 296 |
+
"""
|
| 297 |
+
Transformer block for image-like data.
|
| 298 |
+
First, project the input (aka embedding)
|
| 299 |
+
and reshape to b, t, d.
|
| 300 |
+
Then apply standard transformer action.
|
| 301 |
+
Finally, reshape to image
|
| 302 |
+
"""
|
| 303 |
+
def __init__(self, in_channels, n_heads, d_head,
|
| 304 |
+
depth=1, dropout=0.,):
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.in_channels = in_channels
|
| 307 |
+
inner_dim = n_heads * d_head
|
| 308 |
+
self.norm = Normalize(in_channels)
|
| 309 |
+
|
| 310 |
+
self.proj_in = nn.Conv2d(in_channels,
|
| 311 |
+
inner_dim,
|
| 312 |
+
kernel_size=1,
|
| 313 |
+
stride=1,
|
| 314 |
+
padding=0)
|
| 315 |
+
|
| 316 |
+
self.transformer_blocks = nn.ModuleList(
|
| 317 |
+
[BasicTransformerBlockNoContext(inner_dim, n_heads, d_head, dropout=dropout)
|
| 318 |
+
for d in range(depth)]
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
| 322 |
+
in_channels,
|
| 323 |
+
kernel_size=1,
|
| 324 |
+
stride=1,
|
| 325 |
+
padding=0))
|
| 326 |
+
|
| 327 |
+
def forward(self, x):
|
| 328 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
| 329 |
+
b, c, h, w = x.shape
|
| 330 |
+
x_in = x
|
| 331 |
+
x = self.norm(x)
|
| 332 |
+
x = self.proj_in(x)
|
| 333 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
| 334 |
+
for block in self.transformer_blocks:
|
| 335 |
+
x = block(x)
|
| 336 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 337 |
+
x = self.proj_out(x)
|
| 338 |
+
return x + x_in
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
#######################################
|
| 342 |
+
# Spatial Transformer with Two Branch #
|
| 343 |
+
#######################################
|
| 344 |
+
|
| 345 |
+
class DualSpatialTransformer(nn.Module):
|
| 346 |
+
def __init__(self, in_channels, n_heads, d_head,
|
| 347 |
+
depth=1, dropout=0., context_dim=None,
|
| 348 |
+
disable_self_attn=False):
|
| 349 |
+
super().__init__()
|
| 350 |
+
self.in_channels = in_channels
|
| 351 |
+
inner_dim = n_heads * d_head
|
| 352 |
+
|
| 353 |
+
# First crossattn
|
| 354 |
+
self.norm_0 = Normalize(in_channels)
|
| 355 |
+
self.proj_in_0 = nn.Conv2d(
|
| 356 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 357 |
+
self.transformer_blocks_0 = nn.ModuleList(
|
| 358 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
|
| 359 |
+
disable_self_attn=disable_self_attn)
|
| 360 |
+
for d in range(depth)]
|
| 361 |
+
)
|
| 362 |
+
self.proj_out_0 = zero_module(nn.Conv2d(
|
| 363 |
+
inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
| 364 |
+
|
| 365 |
+
# Second crossattn
|
| 366 |
+
self.norm_1 = Normalize(in_channels)
|
| 367 |
+
self.proj_in_1 = nn.Conv2d(
|
| 368 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 369 |
+
self.transformer_blocks_1 = nn.ModuleList(
|
| 370 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
|
| 371 |
+
disable_self_attn=disable_self_attn)
|
| 372 |
+
for d in range(depth)]
|
| 373 |
+
)
|
| 374 |
+
self.proj_out_1 = zero_module(nn.Conv2d(
|
| 375 |
+
inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
| 376 |
+
|
| 377 |
+
def forward(self, x, context=None, which=None):
|
| 378 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
| 379 |
+
b, c, h, w = x.shape
|
| 380 |
+
x_in = x
|
| 381 |
+
if which==0:
|
| 382 |
+
norm, proj_in, blocks, proj_out = \
|
| 383 |
+
self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
|
| 384 |
+
elif which==1:
|
| 385 |
+
norm, proj_in, blocks, proj_out = \
|
| 386 |
+
self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
|
| 387 |
+
else:
|
| 388 |
+
# assert False, 'DualSpatialTransformer forward with a invalid which branch!'
|
| 389 |
+
# import numpy.random as npr
|
| 390 |
+
# rwhich = 0 if npr.rand() < which else 1
|
| 391 |
+
# context = context[rwhich]
|
| 392 |
+
# if rwhich==0:
|
| 393 |
+
# norm, proj_in, blocks, proj_out = \
|
| 394 |
+
# self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
|
| 395 |
+
# elif rwhich==1:
|
| 396 |
+
# norm, proj_in, blocks, proj_out = \
|
| 397 |
+
# self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
|
| 398 |
+
|
| 399 |
+
# import numpy.random as npr
|
| 400 |
+
# rwhich = 0 if npr.rand() < 0.33 else 1
|
| 401 |
+
# if rwhich==0:
|
| 402 |
+
# context = context[rwhich]
|
| 403 |
+
# norm, proj_in, blocks, proj_out = \
|
| 404 |
+
# self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
|
| 405 |
+
# else:
|
| 406 |
+
|
| 407 |
+
norm, proj_in, blocks, proj_out = \
|
| 408 |
+
self.norm_0, self.proj_in_0, self.transformer_blocks_0, self.proj_out_0
|
| 409 |
+
x0 = norm(x)
|
| 410 |
+
x0 = proj_in(x0)
|
| 411 |
+
x0 = rearrange(x0, 'b c h w -> b (h w) c').contiguous()
|
| 412 |
+
for block in blocks:
|
| 413 |
+
x0 = block(x0, context=context[0])
|
| 414 |
+
x0 = rearrange(x0, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 415 |
+
x0 = proj_out(x0)
|
| 416 |
+
|
| 417 |
+
norm, proj_in, blocks, proj_out = \
|
| 418 |
+
self.norm_1, self.proj_in_1, self.transformer_blocks_1, self.proj_out_1
|
| 419 |
+
x1 = norm(x)
|
| 420 |
+
x1 = proj_in(x1)
|
| 421 |
+
x1 = rearrange(x1, 'b c h w -> b (h w) c').contiguous()
|
| 422 |
+
for block in blocks:
|
| 423 |
+
x1 = block(x1, context=context[1])
|
| 424 |
+
x1 = rearrange(x1, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 425 |
+
x1 = proj_out(x1)
|
| 426 |
+
return x0*which + x1*(1-which) + x_in
|
| 427 |
+
|
| 428 |
+
x = norm(x)
|
| 429 |
+
x = proj_in(x)
|
| 430 |
+
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
| 431 |
+
for block in blocks:
|
| 432 |
+
x = block(x, context=context)
|
| 433 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 434 |
+
x = proj_out(x)
|
| 435 |
+
return x + x_in
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/autoencoder.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
from lib.model_zoo.common.get_model import get_model, register
|
| 6 |
+
|
| 7 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
| 8 |
+
|
| 9 |
+
from .diffusion_modules import Encoder, Decoder
|
| 10 |
+
from .distributions import DiagonalGaussianDistribution
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VQModel(nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
ddconfig,
|
| 16 |
+
lossconfig,
|
| 17 |
+
n_embed,
|
| 18 |
+
embed_dim,
|
| 19 |
+
ckpt_path=None,
|
| 20 |
+
ignore_keys=[],
|
| 21 |
+
image_key="image",
|
| 22 |
+
colorize_nlabels=None,
|
| 23 |
+
monitor=None,
|
| 24 |
+
batch_resize_range=None,
|
| 25 |
+
scheduler_config=None,
|
| 26 |
+
lr_g_factor=1.0,
|
| 27 |
+
remap=None,
|
| 28 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
| 29 |
+
use_ema=False
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.embed_dim = embed_dim
|
| 33 |
+
self.n_embed = n_embed
|
| 34 |
+
self.image_key = image_key
|
| 35 |
+
self.encoder = Encoder(**ddconfig)
|
| 36 |
+
self.decoder = Decoder(**ddconfig)
|
| 37 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 38 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
| 39 |
+
remap=remap,
|
| 40 |
+
sane_index_shape=sane_index_shape)
|
| 41 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
| 42 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 43 |
+
if colorize_nlabels is not None:
|
| 44 |
+
assert type(colorize_nlabels)==int
|
| 45 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 46 |
+
if monitor is not None:
|
| 47 |
+
self.monitor = monitor
|
| 48 |
+
self.batch_resize_range = batch_resize_range
|
| 49 |
+
if self.batch_resize_range is not None:
|
| 50 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
| 51 |
+
|
| 52 |
+
self.use_ema = use_ema
|
| 53 |
+
if self.use_ema:
|
| 54 |
+
self.model_ema = LitEma(self)
|
| 55 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 56 |
+
|
| 57 |
+
if ckpt_path is not None:
|
| 58 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 59 |
+
self.scheduler_config = scheduler_config
|
| 60 |
+
self.lr_g_factor = lr_g_factor
|
| 61 |
+
|
| 62 |
+
@contextmanager
|
| 63 |
+
def ema_scope(self, context=None):
|
| 64 |
+
if self.use_ema:
|
| 65 |
+
self.model_ema.store(self.parameters())
|
| 66 |
+
self.model_ema.copy_to(self)
|
| 67 |
+
if context is not None:
|
| 68 |
+
print(f"{context}: Switched to EMA weights")
|
| 69 |
+
try:
|
| 70 |
+
yield None
|
| 71 |
+
finally:
|
| 72 |
+
if self.use_ema:
|
| 73 |
+
self.model_ema.restore(self.parameters())
|
| 74 |
+
if context is not None:
|
| 75 |
+
print(f"{context}: Restored training weights")
|
| 76 |
+
|
| 77 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 78 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 79 |
+
keys = list(sd.keys())
|
| 80 |
+
for k in keys:
|
| 81 |
+
for ik in ignore_keys:
|
| 82 |
+
if k.startswith(ik):
|
| 83 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 84 |
+
del sd[k]
|
| 85 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
| 86 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 87 |
+
if len(missing) > 0:
|
| 88 |
+
print(f"Missing Keys: {missing}")
|
| 89 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 90 |
+
|
| 91 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 92 |
+
if self.use_ema:
|
| 93 |
+
self.model_ema(self)
|
| 94 |
+
|
| 95 |
+
def encode(self, x):
|
| 96 |
+
h = self.encoder(x)
|
| 97 |
+
h = self.quant_conv(h)
|
| 98 |
+
quant, emb_loss, info = self.quantize(h)
|
| 99 |
+
return quant, emb_loss, info
|
| 100 |
+
|
| 101 |
+
def encode_to_prequant(self, x):
|
| 102 |
+
h = self.encoder(x)
|
| 103 |
+
h = self.quant_conv(h)
|
| 104 |
+
return h
|
| 105 |
+
|
| 106 |
+
def decode(self, quant):
|
| 107 |
+
quant = self.post_quant_conv(quant)
|
| 108 |
+
dec = self.decoder(quant)
|
| 109 |
+
return dec
|
| 110 |
+
|
| 111 |
+
def decode_code(self, code_b):
|
| 112 |
+
quant_b = self.quantize.embed_code(code_b)
|
| 113 |
+
dec = self.decode(quant_b)
|
| 114 |
+
return dec
|
| 115 |
+
|
| 116 |
+
def forward(self, input, return_pred_indices=False):
|
| 117 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
| 118 |
+
dec = self.decode(quant)
|
| 119 |
+
if return_pred_indices:
|
| 120 |
+
return dec, diff, ind
|
| 121 |
+
return dec, diff
|
| 122 |
+
|
| 123 |
+
def get_input(self, batch, k):
|
| 124 |
+
x = batch[k]
|
| 125 |
+
if len(x.shape) == 3:
|
| 126 |
+
x = x[..., None]
|
| 127 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 128 |
+
if self.batch_resize_range is not None:
|
| 129 |
+
lower_size = self.batch_resize_range[0]
|
| 130 |
+
upper_size = self.batch_resize_range[1]
|
| 131 |
+
if self.global_step <= 4:
|
| 132 |
+
# do the first few batches with max size to avoid later oom
|
| 133 |
+
new_resize = upper_size
|
| 134 |
+
else:
|
| 135 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
| 136 |
+
if new_resize != x.shape[2]:
|
| 137 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
| 138 |
+
x = x.detach()
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 142 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
| 143 |
+
# try not to fool the heuristics
|
| 144 |
+
x = self.get_input(batch, self.image_key)
|
| 145 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 146 |
+
|
| 147 |
+
if optimizer_idx == 0:
|
| 148 |
+
# autoencode
|
| 149 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 150 |
+
last_layer=self.get_last_layer(), split="train",
|
| 151 |
+
predicted_indices=ind)
|
| 152 |
+
|
| 153 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 154 |
+
return aeloss
|
| 155 |
+
|
| 156 |
+
if optimizer_idx == 1:
|
| 157 |
+
# discriminator
|
| 158 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 159 |
+
last_layer=self.get_last_layer(), split="train")
|
| 160 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 161 |
+
return discloss
|
| 162 |
+
|
| 163 |
+
def validation_step(self, batch, batch_idx):
|
| 164 |
+
log_dict = self._validation_step(batch, batch_idx)
|
| 165 |
+
with self.ema_scope():
|
| 166 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
| 167 |
+
return log_dict
|
| 168 |
+
|
| 169 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
| 170 |
+
x = self.get_input(batch, self.image_key)
|
| 171 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 172 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
| 173 |
+
self.global_step,
|
| 174 |
+
last_layer=self.get_last_layer(),
|
| 175 |
+
split="val"+suffix,
|
| 176 |
+
predicted_indices=ind
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
| 180 |
+
self.global_step,
|
| 181 |
+
last_layer=self.get_last_layer(),
|
| 182 |
+
split="val"+suffix,
|
| 183 |
+
predicted_indices=ind
|
| 184 |
+
)
|
| 185 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
| 186 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
| 187 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 188 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
| 189 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 190 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
| 191 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
| 192 |
+
self.log_dict(log_dict_ae)
|
| 193 |
+
self.log_dict(log_dict_disc)
|
| 194 |
+
return self.log_dict
|
| 195 |
+
|
| 196 |
+
def configure_optimizers(self):
|
| 197 |
+
lr_d = self.learning_rate
|
| 198 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
| 199 |
+
print("lr_d", lr_d)
|
| 200 |
+
print("lr_g", lr_g)
|
| 201 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 202 |
+
list(self.decoder.parameters())+
|
| 203 |
+
list(self.quantize.parameters())+
|
| 204 |
+
list(self.quant_conv.parameters())+
|
| 205 |
+
list(self.post_quant_conv.parameters()),
|
| 206 |
+
lr=lr_g, betas=(0.5, 0.9))
|
| 207 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 208 |
+
lr=lr_d, betas=(0.5, 0.9))
|
| 209 |
+
|
| 210 |
+
if self.scheduler_config is not None:
|
| 211 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 212 |
+
|
| 213 |
+
print("Setting up LambdaLR scheduler...")
|
| 214 |
+
scheduler = [
|
| 215 |
+
{
|
| 216 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
| 217 |
+
'interval': 'step',
|
| 218 |
+
'frequency': 1
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
| 222 |
+
'interval': 'step',
|
| 223 |
+
'frequency': 1
|
| 224 |
+
},
|
| 225 |
+
]
|
| 226 |
+
return [opt_ae, opt_disc], scheduler
|
| 227 |
+
return [opt_ae, opt_disc], []
|
| 228 |
+
|
| 229 |
+
def get_last_layer(self):
|
| 230 |
+
return self.decoder.conv_out.weight
|
| 231 |
+
|
| 232 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
| 233 |
+
log = dict()
|
| 234 |
+
x = self.get_input(batch, self.image_key)
|
| 235 |
+
x = x.to(self.device)
|
| 236 |
+
if only_inputs:
|
| 237 |
+
log["inputs"] = x
|
| 238 |
+
return log
|
| 239 |
+
xrec, _ = self(x)
|
| 240 |
+
if x.shape[1] > 3:
|
| 241 |
+
# colorize with random projection
|
| 242 |
+
assert xrec.shape[1] > 3
|
| 243 |
+
x = self.to_rgb(x)
|
| 244 |
+
xrec = self.to_rgb(xrec)
|
| 245 |
+
log["inputs"] = x
|
| 246 |
+
log["reconstructions"] = xrec
|
| 247 |
+
if plot_ema:
|
| 248 |
+
with self.ema_scope():
|
| 249 |
+
xrec_ema, _ = self(x)
|
| 250 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
| 251 |
+
log["reconstructions_ema"] = xrec_ema
|
| 252 |
+
return log
|
| 253 |
+
|
| 254 |
+
def to_rgb(self, x):
|
| 255 |
+
assert self.image_key == "segmentation"
|
| 256 |
+
if not hasattr(self, "colorize"):
|
| 257 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 258 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 259 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 260 |
+
return x
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class VQModelInterface(VQModel):
|
| 264 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
| 265 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
| 266 |
+
self.embed_dim = embed_dim
|
| 267 |
+
|
| 268 |
+
def encode(self, x):
|
| 269 |
+
h = self.encoder(x)
|
| 270 |
+
h = self.quant_conv(h)
|
| 271 |
+
return h
|
| 272 |
+
|
| 273 |
+
def decode(self, h, force_not_quantize=False):
|
| 274 |
+
# also go through quantization layer
|
| 275 |
+
if not force_not_quantize:
|
| 276 |
+
quant, emb_loss, info = self.quantize(h)
|
| 277 |
+
else:
|
| 278 |
+
quant = h
|
| 279 |
+
quant = self.post_quant_conv(quant)
|
| 280 |
+
dec = self.decoder(quant)
|
| 281 |
+
return dec
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@register('autoencoderkl')
|
| 285 |
+
class AutoencoderKL(nn.Module):
|
| 286 |
+
def __init__(self,
|
| 287 |
+
ddconfig,
|
| 288 |
+
lossconfig,
|
| 289 |
+
embed_dim,
|
| 290 |
+
ckpt_path=None,
|
| 291 |
+
ignore_keys=[],
|
| 292 |
+
image_key="image",
|
| 293 |
+
colorize_nlabels=None,
|
| 294 |
+
monitor=None,):
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.image_key = image_key
|
| 297 |
+
self.encoder = Encoder(**ddconfig)
|
| 298 |
+
self.decoder = Decoder(**ddconfig)
|
| 299 |
+
assert ddconfig["double_z"]
|
| 300 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
| 301 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 302 |
+
self.embed_dim = embed_dim
|
| 303 |
+
if colorize_nlabels is not None:
|
| 304 |
+
assert type(colorize_nlabels)==int
|
| 305 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 306 |
+
if monitor is not None:
|
| 307 |
+
self.monitor = monitor
|
| 308 |
+
|
| 309 |
+
def encode(self, x):
|
| 310 |
+
h = self.encoder(x)
|
| 311 |
+
moments = self.quant_conv(h)
|
| 312 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 313 |
+
return posterior
|
| 314 |
+
|
| 315 |
+
def decode(self, z):
|
| 316 |
+
z = self.post_quant_conv(z)
|
| 317 |
+
dec = self.decoder(z)
|
| 318 |
+
return dec
|
| 319 |
+
|
| 320 |
+
def forward(self, input, sample_posterior=True):
|
| 321 |
+
posterior = self.encode(input)
|
| 322 |
+
if sample_posterior:
|
| 323 |
+
z = posterior.sample()
|
| 324 |
+
else:
|
| 325 |
+
z = posterior.mode()
|
| 326 |
+
dec = self.decode(z)
|
| 327 |
+
return dec, posterior
|
| 328 |
+
|
| 329 |
+
def get_input(self, batch, k):
|
| 330 |
+
x = batch[k]
|
| 331 |
+
if len(x.shape) == 3:
|
| 332 |
+
x = x[..., None]
|
| 333 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 334 |
+
return x
|
| 335 |
+
|
| 336 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 337 |
+
inputs = self.get_input(batch, self.image_key)
|
| 338 |
+
reconstructions, posterior = self(inputs)
|
| 339 |
+
|
| 340 |
+
if optimizer_idx == 0:
|
| 341 |
+
# train encoder+decoder+logvar
|
| 342 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 343 |
+
last_layer=self.get_last_layer(), split="train")
|
| 344 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 345 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 346 |
+
return aeloss
|
| 347 |
+
|
| 348 |
+
if optimizer_idx == 1:
|
| 349 |
+
# train the discriminator
|
| 350 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 351 |
+
last_layer=self.get_last_layer(), split="train")
|
| 352 |
+
|
| 353 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 354 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 355 |
+
return discloss
|
| 356 |
+
|
| 357 |
+
def validation_step(self, batch, batch_idx):
|
| 358 |
+
inputs = self.get_input(batch, self.image_key)
|
| 359 |
+
reconstructions, posterior = self(inputs)
|
| 360 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
| 361 |
+
last_layer=self.get_last_layer(), split="val")
|
| 362 |
+
|
| 363 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
| 364 |
+
last_layer=self.get_last_layer(), split="val")
|
| 365 |
+
|
| 366 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
| 367 |
+
self.log_dict(log_dict_ae)
|
| 368 |
+
self.log_dict(log_dict_disc)
|
| 369 |
+
return self.log_dict
|
| 370 |
+
|
| 371 |
+
def configure_optimizers(self):
|
| 372 |
+
lr = self.learning_rate
|
| 373 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 374 |
+
list(self.decoder.parameters())+
|
| 375 |
+
list(self.quant_conv.parameters())+
|
| 376 |
+
list(self.post_quant_conv.parameters()),
|
| 377 |
+
lr=lr, betas=(0.5, 0.9))
|
| 378 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 379 |
+
lr=lr, betas=(0.5, 0.9))
|
| 380 |
+
return [opt_ae, opt_disc], []
|
| 381 |
+
|
| 382 |
+
def get_last_layer(self):
|
| 383 |
+
return self.decoder.conv_out.weight
|
| 384 |
+
|
| 385 |
+
@torch.no_grad()
|
| 386 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
| 387 |
+
log = dict()
|
| 388 |
+
x = self.get_input(batch, self.image_key)
|
| 389 |
+
x = x.to(self.device)
|
| 390 |
+
if not only_inputs:
|
| 391 |
+
xrec, posterior = self(x)
|
| 392 |
+
if x.shape[1] > 3:
|
| 393 |
+
# colorize with random projection
|
| 394 |
+
assert xrec.shape[1] > 3
|
| 395 |
+
x = self.to_rgb(x)
|
| 396 |
+
xrec = self.to_rgb(xrec)
|
| 397 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
| 398 |
+
log["reconstructions"] = xrec
|
| 399 |
+
log["inputs"] = x
|
| 400 |
+
return log
|
| 401 |
+
|
| 402 |
+
def to_rgb(self, x):
|
| 403 |
+
assert self.image_key == "segmentation"
|
| 404 |
+
if not hasattr(self, "colorize"):
|
| 405 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 406 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 407 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 408 |
+
return x
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class IdentityFirstStage(nn.Module):
|
| 412 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 413 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
| 414 |
+
super().__init__()
|
| 415 |
+
|
| 416 |
+
def encode(self, x, *args, **kwargs):
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
def decode(self, x, *args, **kwargs):
|
| 420 |
+
return x
|
| 421 |
+
|
| 422 |
+
def quantize(self, x, *args, **kwargs):
|
| 423 |
+
if self.vq_interface:
|
| 424 |
+
return x, None, [None, None, None]
|
| 425 |
+
return x
|
| 426 |
+
|
| 427 |
+
def forward(self, x, *args, **kwargs):
|
| 428 |
+
return x
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/bert.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
# from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AbstractEncoder(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
def encode(self, *args, **kwargs):
|
| 13 |
+
raise NotImplementedError
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ClassEmbedder(nn.Module):
|
| 18 |
+
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.key = key
|
| 21 |
+
self.embedding = nn.Embedding(n_classes, embed_dim)
|
| 22 |
+
|
| 23 |
+
def forward(self, batch, key=None):
|
| 24 |
+
if key is None:
|
| 25 |
+
key = self.key
|
| 26 |
+
# this is for use in crossattn
|
| 27 |
+
c = batch[key][:, None]
|
| 28 |
+
c = self.embedding(c)
|
| 29 |
+
return c
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TransformerEmbedder(AbstractEncoder):
|
| 33 |
+
"""Some transformer encoder layers"""
|
| 34 |
+
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
| 37 |
+
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
| 38 |
+
|
| 39 |
+
def forward(self, tokens):
|
| 40 |
+
z = self.transformer(tokens, return_embeddings=True)
|
| 41 |
+
return z
|
| 42 |
+
|
| 43 |
+
def encode(self, x):
|
| 44 |
+
return self(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BERTTokenizer(AbstractEncoder):
|
| 48 |
+
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
| 49 |
+
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
| 50 |
+
super().__init__()
|
| 51 |
+
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
| 52 |
+
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
| 53 |
+
self.vq_interface = vq_interface
|
| 54 |
+
self.max_length = max_length
|
| 55 |
+
|
| 56 |
+
def forward(self, text):
|
| 57 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 58 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 59 |
+
tokens = batch_encoding["input_ids"]
|
| 60 |
+
return tokens
|
| 61 |
+
|
| 62 |
+
@torch.no_grad()
|
| 63 |
+
def encode(self, text):
|
| 64 |
+
tokens = self(text)
|
| 65 |
+
if not self.vq_interface:
|
| 66 |
+
return tokens
|
| 67 |
+
return None, None, [None, None, tokens]
|
| 68 |
+
|
| 69 |
+
def decode(self, text):
|
| 70 |
+
return text
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class BERTEmbedder(AbstractEncoder):
|
| 74 |
+
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
| 75 |
+
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
| 76 |
+
ckpt_path=None, ignore_keys=[], device="cuda", use_tokenizer=True, embedding_dropout=0.0):
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.use_tknz_fn = use_tokenizer
|
| 79 |
+
if self.use_tknz_fn:
|
| 80 |
+
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
| 81 |
+
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
| 82 |
+
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
| 83 |
+
emb_dropout=embedding_dropout)
|
| 84 |
+
if ckpt_path is not None:
|
| 85 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 86 |
+
|
| 87 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 88 |
+
sd = torch.load(path, map_location="cpu")
|
| 89 |
+
keys = list(sd.keys())
|
| 90 |
+
for k in keys:
|
| 91 |
+
for ik in ignore_keys:
|
| 92 |
+
if k.startswith(ik):
|
| 93 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 94 |
+
del sd[k]
|
| 95 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
| 96 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 97 |
+
|
| 98 |
+
def forward(self, text):
|
| 99 |
+
if self.use_tknz_fn:
|
| 100 |
+
tokens = self.tknz_fn(text)
|
| 101 |
+
else:
|
| 102 |
+
tokens = text
|
| 103 |
+
device = self.transformer.token_emb.weight.device # a trick to get device
|
| 104 |
+
tokens = tokens.to(device)
|
| 105 |
+
z = self.transformer(tokens, return_embeddings=True)
|
| 106 |
+
return z
|
| 107 |
+
|
| 108 |
+
def encode(self, text):
|
| 109 |
+
# output of length 77
|
| 110 |
+
return self(text)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class SpatialRescaler(nn.Module):
|
| 114 |
+
def __init__(self,
|
| 115 |
+
n_stages=1,
|
| 116 |
+
method='bilinear',
|
| 117 |
+
multiplier=0.5,
|
| 118 |
+
in_channels=3,
|
| 119 |
+
out_channels=None,
|
| 120 |
+
bias=False):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.n_stages = n_stages
|
| 123 |
+
assert self.n_stages >= 0
|
| 124 |
+
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
| 125 |
+
self.multiplier = multiplier
|
| 126 |
+
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
| 127 |
+
self.remap_output = out_channels is not None
|
| 128 |
+
if self.remap_output:
|
| 129 |
+
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
| 130 |
+
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
| 131 |
+
|
| 132 |
+
def forward(self,x):
|
| 133 |
+
for stage in range(self.n_stages):
|
| 134 |
+
x = self.interpolator(x, scale_factor=self.multiplier)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if self.remap_output:
|
| 138 |
+
x = self.channel_mapper(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
def encode(self, x):
|
| 142 |
+
return self(x)
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from functools import partial
|
| 5 |
+
from lib.model_zoo.common.get_model import register
|
| 6 |
+
|
| 7 |
+
version = '0'
|
| 8 |
+
symbol = 'clip'
|
| 9 |
+
|
| 10 |
+
class AbstractEncoder(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
def encode(self, *args, **kwargs):
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
| 18 |
+
|
| 19 |
+
def disabled_train(self, mode=True):
|
| 20 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 21 |
+
does not change anymore."""
|
| 22 |
+
return self
|
| 23 |
+
|
| 24 |
+
@register('clip_text_frozen', version)
|
| 25 |
+
class FrozenCLIPTextEmbedder(AbstractEncoder):
|
| 26 |
+
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
| 27 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
| 30 |
+
self.transformer = CLIPTextModel.from_pretrained(version)
|
| 31 |
+
self.device = device
|
| 32 |
+
self.max_length = max_length # TODO: typical value?
|
| 33 |
+
self.freeze()
|
| 34 |
+
|
| 35 |
+
def freeze(self):
|
| 36 |
+
self.transformer = self.transformer.eval()
|
| 37 |
+
#self.train = disabled_train
|
| 38 |
+
for param in self.parameters():
|
| 39 |
+
param.requires_grad = False
|
| 40 |
+
|
| 41 |
+
def forward(self, text):
|
| 42 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 43 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 44 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
| 45 |
+
outputs = self.transformer(input_ids=tokens)
|
| 46 |
+
z = outputs.last_hidden_state
|
| 47 |
+
return z
|
| 48 |
+
|
| 49 |
+
def encode(self, text):
|
| 50 |
+
return self(text)
|
| 51 |
+
|
| 52 |
+
from transformers import CLIPProcessor, CLIPVisionModel
|
| 53 |
+
|
| 54 |
+
@register('clip_vision_frozen', version)
|
| 55 |
+
class FrozenCLIPVisionEmbedder(AbstractEncoder):
|
| 56 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.processor = CLIPProcessor.from_pretrained(version)
|
| 59 |
+
self.transformer = CLIPVisionModel.from_pretrained(version)
|
| 60 |
+
self.device = device
|
| 61 |
+
self.max_length = max_length # TODO: typical value?
|
| 62 |
+
self.freeze()
|
| 63 |
+
|
| 64 |
+
def freeze(self):
|
| 65 |
+
self.transformer = self.transformer.eval()
|
| 66 |
+
#self.train = disabled_train
|
| 67 |
+
for param in self.parameters():
|
| 68 |
+
param.requires_grad = False
|
| 69 |
+
|
| 70 |
+
def forward(self, images):
|
| 71 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
| 72 |
+
pixels = inputs['pixel_values'].to(self.device)
|
| 73 |
+
outputs = self.transformer(pixel_values=pixels)
|
| 74 |
+
z = outputs.last_hidden_state
|
| 75 |
+
return z
|
| 76 |
+
|
| 77 |
+
def encode(self, image):
|
| 78 |
+
return self(image)
|
| 79 |
+
|
| 80 |
+
from transformers import CLIPModel
|
| 81 |
+
|
| 82 |
+
@register('clip_frozen', version)
|
| 83 |
+
class FrozenCLIP(AbstractEncoder):
|
| 84 |
+
def __init__(self,
|
| 85 |
+
version="openai/clip-vit-large-patch14",
|
| 86 |
+
max_length=77,
|
| 87 |
+
encode_type='encode_text',): # clip-vit-base-patch32
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
| 90 |
+
self.processor = CLIPProcessor.from_pretrained(version)
|
| 91 |
+
self.model = CLIPModel.from_pretrained(version)
|
| 92 |
+
self.max_length = max_length # TODO: typical value?
|
| 93 |
+
self.encode_type = encode_type
|
| 94 |
+
self.pinv_text_projection = None
|
| 95 |
+
self.freeze()
|
| 96 |
+
|
| 97 |
+
def get_device(self):
|
| 98 |
+
# A trick to get device
|
| 99 |
+
return self.model.text_projection.weight.device
|
| 100 |
+
|
| 101 |
+
def freeze(self):
|
| 102 |
+
self.model = self.model.eval()
|
| 103 |
+
#self.train = disabled_train
|
| 104 |
+
for param in self.parameters():
|
| 105 |
+
param.requires_grad = False
|
| 106 |
+
|
| 107 |
+
def encode_text_pooled(self, text):
|
| 108 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 109 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 110 |
+
tokens = batch_encoding["input_ids"].to(self.get_device())
|
| 111 |
+
return self.model.get_text_features(input_ids=tokens)
|
| 112 |
+
|
| 113 |
+
def encode_vision_pooled(self, images):
|
| 114 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
| 115 |
+
pixels = inputs['pixel_values'].to(self.get_device())
|
| 116 |
+
return self.model.get_image_features(pixel_values=pixels)
|
| 117 |
+
|
| 118 |
+
def encode_text_noproj(self, text):
|
| 119 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 120 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 121 |
+
tokens = batch_encoding["input_ids"].to(self.get_device())
|
| 122 |
+
outputs = self.model.text_model(input_ids=tokens)
|
| 123 |
+
return outputs.last_hidden_state
|
| 124 |
+
|
| 125 |
+
def encode_vision_noproj(self, images):
|
| 126 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
| 127 |
+
pixels = inputs['pixel_values'].to(self.get_device())
|
| 128 |
+
outputs = self.model.vision_model(pixel_values=pixels)
|
| 129 |
+
return outputs.last_hidden_state
|
| 130 |
+
|
| 131 |
+
def encode_text_bug(self, text):
|
| 132 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 133 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 134 |
+
tokens = batch_encoding["input_ids"].to(self.get_device())
|
| 135 |
+
outputs = self.model.text_model(input_ids=tokens)
|
| 136 |
+
z = outputs.last_hidden_state
|
| 137 |
+
z_pooled = outputs.pooler_output
|
| 138 |
+
z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True)
|
| 139 |
+
return self.model.text_projection(z)
|
| 140 |
+
|
| 141 |
+
def encode_text(self, text):
|
| 142 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 143 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 144 |
+
tokens = batch_encoding["input_ids"].to(self.get_device())
|
| 145 |
+
outputs = self.model.text_model(input_ids=tokens)
|
| 146 |
+
z = self.model.text_projection(outputs.last_hidden_state)
|
| 147 |
+
z_pooled = self.model.text_projection(outputs.pooler_output)
|
| 148 |
+
z = z / torch.norm(z_pooled.unsqueeze(1), dim=-1, keepdim=True)
|
| 149 |
+
return z
|
| 150 |
+
|
| 151 |
+
def encode_vision(self, images):
|
| 152 |
+
z = self.encode_vision_noproj(images)
|
| 153 |
+
z = self.model.vision_model.post_layernorm(z)
|
| 154 |
+
z = self.model.visual_projection(z)
|
| 155 |
+
z_pooled = z[:, 0:1]
|
| 156 |
+
# z_pooled_normed = z_pooled / z_pooled.norm(dim=-1, keepdim=True)
|
| 157 |
+
z = z / torch.norm(z_pooled, dim=-1, keepdim=True)
|
| 158 |
+
return z
|
| 159 |
+
|
| 160 |
+
def encode_vision_pinvtext(self, images):
|
| 161 |
+
blank_text_encode_norm_avg = 28.9096
|
| 162 |
+
z = self.encode_vision(images)
|
| 163 |
+
if self.pinv_text_projection is None:
|
| 164 |
+
self.pinv_text_projection = torch.linalg.pinv(self.model.text_projection.weight).T
|
| 165 |
+
z = torch.matmul(z, self.pinv_text_projection)
|
| 166 |
+
# z = z / torch.norm(z[:, 0:1], dim=-1, keepdim=True)
|
| 167 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
| 168 |
+
z = z*blank_text_encode_norm_avg
|
| 169 |
+
# return z[:, 1:2].repeat(1, 77, 1)
|
| 170 |
+
z2 = self.encode_text_noproj('')
|
| 171 |
+
# z2[:, 1:77] = z[:, 0:76]
|
| 172 |
+
return torch.flip(z, dims=(1,))[:, 0:77]
|
| 173 |
+
|
| 174 |
+
def encode(self, *args, **kwargs):
|
| 175 |
+
return getattr(self, self.encode_type)(*args, **kwargs)
|
| 176 |
+
|
| 177 |
+
#############################
|
| 178 |
+
# copyed from justin's code #
|
| 179 |
+
#############################
|
| 180 |
+
|
| 181 |
+
@register('clip_vision_frozen_justin', version)
|
| 182 |
+
class FrozenCLIPVisionEmbedder_Justin(AbstractEncoder):
|
| 183 |
+
"""
|
| 184 |
+
Uses the CLIP image encoder.
|
| 185 |
+
"""
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
model='ViT-L/14',
|
| 189 |
+
jit=False,
|
| 190 |
+
device='cuda' if torch.cuda.is_available() else 'cpu',
|
| 191 |
+
antialias=False,
|
| 192 |
+
):
|
| 193 |
+
super().__init__()
|
| 194 |
+
from . import clip_justin
|
| 195 |
+
self.model, _ = clip_justin.load(name=model, device=device, jit=jit)
|
| 196 |
+
self.device = device
|
| 197 |
+
self.antialias = antialias
|
| 198 |
+
|
| 199 |
+
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
| 200 |
+
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
| 201 |
+
|
| 202 |
+
# I didn't call this originally, but seems like it was frozen anyway
|
| 203 |
+
self.freeze()
|
| 204 |
+
|
| 205 |
+
def freeze(self):
|
| 206 |
+
self.transformer = self.model.eval()
|
| 207 |
+
for param in self.parameters():
|
| 208 |
+
param.requires_grad = False
|
| 209 |
+
|
| 210 |
+
def preprocess(self, x):
|
| 211 |
+
import kornia
|
| 212 |
+
# Expects inputs in the range -1, 1
|
| 213 |
+
x = kornia.geometry.resize(x, (224, 224),
|
| 214 |
+
interpolation='bicubic',align_corners=True,
|
| 215 |
+
antialias=self.antialias)
|
| 216 |
+
x = (x + 1.) / 2.
|
| 217 |
+
# renormalize according to clip
|
| 218 |
+
x = kornia.enhance.normalize(x, self.mean, self.std)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
# x is assumed to be in range [-1,1]
|
| 223 |
+
return self.model.encode_image(self.preprocess(x)).float()
|
| 224 |
+
|
| 225 |
+
def encode(self, im):
|
| 226 |
+
return self(im).unsqueeze(1)
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clip import load
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/clip.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
import urllib
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Any, Union, List
|
| 6 |
+
from pkg_resources import packaging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .model import build_model
|
| 14 |
+
# from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from torchvision.transforms import InterpolationMode
|
| 18 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 19 |
+
except ImportError:
|
| 20 |
+
BICUBIC = Image.BICUBIC
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
| 24 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 28 |
+
# _tokenizer = _Tokenizer()
|
| 29 |
+
|
| 30 |
+
_MODELS = {
|
| 31 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 32 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 33 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 34 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 35 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
| 36 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 37 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 38 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
| 39 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _download(url: str, root: str):
|
| 44 |
+
os.makedirs(root, exist_ok=True)
|
| 45 |
+
filename = os.path.basename(url)
|
| 46 |
+
|
| 47 |
+
expected_sha256 = url.split("/")[-2]
|
| 48 |
+
download_target = os.path.join(root, filename)
|
| 49 |
+
|
| 50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 52 |
+
|
| 53 |
+
if os.path.isfile(download_target):
|
| 54 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 55 |
+
return download_target
|
| 56 |
+
else:
|
| 57 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 58 |
+
|
| 59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 60 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
| 61 |
+
while True:
|
| 62 |
+
buffer = source.read(8192)
|
| 63 |
+
if not buffer:
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
output.write(buffer)
|
| 67 |
+
loop.update(len(buffer))
|
| 68 |
+
|
| 69 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 70 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
| 71 |
+
|
| 72 |
+
return download_target
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _convert_image_to_rgb(image):
|
| 76 |
+
return image.convert("RGB")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _transform(n_px):
|
| 80 |
+
return Compose([
|
| 81 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 82 |
+
CenterCrop(n_px),
|
| 83 |
+
_convert_image_to_rgb,
|
| 84 |
+
ToTensor(),
|
| 85 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def available_models() -> List[str]:
|
| 90 |
+
"""Returns the names of available CLIP models"""
|
| 91 |
+
return list(_MODELS.keys())
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
| 95 |
+
"""Load a CLIP model
|
| 96 |
+
|
| 97 |
+
Parameters
|
| 98 |
+
----------
|
| 99 |
+
name : str
|
| 100 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 101 |
+
|
| 102 |
+
device : Union[str, torch.device]
|
| 103 |
+
The device to put the loaded model
|
| 104 |
+
|
| 105 |
+
jit : bool
|
| 106 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 107 |
+
|
| 108 |
+
download_root: str
|
| 109 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
| 110 |
+
|
| 111 |
+
Returns
|
| 112 |
+
-------
|
| 113 |
+
model : torch.nn.Module
|
| 114 |
+
The CLIP model
|
| 115 |
+
|
| 116 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 117 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 118 |
+
"""
|
| 119 |
+
if name in _MODELS:
|
| 120 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
| 121 |
+
elif os.path.isfile(name):
|
| 122 |
+
model_path = name
|
| 123 |
+
else:
|
| 124 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 125 |
+
|
| 126 |
+
with open(model_path, 'rb') as opened_file:
|
| 127 |
+
try:
|
| 128 |
+
# loading JIT archive
|
| 129 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
| 130 |
+
state_dict = None
|
| 131 |
+
except RuntimeError:
|
| 132 |
+
# loading saved state dict
|
| 133 |
+
if jit:
|
| 134 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 135 |
+
jit = False
|
| 136 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
| 137 |
+
|
| 138 |
+
if not jit:
|
| 139 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
| 140 |
+
if str(device) == "cpu":
|
| 141 |
+
model.float()
|
| 142 |
+
return model, _transform(model.visual.input_resolution)
|
| 143 |
+
|
| 144 |
+
# patch the device names
|
| 145 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 146 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 147 |
+
|
| 148 |
+
def patch_device(module):
|
| 149 |
+
try:
|
| 150 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 151 |
+
except RuntimeError:
|
| 152 |
+
graphs = []
|
| 153 |
+
|
| 154 |
+
if hasattr(module, "forward1"):
|
| 155 |
+
graphs.append(module.forward1.graph)
|
| 156 |
+
|
| 157 |
+
for graph in graphs:
|
| 158 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 159 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 160 |
+
node.copyAttributes(device_node)
|
| 161 |
+
|
| 162 |
+
model.apply(patch_device)
|
| 163 |
+
patch_device(model.encode_image)
|
| 164 |
+
patch_device(model.encode_text)
|
| 165 |
+
|
| 166 |
+
# patch dtype to float32 on CPU
|
| 167 |
+
if str(device) == "cpu":
|
| 168 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 169 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 170 |
+
float_node = float_input.node()
|
| 171 |
+
|
| 172 |
+
def patch_float(module):
|
| 173 |
+
try:
|
| 174 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 175 |
+
except RuntimeError:
|
| 176 |
+
graphs = []
|
| 177 |
+
|
| 178 |
+
if hasattr(module, "forward1"):
|
| 179 |
+
graphs.append(module.forward1.graph)
|
| 180 |
+
|
| 181 |
+
for graph in graphs:
|
| 182 |
+
for node in graph.findAllNodes("aten::to"):
|
| 183 |
+
inputs = list(node.inputs())
|
| 184 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 185 |
+
if inputs[i].node()["value"] == 5:
|
| 186 |
+
inputs[i].node().copyAttributes(float_node)
|
| 187 |
+
|
| 188 |
+
model.apply(patch_float)
|
| 189 |
+
patch_float(model.encode_image)
|
| 190 |
+
patch_float(model.encode_text)
|
| 191 |
+
|
| 192 |
+
model.float()
|
| 193 |
+
|
| 194 |
+
return model, _transform(model.input_resolution.item())
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
| 198 |
+
# """
|
| 199 |
+
# Returns the tokenized representation of given input string(s)
|
| 200 |
+
|
| 201 |
+
# Parameters
|
| 202 |
+
# ----------
|
| 203 |
+
# texts : Union[str, List[str]]
|
| 204 |
+
# An input string or a list of input strings to tokenize
|
| 205 |
+
|
| 206 |
+
# context_length : int
|
| 207 |
+
# The context length to use; all CLIP models use 77 as the context length
|
| 208 |
+
|
| 209 |
+
# truncate: bool
|
| 210 |
+
# Whether to truncate the text in case its encoding is longer than the context length
|
| 211 |
+
|
| 212 |
+
# Returns
|
| 213 |
+
# -------
|
| 214 |
+
# A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
| 215 |
+
# We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
| 216 |
+
# """
|
| 217 |
+
# if isinstance(texts, str):
|
| 218 |
+
# texts = [texts]
|
| 219 |
+
|
| 220 |
+
# sot_token = _tokenizer.encoder["<|startoftext|>"]
|
| 221 |
+
# eot_token = _tokenizer.encoder["<|endoftext|>"]
|
| 222 |
+
# all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 223 |
+
# if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
| 224 |
+
# result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 225 |
+
# else:
|
| 226 |
+
# result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
| 227 |
+
|
| 228 |
+
# for i, tokens in enumerate(all_tokens):
|
| 229 |
+
# if len(tokens) > context_length:
|
| 230 |
+
# if truncate:
|
| 231 |
+
# tokens = tokens[:context_length]
|
| 232 |
+
# tokens[-1] = eot_token
|
| 233 |
+
# else:
|
| 234 |
+
# raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 235 |
+
# result[i, :len(tokens)] = torch.tensor(tokens)
|
| 236 |
+
|
| 237 |
+
# return result
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/model.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
expansion = 4
|
| 12 |
+
|
| 13 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 20 |
+
|
| 21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 24 |
+
|
| 25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 26 |
+
|
| 27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 30 |
+
|
| 31 |
+
self.downsample = None
|
| 32 |
+
self.stride = stride
|
| 33 |
+
|
| 34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 37 |
+
("-1", nn.AvgPool2d(stride)),
|
| 38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 40 |
+
]))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
identity = x
|
| 44 |
+
|
| 45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
| 47 |
+
out = self.avgpool(out)
|
| 48 |
+
out = self.bn3(self.conv3(out))
|
| 49 |
+
|
| 50 |
+
if self.downsample is not None:
|
| 51 |
+
identity = self.downsample(x)
|
| 52 |
+
|
| 53 |
+
out += identity
|
| 54 |
+
out = self.relu3(out)
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AttentionPool2d(nn.Module):
|
| 59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 66 |
+
self.num_heads = num_heads
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 72 |
+
x, _ = F.multi_head_attention_forward(
|
| 73 |
+
query=x[:1], key=x, value=x,
|
| 74 |
+
embed_dim_to_check=x.shape[-1],
|
| 75 |
+
num_heads=self.num_heads,
|
| 76 |
+
q_proj_weight=self.q_proj.weight,
|
| 77 |
+
k_proj_weight=self.k_proj.weight,
|
| 78 |
+
v_proj_weight=self.v_proj.weight,
|
| 79 |
+
in_proj_weight=None,
|
| 80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 81 |
+
bias_k=None,
|
| 82 |
+
bias_v=None,
|
| 83 |
+
add_zero_attn=False,
|
| 84 |
+
dropout_p=0,
|
| 85 |
+
out_proj_weight=self.c_proj.weight,
|
| 86 |
+
out_proj_bias=self.c_proj.bias,
|
| 87 |
+
use_separate_proj_weight=True,
|
| 88 |
+
training=self.training,
|
| 89 |
+
need_weights=False
|
| 90 |
+
)
|
| 91 |
+
return x.squeeze(0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ModifiedResNet(nn.Module):
|
| 95 |
+
"""
|
| 96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.output_dim = output_dim
|
| 105 |
+
self.input_resolution = input_resolution
|
| 106 |
+
|
| 107 |
+
# the 3-layer stem
|
| 108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 117 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 118 |
+
|
| 119 |
+
# residual layers
|
| 120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 125 |
+
|
| 126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 128 |
+
|
| 129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 131 |
+
|
| 132 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 133 |
+
for _ in range(1, blocks):
|
| 134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 135 |
+
|
| 136 |
+
return nn.Sequential(*layers)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
def stem(x):
|
| 140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 143 |
+
x = self.avgpool(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
x = x.type(self.conv1.weight.dtype)
|
| 147 |
+
x = stem(x)
|
| 148 |
+
x = self.layer1(x)
|
| 149 |
+
x = self.layer2(x)
|
| 150 |
+
x = self.layer3(x)
|
| 151 |
+
x = self.layer4(x)
|
| 152 |
+
x = self.attnpool(x)
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class LayerNorm(nn.LayerNorm):
|
| 158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 159 |
+
|
| 160 |
+
def forward(self, x: torch.Tensor):
|
| 161 |
+
orig_type = x.dtype
|
| 162 |
+
ret = super().forward(x.type(torch.float32))
|
| 163 |
+
return ret.type(orig_type)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class QuickGELU(nn.Module):
|
| 167 |
+
def forward(self, x: torch.Tensor):
|
| 168 |
+
return x * torch.sigmoid(1.702 * x)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ResidualAttentionBlock(nn.Module):
|
| 172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 173 |
+
super().__init__()
|
| 174 |
+
|
| 175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 176 |
+
self.ln_1 = LayerNorm(d_model)
|
| 177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 179 |
+
("gelu", QuickGELU()),
|
| 180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 181 |
+
]))
|
| 182 |
+
self.ln_2 = LayerNorm(d_model)
|
| 183 |
+
self.attn_mask = attn_mask
|
| 184 |
+
|
| 185 |
+
def attention(self, x: torch.Tensor):
|
| 186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 188 |
+
|
| 189 |
+
def forward(self, x: torch.Tensor):
|
| 190 |
+
x = x + self.attention(self.ln_1(x))
|
| 191 |
+
x = x + self.mlp(self.ln_2(x))
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Transformer(nn.Module):
|
| 196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.width = width
|
| 199 |
+
self.layers = layers
|
| 200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor):
|
| 203 |
+
return self.resblocks(x)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class VisionTransformer(nn.Module):
|
| 207 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.input_resolution = input_resolution
|
| 210 |
+
self.output_dim = output_dim
|
| 211 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 212 |
+
|
| 213 |
+
scale = width ** -0.5
|
| 214 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 215 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 216 |
+
self.ln_pre = LayerNorm(width)
|
| 217 |
+
|
| 218 |
+
self.transformer = Transformer(width, layers, heads)
|
| 219 |
+
|
| 220 |
+
self.ln_post = LayerNorm(width)
|
| 221 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 222 |
+
|
| 223 |
+
def forward(self, x: torch.Tensor):
|
| 224 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 225 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 226 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 227 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 228 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 229 |
+
x = self.ln_pre(x)
|
| 230 |
+
|
| 231 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 232 |
+
x = self.transformer(x)
|
| 233 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 234 |
+
|
| 235 |
+
x = self.ln_post(x[:, 0, :])
|
| 236 |
+
|
| 237 |
+
if self.proj is not None:
|
| 238 |
+
x = x @ self.proj
|
| 239 |
+
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class CLIP(nn.Module):
|
| 244 |
+
def __init__(self,
|
| 245 |
+
embed_dim: int,
|
| 246 |
+
# vision
|
| 247 |
+
image_resolution: int,
|
| 248 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 249 |
+
vision_width: int,
|
| 250 |
+
vision_patch_size: int,
|
| 251 |
+
# text
|
| 252 |
+
context_length: int,
|
| 253 |
+
vocab_size: int,
|
| 254 |
+
transformer_width: int,
|
| 255 |
+
transformer_heads: int,
|
| 256 |
+
transformer_layers: int
|
| 257 |
+
):
|
| 258 |
+
super().__init__()
|
| 259 |
+
|
| 260 |
+
self.context_length = context_length
|
| 261 |
+
|
| 262 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 263 |
+
vision_heads = vision_width * 32 // 64
|
| 264 |
+
self.visual = ModifiedResNet(
|
| 265 |
+
layers=vision_layers,
|
| 266 |
+
output_dim=embed_dim,
|
| 267 |
+
heads=vision_heads,
|
| 268 |
+
input_resolution=image_resolution,
|
| 269 |
+
width=vision_width
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
vision_heads = vision_width // 64
|
| 273 |
+
self.visual = VisionTransformer(
|
| 274 |
+
input_resolution=image_resolution,
|
| 275 |
+
patch_size=vision_patch_size,
|
| 276 |
+
width=vision_width,
|
| 277 |
+
layers=vision_layers,
|
| 278 |
+
heads=vision_heads,
|
| 279 |
+
output_dim=embed_dim
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
self.transformer = Transformer(
|
| 283 |
+
width=transformer_width,
|
| 284 |
+
layers=transformer_layers,
|
| 285 |
+
heads=transformer_heads,
|
| 286 |
+
attn_mask=self.build_attention_mask()
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.vocab_size = vocab_size
|
| 290 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 291 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 292 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 293 |
+
|
| 294 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 295 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 296 |
+
|
| 297 |
+
self.initialize_parameters()
|
| 298 |
+
|
| 299 |
+
def initialize_parameters(self):
|
| 300 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 301 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 302 |
+
|
| 303 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 304 |
+
if self.visual.attnpool is not None:
|
| 305 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 306 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 307 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 308 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 309 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 310 |
+
|
| 311 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 312 |
+
for name, param in resnet_block.named_parameters():
|
| 313 |
+
if name.endswith("bn3.weight"):
|
| 314 |
+
nn.init.zeros_(param)
|
| 315 |
+
|
| 316 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 317 |
+
attn_std = self.transformer.width ** -0.5
|
| 318 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 319 |
+
for block in self.transformer.resblocks:
|
| 320 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 321 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 322 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 323 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 324 |
+
|
| 325 |
+
if self.text_projection is not None:
|
| 326 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 327 |
+
|
| 328 |
+
def build_attention_mask(self):
|
| 329 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 330 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 331 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 332 |
+
mask.fill_(float("-inf"))
|
| 333 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 334 |
+
return mask
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def dtype(self):
|
| 338 |
+
return self.visual.conv1.weight.dtype
|
| 339 |
+
|
| 340 |
+
def encode_image(self, image):
|
| 341 |
+
return self.visual(image.type(self.dtype))
|
| 342 |
+
|
| 343 |
+
def encode_text(self, text):
|
| 344 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 345 |
+
|
| 346 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 347 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 348 |
+
x = self.transformer(x)
|
| 349 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 350 |
+
x = self.ln_final(x).type(self.dtype)
|
| 351 |
+
|
| 352 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 353 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 354 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 355 |
+
|
| 356 |
+
return x
|
| 357 |
+
|
| 358 |
+
def forward(self, image, text):
|
| 359 |
+
image_features = self.encode_image(image)
|
| 360 |
+
text_features = self.encode_text(text)
|
| 361 |
+
|
| 362 |
+
# normalized features
|
| 363 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 364 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 365 |
+
|
| 366 |
+
# cosine similarity as logits
|
| 367 |
+
logit_scale = self.logit_scale.exp()
|
| 368 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 369 |
+
logits_per_text = logits_per_image.t()
|
| 370 |
+
|
| 371 |
+
# shape = [global_batch_size, global_batch_size]
|
| 372 |
+
return logits_per_image, logits_per_text
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def convert_weights(model: nn.Module):
|
| 376 |
+
"""Convert applicable model parameters to fp16"""
|
| 377 |
+
|
| 378 |
+
def _convert_weights_to_fp16(l):
|
| 379 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 380 |
+
l.weight.data = l.weight.data.half()
|
| 381 |
+
if l.bias is not None:
|
| 382 |
+
l.bias.data = l.bias.data.half()
|
| 383 |
+
|
| 384 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 385 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 386 |
+
tensor = getattr(l, attr)
|
| 387 |
+
if tensor is not None:
|
| 388 |
+
tensor.data = tensor.data.half()
|
| 389 |
+
|
| 390 |
+
for name in ["text_projection", "proj"]:
|
| 391 |
+
if hasattr(l, name):
|
| 392 |
+
attr = getattr(l, name)
|
| 393 |
+
if attr is not None:
|
| 394 |
+
attr.data = attr.data.half()
|
| 395 |
+
|
| 396 |
+
model.apply(_convert_weights_to_fp16)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def build_model(state_dict: dict):
|
| 400 |
+
vit = "visual.proj" in state_dict
|
| 401 |
+
|
| 402 |
+
if vit:
|
| 403 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 404 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 405 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 406 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 407 |
+
image_resolution = vision_patch_size * grid_size
|
| 408 |
+
else:
|
| 409 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 410 |
+
vision_layers = tuple(counts)
|
| 411 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 412 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 413 |
+
vision_patch_size = None
|
| 414 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 415 |
+
image_resolution = output_width * 32
|
| 416 |
+
|
| 417 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 418 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 419 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 420 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 421 |
+
transformer_heads = transformer_width // 64
|
| 422 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 423 |
+
|
| 424 |
+
model = CLIP(
|
| 425 |
+
embed_dim,
|
| 426 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 427 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 431 |
+
if key in state_dict:
|
| 432 |
+
del state_dict[key]
|
| 433 |
+
|
| 434 |
+
convert_weights(model)
|
| 435 |
+
model.load_state_dict(state_dict)
|
| 436 |
+
return model.eval()
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/clip_justin/simple_tokenizer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import ftfy
|
| 7 |
+
import regex as re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache()
|
| 11 |
+
def default_bpe():
|
| 12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def bytes_to_unicode():
|
| 17 |
+
"""
|
| 18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
+
The reversible bpe codes work on unicode strings.
|
| 20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
+
"""
|
| 26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 27 |
+
cs = bs[:]
|
| 28 |
+
n = 0
|
| 29 |
+
for b in range(2**8):
|
| 30 |
+
if b not in bs:
|
| 31 |
+
bs.append(b)
|
| 32 |
+
cs.append(2**8+n)
|
| 33 |
+
n += 1
|
| 34 |
+
cs = [chr(n) for n in cs]
|
| 35 |
+
return dict(zip(bs, cs))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""Return set of symbol pairs in a word.
|
| 40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
+
"""
|
| 42 |
+
pairs = set()
|
| 43 |
+
prev_char = word[0]
|
| 44 |
+
for char in word[1:]:
|
| 45 |
+
pairs.add((prev_char, char))
|
| 46 |
+
prev_char = char
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def basic_clean(text):
|
| 51 |
+
text = ftfy.fix_text(text)
|
| 52 |
+
text = html.unescape(html.unescape(text))
|
| 53 |
+
return text.strip()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def whitespace_clean(text):
|
| 57 |
+
text = re.sub(r'\s+', ' ', text)
|
| 58 |
+
text = text.strip()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SimpleTokenizer(object):
|
| 63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
| 64 |
+
self.byte_encoder = bytes_to_unicode()
|
| 65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
+
merges = merges[1:49152-256-2+1]
|
| 68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
+
vocab = list(bytes_to_unicode().values())
|
| 70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
+
for merge in merges:
|
| 72 |
+
vocab.append(''.join(merge))
|
| 73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 79 |
+
|
| 80 |
+
def bpe(self, token):
|
| 81 |
+
if token in self.cache:
|
| 82 |
+
return self.cache[token]
|
| 83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 84 |
+
pairs = get_pairs(word)
|
| 85 |
+
|
| 86 |
+
if not pairs:
|
| 87 |
+
return token+'</w>'
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 91 |
+
if bigram not in self.bpe_ranks:
|
| 92 |
+
break
|
| 93 |
+
first, second = bigram
|
| 94 |
+
new_word = []
|
| 95 |
+
i = 0
|
| 96 |
+
while i < len(word):
|
| 97 |
+
try:
|
| 98 |
+
j = word.index(first, i)
|
| 99 |
+
new_word.extend(word[i:j])
|
| 100 |
+
i = j
|
| 101 |
+
except:
|
| 102 |
+
new_word.extend(word[i:])
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 106 |
+
new_word.append(first+second)
|
| 107 |
+
i += 2
|
| 108 |
+
else:
|
| 109 |
+
new_word.append(word[i])
|
| 110 |
+
i += 1
|
| 111 |
+
new_word = tuple(new_word)
|
| 112 |
+
word = new_word
|
| 113 |
+
if len(word) == 1:
|
| 114 |
+
break
|
| 115 |
+
else:
|
| 116 |
+
pairs = get_pairs(word)
|
| 117 |
+
word = ' '.join(word)
|
| 118 |
+
self.cache[token] = word
|
| 119 |
+
return word
|
| 120 |
+
|
| 121 |
+
def encode(self, text):
|
| 122 |
+
bpe_tokens = []
|
| 123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 124 |
+
for token in re.findall(self.pat, text):
|
| 125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 127 |
+
return bpe_tokens
|
| 128 |
+
|
| 129 |
+
def decode(self, tokens):
|
| 130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 132 |
+
return text
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/get_model.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from email.policy import strict
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.models
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import copy
|
| 6 |
+
from ...log_service import print_log
|
| 7 |
+
from .utils import \
|
| 8 |
+
get_total_param, get_total_param_sum, \
|
| 9 |
+
get_unit
|
| 10 |
+
|
| 11 |
+
# def load_state_dict(net, model_path):
|
| 12 |
+
# if isinstance(net, dict):
|
| 13 |
+
# for ni, neti in net.items():
|
| 14 |
+
# paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
|
| 15 |
+
# new_paras = neti.state_dict()
|
| 16 |
+
# new_paras.update(paras)
|
| 17 |
+
# neti.load_state_dict(new_paras)
|
| 18 |
+
# else:
|
| 19 |
+
# paras = torch.load(model_path, map_location=torch.device('cpu'))
|
| 20 |
+
# new_paras = net.state_dict()
|
| 21 |
+
# new_paras.update(paras)
|
| 22 |
+
# net.load_state_dict(new_paras)
|
| 23 |
+
# return
|
| 24 |
+
|
| 25 |
+
# def save_state_dict(net, path):
|
| 26 |
+
# if isinstance(net, (torch.nn.DataParallel,
|
| 27 |
+
# torch.nn.parallel.DistributedDataParallel)):
|
| 28 |
+
# torch.save(net.module.state_dict(), path)
|
| 29 |
+
# else:
|
| 30 |
+
# torch.save(net.state_dict(), path)
|
| 31 |
+
|
| 32 |
+
def singleton(class_):
|
| 33 |
+
instances = {}
|
| 34 |
+
def getinstance(*args, **kwargs):
|
| 35 |
+
if class_ not in instances:
|
| 36 |
+
instances[class_] = class_(*args, **kwargs)
|
| 37 |
+
return instances[class_]
|
| 38 |
+
return getinstance
|
| 39 |
+
|
| 40 |
+
def preprocess_model_args(args):
|
| 41 |
+
# If args has layer_units, get the corresponding
|
| 42 |
+
# units.
|
| 43 |
+
# If args get backbone, get the backbone model.
|
| 44 |
+
args = copy.deepcopy(args)
|
| 45 |
+
if 'layer_units' in args:
|
| 46 |
+
layer_units = [
|
| 47 |
+
get_unit()(i) for i in args.layer_units
|
| 48 |
+
]
|
| 49 |
+
args.layer_units = layer_units
|
| 50 |
+
if 'backbone' in args:
|
| 51 |
+
args.backbone = get_model()(args.backbone)
|
| 52 |
+
return args
|
| 53 |
+
|
| 54 |
+
@singleton
|
| 55 |
+
class get_model(object):
|
| 56 |
+
def __init__(self):
|
| 57 |
+
self.model = {}
|
| 58 |
+
self.version = {}
|
| 59 |
+
|
| 60 |
+
def register(self, model, name, version='x'):
|
| 61 |
+
self.model[name] = model
|
| 62 |
+
self.version[name] = version
|
| 63 |
+
|
| 64 |
+
def __call__(self, cfg, verbose=True):
|
| 65 |
+
"""
|
| 66 |
+
Construct model based on the config.
|
| 67 |
+
"""
|
| 68 |
+
t = cfg.type
|
| 69 |
+
|
| 70 |
+
# the register is in each file
|
| 71 |
+
if t.find('ldm')==0:
|
| 72 |
+
from .. import ldm
|
| 73 |
+
elif t=='autoencoderkl':
|
| 74 |
+
from .. import autoencoder
|
| 75 |
+
elif t.find('clip')==0:
|
| 76 |
+
from .. import clip
|
| 77 |
+
elif t.find('sd')==0:
|
| 78 |
+
from .. import sd
|
| 79 |
+
elif t.find('vd')==0:
|
| 80 |
+
from .. import vd
|
| 81 |
+
elif t.find('openai_unet')==0:
|
| 82 |
+
from .. import openaimodel
|
| 83 |
+
elif t.find('optimus')==0:
|
| 84 |
+
from .. import optimus
|
| 85 |
+
|
| 86 |
+
args = preprocess_model_args(cfg.args)
|
| 87 |
+
net = self.model[t](**args)
|
| 88 |
+
|
| 89 |
+
if 'ckpt' in cfg:
|
| 90 |
+
checkpoint = torch.load(cfg.ckpt, map_location='cpu')
|
| 91 |
+
strict_sd = cfg.get('strict_sd', True)
|
| 92 |
+
net.load_state_dict(checkpoint['state_dict'], strict=strict_sd)
|
| 93 |
+
if verbose:
|
| 94 |
+
print_log('Load ckpt from {}'.format(cfg.ckpt))
|
| 95 |
+
elif 'pth' in cfg:
|
| 96 |
+
sd = torch.load(cfg.pth, map_location='cpu')
|
| 97 |
+
strict_sd = cfg.get('strict_sd', True)
|
| 98 |
+
net.load_state_dict(sd, strict=strict_sd)
|
| 99 |
+
if verbose:
|
| 100 |
+
print_log('Load pth from {}'.format(cfg.pth))
|
| 101 |
+
|
| 102 |
+
# display param_num & param_sum
|
| 103 |
+
if verbose:
|
| 104 |
+
print_log(
|
| 105 |
+
'Load {} with total {} parameters,'
|
| 106 |
+
'{:.3f} parameter sum.'.format(
|
| 107 |
+
t,
|
| 108 |
+
get_total_param(net),
|
| 109 |
+
get_total_param_sum(net) ))
|
| 110 |
+
|
| 111 |
+
return net
|
| 112 |
+
|
| 113 |
+
def get_version(self, name):
|
| 114 |
+
return self.version[name]
|
| 115 |
+
|
| 116 |
+
def register(name, version='x'):
|
| 117 |
+
def wrapper(class_):
|
| 118 |
+
get_model().register(class_, name, version)
|
| 119 |
+
return class_
|
| 120 |
+
return wrapper
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/get_optimizer.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
|
| 6 |
+
def singleton(class_):
|
| 7 |
+
instances = {}
|
| 8 |
+
def getinstance(*args, **kwargs):
|
| 9 |
+
if class_ not in instances:
|
| 10 |
+
instances[class_] = class_(*args, **kwargs)
|
| 11 |
+
return instances[class_]
|
| 12 |
+
return getinstance
|
| 13 |
+
|
| 14 |
+
class get_optimizer(object):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.optimizer = {}
|
| 17 |
+
self.register(optim.SGD, 'sgd')
|
| 18 |
+
self.register(optim.Adam, 'adam')
|
| 19 |
+
self.register(optim.AdamW, 'adamw')
|
| 20 |
+
|
| 21 |
+
def register(self, optim, name):
|
| 22 |
+
self.optimizer[name] = optim
|
| 23 |
+
|
| 24 |
+
def __call__(self, net, cfg):
|
| 25 |
+
if cfg is None:
|
| 26 |
+
return None
|
| 27 |
+
t = cfg.type
|
| 28 |
+
if isinstance(net, (torch.nn.DataParallel,
|
| 29 |
+
torch.nn.parallel.DistributedDataParallel)):
|
| 30 |
+
netm = net.module
|
| 31 |
+
else:
|
| 32 |
+
netm = net
|
| 33 |
+
pg = getattr(netm, 'parameter_group', None)
|
| 34 |
+
|
| 35 |
+
if pg is not None:
|
| 36 |
+
params = []
|
| 37 |
+
for group_name, module_or_para in pg.items():
|
| 38 |
+
if not isinstance(module_or_para, list):
|
| 39 |
+
module_or_para = [module_or_para]
|
| 40 |
+
|
| 41 |
+
grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
|
| 42 |
+
grouped_params = itertools.chain(*grouped_params)
|
| 43 |
+
pg_dict = {'params':grouped_params, 'name':group_name}
|
| 44 |
+
params.append(pg_dict)
|
| 45 |
+
else:
|
| 46 |
+
params = net.parameters()
|
| 47 |
+
return self.optimizer[t](params, lr=0, **cfg.args)
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/get_scheduler.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import numpy as np
|
| 4 |
+
import copy
|
| 5 |
+
from ... import sync
|
| 6 |
+
from ...cfg_holder import cfg_unique_holder as cfguh
|
| 7 |
+
|
| 8 |
+
def singleton(class_):
|
| 9 |
+
instances = {}
|
| 10 |
+
def getinstance(*args, **kwargs):
|
| 11 |
+
if class_ not in instances:
|
| 12 |
+
instances[class_] = class_(*args, **kwargs)
|
| 13 |
+
return instances[class_]
|
| 14 |
+
return getinstance
|
| 15 |
+
|
| 16 |
+
@singleton
|
| 17 |
+
class get_scheduler(object):
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.lr_scheduler = {}
|
| 20 |
+
|
| 21 |
+
def register(self, lrsf, name):
|
| 22 |
+
self.lr_scheduler[name] = lrsf
|
| 23 |
+
|
| 24 |
+
def __call__(self, cfg):
|
| 25 |
+
if cfg is None:
|
| 26 |
+
return None
|
| 27 |
+
if isinstance(cfg, list):
|
| 28 |
+
schedulers = []
|
| 29 |
+
for ci in cfg:
|
| 30 |
+
t = ci.type
|
| 31 |
+
schedulers.append(
|
| 32 |
+
self.lr_scheduler[t](**ci.args))
|
| 33 |
+
if len(schedulers) == 0:
|
| 34 |
+
raise ValueError
|
| 35 |
+
else:
|
| 36 |
+
return compose_scheduler(schedulers)
|
| 37 |
+
t = cfg.type
|
| 38 |
+
return self.lr_scheduler[t](**cfg.args)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def register(name):
|
| 42 |
+
def wrapper(class_):
|
| 43 |
+
get_scheduler().register(class_, name)
|
| 44 |
+
return class_
|
| 45 |
+
return wrapper
|
| 46 |
+
|
| 47 |
+
class template_scheduler(object):
|
| 48 |
+
def __init__(self, step):
|
| 49 |
+
self.step = step
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
raise ValueError
|
| 53 |
+
|
| 54 |
+
def set_lr(self, optim, new_lr, pg_lrscale=None):
|
| 55 |
+
"""
|
| 56 |
+
Set Each parameter_groups in optim with new_lr
|
| 57 |
+
New_lr can be find according to the idx.
|
| 58 |
+
pg_lrscale tells how to scale each pg.
|
| 59 |
+
"""
|
| 60 |
+
# new_lr = self.__getitem__(idx)
|
| 61 |
+
pg_lrscale = copy.deepcopy(pg_lrscale)
|
| 62 |
+
for pg in optim.param_groups:
|
| 63 |
+
if pg_lrscale is None:
|
| 64 |
+
pg['lr'] = new_lr
|
| 65 |
+
else:
|
| 66 |
+
pg['lr'] = new_lr * pg_lrscale.pop(pg['name'])
|
| 67 |
+
assert (pg_lrscale is None) or (len(pg_lrscale)==0), \
|
| 68 |
+
"pg_lrscale doesn't match pg"
|
| 69 |
+
|
| 70 |
+
@register('constant')
|
| 71 |
+
class constant_scheduler(template_scheduler):
|
| 72 |
+
def __init__(self, lr, step):
|
| 73 |
+
super().__init__(step)
|
| 74 |
+
self.lr = lr
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, idx):
|
| 77 |
+
if idx >= self.step:
|
| 78 |
+
raise ValueError
|
| 79 |
+
return self.lr
|
| 80 |
+
|
| 81 |
+
@register('poly')
|
| 82 |
+
class poly_scheduler(template_scheduler):
|
| 83 |
+
def __init__(self, start_lr, end_lr, power, step):
|
| 84 |
+
super().__init__(step)
|
| 85 |
+
self.start_lr = start_lr
|
| 86 |
+
self.end_lr = end_lr
|
| 87 |
+
self.power = power
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx):
|
| 90 |
+
if idx >= self.step:
|
| 91 |
+
raise ValueError
|
| 92 |
+
a, b = self.start_lr, self.end_lr
|
| 93 |
+
p, n = self.power, self.step
|
| 94 |
+
return b + (a-b)*((1-idx/n)**p)
|
| 95 |
+
|
| 96 |
+
@register('linear')
|
| 97 |
+
class linear_scheduler(template_scheduler):
|
| 98 |
+
def __init__(self, start_lr, end_lr, step):
|
| 99 |
+
super().__init__(step)
|
| 100 |
+
self.start_lr = start_lr
|
| 101 |
+
self.end_lr = end_lr
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, idx):
|
| 104 |
+
if idx >= self.step:
|
| 105 |
+
raise ValueError
|
| 106 |
+
a, b, n = self.start_lr, self.end_lr, self.step
|
| 107 |
+
return b + (a-b)*(1-idx/n)
|
| 108 |
+
|
| 109 |
+
@register('multistage')
|
| 110 |
+
class constant_scheduler(template_scheduler):
|
| 111 |
+
def __init__(self, start_lr, milestones, gamma, step):
|
| 112 |
+
super().__init__(step)
|
| 113 |
+
self.start_lr = start_lr
|
| 114 |
+
m = [0] + milestones + [step]
|
| 115 |
+
lr_iter = start_lr
|
| 116 |
+
self.lr = []
|
| 117 |
+
for ms, me in zip(m[0:-1], m[1:]):
|
| 118 |
+
for _ in range(ms, me):
|
| 119 |
+
self.lr.append(lr_iter)
|
| 120 |
+
lr_iter *= gamma
|
| 121 |
+
|
| 122 |
+
def __getitem__(self, idx):
|
| 123 |
+
if idx >= self.step:
|
| 124 |
+
raise ValueError
|
| 125 |
+
return self.lr[idx]
|
| 126 |
+
|
| 127 |
+
class compose_scheduler(template_scheduler):
|
| 128 |
+
def __init__(self, schedulers):
|
| 129 |
+
self.schedulers = schedulers
|
| 130 |
+
self.step = [si.step for si in schedulers]
|
| 131 |
+
self.step_milestone = []
|
| 132 |
+
acc = 0
|
| 133 |
+
for i in self.step:
|
| 134 |
+
acc += i
|
| 135 |
+
self.step_milestone.append(acc)
|
| 136 |
+
self.step = sum(self.step)
|
| 137 |
+
|
| 138 |
+
def __getitem__(self, idx):
|
| 139 |
+
if idx >= self.step:
|
| 140 |
+
raise ValueError
|
| 141 |
+
ms = self.step_milestone
|
| 142 |
+
for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])):
|
| 143 |
+
if mi <= idx < mj:
|
| 144 |
+
return self.schedulers[idx-mi]
|
| 145 |
+
raise ValueError
|
| 146 |
+
|
| 147 |
+
####################
|
| 148 |
+
# lambda schedular #
|
| 149 |
+
####################
|
| 150 |
+
|
| 151 |
+
class LambdaWarmUpCosineScheduler(template_scheduler):
|
| 152 |
+
"""
|
| 153 |
+
note: use with a base_lr of 1.0
|
| 154 |
+
"""
|
| 155 |
+
def __init__(self,
|
| 156 |
+
base_lr,
|
| 157 |
+
warm_up_steps,
|
| 158 |
+
lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
| 159 |
+
cfgt = cfguh().cfg.train
|
| 160 |
+
bs = cfgt.batch_size
|
| 161 |
+
if 'gradacc_every' not in cfgt:
|
| 162 |
+
print('Warning, gradacc_every is not found in xml, use 1 as default.')
|
| 163 |
+
acc = cfgt.get('gradacc_every', 1)
|
| 164 |
+
self.lr_multi = base_lr * bs * acc
|
| 165 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 166 |
+
self.lr_start = lr_start
|
| 167 |
+
self.lr_min = lr_min
|
| 168 |
+
self.lr_max = lr_max
|
| 169 |
+
self.lr_max_decay_steps = max_decay_steps
|
| 170 |
+
self.last_lr = 0.
|
| 171 |
+
self.verbosity_interval = verbosity_interval
|
| 172 |
+
|
| 173 |
+
def schedule(self, n):
|
| 174 |
+
if self.verbosity_interval > 0:
|
| 175 |
+
if n % self.verbosity_interval == 0:
|
| 176 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
| 177 |
+
if n < self.lr_warm_up_steps:
|
| 178 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
| 179 |
+
self.last_lr = lr
|
| 180 |
+
return lr
|
| 181 |
+
else:
|
| 182 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
| 183 |
+
t = min(t, 1.0)
|
| 184 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
| 185 |
+
1 + np.cos(t * np.pi))
|
| 186 |
+
self.last_lr = lr
|
| 187 |
+
return lr
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, idx):
|
| 190 |
+
return self.schedule(idx) * self.lr_multi
|
| 191 |
+
|
| 192 |
+
class LambdaWarmUpCosineScheduler2(template_scheduler):
|
| 193 |
+
"""
|
| 194 |
+
supports repeated iterations, configurable via lists
|
| 195 |
+
note: use with a base_lr of 1.0.
|
| 196 |
+
"""
|
| 197 |
+
def __init__(self,
|
| 198 |
+
base_lr,
|
| 199 |
+
warm_up_steps,
|
| 200 |
+
f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
| 201 |
+
cfgt = cfguh().cfg.train
|
| 202 |
+
# bs = cfgt.batch_size
|
| 203 |
+
# if 'gradacc_every' not in cfgt:
|
| 204 |
+
# print('Warning, gradacc_every is not found in xml, use 1 as default.')
|
| 205 |
+
# acc = cfgt.get('gradacc_every', 1)
|
| 206 |
+
# self.lr_multi = base_lr * bs * acc
|
| 207 |
+
self.lr_multi = base_lr
|
| 208 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
| 209 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 210 |
+
self.f_start = f_start
|
| 211 |
+
self.f_min = f_min
|
| 212 |
+
self.f_max = f_max
|
| 213 |
+
self.cycle_lengths = cycle_lengths
|
| 214 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
| 215 |
+
self.last_f = 0.
|
| 216 |
+
self.verbosity_interval = verbosity_interval
|
| 217 |
+
|
| 218 |
+
def find_in_interval(self, n):
|
| 219 |
+
interval = 0
|
| 220 |
+
for cl in self.cum_cycles[1:]:
|
| 221 |
+
if n <= cl:
|
| 222 |
+
return interval
|
| 223 |
+
interval += 1
|
| 224 |
+
|
| 225 |
+
def schedule(self, n):
|
| 226 |
+
cycle = self.find_in_interval(n)
|
| 227 |
+
n = n - self.cum_cycles[cycle]
|
| 228 |
+
if self.verbosity_interval > 0:
|
| 229 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 230 |
+
f"current cycle {cycle}")
|
| 231 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 232 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 233 |
+
self.last_f = f
|
| 234 |
+
return f
|
| 235 |
+
else:
|
| 236 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
| 237 |
+
t = min(t, 1.0)
|
| 238 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
| 239 |
+
1 + np.cos(t * np.pi))
|
| 240 |
+
self.last_f = f
|
| 241 |
+
return f
|
| 242 |
+
|
| 243 |
+
def __getitem__(self, idx):
|
| 244 |
+
return self.schedule(idx) * self.lr_multi
|
| 245 |
+
|
| 246 |
+
@register('stable_diffusion_linear')
|
| 247 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
| 248 |
+
def schedule(self, n):
|
| 249 |
+
cycle = self.find_in_interval(n)
|
| 250 |
+
n = n - self.cum_cycles[cycle]
|
| 251 |
+
if self.verbosity_interval > 0:
|
| 252 |
+
if n % self.verbosity_interval == 0:
|
| 253 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 254 |
+
f"current cycle {cycle}")
|
| 255 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 256 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 257 |
+
self.last_f = f
|
| 258 |
+
return f
|
| 259 |
+
else:
|
| 260 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
| 261 |
+
self.last_f = f
|
| 262 |
+
return f
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/common/utils.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import copy
|
| 6 |
+
import functools
|
| 7 |
+
import itertools
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
########
|
| 12 |
+
# unit #
|
| 13 |
+
########
|
| 14 |
+
|
| 15 |
+
def singleton(class_):
|
| 16 |
+
instances = {}
|
| 17 |
+
def getinstance(*args, **kwargs):
|
| 18 |
+
if class_ not in instances:
|
| 19 |
+
instances[class_] = class_(*args, **kwargs)
|
| 20 |
+
return instances[class_]
|
| 21 |
+
return getinstance
|
| 22 |
+
|
| 23 |
+
def str2value(v):
|
| 24 |
+
v = v.strip()
|
| 25 |
+
try:
|
| 26 |
+
return int(v)
|
| 27 |
+
except:
|
| 28 |
+
pass
|
| 29 |
+
try:
|
| 30 |
+
return float(v)
|
| 31 |
+
except:
|
| 32 |
+
pass
|
| 33 |
+
if v in ('True', 'true'):
|
| 34 |
+
return True
|
| 35 |
+
elif v in ('False', 'false'):
|
| 36 |
+
return False
|
| 37 |
+
else:
|
| 38 |
+
return v
|
| 39 |
+
|
| 40 |
+
@singleton
|
| 41 |
+
class get_unit(object):
|
| 42 |
+
def __init__(self):
|
| 43 |
+
self.unit = {}
|
| 44 |
+
self.register('none', None)
|
| 45 |
+
|
| 46 |
+
# general convolution
|
| 47 |
+
self.register('conv' , nn.Conv2d)
|
| 48 |
+
self.register('bn' , nn.BatchNorm2d)
|
| 49 |
+
self.register('relu' , nn.ReLU)
|
| 50 |
+
self.register('relu6' , nn.ReLU6)
|
| 51 |
+
self.register('lrelu' , nn.LeakyReLU)
|
| 52 |
+
self.register('dropout' , nn.Dropout)
|
| 53 |
+
self.register('dropout2d', nn.Dropout2d)
|
| 54 |
+
self.register('sine', Sine)
|
| 55 |
+
self.register('relusine', ReLUSine)
|
| 56 |
+
|
| 57 |
+
def register(self,
|
| 58 |
+
name,
|
| 59 |
+
unitf,):
|
| 60 |
+
|
| 61 |
+
self.unit[name] = unitf
|
| 62 |
+
|
| 63 |
+
def __call__(self, name):
|
| 64 |
+
if name is None:
|
| 65 |
+
return None
|
| 66 |
+
i = name.find('(')
|
| 67 |
+
i = len(name) if i==-1 else i
|
| 68 |
+
t = name[:i]
|
| 69 |
+
f = self.unit[t]
|
| 70 |
+
args = name[i:].strip('()')
|
| 71 |
+
if len(args) == 0:
|
| 72 |
+
args = {}
|
| 73 |
+
return f
|
| 74 |
+
else:
|
| 75 |
+
args = args.split('=')
|
| 76 |
+
args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
|
| 77 |
+
args = list(itertools.chain.from_iterable(args))
|
| 78 |
+
args = [i.strip() for i in args if len(i)>0]
|
| 79 |
+
kwargs = {}
|
| 80 |
+
for k, v in zip(args[::2], args[1::2]):
|
| 81 |
+
if v[0]=='(' and v[-1]==')':
|
| 82 |
+
kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
|
| 83 |
+
elif v[0]=='[' and v[-1]==']':
|
| 84 |
+
kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
|
| 85 |
+
else:
|
| 86 |
+
kwargs[k] = str2value(v)
|
| 87 |
+
return functools.partial(f, **kwargs)
|
| 88 |
+
|
| 89 |
+
def register(name):
|
| 90 |
+
def wrapper(class_):
|
| 91 |
+
get_unit().register(name, class_)
|
| 92 |
+
return class_
|
| 93 |
+
return wrapper
|
| 94 |
+
|
| 95 |
+
class Sine(object):
|
| 96 |
+
def __init__(self, freq, gain=1):
|
| 97 |
+
self.freq = freq
|
| 98 |
+
self.gain = gain
|
| 99 |
+
self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
|
| 100 |
+
|
| 101 |
+
def __call__(self, x, gain=1):
|
| 102 |
+
act_gain = self.gain * gain
|
| 103 |
+
return torch.sin(self.freq * x) * act_gain
|
| 104 |
+
|
| 105 |
+
def __repr__(self,):
|
| 106 |
+
return self.repr
|
| 107 |
+
|
| 108 |
+
class ReLUSine(nn.Module):
|
| 109 |
+
def __init(self):
|
| 110 |
+
super().__init__()
|
| 111 |
+
|
| 112 |
+
def forward(self, input):
|
| 113 |
+
a = torch.sin(30 * input)
|
| 114 |
+
b = nn.ReLU(inplace=False)(input)
|
| 115 |
+
return a+b
|
| 116 |
+
|
| 117 |
+
@register('lrelu_agc')
|
| 118 |
+
# class lrelu_agc(nn.Module):
|
| 119 |
+
class lrelu_agc(object):
|
| 120 |
+
"""
|
| 121 |
+
The lrelu layer with alpha, gain and clamp
|
| 122 |
+
"""
|
| 123 |
+
def __init__(self, alpha=0.1, gain=1, clamp=None):
|
| 124 |
+
# super().__init__()
|
| 125 |
+
self.alpha = alpha
|
| 126 |
+
if gain == 'sqrt_2':
|
| 127 |
+
self.gain = np.sqrt(2)
|
| 128 |
+
else:
|
| 129 |
+
self.gain = gain
|
| 130 |
+
self.clamp = clamp
|
| 131 |
+
self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
|
| 132 |
+
alpha, gain, clamp)
|
| 133 |
+
|
| 134 |
+
# def forward(self, x, gain=1):
|
| 135 |
+
def __call__(self, x, gain=1):
|
| 136 |
+
x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
|
| 137 |
+
act_gain = self.gain * gain
|
| 138 |
+
act_clamp = self.clamp * gain if self.clamp is not None else None
|
| 139 |
+
if act_gain != 1:
|
| 140 |
+
x = x * act_gain
|
| 141 |
+
if act_clamp is not None:
|
| 142 |
+
x = x.clamp(-act_clamp, act_clamp)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def __repr__(self,):
|
| 146 |
+
return self.repr
|
| 147 |
+
|
| 148 |
+
####################
|
| 149 |
+
# spatial encoding #
|
| 150 |
+
####################
|
| 151 |
+
|
| 152 |
+
@register('se')
|
| 153 |
+
class SpatialEncoding(nn.Module):
|
| 154 |
+
def __init__(self,
|
| 155 |
+
in_dim,
|
| 156 |
+
out_dim,
|
| 157 |
+
sigma = 6,
|
| 158 |
+
cat_input=True,
|
| 159 |
+
require_grad=False,):
|
| 160 |
+
|
| 161 |
+
super().__init__()
|
| 162 |
+
assert out_dim % (2*in_dim) == 0, "dimension must be dividable"
|
| 163 |
+
|
| 164 |
+
n = out_dim // 2 // in_dim
|
| 165 |
+
m = 2**np.linspace(0, sigma, n)
|
| 166 |
+
m = np.stack([m] + [np.zeros_like(m)]*(in_dim-1), axis=-1)
|
| 167 |
+
m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
|
| 168 |
+
self.emb = torch.FloatTensor(m)
|
| 169 |
+
if require_grad:
|
| 170 |
+
self.emb = nn.Parameter(self.emb, requires_grad=True)
|
| 171 |
+
self.in_dim = in_dim
|
| 172 |
+
self.out_dim = out_dim
|
| 173 |
+
self.sigma = sigma
|
| 174 |
+
self.cat_input = cat_input
|
| 175 |
+
self.require_grad = require_grad
|
| 176 |
+
|
| 177 |
+
def forward(self, x, format='[n x c]'):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
x: [n x m1],
|
| 181 |
+
m1 usually is 2
|
| 182 |
+
Outputs:
|
| 183 |
+
y: [n x m2]
|
| 184 |
+
m2 dimention number
|
| 185 |
+
"""
|
| 186 |
+
if format == '[bs x c x 2D]':
|
| 187 |
+
xshape = x.shape
|
| 188 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 189 |
+
x = x.view(-1, x.size(-1))
|
| 190 |
+
elif format == '[n x c]':
|
| 191 |
+
pass
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError
|
| 194 |
+
|
| 195 |
+
if not self.require_grad:
|
| 196 |
+
self.emb = self.emb.to(x.device)
|
| 197 |
+
y = torch.mm(x, self.emb.T)
|
| 198 |
+
if self.cat_input:
|
| 199 |
+
z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
|
| 200 |
+
else:
|
| 201 |
+
z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
|
| 202 |
+
|
| 203 |
+
if format == '[bs x c x 2D]':
|
| 204 |
+
z = z.view(xshape[0], xshape[2], xshape[3], -1)
|
| 205 |
+
z = z.permute(0, 3, 1, 2).contiguous()
|
| 206 |
+
return z
|
| 207 |
+
|
| 208 |
+
def extra_repr(self):
|
| 209 |
+
outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
|
| 210 |
+
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
|
| 211 |
+
return outstr
|
| 212 |
+
|
| 213 |
+
@register('rffe')
|
| 214 |
+
class RFFEncoding(SpatialEncoding):
|
| 215 |
+
"""
|
| 216 |
+
Random Fourier Features
|
| 217 |
+
"""
|
| 218 |
+
def __init__(self,
|
| 219 |
+
in_dim,
|
| 220 |
+
out_dim,
|
| 221 |
+
sigma = 6,
|
| 222 |
+
cat_input=True,
|
| 223 |
+
require_grad=False,):
|
| 224 |
+
|
| 225 |
+
super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
|
| 226 |
+
n = out_dim // 2
|
| 227 |
+
m = np.random.normal(0, sigma, size=(n, in_dim))
|
| 228 |
+
self.emb = torch.FloatTensor(m)
|
| 229 |
+
if require_grad:
|
| 230 |
+
self.emb = nn.Parameter(self.emb, requires_grad=True)
|
| 231 |
+
|
| 232 |
+
def extra_repr(self):
|
| 233 |
+
outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
|
| 234 |
+
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
|
| 235 |
+
return outstr
|
| 236 |
+
|
| 237 |
+
##########
|
| 238 |
+
# helper #
|
| 239 |
+
##########
|
| 240 |
+
|
| 241 |
+
def freeze(net):
|
| 242 |
+
for m in net.modules():
|
| 243 |
+
if isinstance(m, (
|
| 244 |
+
nn.BatchNorm2d,
|
| 245 |
+
nn.SyncBatchNorm,)):
|
| 246 |
+
# inplace_abn not supported
|
| 247 |
+
m.eval()
|
| 248 |
+
for pi in net.parameters():
|
| 249 |
+
pi.requires_grad = False
|
| 250 |
+
return net
|
| 251 |
+
|
| 252 |
+
def common_init(m):
|
| 253 |
+
if isinstance(m, (
|
| 254 |
+
nn.Conv2d,
|
| 255 |
+
nn.ConvTranspose2d,)):
|
| 256 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 257 |
+
if m.bias is not None:
|
| 258 |
+
nn.init.constant_(m.bias, 0)
|
| 259 |
+
elif isinstance(m, (
|
| 260 |
+
nn.BatchNorm2d,
|
| 261 |
+
nn.SyncBatchNorm,)):
|
| 262 |
+
nn.init.constant_(m.weight, 1)
|
| 263 |
+
nn.init.constant_(m.bias, 0)
|
| 264 |
+
else:
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
def init_module(module):
|
| 268 |
+
"""
|
| 269 |
+
Args:
|
| 270 |
+
module: [nn.module] list or nn.module
|
| 271 |
+
a list of module to be initialized.
|
| 272 |
+
"""
|
| 273 |
+
if isinstance(module, (list, tuple)):
|
| 274 |
+
module = list(module)
|
| 275 |
+
else:
|
| 276 |
+
module = [module]
|
| 277 |
+
|
| 278 |
+
for mi in module:
|
| 279 |
+
for mii in mi.modules():
|
| 280 |
+
common_init(mii)
|
| 281 |
+
|
| 282 |
+
def get_total_param(net):
|
| 283 |
+
if getattr(net, 'parameters', None) is None:
|
| 284 |
+
return 0
|
| 285 |
+
return sum(p.numel() for p in net.parameters())
|
| 286 |
+
|
| 287 |
+
def get_total_param_sum(net):
|
| 288 |
+
if getattr(net, 'parameters', None) is None:
|
| 289 |
+
return 0
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
|
| 292 |
+
return s
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DDIMSampler(object):
|
| 12 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = model
|
| 15 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
| 16 |
+
self.schedule = schedule
|
| 17 |
+
|
| 18 |
+
def register_buffer(self, name, attr):
|
| 19 |
+
if type(attr) == torch.Tensor:
|
| 20 |
+
if attr.device != torch.device("cuda"):
|
| 21 |
+
attr = attr.to(torch.device("cuda"))
|
| 22 |
+
setattr(self, name, attr)
|
| 23 |
+
|
| 24 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
| 25 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize,
|
| 26 |
+
num_ddim_timesteps=ddim_num_steps,
|
| 27 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
| 28 |
+
verbose=verbose)
|
| 29 |
+
alphas_cumprod = self.model.alphas_cumprod
|
| 30 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
| 31 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
| 32 |
+
|
| 33 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
| 34 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| 35 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
| 36 |
+
|
| 37 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 38 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
| 39 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
| 40 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
| 41 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
| 42 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
| 43 |
+
|
| 44 |
+
# ddim sampling parameters
|
| 45 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
| 46 |
+
alphacums=alphas_cumprod.cpu(),
|
| 47 |
+
ddim_timesteps=self.ddim_timesteps,
|
| 48 |
+
eta=ddim_eta,verbose=verbose)
|
| 49 |
+
|
| 50 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
| 51 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
| 52 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
| 53 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
| 54 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
| 55 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
| 56 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
| 57 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def sample(self,
|
| 61 |
+
S,
|
| 62 |
+
batch_size,
|
| 63 |
+
shape,
|
| 64 |
+
conditioning=None,
|
| 65 |
+
callback=None,
|
| 66 |
+
normals_sequence=None,
|
| 67 |
+
img_callback=None,
|
| 68 |
+
quantize_x0=False,
|
| 69 |
+
eta=0.,
|
| 70 |
+
mask=None,
|
| 71 |
+
x0=None,
|
| 72 |
+
temperature=1.,
|
| 73 |
+
noise_dropout=0.,
|
| 74 |
+
score_corrector=None,
|
| 75 |
+
corrector_kwargs=None,
|
| 76 |
+
verbose=True,
|
| 77 |
+
x_T=None,
|
| 78 |
+
log_every_t=100,
|
| 79 |
+
unconditional_guidance_scale=1.,
|
| 80 |
+
unconditional_conditioning=None,
|
| 81 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 82 |
+
**kwargs
|
| 83 |
+
):
|
| 84 |
+
if conditioning is not None:
|
| 85 |
+
if isinstance(conditioning, dict):
|
| 86 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 87 |
+
if cbs != batch_size:
|
| 88 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 89 |
+
else:
|
| 90 |
+
if conditioning.shape[0] != batch_size:
|
| 91 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 92 |
+
|
| 93 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 94 |
+
# sampling
|
| 95 |
+
C, H, W = shape
|
| 96 |
+
size = (batch_size, C, H, W)
|
| 97 |
+
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
| 98 |
+
|
| 99 |
+
samples, intermediates = self.ddim_sampling(conditioning, size,
|
| 100 |
+
callback=callback,
|
| 101 |
+
img_callback=img_callback,
|
| 102 |
+
quantize_denoised=quantize_x0,
|
| 103 |
+
mask=mask, x0=x0,
|
| 104 |
+
ddim_use_original_steps=False,
|
| 105 |
+
noise_dropout=noise_dropout,
|
| 106 |
+
temperature=temperature,
|
| 107 |
+
score_corrector=score_corrector,
|
| 108 |
+
corrector_kwargs=corrector_kwargs,
|
| 109 |
+
x_T=x_T,
|
| 110 |
+
log_every_t=log_every_t,
|
| 111 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 112 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 113 |
+
)
|
| 114 |
+
return samples, intermediates
|
| 115 |
+
|
| 116 |
+
@torch.no_grad()
|
| 117 |
+
def ddim_sampling(self,
|
| 118 |
+
cond, shape,
|
| 119 |
+
x_T=None,
|
| 120 |
+
ddim_use_original_steps=False,
|
| 121 |
+
callback=None,
|
| 122 |
+
timesteps=None,
|
| 123 |
+
quantize_denoised=False,
|
| 124 |
+
mask=None, x0=None,
|
| 125 |
+
img_callback=None, log_every_t=100,
|
| 126 |
+
temperature=1.,
|
| 127 |
+
noise_dropout=0.,
|
| 128 |
+
score_corrector=None,
|
| 129 |
+
corrector_kwargs=None,
|
| 130 |
+
unconditional_guidance_scale=1.,
|
| 131 |
+
unconditional_conditioning=None,):
|
| 132 |
+
device = torch.device('cuda:1')
|
| 133 |
+
b = shape[0]
|
| 134 |
+
if x_T is None:
|
| 135 |
+
img = torch.randn(shape, device=device)
|
| 136 |
+
else:
|
| 137 |
+
img = x_T.cuda(1)
|
| 138 |
+
|
| 139 |
+
if timesteps is None:
|
| 140 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 141 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 142 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 143 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 144 |
+
|
| 145 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 146 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 147 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 148 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 149 |
+
|
| 150 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 151 |
+
|
| 152 |
+
for i, step in enumerate(iterator):
|
| 153 |
+
index = total_steps - i - 1
|
| 154 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 155 |
+
|
| 156 |
+
if mask is not None:
|
| 157 |
+
assert x0 is not None
|
| 158 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
| 159 |
+
img = img_orig * mask + (1. - mask) * img
|
| 160 |
+
|
| 161 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 162 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
| 163 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
| 164 |
+
corrector_kwargs=corrector_kwargs,
|
| 165 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 166 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 167 |
+
img, pred_x0 = outs
|
| 168 |
+
if callback: callback(i)
|
| 169 |
+
if img_callback: img_callback(pred_x0, i)
|
| 170 |
+
|
| 171 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 172 |
+
intermediates['x_inter'].append(img)
|
| 173 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 174 |
+
|
| 175 |
+
return img, intermediates
|
| 176 |
+
|
| 177 |
+
@torch.no_grad()
|
| 178 |
+
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
| 179 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
| 180 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
| 181 |
+
b, *_, device = *x.shape, 'cuda:1'
|
| 182 |
+
device = torch.device('cuda:1')
|
| 183 |
+
|
| 184 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 185 |
+
e_t = self.model.apply_model(x, t, c)
|
| 186 |
+
else:
|
| 187 |
+
x_in = torch.cat([x] * 2)
|
| 188 |
+
t_in = torch.cat([t] * 2)
|
| 189 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
| 190 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
| 191 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 192 |
+
|
| 193 |
+
if score_corrector is not None:
|
| 194 |
+
assert self.model.parameterization == "eps"
|
| 195 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
| 196 |
+
|
| 197 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 198 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 199 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 200 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 201 |
+
# select parameters corresponding to the currently considered timestep
|
| 202 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 203 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 204 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 205 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
| 206 |
+
print(sqrt_one_minus_at.device)
|
| 207 |
+
# current prediction for x_0
|
| 208 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 209 |
+
if quantize_denoised:
|
| 210 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
| 211 |
+
# direction pointing to x_t
|
| 212 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 213 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 214 |
+
if noise_dropout > 0.:
|
| 215 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 216 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 217 |
+
return x_prev, pred_x0
|
| 218 |
+
|
| 219 |
+
# XX-added for forward then backward
|
| 220 |
+
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def sample_fwdbwd(self,
|
| 223 |
+
S,
|
| 224 |
+
x0,
|
| 225 |
+
S_ref,
|
| 226 |
+
batch_size,
|
| 227 |
+
shape,
|
| 228 |
+
conditioning=None,
|
| 229 |
+
callback=None,
|
| 230 |
+
img_callback=None,
|
| 231 |
+
quantize_x0=False,
|
| 232 |
+
eta=0.,
|
| 233 |
+
temperature=1.,
|
| 234 |
+
noise_dropout=0.,
|
| 235 |
+
score_corrector=None,
|
| 236 |
+
corrector_kwargs=None,
|
| 237 |
+
verbose=True,
|
| 238 |
+
x_T=None,
|
| 239 |
+
log_every_t=100,
|
| 240 |
+
unconditional_guidance_scale=1.,
|
| 241 |
+
unconditional_conditioning=None,
|
| 242 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 243 |
+
**kwargs):
|
| 244 |
+
if conditioning is not None:
|
| 245 |
+
if isinstance(conditioning, dict):
|
| 246 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 247 |
+
if cbs != batch_size:
|
| 248 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 249 |
+
else:
|
| 250 |
+
if conditioning.shape[0] != batch_size:
|
| 251 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 252 |
+
|
| 253 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
| 254 |
+
# sampling
|
| 255 |
+
C, H, W = shape
|
| 256 |
+
size = (batch_size, C, H, W)
|
| 257 |
+
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
| 258 |
+
|
| 259 |
+
samples, intermediates = self.ddim_sampling_fwdbwd(
|
| 260 |
+
conditioning, size, x_0=x0, refstep=S_ref,
|
| 261 |
+
callback=callback,
|
| 262 |
+
img_callback=img_callback,
|
| 263 |
+
quantize_denoised=quantize_x0,
|
| 264 |
+
ddim_use_original_steps=False,
|
| 265 |
+
noise_dropout=noise_dropout,
|
| 266 |
+
temperature=temperature,
|
| 267 |
+
score_corrector=score_corrector,
|
| 268 |
+
corrector_kwargs=corrector_kwargs,
|
| 269 |
+
log_every_t=log_every_t,
|
| 270 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 271 |
+
unconditional_conditioning=unconditional_conditioning,)
|
| 272 |
+
|
| 273 |
+
return samples, intermediates
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@torch.no_grad()
|
| 277 |
+
def ddim_sampling_fwdbwd(self,
|
| 278 |
+
cond,
|
| 279 |
+
shape,
|
| 280 |
+
x_0,
|
| 281 |
+
refstep,
|
| 282 |
+
ddim_use_original_steps=False,
|
| 283 |
+
callback=None,
|
| 284 |
+
timesteps=None,
|
| 285 |
+
quantize_denoised=False,
|
| 286 |
+
img_callback=None,
|
| 287 |
+
log_every_t=100,
|
| 288 |
+
temperature=1.,
|
| 289 |
+
noise_dropout=0.,
|
| 290 |
+
score_corrector=None,
|
| 291 |
+
corrector_kwargs=None,
|
| 292 |
+
unconditional_guidance_scale=1.,
|
| 293 |
+
unconditional_conditioning=None, ):
|
| 294 |
+
'''
|
| 295 |
+
A function that forward diffuse x_0 to a reference step
|
| 296 |
+
(reference step number is correlated with the timesteps
|
| 297 |
+
so the real time steps are timesteps[0:refstep] )
|
| 298 |
+
\i.e. x_0 -> x_k -> x_0'; k \in {0 .. len(timesteps)}
|
| 299 |
+
'''
|
| 300 |
+
|
| 301 |
+
device = self.model.betas.device
|
| 302 |
+
b = shape[0]
|
| 303 |
+
|
| 304 |
+
if timesteps is None:
|
| 305 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 306 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 307 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 308 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 309 |
+
|
| 310 |
+
timesteps = timesteps[0:refstep]
|
| 311 |
+
|
| 312 |
+
# forward diffusion
|
| 313 |
+
t = torch.full((b,), timesteps[-1], device=device, dtype=torch.long)
|
| 314 |
+
x_noisy = self.model.q_sample(x_start=x_0, t=t)
|
| 315 |
+
img = x_noisy
|
| 316 |
+
|
| 317 |
+
intermediates = {'x_inter': [x_noisy], 'pred_x0': [x_0]}
|
| 318 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 319 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 320 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 321 |
+
|
| 322 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 323 |
+
|
| 324 |
+
for i, step in enumerate(iterator):
|
| 325 |
+
index = total_steps - i - 1
|
| 326 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
| 327 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 328 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
| 329 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
| 330 |
+
corrector_kwargs=corrector_kwargs,
|
| 331 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 332 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 333 |
+
img, pred_x0 = outs
|
| 334 |
+
if callback: callback(i)
|
| 335 |
+
if img_callback: img_callback(pred_x0, i)
|
| 336 |
+
|
| 337 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 338 |
+
intermediates['x_inter'].append(img)
|
| 339 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 340 |
+
|
| 341 |
+
return img, intermediates
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim_dualcontext.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 7 |
+
|
| 8 |
+
from .ddim import DDIMSampler
|
| 9 |
+
|
| 10 |
+
class DDIMSampler_DualContext(DDIMSampler):
|
| 11 |
+
@torch.no_grad()
|
| 12 |
+
def sample_text(self, *args, **kwargs):
|
| 13 |
+
self.cond_type = 'prompt'
|
| 14 |
+
return self.sample(*args, **kwargs)
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def sample_vision(self, *args, **kwargs):
|
| 18 |
+
self.cond_type = 'vision'
|
| 19 |
+
return self.sample(*args, **kwargs)
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def sample_mixed(self, *args, **kwargs):
|
| 23 |
+
self.cond_type = kwargs.pop('cond_mixed_p')
|
| 24 |
+
return self.sample(*args, **kwargs)
|
| 25 |
+
|
| 26 |
+
@torch.no_grad()
|
| 27 |
+
def sample(self,
|
| 28 |
+
steps,
|
| 29 |
+
shape,
|
| 30 |
+
xt=None,
|
| 31 |
+
conditioning=None,
|
| 32 |
+
eta=0.,
|
| 33 |
+
temperature=1.,
|
| 34 |
+
noise_dropout=0.,
|
| 35 |
+
verbose=True,
|
| 36 |
+
log_every_t=100,
|
| 37 |
+
unconditional_guidance_scale=1.,
|
| 38 |
+
unconditional_conditioning=None,):
|
| 39 |
+
|
| 40 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
| 41 |
+
# sampling
|
| 42 |
+
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
| 43 |
+
|
| 44 |
+
samples, intermediates = self.ddim_sampling(
|
| 45 |
+
conditioning,
|
| 46 |
+
shape,
|
| 47 |
+
xt=xt,
|
| 48 |
+
ddim_use_original_steps=False,
|
| 49 |
+
noise_dropout=noise_dropout,
|
| 50 |
+
temperature=temperature,
|
| 51 |
+
log_every_t=log_every_t,
|
| 52 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 53 |
+
unconditional_conditioning=unconditional_conditioning,)
|
| 54 |
+
return samples, intermediates
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def ddim_sampling(self,
|
| 58 |
+
conditioning,
|
| 59 |
+
shape,
|
| 60 |
+
xt=None,
|
| 61 |
+
ddim_use_original_steps=False,
|
| 62 |
+
timesteps=None,
|
| 63 |
+
log_every_t=100,
|
| 64 |
+
temperature=1.,
|
| 65 |
+
noise_dropout=0.,
|
| 66 |
+
unconditional_guidance_scale=1.,
|
| 67 |
+
unconditional_conditioning=None,):
|
| 68 |
+
device = self.model.betas.device
|
| 69 |
+
bs = shape[0]
|
| 70 |
+
if xt is None:
|
| 71 |
+
img = torch.randn(shape, device=device)
|
| 72 |
+
else:
|
| 73 |
+
img = xt
|
| 74 |
+
|
| 75 |
+
if timesteps is None:
|
| 76 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 77 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 78 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 79 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 80 |
+
|
| 81 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 82 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 83 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 84 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 85 |
+
|
| 86 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 87 |
+
|
| 88 |
+
for i, step in enumerate(iterator):
|
| 89 |
+
index = total_steps - i - 1
|
| 90 |
+
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
| 91 |
+
|
| 92 |
+
outs = self.p_sample_ddim(img, conditioning, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 93 |
+
temperature=temperature,
|
| 94 |
+
noise_dropout=noise_dropout,
|
| 95 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 96 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 97 |
+
img, pred_x0 = outs
|
| 98 |
+
|
| 99 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 100 |
+
intermediates['x_inter'].append(img)
|
| 101 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 102 |
+
|
| 103 |
+
return img, intermediates
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
def p_sample_ddim(self, x, conditioning, t, index, repeat_noise=False, use_original_steps=False,
|
| 107 |
+
temperature=1., noise_dropout=0.,
|
| 108 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
| 109 |
+
b, *_, device = *x.shape, x.device
|
| 110 |
+
|
| 111 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 112 |
+
e_t = self.model.apply_model(x, t, conditioning, cond_type=self.cond_type)
|
| 113 |
+
else:
|
| 114 |
+
x_in = torch.cat([x] * 2)
|
| 115 |
+
t_in = torch.cat([t] * 2)
|
| 116 |
+
# c_in = torch.cat([unconditional_conditioning, conditioning])
|
| 117 |
+
|
| 118 |
+
# Added for vd-dc dual guidance
|
| 119 |
+
if isinstance(unconditional_conditioning, list):
|
| 120 |
+
c_in = [torch.cat([ui, ci]) for ui, ci in zip(unconditional_conditioning, conditioning)]
|
| 121 |
+
else:
|
| 122 |
+
c_in = torch.cat([unconditional_conditioning, conditioning])
|
| 123 |
+
|
| 124 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, cond_type=self.cond_type).chunk(2)
|
| 125 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 126 |
+
|
| 127 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 128 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 129 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 130 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 131 |
+
# select parameters corresponding to the currently considered timestep
|
| 132 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 133 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 134 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 135 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
| 136 |
+
|
| 137 |
+
# current prediction for x_0
|
| 138 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 139 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 140 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 141 |
+
if noise_dropout > 0.:
|
| 142 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 143 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 144 |
+
return x_prev, pred_x0
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim_dualmodel.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 7 |
+
|
| 8 |
+
from .ddim import DDIMSampler
|
| 9 |
+
|
| 10 |
+
class DDIMSampler_DualModel(DDIMSampler):
|
| 11 |
+
def __init__(self, model_t2i, model_v2i, schedule="linear", **kwargs):
|
| 12 |
+
self.model = model_t2i
|
| 13 |
+
self.model_t2i = model_t2i
|
| 14 |
+
self.model_v2i = model_v2i
|
| 15 |
+
self.device = self.model_t2i.device
|
| 16 |
+
self.ddpm_num_timesteps = model_t2i.num_timesteps
|
| 17 |
+
self.schedule = schedule
|
| 18 |
+
|
| 19 |
+
@torch.no_grad()
|
| 20 |
+
def sample_text(self, *args, **kwargs):
|
| 21 |
+
self.cond_type = 'prompt'
|
| 22 |
+
self.p_sample_model_type = 't2i'
|
| 23 |
+
return self.sample(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
@torch.no_grad()
|
| 26 |
+
def sample_vision(self, *args, **kwargs):
|
| 27 |
+
self.cond_type = 'vision'
|
| 28 |
+
self.p_sample_model_type = 'v2i'
|
| 29 |
+
return self.sample(*args, **kwargs)
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def sample(self,
|
| 33 |
+
steps,
|
| 34 |
+
shape,
|
| 35 |
+
xt=None,
|
| 36 |
+
conditioning=None,
|
| 37 |
+
eta=0.,
|
| 38 |
+
temperature=1.,
|
| 39 |
+
noise_dropout=0.,
|
| 40 |
+
verbose=True,
|
| 41 |
+
log_every_t=100,
|
| 42 |
+
unconditional_guidance_scale=1.,
|
| 43 |
+
unconditional_conditioning=None,):
|
| 44 |
+
|
| 45 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
| 46 |
+
# sampling
|
| 47 |
+
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
| 48 |
+
|
| 49 |
+
samples, intermediates = self.ddim_sampling(
|
| 50 |
+
conditioning,
|
| 51 |
+
shape,
|
| 52 |
+
xt=xt,
|
| 53 |
+
ddim_use_original_steps=False,
|
| 54 |
+
noise_dropout=noise_dropout,
|
| 55 |
+
temperature=temperature,
|
| 56 |
+
log_every_t=log_every_t,
|
| 57 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 58 |
+
unconditional_conditioning=unconditional_conditioning,)
|
| 59 |
+
return samples, intermediates
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def ddim_sampling(self,
|
| 63 |
+
conditioning,
|
| 64 |
+
shape,
|
| 65 |
+
xt=None,
|
| 66 |
+
ddim_use_original_steps=False,
|
| 67 |
+
timesteps=None,
|
| 68 |
+
log_every_t=100,
|
| 69 |
+
temperature=1.,
|
| 70 |
+
noise_dropout=0.,
|
| 71 |
+
unconditional_guidance_scale=1.,
|
| 72 |
+
unconditional_conditioning=None,):
|
| 73 |
+
device = self.model.betas.device
|
| 74 |
+
bs = shape[0]
|
| 75 |
+
if xt is None:
|
| 76 |
+
img = torch.randn(shape, device=device)
|
| 77 |
+
else:
|
| 78 |
+
img = xt
|
| 79 |
+
|
| 80 |
+
if timesteps is None:
|
| 81 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 82 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 83 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 84 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 85 |
+
|
| 86 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
| 87 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 88 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 89 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 90 |
+
|
| 91 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 92 |
+
|
| 93 |
+
for i, step in enumerate(iterator):
|
| 94 |
+
index = total_steps - i - 1
|
| 95 |
+
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
| 96 |
+
|
| 97 |
+
outs = self.p_sample_ddim(img, conditioning, ts, index=index, use_original_steps=ddim_use_original_steps,
|
| 98 |
+
temperature=temperature,
|
| 99 |
+
noise_dropout=noise_dropout,
|
| 100 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 101 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 102 |
+
img, pred_x0 = outs
|
| 103 |
+
|
| 104 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 105 |
+
intermediates['x_inter'].append(img)
|
| 106 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 107 |
+
|
| 108 |
+
return img, intermediates
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def p_sample_ddim(self, x, conditioning, t, index, repeat_noise=False, use_original_steps=False,
|
| 112 |
+
temperature=1., noise_dropout=0.,
|
| 113 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
| 114 |
+
b, *_, device = *x.shape, x.device
|
| 115 |
+
|
| 116 |
+
if self.p_sample_model_type == 't2i':
|
| 117 |
+
apply_model = self.model_t2i.apply_model
|
| 118 |
+
elif self.p_sample_model_type == 'v2i':
|
| 119 |
+
apply_model = self.model_v2i.apply_model
|
| 120 |
+
|
| 121 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 122 |
+
e_t = apply_model(x, t, conditioning)
|
| 123 |
+
else:
|
| 124 |
+
x_in = torch.cat([x] * 2)
|
| 125 |
+
t_in = torch.cat([t] * 2)
|
| 126 |
+
c_in = torch.cat([unconditional_conditioning, conditioning])
|
| 127 |
+
e_t_uncond, e_t = apply_model(x_in, t_in, c_in).chunk(2)
|
| 128 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 129 |
+
|
| 130 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 131 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 132 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 133 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 134 |
+
# select parameters corresponding to the currently considered timestep
|
| 135 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
| 136 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
| 137 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
| 138 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
| 139 |
+
|
| 140 |
+
# current prediction for x_0
|
| 141 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 142 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 143 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 144 |
+
if noise_dropout > 0.:
|
| 145 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 146 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 147 |
+
return x_prev, pred_x0
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def sample_mixed(self,
|
| 151 |
+
steps,
|
| 152 |
+
steps_t2i,
|
| 153 |
+
steps_v2i,
|
| 154 |
+
shape,
|
| 155 |
+
xt=None,
|
| 156 |
+
c_prompt=None,
|
| 157 |
+
c_vision=None,
|
| 158 |
+
eta=0.,
|
| 159 |
+
temperature=1.,
|
| 160 |
+
noise_dropout=0.,
|
| 161 |
+
verbose=True,
|
| 162 |
+
log_every_t=100,
|
| 163 |
+
uc_scale=1.,
|
| 164 |
+
uc_prompt=None,
|
| 165 |
+
uc_vision=None,):
|
| 166 |
+
|
| 167 |
+
print(f'DDIM mixed sampling with shape {shape}, eta {eta}')
|
| 168 |
+
print(f'steps_t2i {steps_t2i}')
|
| 169 |
+
print(f'steps_v2i {steps_v2i}')
|
| 170 |
+
|
| 171 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
| 172 |
+
self.ddim_timesteps_t2i = self.ddim_timesteps[steps_t2i]
|
| 173 |
+
self.ddim_timesteps_v2i = self.ddim_timesteps[steps_v2i]
|
| 174 |
+
|
| 175 |
+
samples, intermediates = self.ddim_sampling_mixed(
|
| 176 |
+
c_prompt,
|
| 177 |
+
c_vision,
|
| 178 |
+
shape,
|
| 179 |
+
xt=xt,
|
| 180 |
+
noise_dropout=noise_dropout,
|
| 181 |
+
temperature=temperature,
|
| 182 |
+
log_every_t=log_every_t,
|
| 183 |
+
uc_scale=uc_scale,
|
| 184 |
+
uc_prompt=uc_prompt,
|
| 185 |
+
uc_vision=uc_vision, )
|
| 186 |
+
return samples, intermediates
|
| 187 |
+
|
| 188 |
+
@torch.no_grad()
|
| 189 |
+
def ddim_sampling_mixed(self,
|
| 190 |
+
c_prompt,
|
| 191 |
+
c_vision,
|
| 192 |
+
shape,
|
| 193 |
+
xt=None,
|
| 194 |
+
log_every_t=100,
|
| 195 |
+
temperature=1.,
|
| 196 |
+
noise_dropout=0.,
|
| 197 |
+
uc_scale=1.,
|
| 198 |
+
uc_prompt=None,
|
| 199 |
+
uc_vision=None, ):
|
| 200 |
+
device = self.device
|
| 201 |
+
bs = shape[0]
|
| 202 |
+
if xt is None:
|
| 203 |
+
img = torch.randn(shape, device=device)
|
| 204 |
+
else:
|
| 205 |
+
img = xt
|
| 206 |
+
|
| 207 |
+
timesteps = self.ddim_timesteps
|
| 208 |
+
intermediates = {'x_inter': [], 'pred_x0': []}
|
| 209 |
+
time_range = np.flip(timesteps)
|
| 210 |
+
total_steps = timesteps.shape[0]
|
| 211 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 212 |
+
|
| 213 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 214 |
+
|
| 215 |
+
for i, step in enumerate(iterator):
|
| 216 |
+
if step in self.ddim_timesteps_t2i:
|
| 217 |
+
self.p_sample_model_type = 't2i'
|
| 218 |
+
conditioning = c_prompt
|
| 219 |
+
unconditional_conditioning = uc_prompt
|
| 220 |
+
elif step in self.ddim_timesteps_v2i:
|
| 221 |
+
self.p_sample_model_type = 'v2i'
|
| 222 |
+
conditioning = c_vision
|
| 223 |
+
unconditional_conditioning = uc_vision
|
| 224 |
+
else:
|
| 225 |
+
raise ValueError # shouldn't reached
|
| 226 |
+
|
| 227 |
+
index = total_steps - i - 1
|
| 228 |
+
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
| 229 |
+
outs = self.p_sample_ddim(
|
| 230 |
+
img, conditioning, ts,
|
| 231 |
+
index=index,
|
| 232 |
+
temperature=temperature,
|
| 233 |
+
noise_dropout=noise_dropout,
|
| 234 |
+
unconditional_guidance_scale=uc_scale,
|
| 235 |
+
unconditional_conditioning=unconditional_conditioning)
|
| 236 |
+
img, pred_x0 = outs
|
| 237 |
+
|
| 238 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 239 |
+
intermediates['x_inter'].append(img)
|
| 240 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 241 |
+
|
| 242 |
+
return img, intermediates
|
| 243 |
+
|
| 244 |
+
|
versatile_diffusion/log/sd_nodataset/99999_evalonly/sd_variation/code/lib/model_zoo/ddim_vd.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
| 7 |
+
|
| 8 |
+
from .ddim import DDIMSampler
|
| 9 |
+
|
| 10 |
+
class DDIMSampler_VD(DDIMSampler):
|
| 11 |
+
@torch.no_grad()
|
| 12 |
+
def sample(self,
|
| 13 |
+
steps,
|
| 14 |
+
shape,
|
| 15 |
+
xt=None,
|
| 16 |
+
conditioning=None,
|
| 17 |
+
unconditional_guidance_scale=1.,
|
| 18 |
+
unconditional_conditioning=None,
|
| 19 |
+
xtype='image',
|
| 20 |
+
ctype='prompt',
|
| 21 |
+
eta=0.,
|
| 22 |
+
temperature=1.,
|
| 23 |
+
noise_dropout=0.,
|
| 24 |
+
verbose=True,
|
| 25 |
+
log_every_t=100,):
|
| 26 |
+
|
| 27 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
| 28 |
+
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
| 29 |
+
samples, intermediates = self.ddim_sampling(
|
| 30 |
+
shape,
|
| 31 |
+
xt=xt,
|
| 32 |
+
conditioning=conditioning,
|
| 33 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 34 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 35 |
+
xtype=xtype,
|
| 36 |
+
ctype=ctype,
|
| 37 |
+
ddim_use_original_steps=False,
|
| 38 |
+
noise_dropout=noise_dropout,
|
| 39 |
+
temperature=temperature,
|
| 40 |
+
log_every_t=log_every_t,)
|
| 41 |
+
return samples, intermediates
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def ddim_sampling(self,
|
| 45 |
+
shape,
|
| 46 |
+
xt=None,
|
| 47 |
+
conditioning=None,
|
| 48 |
+
unconditional_guidance_scale=1.,
|
| 49 |
+
unconditional_conditioning=None,
|
| 50 |
+
xtype='image',
|
| 51 |
+
ctype='prompt',
|
| 52 |
+
ddim_use_original_steps=False,
|
| 53 |
+
timesteps=None,
|
| 54 |
+
noise_dropout=0.,
|
| 55 |
+
temperature=1.,
|
| 56 |
+
log_every_t=100,):
|
| 57 |
+
|
| 58 |
+
device = 1
|
| 59 |
+
bs = shape[0]
|
| 60 |
+
if xt is None:
|
| 61 |
+
xt = torch.randn(shape, device=device)
|
| 62 |
+
|
| 63 |
+
if timesteps is None:
|
| 64 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 65 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 66 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 67 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 68 |
+
|
| 69 |
+
intermediates = {'pred_xt': [], 'pred_x0': []}
|
| 70 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 71 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 72 |
+
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 73 |
+
|
| 74 |
+
pred_xt = xt
|
| 75 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 76 |
+
for i, step in enumerate(iterator):
|
| 77 |
+
index = total_steps - i - 1
|
| 78 |
+
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
| 79 |
+
|
| 80 |
+
outs = self.p_sample_ddim(
|
| 81 |
+
pred_xt, conditioning, ts, index,
|
| 82 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 83 |
+
unconditional_conditioning=unconditional_conditioning,
|
| 84 |
+
xtype=xtype,
|
| 85 |
+
ctype=ctype,
|
| 86 |
+
use_original_steps=ddim_use_original_steps,
|
| 87 |
+
noise_dropout=noise_dropout,
|
| 88 |
+
temperature=temperature,)
|
| 89 |
+
pred_xt, pred_x0 = outs
|
| 90 |
+
|
| 91 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 92 |
+
intermediates['pred_xt'].append(pred_xt)
|
| 93 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 94 |
+
|
| 95 |
+
return pred_xt, intermediates
|
| 96 |
+
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def p_sample_ddim(self, x, conditioning, t, index,
|
| 99 |
+
unconditional_guidance_scale=1.,
|
| 100 |
+
unconditional_conditioning=None,
|
| 101 |
+
xtype='image',
|
| 102 |
+
ctype='prompt',
|
| 103 |
+
repeat_noise=False,
|
| 104 |
+
use_original_steps=False,
|
| 105 |
+
noise_dropout=0.,
|
| 106 |
+
temperature=1.,):
|
| 107 |
+
|
| 108 |
+
b, *_, device = *x.shape, x.device
|
| 109 |
+
|
| 110 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
| 111 |
+
e_t = self.model.apply_model(x, t, conditioning, xtype=xtype, ctype=ctype)
|
| 112 |
+
else:
|
| 113 |
+
x_in = torch.cat([x] * 2)
|
| 114 |
+
t_in = torch.cat([t] * 2)
|
| 115 |
+
c_in = torch.cat([unconditional_conditioning, conditioning])
|
| 116 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, xtype=xtype, ctype=ctype).chunk(2)
|
| 117 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 118 |
+
|
| 119 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 120 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 121 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 122 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 123 |
+
# select parameters corresponding to the currently considered timestep
|
| 124 |
+
|
| 125 |
+
if xtype == 'image':
|
| 126 |
+
extended_shape = (b, 1, 1, 1)
|
| 127 |
+
elif xtype == 'text':
|
| 128 |
+
extended_shape = (b, 1)
|
| 129 |
+
|
| 130 |
+
a_t = torch.full(extended_shape, alphas[index], device=device)
|
| 131 |
+
a_prev = torch.full(extended_shape, alphas_prev[index], device=device)
|
| 132 |
+
sigma_t = torch.full(extended_shape, sigmas[index], device=device)
|
| 133 |
+
sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index],device=device)
|
| 134 |
+
|
| 135 |
+
# current prediction for x_0
|
| 136 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 137 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 138 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 139 |
+
if noise_dropout > 0.:
|
| 140 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 141 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 142 |
+
return x_prev, pred_x0
|
| 143 |
+
|
| 144 |
+
class DDIMSampler_VD_DualContext(DDIMSampler_VD):
|
| 145 |
+
@torch.no_grad()
|
| 146 |
+
def sample_dc(self,
|
| 147 |
+
steps,
|
| 148 |
+
shape,
|
| 149 |
+
xt=None,
|
| 150 |
+
first_conditioning=None,
|
| 151 |
+
second_conditioning=None,
|
| 152 |
+
unconditional_guidance_scale=1.,
|
| 153 |
+
xtype='image',
|
| 154 |
+
first_ctype='prompt',
|
| 155 |
+
second_ctype='prompt',
|
| 156 |
+
eta=0.,
|
| 157 |
+
temperature=1.,
|
| 158 |
+
mixed_ratio=0.5,
|
| 159 |
+
noise_dropout=0.,
|
| 160 |
+
verbose=True,
|
| 161 |
+
log_every_t=100,):
|
| 162 |
+
|
| 163 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
| 164 |
+
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
| 165 |
+
samples, intermediates = self.ddim_sampling_dc(
|
| 166 |
+
shape,
|
| 167 |
+
xt=xt,
|
| 168 |
+
first_conditioning=first_conditioning,
|
| 169 |
+
second_conditioning=second_conditioning,
|
| 170 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 171 |
+
xtype=xtype,
|
| 172 |
+
first_ctype=first_ctype,
|
| 173 |
+
second_ctype=second_ctype,
|
| 174 |
+
ddim_use_original_steps=False,
|
| 175 |
+
noise_dropout=noise_dropout,
|
| 176 |
+
temperature=temperature,
|
| 177 |
+
log_every_t=log_every_t,
|
| 178 |
+
mixed_ratio=mixed_ratio, )
|
| 179 |
+
return samples, intermediates
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
def ddim_sampling_dc(self,
|
| 183 |
+
shape,
|
| 184 |
+
xt=None,
|
| 185 |
+
first_conditioning=None,
|
| 186 |
+
second_conditioning=None,
|
| 187 |
+
unconditional_guidance_scale=1.,
|
| 188 |
+
xtype='image',
|
| 189 |
+
first_ctype='prompt',
|
| 190 |
+
second_ctype='prompt',
|
| 191 |
+
ddim_use_original_steps=False,
|
| 192 |
+
timesteps=None,
|
| 193 |
+
noise_dropout=0.,
|
| 194 |
+
temperature=1.,
|
| 195 |
+
mixed_ratio=0.5,
|
| 196 |
+
log_every_t=100,):
|
| 197 |
+
|
| 198 |
+
device = self.model.device
|
| 199 |
+
bs = shape[0]
|
| 200 |
+
if xt is None:
|
| 201 |
+
xt = torch.randn(shape, device=device)
|
| 202 |
+
|
| 203 |
+
if timesteps is None:
|
| 204 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
| 205 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
| 206 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
| 207 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
| 208 |
+
|
| 209 |
+
intermediates = {'pred_xt': [], 'pred_x0': []}
|
| 210 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
| 211 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
| 212 |
+
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
| 213 |
+
|
| 214 |
+
pred_xt = xt
|
| 215 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
| 216 |
+
for i, step in enumerate(iterator):
|
| 217 |
+
index = total_steps - i - 1
|
| 218 |
+
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
| 219 |
+
|
| 220 |
+
outs = self.p_sample_ddim_dc(
|
| 221 |
+
pred_xt,
|
| 222 |
+
first_conditioning,
|
| 223 |
+
second_conditioning,
|
| 224 |
+
ts, index,
|
| 225 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 226 |
+
xtype=xtype,
|
| 227 |
+
first_ctype=first_ctype,
|
| 228 |
+
second_ctype=second_ctype,
|
| 229 |
+
use_original_steps=ddim_use_original_steps,
|
| 230 |
+
noise_dropout=noise_dropout,
|
| 231 |
+
temperature=temperature,
|
| 232 |
+
mixed_ratio=mixed_ratio,)
|
| 233 |
+
pred_xt, pred_x0 = outs
|
| 234 |
+
|
| 235 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
| 236 |
+
intermediates['pred_xt'].append(pred_xt)
|
| 237 |
+
intermediates['pred_x0'].append(pred_x0)
|
| 238 |
+
|
| 239 |
+
return pred_xt, intermediates
|
| 240 |
+
|
| 241 |
+
@torch.no_grad()
|
| 242 |
+
def p_sample_ddim_dc(self, x,
|
| 243 |
+
first_conditioning,
|
| 244 |
+
second_conditioning,
|
| 245 |
+
t, index,
|
| 246 |
+
unconditional_guidance_scale=1.,
|
| 247 |
+
xtype='image',
|
| 248 |
+
first_ctype='prompt',
|
| 249 |
+
second_ctype='prompt',
|
| 250 |
+
repeat_noise=False,
|
| 251 |
+
use_original_steps=False,
|
| 252 |
+
noise_dropout=0.,
|
| 253 |
+
temperature=1.,
|
| 254 |
+
mixed_ratio=0.5,):
|
| 255 |
+
|
| 256 |
+
b, *_, device = *x.shape, x.device
|
| 257 |
+
|
| 258 |
+
x_in = torch.cat([x] * 2)
|
| 259 |
+
t_in = torch.cat([t] * 2)
|
| 260 |
+
first_c = torch.cat(first_conditioning)
|
| 261 |
+
second_c = torch.cat(second_conditioning)
|
| 262 |
+
|
| 263 |
+
e_t_uncond, e_t = self.model.apply_model_dc(
|
| 264 |
+
x_in, t_in, first_c, second_c, xtype=xtype, first_ctype=first_ctype, second_ctype=second_ctype, mixed_ratio=mixed_ratio).chunk(2)
|
| 265 |
+
|
| 266 |
+
# e_t_uncond, e_t = self.model.apply_model(x_in, t_in, first_c, xtype='image', ctype='vision').chunk(2)
|
| 267 |
+
|
| 268 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
| 269 |
+
|
| 270 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
| 271 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
| 272 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
| 273 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
| 274 |
+
# select parameters corresponding to the currently considered timestep
|
| 275 |
+
|
| 276 |
+
if xtype == 'image':
|
| 277 |
+
extended_shape = (b, 1, 1, 1)
|
| 278 |
+
elif xtype == 'text':
|
| 279 |
+
extended_shape = (b, 1)
|
| 280 |
+
|
| 281 |
+
a_t = torch.full(extended_shape, alphas[index], device=device)
|
| 282 |
+
a_prev = torch.full(extended_shape, alphas_prev[index], device=device)
|
| 283 |
+
sigma_t = torch.full(extended_shape, sigmas[index], device=device)
|
| 284 |
+
sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index],device=device)
|
| 285 |
+
|
| 286 |
+
# current prediction for x_0
|
| 287 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
| 288 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
| 289 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
| 290 |
+
if noise_dropout > 0.:
|
| 291 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
| 292 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
| 293 |
+
return x_prev, pred_x0
|