jiaxi2002 commited on
Commit
feb33a0
·
verified ·
1 Parent(s): 4130807

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .github/workflows/logo.gif +3 -0
  3. .github/workflows/publish.yaml +29 -0
  4. .gitignore +8 -0
  5. .vscode/launch.json +36 -0
  6. FLUX.1-Kontext-dev.py +68 -0
  7. LICENSE +201 -0
  8. README.md +522 -0
  9. README_zh.md +538 -0
  10. _temp_.py +184 -0
  11. apps/gradio/DiffSynth_Studio.py +252 -0
  12. apps/gradio/entity_level_control.py +390 -0
  13. apps/gradio/qwen_image_eligen.py +382 -0
  14. apps/streamlit/DiffSynth_Studio.py +15 -0
  15. apps/streamlit/pages/1_Image_Creator.py +362 -0
  16. apps/streamlit/pages/2_Video_Creator.py +197 -0
  17. deal1.py +82 -0
  18. deal2.py +127 -0
  19. diffsynth.egg-info/PKG-INFO +31 -0
  20. diffsynth.egg-info/SOURCES.txt +247 -0
  21. diffsynth.egg-info/dependency_links.txt +1 -0
  22. diffsynth.egg-info/requires.txt +15 -0
  23. diffsynth.egg-info/top_level.txt +1 -0
  24. diffsynth/__init__.py +6 -0
  25. diffsynth/configs/__init__.py +0 -0
  26. diffsynth/configs/model_config.py +859 -0
  27. diffsynth/controlnets/__init__.py +2 -0
  28. diffsynth/controlnets/controlnet_unit.py +91 -0
  29. diffsynth/controlnets/processors.py +62 -0
  30. diffsynth/data/__init__.py +1 -0
  31. diffsynth/data/simple_text_image.py +41 -0
  32. diffsynth/data/video.py +217 -0
  33. diffsynth/distributed/__init__.py +0 -0
  34. diffsynth/distributed/xdit_context_parallel.py +131 -0
  35. diffsynth/extensions/ESRGAN/__init__.py +137 -0
  36. diffsynth/extensions/FastBlend/__init__.py +63 -0
  37. diffsynth/extensions/FastBlend/api.py +397 -0
  38. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  39. diffsynth/extensions/FastBlend/data.py +146 -0
  40. diffsynth/extensions/FastBlend/patch_match.py +299 -0
  41. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  42. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  43. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  44. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  45. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  46. diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py +1 -0
  47. diffsynth/extensions/ImageQualityMetric/BLIP/blip.py +77 -0
  48. diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py +44 -0
  49. diffsynth/extensions/ImageQualityMetric/BLIP/med.py +947 -0
  50. diffsynth/extensions/ImageQualityMetric/BLIP/vit.py +301 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .github/workflows/logo.gif filter=lfs diff=lfs merge=lfs -text
37
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
.github/workflows/logo.gif ADDED

Git LFS Details

  • SHA256: 36a7627b7f0f0a508ec64aba72e5d95d38dfe7958bd8cf42d2a63f6ac2641529
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
.github/workflows/publish.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v**'
7
+
8
+ concurrency:
9
+ group: ${{ github.workflow }}-${{ github.ref }}-publish
10
+ cancel-in-progress: true
11
+
12
+ jobs:
13
+ build-n-publish:
14
+ runs-on: ubuntu-20.04
15
+ #if: startsWith(github.event.ref, 'refs/tags')
16
+ steps:
17
+ - uses: actions/checkout@v2
18
+ - name: Set up Python 3.10
19
+ uses: actions/setup-python@v2
20
+ with:
21
+ python-version: '3.10'
22
+ - name: Install wheel
23
+ run: pip install wheel==0.44.0 && pip install -r requirements.txt
24
+ - name: Build DiffSynth
25
+ run: python setup.py sdist bdist_wheel
26
+ - name: Publish package to PyPI
27
+ run: |
28
+ pip install twine
29
+ twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *.png
2
+ *pycache*
3
+ *.safetensors
4
+ *.ckpt
5
+ models/
6
+ *.log
7
+ *.html
8
+ *.jpg
.vscode/launch.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "KontextT2I 1201 debug",
6
+ "type": "python",
7
+ "request": "launch",
8
+ "program": "examples/flux/model_training/train.py", // 关键:指定 accelerate 可执行文件路径
9
+ "args": [
10
+ "--dataset_base_path", "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/multi_frame",
11
+ "--dataset_metadata_path", "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/multi_frame/pairs.txt",
12
+ "--data_file_keys", "image,prompt",
13
+ "--max_pixels", "1048576",
14
+ "--dataset_repeat", "400",
15
+ "--model_id_with_origin_paths", "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-Kontext-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-Kontext-dev:text_encoder_2/,black-forest-labs/FLUX.1-Kontext-dev:ae.safetensors",
16
+ "--learning_rate", "1e-5",
17
+ "--num_epochs", "5",
18
+ "--remove_prefix_in_ckpt", "pipe.dit.",
19
+ "--output_path", "./models/train/FLUX.1_lora_1127_mbti",
20
+ "--lora_base_model", "dit",
21
+ "--lora_target_modules", "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
22
+ "--lora_rank", "32",
23
+ "--align_to_opensource_format",
24
+ "--use_gradient_checkpointing",
25
+ "--default_caption", "Convert this real photo into a mbti style."
26
+ ],
27
+ "console": "integratedTerminal", // 输出日志到 VS Code 内置终端(方便查看)
28
+ "justMyCode": false, // 允许调试第三方库(如 accelerate、transformers)
29
+ "cwd": "${workspaceFolder}", // 工作目录设为项目根目录(确保路径正确)
30
+ "env": {
31
+ "PYTHONUNBUFFERED": "1", // 禁用输出缓冲,实时查看日志
32
+ "CUDA_VISIBLE_DEVICES": "7"
33
+ }
34
+ }
35
+ ]
36
+ }
FLUX.1-Kontext-dev.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
3
+ from PIL import Image
4
+ import os
5
+ import json
6
+
7
+
8
+
9
+ # for i in range(2):
10
+ # pipe.load_lora(pipe.dit, f"models/train/FLUX.1_lora_1126/epoch-{i}.safetensors", alpha=1)
11
+ # step = 25
12
+ # input_path = "dataset/multi_frame"
13
+ # base_path = f"validate_result/multi_frame{step}"
14
+ # save_path = f"{base_path}/epoch{i}"
15
+ # save_path_GT = f"{base_path}/GT"
16
+ # os.makedirs(save_path, exist_ok=True)
17
+ # os.makedirs(save_path_GT, exist_ok=True)
18
+ # for img in os.listdir(input_path):
19
+ # image = Image.open(os.path.join(input_path,img))
20
+ # image.save(os.path.join(save_path_GT,img))
21
+ # prompt="Convert this image into a line art style: retain the original scenes and characters unchanged, present it as a black-and-white sketch effect, and make it suitable for storyboard design. Requirements: use bold and powerful lines, highlight structures and textures with concise strokes, adopt a style close to comic sketching, roughly outline the scenes and character movements with simple lines, prohibit the depiction of details, and represent the characters' facial features with the simplest lines.",
22
+ # # prompt = "Convert this image into a mbti style"
23
+ # for fig in os.listdir(input_path):
24
+ # if not fig.endswith(".png"):
25
+ # continue
26
+ # image = pipe(
27
+ # prompt = prompt,
28
+ # kontext_images=Image.open(os.path.join(input_path,fig)).resize((768, 768)),
29
+ # height=768, width=768,
30
+ # seed=0,
31
+ # num_inference_steps=step
32
+ # )
33
+ # image.save(os.path.join(save_path,fig))
34
+
35
+ for i in range(2):
36
+ pipe = FluxImagePipeline.from_pretrained(
37
+ torch_dtype=torch.bfloat16,
38
+ device="cuda",
39
+ model_configs=[
40
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
41
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="text_encoder/model.safetensors"),
42
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="text_encoder_2/"),
43
+ ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="ae.safetensors"),
44
+ ],
45
+ )
46
+ pipe.load_lora(pipe.dit, f"models/train/FLUX.1_lora_1126/epoch-{i}.safetensors", alpha=1)
47
+ step = 25
48
+ base_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/validate_result/t2i_1201{step}"
49
+ save_path = f"{base_path}/epoch{i}"
50
+ os.makedirs(save_path, exist_ok=True)
51
+ with open("nano_comprehension_1201.txt", "r") as f:
52
+ prompts = f.readlines()
53
+ for prompt in prompts:
54
+ prompt = prompt.strip()
55
+ if prompt == "":
56
+ continue
57
+ prompt_dict = json.loads(prompt)
58
+ fig = f"{prompt_dict["Image_Name"]}.png"
59
+ del prompt_dict["Image_Name"]
60
+ prompt = json.dumps(prompt_dict, ensure_ascii=False)
61
+ image = pipe(
62
+ prompt = prompt,
63
+ height=768, width=768,
64
+ seed=0,
65
+ num_inference_steps=step
66
+ )
67
+ image.save(os.path.join(save_path,fig))
68
+
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2023] [Zhongjie Duan]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSynth-Studio
2
+
3
+ <a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
4
+
5
+ [![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
6
+ [![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
7
+ [![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
8
+ [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
9
+ [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
10
+
11
+ [切换到中文](./README_zh.md)
12
+
13
+ ## Introduction
14
+
15
+ Welcome to the magic world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by [ModelScope](https://www.modelscope.cn/) team. We aim to foster technical innovation through framework development, bring together the power of the open-source community, and explore the limits of generative models!
16
+
17
+ DiffSynth currently includes two open-source projects:
18
+ * [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, for academia, providing support for more cutting-edge model capabilities.
19
+ * [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, for industry, offering higher computing performance and more stable features.
20
+
21
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core projects behind ModelScope [AIGC zone](https://modelscope.cn/aigc/home), offering powerful AI content generation abilities. Come and try our carefully designed features and start your AI creation journey!
22
+
23
+ ## Installation
24
+
25
+ Install from source (recommended):
26
+
27
+ ```
28
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
29
+ cd DiffSynth-Studio
30
+ pip install -e .
31
+ ```
32
+
33
+ <details>
34
+ <summary>Other installation methods</summary>
35
+
36
+ Install from PyPI (version updates may be delayed; for latest features, install from source)
37
+
38
+ ```
39
+ pip install diffsynth
40
+ ```
41
+
42
+ If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
43
+
44
+ * [torch](https://pytorch.org/get-started/locally/)
45
+ * [sentencepiece](https://github.com/google/sentencepiece)
46
+ * [cmake](https://cmake.org)
47
+ * [cupy](https://docs.cupy.dev/en/stable/install.html)
48
+
49
+ </details>
50
+
51
+ ## Basic Framework
52
+
53
+ DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
54
+
55
+ ### Qwen-Image Series (🔥New Model)
56
+
57
+ Details: [./examples/qwen_image/](./examples/qwen_image/)
58
+
59
+ ![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924)
60
+
61
+ <details>
62
+
63
+ <summary>Quick Start</summary>
64
+
65
+ ```python
66
+ from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
67
+ from PIL import Image
68
+ import torch
69
+
70
+ pipe = QwenImagePipeline.from_pretrained(
71
+ torch_dtype=torch.bfloat16,
72
+ device="cuda",
73
+ model_configs=[
74
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
75
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
76
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
77
+ ],
78
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
79
+ )
80
+ prompt = "A detailed portrait of a girl underwater, wearing a blue flowing dress, hair gently floating, clear light and shadow, surrounded by bubbles, calm expression, fine details, dreamy and beautiful."
81
+ image = pipe(
82
+ prompt, seed=0, num_inference_steps=40,
83
+ # edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
84
+ )
85
+ image.save("image.jpg")
86
+ ```
87
+
88
+ </details>
89
+
90
+ <details>
91
+
92
+ <summary>Model Overview</summary>
93
+
94
+ |Model ID|Inference|Low VRAM Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
95
+ |-|-|-|-|-|-|-|
96
+ |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
97
+ |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
98
+ |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
99
+ |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
100
+ |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
101
+ |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
102
+ |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
103
+ |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
104
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
105
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
106
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
107
+ |[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
108
+ |[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
109
+
110
+ </details>
111
+
112
+ ### FLUX Series
113
+
114
+ Detail page: [./examples/flux/](./examples/flux/)
115
+
116
+ ![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
117
+
118
+ <details>
119
+
120
+ <summary>Quick Start</summary>
121
+
122
+ ```python
123
+ import torch
124
+ from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
125
+
126
+ pipe = FluxImagePipeline.from_pretrained(
127
+ torch_dtype=torch.bfloat16,
128
+ device="cuda",
129
+ model_configs=[
130
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
131
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
132
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
133
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
134
+ ],
135
+ )
136
+
137
+ image = pipe(prompt="a cat", seed=0)
138
+ image.save("image.jpg")
139
+ ```
140
+
141
+ </details>
142
+
143
+ <details>
144
+
145
+ <summary>Model Overview</summary>
146
+
147
+ | Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
148
+ |-|-|-|-|-|-|-|-|
149
+ |[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
150
+ |[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
151
+ |[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
152
+ |[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
153
+ |[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
154
+ |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
155
+ |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
156
+ |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
157
+ |[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
158
+ |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
159
+ |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
160
+ |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
161
+ |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
162
+ |[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
163
+
164
+ </details>
165
+
166
+
167
+
168
+ ### Wan Series
169
+
170
+ Detail page: [./examples/wanvideo/](./examples/wanvideo/)
171
+
172
+ https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
173
+
174
+ <details>
175
+
176
+ <summary>Quick Start</summary>
177
+
178
+ ```python
179
+ import torch
180
+ from diffsynth import save_video
181
+ from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
182
+
183
+ pipe = WanVideoPipeline.from_pretrained(
184
+ torch_dtype=torch.bfloat16,
185
+ device="cuda",
186
+ model_configs=[
187
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
188
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
189
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
190
+ ],
191
+ )
192
+ pipe.enable_vram_management()
193
+
194
+ video = pipe(
195
+ prompt="A documentary photography style scene: a lively puppy rapidly running on green grass. The puppy has brown-yellow fur, upright ears, and looks focused and joyful. Sunlight shines on its body, making the fur appear soft and shiny. The background is an open field with occasional wildflowers, and faint blue sky and clouds in the distance. Strong sense of perspective captures the motion of the puppy and the vitality of the surrounding grass. Mid-shot side-moving view.",
196
+ negative_prompt="Bright colors, overexposed, static, blurry details, subtitles, style, artwork, image, still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, malformed limbs, fused fingers, still frame, messy background, three legs, crowded background people, walking backwards",
197
+ seed=0, tiled=True,
198
+ )
199
+ save_video(video, "video1.mp4", fps=15, quality=5)
200
+ ```
201
+
202
+ </details>
203
+
204
+ <details>
205
+
206
+ <summary>Model Overview</summary>
207
+
208
+ | Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
209
+ |-|-|-|-|-|-|-|
210
+ |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
211
+ |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
212
+ |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
213
+ |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
214
+ |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
215
+ |[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
216
+ |[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
217
+ |[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
218
+ |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
219
+ |[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
220
+ |[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
221
+ |[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
222
+ |[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
223
+ |[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
224
+ |[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
225
+ |[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
226
+ |[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
227
+ |[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
228
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
229
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
230
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
231
+ |[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
232
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
233
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
234
+ |[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
235
+ |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
236
+ |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
237
+ |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
238
+ |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
239
+ |[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
240
+ |[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
241
+
242
+ </details>
243
+
244
+ ### More Models
245
+
246
+
247
+
248
+ <details>
249
+ <summary>Image Generation Models</summary>
250
+
251
+ Detail page: [./examples/image_synthesis/](./examples/image_synthesis/)
252
+
253
+ |FLUX|Stable Diffusion 3|
254
+ |-|-|
255
+ |![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
256
+
257
+ |Kolors|Hunyuan-DiT|
258
+ |-|-|
259
+ |![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
260
+
261
+ |Stable Diffusion|Stable Diffusion XL|
262
+ |-|-|
263
+ |![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
264
+
265
+ </details>
266
+
267
+
268
+
269
+ <details>
270
+ <summary>Video Generation Models</summary>
271
+
272
+ - HunyuanVideo: [./examples/HunyuanVideo/](./examples/HunyuanVideo/)
273
+
274
+ https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
275
+
276
+ - StepVideo: [./examples/stepvideo/](./examples/stepvideo/)
277
+
278
+ https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
279
+
280
+ - CogVideoX: [./examples/CogVideoX/](./examples/CogVideoX/)
281
+
282
+ https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
283
+
284
+ </details>
285
+
286
+
287
+
288
+ <details>
289
+ <summary>Image Quality Assessment Models</summary>
290
+
291
+ We have integrated a series of image quality assessment models. These models can be used for evaluating image generation models, alignment training, and similar tasks.
292
+
293
+ Detail page: [./examples/image_quality_metric/](./examples/image_quality_metric/)
294
+
295
+ * [ImageReward](https://github.com/THUDM/ImageReward)
296
+ * [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
297
+ * [PickScore](https://github.com/yuvalkirstain/pickscore)
298
+ * [CLIP](https://github.com/openai/CLIP)
299
+ * [HPSv2](https://github.com/tgxs002/HPSv2)
300
+ * [HPSv2.1](https://github.com/tgxs002/HPSv2)
301
+ * [MPS](https://github.com/Kwai-Kolors/MPS)
302
+
303
+ </details>
304
+
305
+
306
+
307
+ ## Innovative Achievements
308
+
309
+ DiffSynth-Studio is not just an engineering model framework, but also a platform for incubating innovative results.
310
+
311
+ <details>
312
+ <summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>
313
+
314
+ - Detail page: https://github.com/modelscope/Nexus-Gen
315
+ - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
316
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
317
+ - Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
318
+ - Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
319
+
320
+ ![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
321
+
322
+ </details>
323
+
324
+ <details>
325
+ <summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>
326
+
327
+ - Detail page: [./examples/ArtAug/](./examples/ArtAug/)
328
+ - Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
329
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
330
+ - Online Demo: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
331
+
332
+ |FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
333
+ |-|-|
334
+ |![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
335
+
336
+ </details>
337
+
338
+ <details>
339
+ <summary>EliGen: Precise Image Region Control</summary>
340
+
341
+ - Detail page: [./examples/EntityControl/](./examples/EntityControl/)
342
+ - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
343
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
344
+ - Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
345
+ - Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
346
+
347
+ |Entity Control Mask|Generated Image|
348
+ |-|-|
349
+ |![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
350
+
351
+ </details>
352
+
353
+ <details>
354
+ <summary>ExVideo: Extended Training for Video Generation Models</summary>
355
+
356
+ - Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
357
+ - Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
358
+ - Code Example: [./examples/ExVideo/](./examples/ExVideo/)
359
+ - Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
360
+
361
+ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
362
+
363
+ </details>
364
+
365
+ <details>
366
+ <summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>
367
+
368
+ - Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
369
+ - Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
370
+ - Code Example: [./examples/Diffutoon/](./examples/Diffutoon/)
371
+
372
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
373
+
374
+ </details>
375
+
376
+ <details>
377
+ <summary>DiffSynth: The Initial Version of This Project</summary>
378
+
379
+ - Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
380
+ - Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
381
+ - Code Example: [./examples/diffsynth/](./examples/diffsynth/)
382
+
383
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
384
+
385
+ </details>
386
+
387
+
388
+
389
+ ## Update History
390
+
391
+ - **November 4, 2025**: We support [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained on Wan 2.1 and enables motion generation conditioned on reference videos.
392
+
393
+ - **October 30, 2025**: We support [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which enables text-to-video, image-to-video, and video continuation capabilities. This model adopts Wan's framework for both inference and training in this project.
394
+
395
+ - **October 27, 2025**: We support [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, further expanding Wan's ecosystem.
396
+
397
+ - **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) is released! This model is jointly developed and open-sourced by us and the Taobao Design Team. The model is built upon Qwen-Image, specifically designed for e-commerce poster scenarios, and supports precise partition layout control. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
398
+
399
+ - **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
400
+
401
+ - **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
402
+
403
+ - **August 21, 2025**: [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) is released! Compared to the V1 version, the training dataset has been updated to the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), enabling generated images to better align with the inherent image distribution and style of Qwen-Image. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
404
+
405
+ - **August 21, 2025**: We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structure control LoRA model. Following "In Context" routine, it supports various types of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
406
+
407
+ - **August 20, 2025** We open-sourced [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix), which improves the editing performance of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our example code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py).
408
+
409
+ - **August 19, 2025** 🔥 Qwen-Image-Edit is now open source. Welcome the new member to the image editing model family!
410
+
411
+ - **August 18, 2025** We trained and open-sourced the Inpaint ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
412
+
413
+ - **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset). This is an image dataset generated using the Qwen-Image model, with a total of 160,000 `1024 x 1024` images. It includes the general, English text rendering, and Chinese text rendering subsets. We provide caption, entity and control images annotations for each image. Developers can use this dataset to train models such as ControlNet and EliGen for the Qwen-Image model. We aim to promote technological development through open-source contributions!
414
+
415
+ - **August 13, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
416
+
417
+ - **August 12, 2025** We trained and open-sourced the ControlNet model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny), which adopts a lightweight architectural design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
418
+
419
+ - **August 11, 2025** We released another distilled acceleration model for Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA). It uses the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure is changed to LoRA. This makes it work better with other open-source models.
420
+
421
+ - **August 7, 2025** We open-sourced the entity control LoRA of Qwen-Image, [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen). Qwen-Image-EliGen is able to achieve entity-level controlled text-to-image generation. See the [paper](https://arxiv.org/abs/2501.01097) for technical details. Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
422
+
423
+ - **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.
424
+
425
+ - **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family!
426
+
427
+ - **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/).
428
+
429
+ - **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/).
430
+
431
+ - **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks.
432
+ - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
433
+ - Github Repo: https://github.com/modelscope/Nexus-Gen
434
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
435
+ - Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
436
+ - Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
437
+
438
+ <details>
439
+ <summary>More</summary>
440
+
441
+ - **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
442
+
443
+ - **March 25, 2025** Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
444
+
445
+ - **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
446
+
447
+ - **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
448
+
449
+ - **February 25, 2025** We support Wan-Video, a collection of SOTA video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
450
+
451
+ - **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
452
+
453
+ - **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
454
+ - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
455
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
456
+ - Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
457
+ - Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
458
+
459
+ - **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
460
+
461
+ - **December 18, 2024** We propose ArtAug, an approach designed to improve text-to-image synthesis models through synthesis-understanding interactions. We have trained an ArtAug enhancement module for FLUX.1-dev in the format of LoRA. This model integrates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, leading to an improvement in the quality of generated images.
462
+ - Paper: https://arxiv.org/abs/2412.12888
463
+ - Examples: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
464
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
465
+ - Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (Coming soon)
466
+
467
+ - **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models that can be freely combined, even if their structures differ. Additionally, ControlNet models are compatible with high-resolution refinement and partition control techniques, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
468
+
469
+ - **October 8, 2024.** We release the extended LoRA based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
470
+
471
+ - **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
472
+ - Text to video
473
+ - Video editing
474
+ - Self-upscaling
475
+ - Video interpolation
476
+
477
+ - **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
478
+ - Use it in our [WebUI](#usage-in-webui).
479
+
480
+ - **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
481
+ - Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
482
+ - LoRA, ControlNet, and additional models will be available soon.
483
+
484
+ - **June 21, 2024.** We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
485
+ - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
486
+ - Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
487
+ - Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
488
+ - Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
489
+ - You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
490
+
491
+ - **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
492
+
493
+ - **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
494
+ - [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
495
+ - The source codes are released in this project.
496
+ - The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
497
+
498
+ - **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
499
+
500
+ - **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
501
+ - The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
502
+ - Demo videos are shown on Bilibili, including three tasks.
503
+ - [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
504
+ - [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
505
+ - [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
506
+ - The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
507
+ - An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
508
+
509
+ - **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
510
+ - The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
511
+ - FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
512
+ - The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
513
+ - The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
514
+ - A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
515
+ - Since OLSS requires additional training, we don't implement it in this project.
516
+
517
+ - **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
518
+ - [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
519
+ - The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
520
+ - The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
521
+
522
+ </details>
README_zh.md ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSynth-Studio
2
+
3
+ <a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
4
+
5
+ [![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
6
+ [![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
7
+ [![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
8
+ [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
9
+ [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
10
+
11
+ [Switch to English](./README.md)
12
+
13
+ ## 简介
14
+
15
+ 欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
16
+
17
+ DiffSynth 目前包括两个开源项目:
18
+ * [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
19
+ * [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
20
+
21
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 作为魔搭社区 [AIGC 专区](https://modelscope.cn/aigc/home) 的核心技术支撑,提供了强大的AI生成内容能力。欢迎体验我们精心打造的产品化功能,开启您的AI创作之旅!
22
+
23
+ ## 安装
24
+
25
+ 从源码安装(推荐):
26
+
27
+ ```
28
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
29
+ cd DiffSynth-Studio
30
+ pip install -e .
31
+ ```
32
+
33
+ <details>
34
+ <summary>其他安装方式</summary>
35
+
36
+ 从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
37
+
38
+ ```
39
+ pip install diffsynth
40
+ ```
41
+
42
+ 如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
43
+
44
+ * [torch](https://pytorch.org/get-started/locally/)
45
+ * [sentencepiece](https://github.com/google/sentencepiece)
46
+ * [cmake](https://cmake.org)
47
+ * [cupy](https://docs.cupy.dev/en/stable/install.html)
48
+
49
+ </details>
50
+
51
+
52
+
53
+ ## 基础框架
54
+
55
+ DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
56
+
57
+ ### Qwen-Image 系列 (🔥新模型)
58
+
59
+ 详细页面:[./examples/qwen_image/](./examples/qwen_image/)
60
+
61
+ ![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924)
62
+
63
+ <details>
64
+
65
+ <summary>快速开始</summary>
66
+
67
+ ```python
68
+ from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
69
+ from PIL import Image
70
+ import torch
71
+
72
+ pipe = QwenImagePipeline.from_pretrained(
73
+ torch_dtype=torch.bfloat16,
74
+ device="cuda",
75
+ model_configs=[
76
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
77
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
78
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
79
+ ],
80
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
81
+ )
82
+ prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
83
+ image = pipe(
84
+ prompt, seed=0, num_inference_steps=40,
85
+ # edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
86
+ )
87
+ image.save("image.jpg")
88
+ ```
89
+
90
+ </details>
91
+
92
+ <details>
93
+
94
+ <summary>模型总览</summary>
95
+
96
+ |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
97
+ |-|-|-|-|-|-|-|
98
+ |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
99
+ |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
100
+ |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
101
+ |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
102
+ |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
103
+ |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
104
+ |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
105
+ |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](./examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
106
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
107
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
108
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
109
+ |[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](./examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
110
+ |[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
111
+
112
+ </details>
113
+
114
+ ### FLUX 系列
115
+
116
+ 详细页面:[./examples/flux/](./examples/flux/)
117
+
118
+ ![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
119
+
120
+ <details>
121
+
122
+ <summary>快速开始</summary>
123
+
124
+ ```python
125
+ import torch
126
+ from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
127
+
128
+ pipe = FluxImagePipeline.from_pretrained(
129
+ torch_dtype=torch.bfloat16,
130
+ device="cuda",
131
+ model_configs=[
132
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
133
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
134
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
135
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
136
+ ],
137
+ )
138
+
139
+ image = pipe(prompt="a cat", seed=0)
140
+ image.save("image.jpg")
141
+ ```
142
+
143
+ </details>
144
+
145
+ <details>
146
+
147
+ <summary>模型总览</summary>
148
+
149
+ |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
150
+ |-|-|-|-|-|-|-|-|
151
+ |[FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
152
+ |[FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](./examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
153
+ |[FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
154
+ |[FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
155
+ |[FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
156
+ |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
157
+ |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
158
+ |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
159
+ |[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
160
+ |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
161
+ |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
162
+ |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)|
163
+ |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
164
+ |[Nexus-Gen](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](./examples/flux/model_training/lora/Nexus-Gen.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen.py)|
165
+
166
+ </details>
167
+
168
+ ### Wan 系列
169
+
170
+ 详细页面:[./examples/wanvideo/](./examples/wanvideo/)
171
+
172
+ https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
173
+
174
+ <details>
175
+
176
+ <summary>快速开始</summary>
177
+
178
+ ```python
179
+ import torch
180
+ from diffsynth import save_video
181
+ from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
182
+
183
+ pipe = WanVideoPipeline.from_pretrained(
184
+ torch_dtype=torch.bfloat16,
185
+ device="cuda",
186
+ model_configs=[
187
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
188
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
189
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
190
+ ],
191
+ )
192
+ pipe.enable_vram_management()
193
+
194
+ video = pipe(
195
+ prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
196
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
197
+ seed=0, tiled=True,
198
+ )
199
+ save_video(video, "video1.mp4", fps=15, quality=5)
200
+ ```
201
+
202
+ </details>
203
+
204
+ <details>
205
+
206
+ <summary>模型总览</summary>
207
+
208
+ |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
209
+ |-|-|-|-|-|-|-|
210
+ |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
211
+ |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
212
+ |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
213
+ |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
214
+ |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
215
+ |[Wan-AI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
216
+ |[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
217
+ |[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
218
+ |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
219
+ |[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
220
+ |[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
221
+ |[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
222
+ |[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
223
+ |[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
224
+ |[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
225
+ |[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
226
+ |[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
227
+ |[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
228
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
229
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
230
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
231
+ |[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
232
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
233
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
234
+ |[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
235
+ |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
236
+ |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
237
+ |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
238
+ |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](./examples/wanvideo/model_inference/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](./examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](./examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
239
+ |[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](./examples/wanvideo/model_inference/LongCat-Video.py)|[code](./examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](./examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](./examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
240
+ |[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](./examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](./examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
241
+
242
+ </details>
243
+
244
+
245
+
246
+ ### 更多模型
247
+
248
+
249
+
250
+ <details>
251
+ <summary>图像生成模型</summary>
252
+
253
+ 详细页面:[./examples/image_synthesis/](./examples/image_synthesis/)
254
+
255
+ |FLUX|Stable Diffusion 3|
256
+ |-|-|
257
+ |![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
258
+
259
+ |Kolors|Hunyuan-DiT|
260
+ |-|-|
261
+ |![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
262
+
263
+ |Stable Diffusion|Stable Diffusion XL|
264
+ |-|-|
265
+ |![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
266
+
267
+ </details>
268
+
269
+
270
+
271
+ <details>
272
+ <summary>视频生成模型</summary>
273
+
274
+ - HunyuanVideo:[./examples/HunyuanVideo/](./examples/HunyuanVideo/)
275
+
276
+ https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
277
+
278
+ - StepVideo:[./examples/stepvideo/](./examples/stepvideo/)
279
+
280
+ https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
281
+
282
+ - CogVideoX:[./examples/CogVideoX/](./examples/CogVideoX/)
283
+
284
+ https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
285
+
286
+ </details>
287
+
288
+
289
+
290
+ <details>
291
+ <summary>图像质量评估模型</summary>
292
+
293
+ 我们集成了一系列图像质量评估模型,这些模型可以用于图像生成模型的评测、对齐训练等场景中。
294
+
295
+ 详细页面:[./examples/image_quality_metric/](./examples/image_quality_metric/)
296
+
297
+ * [ImageReward](https://github.com/THUDM/ImageReward)
298
+ * [Aesthetic](https://github.com/christophschuhmann/improved-aesthetic-predictor)
299
+ * [PickScore](https://github.com/yuvalkirstain/pickscore)
300
+ * [CLIP](https://github.com/openai/CLIP)
301
+ * [HPSv2](https://github.com/tgxs002/HPSv2)
302
+ * [HPSv2.1](https://github.com/tgxs002/HPSv2)
303
+ * [MPS](https://github.com/Kwai-Kolors/MPS)
304
+
305
+ </details>
306
+
307
+
308
+
309
+ ## 创新成果
310
+
311
+ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
312
+
313
+ <details>
314
+ <summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
315
+
316
+ - 详细页面:https://github.com/modelscope/Nexus-Gen
317
+ - 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
318
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
319
+ - 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
320
+ - 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
321
+
322
+ ![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
323
+
324
+ </details>
325
+
326
+
327
+
328
+ <details>
329
+ <summary>ArtAug: 图像生成模型的美学提升</summary>
330
+
331
+ - 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
332
+ - 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
333
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
334
+ - 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
335
+
336
+ |FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
337
+ |-|-|
338
+ |![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
339
+
340
+ </details>
341
+
342
+
343
+
344
+ <details>
345
+
346
+ <summary>EliGen: 精准的图像分区控制</summary>
347
+
348
+ - 详细页面:[./examples/EntityControl/](./examples/EntityControl/)
349
+ - 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
350
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
351
+ - 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
352
+ - 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
353
+
354
+ |实体控制区域|生成图像|
355
+ |-|-|
356
+ |![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
357
+
358
+ </details>
359
+
360
+
361
+
362
+ <details>
363
+
364
+ <summary>ExVideo: 视频生成模型的扩展训练</summary>
365
+
366
+ - 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
367
+ - 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
368
+ - 代码样例:[./examples/ExVideo/](./examples/ExVideo/)
369
+ - 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
370
+
371
+ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
372
+
373
+ </details>
374
+
375
+
376
+
377
+ <details>
378
+
379
+ <summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
380
+
381
+ - 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
382
+ - 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
383
+ - 代码样例:[./examples/Diffutoon/](./examples/Diffutoon/)
384
+
385
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
386
+
387
+ </details>
388
+
389
+
390
+
391
+ <details>
392
+
393
+ <summary>DiffSynth: 本项目的初代版本</summary>
394
+
395
+ - 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
396
+ - 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
397
+ - 代码样例:[./examples/diffsynth/](./examples/diffsynth/)
398
+
399
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
400
+
401
+ </details>
402
+
403
+
404
+
405
+ ## 更新历史
406
+
407
+ - **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
408
+
409
+ - **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
410
+
411
+ - **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
412
+
413
+ - **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
414
+
415
+ - **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
416
+
417
+ - **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
418
+
419
+ - **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
420
+
421
+ - **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
422
+
423
+ - **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
424
+
425
+ - **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
426
+
427
+ - **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
428
+
429
+ - **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
430
+
431
+ - **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
432
+
433
+ - **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
434
+
435
+ - **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
436
+
437
+ - **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
438
+
439
+ - **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
440
+
441
+ - **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
442
+
443
+ - **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
444
+
445
+ - **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
446
+
447
+ - **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
448
+ - 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
449
+ - Github 仓库: https://github.com/modelscope/Nexus-Gen
450
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
451
+ - 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
452
+ - 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
453
+
454
+ <details>
455
+ <summary>更多</summary>
456
+
457
+ - **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
458
+
459
+ - **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
460
+
461
+ - **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
462
+
463
+ - **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
464
+
465
+ - **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
466
+
467
+ - **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
468
+
469
+ - **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
470
+ - 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
471
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
472
+ - 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
473
+ - 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
474
+
475
+ - **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
476
+
477
+ - **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学理解融入 FLUX.1-dev,从而提升了生成图像的质量。
478
+ - 论文: https://arxiv.org/abs/2412.12888
479
+ - 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
480
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
481
+ - 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
482
+
483
+ - **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
484
+
485
+ - **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
486
+
487
+ - **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
488
+ - 文本到视频
489
+ - 视频编辑
490
+ - 自我超分
491
+ - 视频插帧
492
+
493
+ - **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
494
+ - 在我们的 [WebUI](#usage-in-webui) 中使用它。
495
+
496
+ - **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
497
+ - 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
498
+ - LoRA、ControlNet 和其他附加模型将很快推出。
499
+
500
+ - **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
501
+ - [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
502
+ - 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
503
+ - 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
504
+ - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
505
+ - 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
506
+
507
+ - **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
508
+
509
+ - **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
510
+ - [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
511
+ - 源代码已在此项目中发布。
512
+ - 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
513
+
514
+ - **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
515
+
516
+ - **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
517
+ - sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
518
+ - 演示视频已在 Bilibili 上展示,包含三个任务:
519
+ - [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
520
+ - [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
521
+ - [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
522
+ - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
523
+ - 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
524
+
525
+ - **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
526
+ - 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
527
+ - FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
528
+ - OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
529
+ - 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
530
+ - 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
531
+ - 由于 OLSS 需要额外训练,我们未在本项目中实现它。
532
+
533
+ - **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
534
+ - [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
535
+ - 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
536
+ - 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
537
+
538
+ </details>
_temp_.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # import os
3
+ # render2real_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/spotlight_sketch/epoch0"
4
+ # sketch_enhance_body_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/spotlight_sketch/GT"
5
+ # render2real_files = set(os.listdir(render2real_path))
6
+ # sketch_enhance_body_files = set(os.listdir(sketch_enhance_body_path))
7
+ # for file in render2real_files:
8
+ # if file not in sketch_enhance_body_files:
9
+ # print(f"Removing {file} from render2real_path")
10
+ # os.remove(os.path.join(render2real_path, file))
11
+ # for file in sketch_enhance_body_files:
12
+ # if file not in render2real_files:
13
+ # print(f"Removing {file} from sketch_enhance_body_path")
14
+ # os.remove(os.path.join(sketch_enhance_body_path, file))
15
+
16
+ # import os
17
+ # input_path = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/mbti/Realistic"
18
+ # for file in os.listdir(input_path):
19
+ # # 把文件名称最后的_Realistic去掉
20
+ # new_name = file.replace("_Realistic", "")
21
+ # os.rename(os.path.join(input_path, file), os.path.join(input_path, new_name))
22
+
23
+ # import os
24
+ # for file in os.listdir("dataset/spotlight_sketch_cat/GT"):
25
+ # with open("dataset/spotlight_sketch_cat/pairs_t2t.txt", "a") as f:
26
+ # # 目标图 原图
27
+ # f.write(f"GT/{file}\tepoch0/{file}\n")
28
+
29
+
30
+ # import os
31
+ # import json
32
+ # from tqdm import tqdm
33
+ # input_txt = "/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/spotlight_sketch_cat/spotlight_nano_comprehension_1203.txt"
34
+ # with open(input_txt, "r") as f:
35
+ # lines = f.readlines()
36
+ # for i in tqdm(range(len(lines))):
37
+ # data = json.loads(lines[i])
38
+ # fig_id = f"{data['Image_Name']}.png"
39
+ # del data["Image_Name"]
40
+ # input_dir = "dataset/spotlight_sketch_cat/epoch0"
41
+ # GT_dir = "dataset/spotlight_sketch_cat/GT"
42
+ # for file in os.listdir(input_dir):
43
+ # if fig_id in file:
44
+ # with open("dataset/spotlight_sketch_cat/pairs_i2i.txt", "a") as f:
45
+ # # 目标 原图 prompt
46
+ # f.write(f"{GT_dir}/{file}\t{input_dir}/{file}\t{data}\n")
47
+
48
+ # 把文件夹中的图片每六张拼成一个3行两列的大图,保存到另一个文件夹中,原图拼接不要截图
49
+ import os
50
+ from PIL import Image
51
+ from tqdm import tqdm
52
+ import numpy as np
53
+ base_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/the roses","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/nouvelle","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/legs","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/frankenstein"]
54
+ # 核心配置参数(按需求定义)
55
+ crop_size = (1920, 800) # 目标CenterCrop尺寸 (宽, 高)
56
+ resize_size = (477, 188) # 下采样后的单张尺寸 (宽, 高)
57
+ line_width = 6 # 黑线宽度(6像素)
58
+ target_merge_size = (960, 576) # 最终拼接目标尺寸 (宽, 高)
59
+
60
+
61
+ def center_crop_to_size(img, target_size):
62
+ """
63
+ 对图片进行CenterCrop到指定尺寸,不足部分用黑色像素填充
64
+ :param img: PIL Image对象
65
+ :param target_size: (target_w, target_h) 目标裁剪尺寸
66
+ :return: crop+补黑后的PIL Image
67
+ """
68
+ target_w, target_h = target_size
69
+ img_w, img_h = img.size
70
+
71
+ # Step1: 计算CenterCrop的区域(中心对齐)
72
+ # 水平方向裁剪
73
+ if img_w >= target_w:
74
+ left = (img_w - target_w) // 2
75
+ right = left + target_w
76
+ else:
77
+ left = 0
78
+ right = img_w
79
+ # 垂直方向裁剪
80
+ if img_h >= target_h:
81
+ top = (img_h - target_h) // 2
82
+ bottom = top + target_h
83
+ else:
84
+ top = 0
85
+ bottom = img_h
86
+
87
+ # Step2: 执行CenterCrop
88
+ cropped = img.crop((left, top, right, bottom))
89
+
90
+ # Step3: 不足目标尺寸的部分用黑色填充
91
+ if cropped.size != (target_w, target_h):
92
+ new_img = Image.new("RGB", (target_w, target_h), (0, 0, 0)) # 黑色背景
93
+ new_img.paste(cropped, ((target_w - cropped.width) // 2, (target_h - cropped.height) // 2))
94
+ cropped = new_img
95
+
96
+ return cropped
97
+ for k in range(len(base_dirs)):
98
+ save_path = f"{base_dirs[k]}_dedup_cat"
99
+ os.makedirs(save_path, exist_ok=True)
100
+ input_path = f"{base_dirs[k]}_dedup"
101
+ # 获取并排序文件列表
102
+ files = [f for f in os.listdir(input_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
103
+ files.sort()
104
+
105
+ # 遍历文件,每6张拼接一次
106
+ for i in tqdm(range(0, len(files), 6), desc="拼接图片"):
107
+ # 初始化最终拼接画布(目标尺寸960×576,黑色背景)
108
+ merged_image = np.zeros((target_merge_size[1], target_merge_size[0], 3), dtype=np.uint8)
109
+
110
+ # 逐个处理6张图片
111
+ valid_imgs = [] # 存储处理后的有效图片
112
+ for j in range(6):
113
+ if i + j >= len(files):
114
+ # 不足6张时,break
115
+ img_np = np.zeros((resize_size[1], resize_size[0], 3), dtype=np.uint8)
116
+ valid_imgs.append(img_np)
117
+ continue
118
+
119
+ img_path = os.path.join(input_path, files[i + j])
120
+ try:
121
+ # 读取图片(确保RGB格式)
122
+ img = Image.open(img_path).convert("RGB")
123
+ img_w, img_h = img.size
124
+
125
+ # 过滤条件:原始宽度不足1800则跳过并打印
126
+ if img_w < 1800:
127
+ print(f"跳过文件 {files[i+j]}: 原始宽度 {img_w} < 1800")
128
+ # 用黑色图片填充该位置
129
+ img_np = np.zeros((resize_size[1], resize_size[0], 3), dtype=np.uint8)
130
+ valid_imgs.append(img_np)
131
+ continue
132
+
133
+ # Step1: CenterCrop到1920×800,不足补黑
134
+ cropped_img = center_crop_to_size(img, crop_size)
135
+
136
+ # Step2: 下采样到477×188(LANCZOS插值,保持画质)
137
+ resized_img = cropped_img.resize(resize_size, resample=Image.LANCZOS)
138
+
139
+ # 转为numpy数组
140
+ img_np = np.array(resized_img)
141
+ valid_imgs.append(img_np)
142
+
143
+ except Exception as e:
144
+ print(f"处理文件 {files[i+j]} 出错: {str(e)}")
145
+ # 出错时用黑色图片填充
146
+ img_np = np.zeros((resize_size[1], resize_size[0], 3), dtype=np.uint8)
147
+ valid_imgs.append(img_np)
148
+
149
+ # Step3: 计算每张图在拼接画布中的位置(3行2列 + 6像素黑线)
150
+ # 验证拼接尺寸兼容性(防止配置错误)
151
+ assert len(valid_imgs) == 6, "有效图片数量必须为6张"
152
+ # 计算图片+黑线的总占位,确保适配960×576
153
+ total_col = 2 * resize_size[0] + 1 * line_width # 2列+1条竖线
154
+ total_row = 3 * resize_size[1] + 2 * line_width # 3行+2条横线
155
+ # 计算画布中的偏移(居中放置,保证最终尺寸960×576)
156
+ offset_x = (target_merge_size[0] - total_col) // 2
157
+ offset_y = (target_merge_size[1] - total_row) // 2
158
+
159
+ # 逐个放置图片到拼接画布
160
+ for idx, img_np in enumerate(valid_imgs):
161
+ row = idx // 2 # 0/1/2行
162
+ col = idx % 2 # 0/1列
163
+
164
+ # 计算当前图片的起始位置(含黑线+整体偏移)
165
+ x_start = offset_x + col * (resize_size[0] + line_width)
166
+ y_start = offset_y + row * (resize_size[1] + line_width)
167
+ x_end = x_start + resize_size[0]
168
+ y_end = y_start + resize_size[1]
169
+
170
+ # 确保不超出画布边界
171
+ x_end = min(x_end, target_merge_size[0])
172
+ y_end = min(y_end, target_merge_size[1])
173
+ x_start = max(x_start, 0)
174
+ y_start = max(y_start, 0)
175
+
176
+ # 放置图片到画布
177
+ merged_image[y_start:y_end, x_start:x_end, :] = img_np[:y_end-y_start, :x_end-x_start, :]
178
+
179
+ # Step4: 保存最终拼接图片
180
+ save_name = f'merged_{i//6}.png'
181
+ save_full_path = os.path.join(save_path, save_name)
182
+ Image.fromarray(merged_image).save(save_full_path)
183
+
184
+ print(f"所有图片处理完成!结果保存至: {save_path}")
apps/gradio/DiffSynth_Studio.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
3
+ import os, torch
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+
8
+ config = {
9
+ "model_config": {
10
+ "Stable Diffusion": {
11
+ "model_folder": "models/stable_diffusion",
12
+ "pipeline_class": SDImagePipeline,
13
+ "default_parameters": {
14
+ "cfg_scale": 7.0,
15
+ "height": 512,
16
+ "width": 512,
17
+ }
18
+ },
19
+ "Stable Diffusion XL": {
20
+ "model_folder": "models/stable_diffusion_xl",
21
+ "pipeline_class": SDXLImagePipeline,
22
+ "default_parameters": {
23
+ "cfg_scale": 7.0,
24
+ }
25
+ },
26
+ "Stable Diffusion 3": {
27
+ "model_folder": "models/stable_diffusion_3",
28
+ "pipeline_class": SD3ImagePipeline,
29
+ "default_parameters": {
30
+ "cfg_scale": 7.0,
31
+ }
32
+ },
33
+ "Stable Diffusion XL Turbo": {
34
+ "model_folder": "models/stable_diffusion_xl_turbo",
35
+ "pipeline_class": SDXLImagePipeline,
36
+ "default_parameters": {
37
+ "negative_prompt": "",
38
+ "cfg_scale": 1.0,
39
+ "num_inference_steps": 1,
40
+ "height": 512,
41
+ "width": 512,
42
+ }
43
+ },
44
+ "Kolors": {
45
+ "model_folder": "models/kolors",
46
+ "pipeline_class": SDXLImagePipeline,
47
+ "default_parameters": {
48
+ "cfg_scale": 7.0,
49
+ }
50
+ },
51
+ "HunyuanDiT": {
52
+ "model_folder": "models/HunyuanDiT",
53
+ "pipeline_class": HunyuanDiTImagePipeline,
54
+ "default_parameters": {
55
+ "cfg_scale": 7.0,
56
+ }
57
+ },
58
+ "FLUX": {
59
+ "model_folder": "models/FLUX",
60
+ "pipeline_class": FluxImagePipeline,
61
+ "default_parameters": {
62
+ "cfg_scale": 1.0,
63
+ }
64
+ }
65
+ },
66
+ "max_num_painter_layers": 8,
67
+ "max_num_model_cache": 1,
68
+ }
69
+
70
+
71
+ def load_model_list(model_type):
72
+ if model_type is None:
73
+ return []
74
+ folder = config["model_config"][model_type]["model_folder"]
75
+ file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
76
+ if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
77
+ file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
78
+ file_list = sorted(file_list)
79
+ return file_list
80
+
81
+
82
+ def load_model(model_type, model_path):
83
+ global model_dict
84
+ model_key = f"{model_type}:{model_path}"
85
+ if model_key in model_dict:
86
+ return model_dict[model_key]
87
+ model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
88
+ model_manager = ModelManager()
89
+ if model_type == "HunyuanDiT":
90
+ model_manager.load_models([
91
+ os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
92
+ os.path.join(model_path, "mt5/pytorch_model.bin"),
93
+ os.path.join(model_path, "model/pytorch_model_ema.pt"),
94
+ os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
95
+ ])
96
+ elif model_type == "Kolors":
97
+ model_manager.load_models([
98
+ os.path.join(model_path, "text_encoder"),
99
+ os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
100
+ os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
101
+ ])
102
+ elif model_type == "FLUX":
103
+ model_manager.torch_dtype = torch.bfloat16
104
+ file_list = [
105
+ os.path.join(model_path, "text_encoder/model.safetensors"),
106
+ os.path.join(model_path, "text_encoder_2"),
107
+ ]
108
+ for file_name in os.listdir(model_path):
109
+ if file_name.endswith(".safetensors"):
110
+ file_list.append(os.path.join(model_path, file_name))
111
+ model_manager.load_models(file_list)
112
+ else:
113
+ model_manager.load_model(model_path)
114
+ pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
115
+ while len(model_dict) + 1 > config["max_num_model_cache"]:
116
+ key = next(iter(model_dict.keys()))
117
+ model_manager_to_release, _ = model_dict[key]
118
+ model_manager_to_release.to("cpu")
119
+ del model_dict[key]
120
+ torch.cuda.empty_cache()
121
+ model_dict[model_key] = model_manager, pipe
122
+ return model_manager, pipe
123
+
124
+
125
+ model_dict = {}
126
+
127
+ with gr.Blocks() as app:
128
+ gr.Markdown("# DiffSynth-Studio Painter")
129
+ with gr.Row():
130
+ with gr.Column(scale=382, min_width=100):
131
+
132
+ with gr.Accordion(label="Model"):
133
+ model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
134
+ model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
135
+
136
+ @gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
137
+ def model_type_to_model_path(model_type):
138
+ return gr.Dropdown(choices=load_model_list(model_type))
139
+
140
+ with gr.Accordion(label="Prompt"):
141
+ prompt = gr.Textbox(label="Prompt", lines=3)
142
+ negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
143
+ cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
144
+ embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
145
+
146
+ with gr.Accordion(label="Image"):
147
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
148
+ height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
149
+ width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
150
+ with gr.Column():
151
+ use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
152
+ seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
153
+
154
+ @gr.on(
155
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
156
+ outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
157
+ triggers=model_path.change
158
+ )
159
+ def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
160
+ load_model(model_type, model_path)
161
+ cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
162
+ embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
163
+ num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
164
+ height = config["model_config"][model_type]["default_parameters"].get("height", height)
165
+ width = config["model_config"][model_type]["default_parameters"].get("width", width)
166
+ return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
167
+
168
+
169
+ with gr.Column(scale=618, min_width=100):
170
+ with gr.Accordion(label="Painter"):
171
+ enable_local_prompt_list = []
172
+ local_prompt_list = []
173
+ mask_scale_list = []
174
+ canvas_list = []
175
+ for painter_layer_id in range(config["max_num_painter_layers"]):
176
+ with gr.Tab(label=f"Layer {painter_layer_id}"):
177
+ enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
178
+ local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
179
+ mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
180
+ canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
181
+ brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
182
+ label="Painter", key=f"canvas_{painter_layer_id}")
183
+ @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
184
+ def resize_canvas(height, width, canvas):
185
+ h, w = canvas["background"].shape[:2]
186
+ if h != height or width != w:
187
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
188
+ else:
189
+ return canvas
190
+
191
+ enable_local_prompt_list.append(enable_local_prompt)
192
+ local_prompt_list.append(local_prompt)
193
+ mask_scale_list.append(mask_scale)
194
+ canvas_list.append(canvas)
195
+ with gr.Accordion(label="Results"):
196
+ run_button = gr.Button(value="Generate", variant="primary")
197
+ output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
198
+ with gr.Row():
199
+ with gr.Column():
200
+ output_to_painter_button = gr.Button(value="Set as painter's background")
201
+ with gr.Column():
202
+ output_to_input_button = gr.Button(value="Set as input image")
203
+ painter_background = gr.State(None)
204
+ input_background = gr.State(None)
205
+ @gr.on(
206
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
207
+ outputs=[output_image],
208
+ triggers=run_button.click
209
+ )
210
+ def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
211
+ _, pipe = load_model(model_type, model_path)
212
+ input_params = {
213
+ "prompt": prompt,
214
+ "negative_prompt": negative_prompt,
215
+ "cfg_scale": cfg_scale,
216
+ "num_inference_steps": num_inference_steps,
217
+ "height": height,
218
+ "width": width,
219
+ "progress_bar_cmd": progress.tqdm,
220
+ }
221
+ if isinstance(pipe, FluxImagePipeline):
222
+ input_params["embedded_guidance"] = embedded_guidance
223
+ enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
224
+ args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
225
+ args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
226
+ args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
227
+ args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
228
+ )
229
+ local_prompts, masks, mask_scales = [], [], []
230
+ for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
231
+ enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
232
+ ):
233
+ if enable_local_prompt:
234
+ local_prompts.append(local_prompt)
235
+ masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
236
+ mask_scales.append(mask_scale)
237
+ input_params.update({
238
+ "local_prompts": local_prompts,
239
+ "masks": masks,
240
+ "mask_scales": mask_scales,
241
+ })
242
+ torch.manual_seed(seed)
243
+ image = pipe(**input_params)
244
+ return image
245
+
246
+ @gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
247
+ def send_output_to_painter_background(output_image, *canvas_list):
248
+ for canvas in canvas_list:
249
+ h, w = canvas["background"].shape[:2]
250
+ canvas["background"] = output_image.resize((w, h))
251
+ return tuple(canvas_list)
252
+ app.launch()
apps/gradio/entity_level_control.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import random
6
+ import json
7
+ import gradio as gr
8
+ from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
9
+ from modelscope import dataset_snapshot_download
10
+
11
+
12
+ dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
13
+ example_json = 'data/examples/eligen/entity_control/ui_examples.json'
14
+ with open(example_json, 'r') as f:
15
+ examples = json.load(f)['examples']
16
+
17
+ for idx in range(len(examples)):
18
+ example_id = examples[idx]['example_id']
19
+ entity_prompts = examples[idx]['local_prompt_list']
20
+ examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
21
+
22
+ def create_canvas_data(background, masks):
23
+ if background.shape[-1] == 3:
24
+ background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
25
+ layers = []
26
+ for mask in masks:
27
+ if mask is not None:
28
+ mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
29
+ layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
30
+ layer[..., -1] = mask_single_channel
31
+ layers.append(layer)
32
+ else:
33
+ layers.append(np.zeros_like(background))
34
+
35
+ composite = background.copy()
36
+ for layer in layers:
37
+ if layer.size > 0:
38
+ composite = np.where(layer[..., -1:] > 0, layer, composite)
39
+ return {
40
+ "background": background,
41
+ "layers": layers,
42
+ "composite": composite,
43
+ }
44
+
45
+ def load_example(load_example_button):
46
+ example_idx = int(load_example_button.split()[-1]) - 1
47
+ example = examples[example_idx]
48
+ result = [
49
+ 50,
50
+ example["global_prompt"],
51
+ example["negative_prompt"],
52
+ example["seed"],
53
+ *example["local_prompt_list"],
54
+ ]
55
+ num_entities = len(example["local_prompt_list"])
56
+ result += [""] * (config["max_num_painter_layers"] - num_entities)
57
+ masks = []
58
+ for mask in example["mask_lists"]:
59
+ mask_single_channel = np.array(mask.convert("L"))
60
+ masks.append(mask_single_channel)
61
+ for _ in range(config["max_num_painter_layers"] - len(masks)):
62
+ blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
63
+ masks.append(blank_mask)
64
+ background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
65
+ canvas_data_list = []
66
+ for mask in masks:
67
+ canvas_data = create_canvas_data(background, [mask])
68
+ canvas_data_list.append(canvas_data)
69
+ result.extend(canvas_data_list)
70
+ return result
71
+
72
+ def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
73
+ save_dir = os.path.join('workdirs/tmp_mask', random_dir)
74
+ print(f'save to {save_dir}')
75
+ os.makedirs(save_dir, exist_ok=True)
76
+ for i, mask in enumerate(masks):
77
+ save_path = os.path.join(save_dir, f'{i}.png')
78
+ mask.save(save_path)
79
+ sample = {
80
+ "global_prompt": global_prompt,
81
+ "mask_prompts": mask_prompts,
82
+ "seed": seed,
83
+ }
84
+ with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
85
+ json.dump(sample, f, indent=4)
86
+
87
+ def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
88
+ # Create a blank image for overlays
89
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
90
+ colors = [
91
+ (165, 238, 173, 80),
92
+ (76, 102, 221, 80),
93
+ (221, 160, 77, 80),
94
+ (204, 93, 71, 80),
95
+ (145, 187, 149, 80),
96
+ (134, 141, 172, 80),
97
+ (157, 137, 109, 80),
98
+ (153, 104, 95, 80),
99
+ (165, 238, 173, 80),
100
+ (76, 102, 221, 80),
101
+ (221, 160, 77, 80),
102
+ (204, 93, 71, 80),
103
+ (145, 187, 149, 80),
104
+ (134, 141, 172, 80),
105
+ (157, 137, 109, 80),
106
+ (153, 104, 95, 80),
107
+ ]
108
+ # Generate random colors for each mask
109
+ if use_random_colors:
110
+ colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
111
+ # Font settings
112
+ try:
113
+ font = ImageFont.truetype("arial", font_size) # Adjust as needed
114
+ except IOError:
115
+ font = ImageFont.load_default(font_size)
116
+ # Overlay each mask onto the overlay image
117
+ for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
118
+ if mask is None:
119
+ continue
120
+ # Convert mask to RGBA mode
121
+ mask_rgba = mask.convert('RGBA')
122
+ mask_data = mask_rgba.getdata()
123
+ new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
124
+ mask_rgba.putdata(new_data)
125
+ # Draw the mask prompt text on the mask
126
+ draw = ImageDraw.Draw(mask_rgba)
127
+ mask_bbox = mask.getbbox() # Get the bounding box of the mask
128
+ if mask_bbox is None:
129
+ continue
130
+ text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
131
+ draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
132
+ # Alpha composite the overlay with this mask
133
+ overlay = Image.alpha_composite(overlay, mask_rgba)
134
+ # Composite the overlay onto the original image
135
+ result = Image.alpha_composite(image.convert('RGBA'), overlay)
136
+ return result
137
+
138
+ config = {
139
+ "model_config": {
140
+ "FLUX": {
141
+ "model_folder": "models/FLUX",
142
+ "pipeline_class": FluxImagePipeline,
143
+ "default_parameters": {
144
+ "cfg_scale": 3.0,
145
+ "embedded_guidance": 3.5,
146
+ "num_inference_steps": 30,
147
+ }
148
+ },
149
+ },
150
+ "max_num_painter_layers": 8,
151
+ "max_num_model_cache": 1,
152
+ }
153
+
154
+ model_dict = {}
155
+
156
+ def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
157
+ global model_dict
158
+ model_key = f"{model_type}:{model_path}"
159
+ if model_key in model_dict:
160
+ return model_dict[model_key]
161
+ model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
162
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
163
+ model_manager.load_lora(
164
+ download_customized_models(
165
+ model_id="DiffSynth-Studio/Eligen",
166
+ origin_file_path="model_bf16.safetensors",
167
+ local_dir="models/lora/entity_control",
168
+ ),
169
+ lora_alpha=1,
170
+ )
171
+ pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
172
+ model_dict[model_key] = model_manager, pipe
173
+ return model_manager, pipe
174
+
175
+
176
+ with gr.Blocks() as app:
177
+ gr.Markdown(
178
+ """## EliGen: Entity-Level Controllable Text-to-Image Model
179
+ 1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
180
+ 2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
181
+ 3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
182
+ 4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
183
+ """
184
+ )
185
+
186
+ loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
187
+ main_interface = gr.Column(visible=False)
188
+
189
+ def initialize_model():
190
+ try:
191
+ load_model()
192
+ return {
193
+ loading_status: gr.update(value="Model loaded successfully!", visible=False),
194
+ main_interface: gr.update(visible=True),
195
+ }
196
+ except Exception as e:
197
+ print(f'Failed to load model with error: {e}')
198
+ return {
199
+ loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
200
+ main_interface: gr.update(visible=True),
201
+ }
202
+
203
+ app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
204
+
205
+ with main_interface:
206
+ with gr.Row():
207
+ local_prompt_list = []
208
+ canvas_list = []
209
+ random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
210
+ with gr.Column(scale=382, min_width=100):
211
+ model_type = gr.State('FLUX')
212
+ model_path = gr.State('FLUX.1-dev')
213
+ with gr.Accordion(label="Global prompt"):
214
+ prompt = gr.Textbox(label="Global Prompt", lines=3)
215
+ negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
216
+ with gr.Accordion(label="Inference Options", open=True):
217
+ seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
218
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
219
+ cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
220
+ embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
221
+ height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
222
+ width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
223
+ with gr.Accordion(label="Inpaint Input Image", open=False):
224
+ input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
225
+ background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
226
+
227
+ with gr.Column():
228
+ reset_input_button = gr.Button(value="Reset Inpaint Input")
229
+ send_input_to_painter = gr.Button(value="Set as painter's background")
230
+ @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
231
+ def reset_input_image(input_image):
232
+ return None
233
+
234
+ with gr.Column(scale=618, min_width=100):
235
+ with gr.Accordion(label="Entity Painter"):
236
+ for painter_layer_id in range(config["max_num_painter_layers"]):
237
+ with gr.Tab(label=f"Entity {painter_layer_id}"):
238
+ local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
239
+ canvas = gr.ImageEditor(
240
+ canvas_size=(512, 512),
241
+ sources=None,
242
+ layers=False,
243
+ interactive=True,
244
+ image_mode="RGBA",
245
+ brush=gr.Brush(
246
+ default_size=50,
247
+ default_color="#000000",
248
+ colors=["#000000"],
249
+ ),
250
+ label="Entity Mask Painter",
251
+ key=f"canvas_{painter_layer_id}",
252
+ width=width,
253
+ height=height,
254
+ )
255
+ @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
256
+ def resize_canvas(height, width, canvas):
257
+ h, w = canvas["background"].shape[:2]
258
+ if h != height or width != w:
259
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
260
+ else:
261
+ return canvas
262
+ local_prompt_list.append(local_prompt)
263
+ canvas_list.append(canvas)
264
+ with gr.Accordion(label="Results"):
265
+ run_button = gr.Button(value="Generate", variant="primary")
266
+ output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
267
+ with gr.Row():
268
+ with gr.Column():
269
+ output_to_painter_button = gr.Button(value="Set as painter's background")
270
+ with gr.Column():
271
+ return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
272
+ output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
273
+ real_output = gr.State(None)
274
+ mask_out = gr.State(None)
275
+
276
+ @gr.on(
277
+ inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
278
+ outputs=[output_image, real_output, mask_out],
279
+ triggers=run_button.click
280
+ )
281
+ def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
282
+ _, pipe = load_model(model_type, model_path)
283
+ input_params = {
284
+ "prompt": prompt,
285
+ "negative_prompt": negative_prompt,
286
+ "cfg_scale": cfg_scale,
287
+ "num_inference_steps": num_inference_steps,
288
+ "height": height,
289
+ "width": width,
290
+ "progress_bar_cmd": progress.tqdm,
291
+ }
292
+ if isinstance(pipe, FluxImagePipeline):
293
+ input_params["embedded_guidance"] = embedded_guidance
294
+ if input_image is not None:
295
+ input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
296
+ input_params["enable_eligen_inpaint"] = True
297
+
298
+ local_prompt_list, canvas_list = (
299
+ args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
300
+ args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
301
+ )
302
+ local_prompts, masks = [], []
303
+ for local_prompt, canvas in zip(local_prompt_list, canvas_list):
304
+ if isinstance(local_prompt, str) and len(local_prompt) > 0:
305
+ local_prompts.append(local_prompt)
306
+ masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
307
+ entity_masks = None if len(masks) == 0 else masks
308
+ entity_prompts = None if len(local_prompts) == 0 else local_prompts
309
+ input_params.update({
310
+ "eligen_entity_prompts": entity_prompts,
311
+ "eligen_entity_masks": entity_masks,
312
+ })
313
+ torch.manual_seed(seed)
314
+ # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
315
+ image = pipe(**input_params)
316
+ masks = [mask.resize(image.size) for mask in masks]
317
+ image_with_mask = visualize_masks(image, masks, local_prompts)
318
+
319
+ real_output = gr.State(image)
320
+ mask_out = gr.State(image_with_mask)
321
+
322
+ if return_with_mask:
323
+ return image_with_mask, real_output, mask_out
324
+ return image, real_output, mask_out
325
+
326
+ @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
327
+ def send_input_to_painter_background(input_image, *canvas_list):
328
+ if input_image is None:
329
+ return tuple(canvas_list)
330
+ for canvas in canvas_list:
331
+ h, w = canvas["background"].shape[:2]
332
+ canvas["background"] = input_image.resize((w, h))
333
+ return tuple(canvas_list)
334
+ @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
335
+ def send_output_to_painter_background(real_output, *canvas_list):
336
+ if real_output is None:
337
+ return tuple(canvas_list)
338
+ for canvas in canvas_list:
339
+ h, w = canvas["background"].shape[:2]
340
+ canvas["background"] = real_output.value.resize((w, h))
341
+ return tuple(canvas_list)
342
+ @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
343
+ def show_output(return_with_mask, real_output, mask_out):
344
+ if return_with_mask:
345
+ return mask_out.value
346
+ else:
347
+ return real_output.value
348
+ @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
349
+ def send_output_to_pipe_input(real_output):
350
+ return real_output.value
351
+
352
+ with gr.Column():
353
+ gr.Markdown("## Examples")
354
+ for i in range(0, len(examples), 2):
355
+ with gr.Row():
356
+ if i < len(examples):
357
+ example = examples[i]
358
+ with gr.Column():
359
+ example_image = gr.Image(
360
+ value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
361
+ label=example["description"],
362
+ interactive=False,
363
+ width=1024,
364
+ height=512
365
+ )
366
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
367
+ load_example_button.click(
368
+ load_example,
369
+ inputs=[load_example_button],
370
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
371
+ )
372
+
373
+ if i + 1 < len(examples):
374
+ example = examples[i + 1]
375
+ with gr.Column():
376
+ example_image = gr.Image(
377
+ value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
378
+ label=example["description"],
379
+ interactive=False,
380
+ width=1024,
381
+ height=512
382
+ )
383
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
384
+ load_example_button.click(
385
+ load_example,
386
+ inputs=[load_example_button],
387
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
388
+ )
389
+ app.config["show_progress"] = "hidden"
390
+ app.launch()
apps/gradio/qwen_image_eligen.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import random
6
+ import json
7
+ import gradio as gr
8
+ from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
9
+ from modelscope import dataset_snapshot_download, snapshot_download
10
+
11
+ # pip install pydantic==2.10.6
12
+ # pip install gradio==5.4.0
13
+
14
+ snapshot_download("DiffSynth-Studio/Qwen-Image-EliGen", local_dir="models/DiffSynth-Studio/Qwen-Image-EliGen", allow_file_pattern="model.safetensors")
15
+
16
+ dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/qwen-image/*")
17
+ example_json = 'data/examples/eligen/qwen-image/ui_examples.json'
18
+ with open(example_json, 'r') as f:
19
+ examples = json.load(f)['examples']
20
+
21
+ for idx in range(len(examples)):
22
+ example_id = examples[idx]['example_id']
23
+ entity_prompts = examples[idx]['local_prompt_list']
24
+ examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/qwen-image/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
25
+
26
+ def create_canvas_data(background, masks):
27
+ if background.shape[-1] == 3:
28
+ background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
29
+ layers = []
30
+ for mask in masks:
31
+ if mask is not None:
32
+ mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
33
+ layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
34
+ layer[..., -1] = mask_single_channel
35
+ layers.append(layer)
36
+ else:
37
+ layers.append(np.zeros_like(background))
38
+
39
+ composite = background.copy()
40
+ for layer in layers:
41
+ if layer.size > 0:
42
+ composite = np.where(layer[..., -1:] > 0, layer, composite)
43
+ return {
44
+ "background": background,
45
+ "layers": layers,
46
+ "composite": composite,
47
+ }
48
+
49
+ def load_example(load_example_button):
50
+ example_idx = int(load_example_button.split()[-1]) - 1
51
+ example = examples[example_idx]
52
+ result = [
53
+ 50,
54
+ example["global_prompt"],
55
+ example["negative_prompt"],
56
+ example["seed"],
57
+ *example["local_prompt_list"],
58
+ ]
59
+ num_entities = len(example["local_prompt_list"])
60
+ result += [""] * (config["max_num_painter_layers"] - num_entities)
61
+ masks = []
62
+ for mask in example["mask_lists"]:
63
+ mask_single_channel = np.array(mask.convert("L"))
64
+ masks.append(mask_single_channel)
65
+ for _ in range(config["max_num_painter_layers"] - len(masks)):
66
+ blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
67
+ masks.append(blank_mask)
68
+ background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
69
+ canvas_data_list = []
70
+ for mask in masks:
71
+ canvas_data = create_canvas_data(background, [mask])
72
+ canvas_data_list.append(canvas_data)
73
+ result.extend(canvas_data_list)
74
+ return result
75
+
76
+ def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
77
+ save_dir = os.path.join('workdirs/tmp_mask', random_dir)
78
+ print(f'save to {save_dir}')
79
+ os.makedirs(save_dir, exist_ok=True)
80
+ for i, mask in enumerate(masks):
81
+ save_path = os.path.join(save_dir, f'{i}.png')
82
+ mask.save(save_path)
83
+ sample = {
84
+ "global_prompt": global_prompt,
85
+ "mask_prompts": mask_prompts,
86
+ "seed": seed,
87
+ }
88
+ with open(os.path.join(save_dir, f"prompts.json"), 'w', encoding='utf-8') as f:
89
+ json.dump(sample, f, ensure_ascii=False, indent=4)
90
+
91
+ def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
92
+ # Create a blank image for overlays
93
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
94
+ colors = [
95
+ (165, 238, 173, 80),
96
+ (76, 102, 221, 80),
97
+ (221, 160, 77, 80),
98
+ (204, 93, 71, 80),
99
+ (145, 187, 149, 80),
100
+ (134, 141, 172, 80),
101
+ (157, 137, 109, 80),
102
+ (153, 104, 95, 80),
103
+ (165, 238, 173, 80),
104
+ (76, 102, 221, 80),
105
+ (221, 160, 77, 80),
106
+ (204, 93, 71, 80),
107
+ (145, 187, 149, 80),
108
+ (134, 141, 172, 80),
109
+ (157, 137, 109, 80),
110
+ (153, 104, 95, 80),
111
+ ]
112
+ # Generate random colors for each mask
113
+ if use_random_colors:
114
+ colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
115
+ # Font settings
116
+ try:
117
+ font = ImageFont.truetype("wqy-zenhei.ttc", font_size) # Adjust as needed
118
+ except IOError:
119
+ font = ImageFont.load_default(font_size)
120
+ # Overlay each mask onto the overlay image
121
+ for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
122
+ if mask is None:
123
+ continue
124
+ # Convert mask to RGBA mode
125
+ mask_rgba = mask.convert('RGBA')
126
+ mask_data = mask_rgba.getdata()
127
+ new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
128
+ mask_rgba.putdata(new_data)
129
+ # Draw the mask prompt text on the mask
130
+ draw = ImageDraw.Draw(mask_rgba)
131
+ mask_bbox = mask.getbbox() # Get the bounding box of the mask
132
+ if mask_bbox is None:
133
+ continue
134
+ text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
135
+ draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
136
+ # Alpha composite the overlay with this mask
137
+ overlay = Image.alpha_composite(overlay, mask_rgba)
138
+ # Composite the overlay onto the original image
139
+ result = Image.alpha_composite(image.convert('RGBA'), overlay)
140
+ return result
141
+
142
+ config = {
143
+ "max_num_painter_layers": 8,
144
+ "max_num_model_cache": 1,
145
+ }
146
+
147
+ model_dict = {}
148
+
149
+ def load_model(model_type='qwen-image'):
150
+ global model_dict
151
+ model_key = f"{model_type}"
152
+ if model_key in model_dict:
153
+ return model_dict[model_key]
154
+ pipe = QwenImagePipeline.from_pretrained(
155
+ torch_dtype=torch.bfloat16,
156
+ device="cuda",
157
+ model_configs=[
158
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
159
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
160
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
161
+ ],
162
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
163
+ )
164
+ pipe.load_lora(pipe.dit, "models/DiffSynth-Studio/Qwen-Image-EliGen/model.safetensors")
165
+ model_dict[model_key] = pipe
166
+ return pipe
167
+
168
+ load_model('qwen-image')
169
+
170
+ with gr.Blocks() as app:
171
+ gr.Markdown(
172
+ """## EliGen: Entity-Level Controllable Text-to-Image Model
173
+ 1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
174
+ 2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
175
+ 3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
176
+ 4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
177
+ """
178
+ )
179
+
180
+ loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
181
+ main_interface = gr.Column(visible=False)
182
+
183
+ def initialize_model():
184
+ try:
185
+ load_model('qwen-image')
186
+ return {
187
+ loading_status: gr.update(value="Model loaded successfully!", visible=False),
188
+ main_interface: gr.update(visible=True),
189
+ }
190
+ except Exception as e:
191
+ print(f'Failed to load model with error: {e}')
192
+ return {
193
+ loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
194
+ main_interface: gr.update(visible=True),
195
+ }
196
+
197
+ app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
198
+
199
+ with main_interface:
200
+ with gr.Row():
201
+ local_prompt_list = []
202
+ canvas_list = []
203
+ random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
204
+ with gr.Column(scale=382, min_width=100):
205
+ model_type = gr.State('qwen-image')
206
+ with gr.Accordion(label="Global prompt"):
207
+ prompt = gr.Textbox(label="Global Prompt", lines=3)
208
+ negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=3)
209
+ with gr.Accordion(label="Inference Options", open=True):
210
+ seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
211
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
212
+ cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=4.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
213
+ height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
214
+ width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
215
+ with gr.Accordion(label="Inpaint Input Image", open=False, visible=False):
216
+ input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
217
+ background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
218
+
219
+ with gr.Column():
220
+ reset_input_button = gr.Button(value="Reset Inpaint Input")
221
+ send_input_to_painter = gr.Button(value="Set as painter's background")
222
+ @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
223
+ def reset_input_image(input_image):
224
+ return None
225
+
226
+ with gr.Column(scale=618, min_width=100):
227
+ with gr.Accordion(label="Entity Painter"):
228
+ for painter_layer_id in range(config["max_num_painter_layers"]):
229
+ with gr.Tab(label=f"Entity {painter_layer_id}"):
230
+ local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
231
+ canvas = gr.ImageEditor(
232
+ canvas_size=(1024, 1024),
233
+ sources=None,
234
+ layers=False,
235
+ interactive=True,
236
+ image_mode="RGBA",
237
+ brush=gr.Brush(
238
+ default_size=50,
239
+ default_color="#000000",
240
+ colors=["#000000"],
241
+ ),
242
+ label="Entity Mask Painter",
243
+ key=f"canvas_{painter_layer_id}",
244
+ width=width,
245
+ height=height,
246
+ )
247
+ @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
248
+ def resize_canvas(height, width, canvas):
249
+ if canvas is None or canvas["background"] is None:
250
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
251
+ h, w = canvas["background"].shape[:2]
252
+ if h != height or width != w:
253
+ return np.ones((height, width, 3), dtype=np.uint8) * 255
254
+ else:
255
+ return canvas
256
+ local_prompt_list.append(local_prompt)
257
+ canvas_list.append(canvas)
258
+ with gr.Accordion(label="Results"):
259
+ run_button = gr.Button(value="Generate", variant="primary")
260
+ output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
261
+ with gr.Row():
262
+ with gr.Column():
263
+ output_to_painter_button = gr.Button(value="Set as painter's background")
264
+ with gr.Column():
265
+ return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
266
+ output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
267
+ real_output = gr.State(None)
268
+ mask_out = gr.State(None)
269
+
270
+ @gr.on(
271
+ inputs=[model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
272
+ outputs=[output_image, real_output, mask_out],
273
+ triggers=run_button.click
274
+ )
275
+ def generate_image(model_type, prompt, negative_prompt, cfg_scale, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
276
+ pipe = load_model(model_type)
277
+ input_params = {
278
+ "prompt": prompt,
279
+ "negative_prompt": negative_prompt,
280
+ "cfg_scale": cfg_scale,
281
+ "num_inference_steps": num_inference_steps,
282
+ "height": height,
283
+ "width": width,
284
+ "progress_bar_cmd": progress.tqdm,
285
+ }
286
+ # if input_image is not None:
287
+ # input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
288
+ # input_params["enable_eligen_inpaint"] = True
289
+
290
+ local_prompt_list, canvas_list = (
291
+ args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
292
+ args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
293
+ )
294
+ local_prompts, masks = [], []
295
+ for local_prompt, canvas in zip(local_prompt_list, canvas_list):
296
+ if isinstance(local_prompt, str) and len(local_prompt) > 0:
297
+ local_prompts.append(local_prompt)
298
+ masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
299
+ entity_prompts = None if len(local_prompts) == 0 else local_prompts
300
+ entity_masks = None if len(masks) == 0 or entity_prompts is None else masks
301
+ input_params.update({
302
+ "eligen_entity_prompts": entity_prompts,
303
+ "eligen_entity_masks": entity_masks,
304
+ })
305
+ torch.manual_seed(seed)
306
+ save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
307
+ image = pipe(**input_params)
308
+ masks = [mask.resize(image.size) for mask in masks]
309
+ image_with_mask = visualize_masks(image, masks, local_prompts)
310
+
311
+ real_output = gr.State(image)
312
+ mask_out = gr.State(image_with_mask)
313
+
314
+ if return_with_mask:
315
+ return image_with_mask, real_output, mask_out
316
+ return image, real_output, mask_out
317
+
318
+ @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
319
+ def send_input_to_painter_background(input_image, *canvas_list):
320
+ if input_image is None:
321
+ return tuple(canvas_list)
322
+ for canvas in canvas_list:
323
+ h, w = canvas["background"].shape[:2]
324
+ canvas["background"] = input_image.resize((w, h))
325
+ return tuple(canvas_list)
326
+ @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
327
+ def send_output_to_painter_background(real_output, *canvas_list):
328
+ if real_output is None:
329
+ return tuple(canvas_list)
330
+ for canvas in canvas_list:
331
+ h, w = canvas["background"].shape[:2]
332
+ canvas["background"] = real_output.value.resize((w, h))
333
+ return tuple(canvas_list)
334
+ @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
335
+ def show_output(return_with_mask, real_output, mask_out):
336
+ if return_with_mask:
337
+ return mask_out.value
338
+ else:
339
+ return real_output.value
340
+ @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
341
+ def send_output_to_pipe_input(real_output):
342
+ return real_output.value
343
+
344
+ with gr.Column():
345
+ gr.Markdown("## Examples")
346
+ for i in range(0, len(examples), 2):
347
+ with gr.Row():
348
+ if i < len(examples):
349
+ example = examples[i]
350
+ with gr.Column():
351
+ example_image = gr.Image(
352
+ value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
353
+ label=example["description"],
354
+ interactive=False,
355
+ width=1024,
356
+ height=512
357
+ )
358
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
359
+ load_example_button.click(
360
+ load_example,
361
+ inputs=[load_example_button],
362
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
363
+ )
364
+
365
+ if i + 1 < len(examples):
366
+ example = examples[i + 1]
367
+ with gr.Column():
368
+ example_image = gr.Image(
369
+ value=f"data/examples/eligen/qwen-image/example_{example['example_id']}/example_image.png",
370
+ label=example["description"],
371
+ interactive=False,
372
+ width=1024,
373
+ height=512
374
+ )
375
+ load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
376
+ load_example_button.click(
377
+ load_example,
378
+ inputs=[load_example_button],
379
+ outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
380
+ )
381
+ app.config["show_progress"] = "hidden"
382
+ app.launch(share=False)
apps/streamlit/DiffSynth_Studio.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set web page format
2
+ import streamlit as st
3
+ st.set_page_config(layout="wide")
4
+ # Disable virtual VRAM on windows system
5
+ import torch
6
+ torch.cuda.set_per_process_memory_fraction(0.999, 0)
7
+
8
+
9
+ st.markdown("""
10
+ # DiffSynth Studio
11
+
12
+ [Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
13
+
14
+ Welcome to DiffSynth Studio.
15
+ """)
apps/streamlit/pages/1_Image_Creator.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, io, json, time
2
+ import numpy as np
3
+ from PIL import Image
4
+ import streamlit as st
5
+ st.set_page_config(layout="wide")
6
+ from streamlit_drawable_canvas import st_canvas
7
+ from diffsynth.models import ModelManager
8
+ from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
9
+ from diffsynth.data.video import crop_and_resize
10
+
11
+
12
+ config = {
13
+ "Stable Diffusion": {
14
+ "model_folder": "models/stable_diffusion",
15
+ "pipeline_class": SDImagePipeline,
16
+ "fixed_parameters": {}
17
+ },
18
+ "Stable Diffusion XL": {
19
+ "model_folder": "models/stable_diffusion_xl",
20
+ "pipeline_class": SDXLImagePipeline,
21
+ "fixed_parameters": {}
22
+ },
23
+ "Stable Diffusion 3": {
24
+ "model_folder": "models/stable_diffusion_3",
25
+ "pipeline_class": SD3ImagePipeline,
26
+ "fixed_parameters": {}
27
+ },
28
+ "Stable Diffusion XL Turbo": {
29
+ "model_folder": "models/stable_diffusion_xl_turbo",
30
+ "pipeline_class": SDXLImagePipeline,
31
+ "fixed_parameters": {
32
+ "negative_prompt": "",
33
+ "cfg_scale": 1.0,
34
+ "num_inference_steps": 1,
35
+ "height": 512,
36
+ "width": 512,
37
+ }
38
+ },
39
+ "Kolors": {
40
+ "model_folder": "models/kolors",
41
+ "pipeline_class": SDXLImagePipeline,
42
+ "fixed_parameters": {}
43
+ },
44
+ "HunyuanDiT": {
45
+ "model_folder": "models/HunyuanDiT",
46
+ "pipeline_class": HunyuanDiTImagePipeline,
47
+ "fixed_parameters": {
48
+ "height": 1024,
49
+ "width": 1024,
50
+ }
51
+ },
52
+ "FLUX": {
53
+ "model_folder": "models/FLUX",
54
+ "pipeline_class": FluxImagePipeline,
55
+ "fixed_parameters": {
56
+ "cfg_scale": 1.0,
57
+ }
58
+ }
59
+ }
60
+
61
+
62
+ def load_model_list(model_type):
63
+ folder = config[model_type]["model_folder"]
64
+ file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
65
+ if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
66
+ file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
67
+ file_list = sorted(file_list)
68
+ return file_list
69
+
70
+
71
+ def release_model():
72
+ if "model_manager" in st.session_state:
73
+ st.session_state["model_manager"].to("cpu")
74
+ del st.session_state["loaded_model_path"]
75
+ del st.session_state["model_manager"]
76
+ del st.session_state["pipeline"]
77
+ torch.cuda.empty_cache()
78
+
79
+
80
+ def load_model(model_type, model_path):
81
+ model_manager = ModelManager()
82
+ if model_type == "HunyuanDiT":
83
+ model_manager.load_models([
84
+ os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
85
+ os.path.join(model_path, "mt5/pytorch_model.bin"),
86
+ os.path.join(model_path, "model/pytorch_model_ema.pt"),
87
+ os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
88
+ ])
89
+ elif model_type == "Kolors":
90
+ model_manager.load_models([
91
+ os.path.join(model_path, "text_encoder"),
92
+ os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
93
+ os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
94
+ ])
95
+ elif model_type == "FLUX":
96
+ model_manager.torch_dtype = torch.bfloat16
97
+ file_list = [
98
+ os.path.join(model_path, "text_encoder/model.safetensors"),
99
+ os.path.join(model_path, "text_encoder_2"),
100
+ ]
101
+ for file_name in os.listdir(model_path):
102
+ if file_name.endswith(".safetensors"):
103
+ file_list.append(os.path.join(model_path, file_name))
104
+ model_manager.load_models(file_list)
105
+ else:
106
+ model_manager.load_model(model_path)
107
+ pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
108
+ st.session_state.loaded_model_path = model_path
109
+ st.session_state.model_manager = model_manager
110
+ st.session_state.pipeline = pipeline
111
+ return model_manager, pipeline
112
+
113
+
114
+ def use_output_image_as_input(update=True):
115
+ # Search for input image
116
+ output_image_id = 0
117
+ selected_output_image = None
118
+ while True:
119
+ if f"use_output_as_input_{output_image_id}" not in st.session_state:
120
+ break
121
+ if st.session_state[f"use_output_as_input_{output_image_id}"]:
122
+ selected_output_image = st.session_state["output_images"][output_image_id]
123
+ break
124
+ output_image_id += 1
125
+ if update and selected_output_image is not None:
126
+ st.session_state["input_image"] = selected_output_image
127
+ return selected_output_image is not None
128
+
129
+
130
+ def apply_stroke_to_image(stroke_image, image):
131
+ image = np.array(image.convert("RGB")).astype(np.float32)
132
+ height, width, _ = image.shape
133
+
134
+ stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
135
+ weight = stroke_image[:, :, -1:] / 255
136
+ stroke_image = stroke_image[:, :, :-1]
137
+
138
+ image = stroke_image * weight + image * (1 - weight)
139
+ image = np.clip(image, 0, 255).astype(np.uint8)
140
+ image = Image.fromarray(image)
141
+ return image
142
+
143
+
144
+ @st.cache_data
145
+ def image2bits(image):
146
+ image_byte = io.BytesIO()
147
+ image.save(image_byte, format="PNG")
148
+ image_byte = image_byte.getvalue()
149
+ return image_byte
150
+
151
+
152
+ def show_output_image(image):
153
+ st.image(image, use_column_width="always")
154
+ st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
155
+ st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
156
+
157
+
158
+ column_input, column_output = st.columns(2)
159
+ with st.sidebar:
160
+ # Select a model
161
+ with st.expander("Model", expanded=True):
162
+ model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
163
+ fixed_parameters = config[model_type]["fixed_parameters"]
164
+ model_path_list = ["None"] + load_model_list(model_type)
165
+ model_path = st.selectbox("Model path", model_path_list)
166
+
167
+ # Load the model
168
+ if model_path == "None":
169
+ # No models are selected. Release VRAM.
170
+ st.markdown("No models are selected.")
171
+ release_model()
172
+ else:
173
+ # A model is selected.
174
+ model_path = os.path.join(config[model_type]["model_folder"], model_path)
175
+ if st.session_state.get("loaded_model_path", "") != model_path:
176
+ # The loaded model is not the selected model. Reload it.
177
+ st.markdown(f"Loading model at {model_path}.")
178
+ st.markdown("Please wait a moment...")
179
+ release_model()
180
+ model_manager, pipeline = load_model(model_type, model_path)
181
+ st.markdown("Done.")
182
+ else:
183
+ # The loaded model is not the selected model. Fetch it from `st.session_state`.
184
+ st.markdown(f"Loading model at {model_path}.")
185
+ st.markdown("Please wait a moment...")
186
+ model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
187
+ st.markdown("Done.")
188
+
189
+ # Show parameters
190
+ with st.expander("Prompt", expanded=True):
191
+ prompt = st.text_area("Positive prompt")
192
+ if "negative_prompt" in fixed_parameters:
193
+ negative_prompt = fixed_parameters["negative_prompt"]
194
+ else:
195
+ negative_prompt = st.text_area("Negative prompt")
196
+ if "cfg_scale" in fixed_parameters:
197
+ cfg_scale = fixed_parameters["cfg_scale"]
198
+ else:
199
+ cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
200
+ with st.expander("Image", expanded=True):
201
+ if "num_inference_steps" in fixed_parameters:
202
+ num_inference_steps = fixed_parameters["num_inference_steps"]
203
+ else:
204
+ num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
205
+ if "height" in fixed_parameters:
206
+ height = fixed_parameters["height"]
207
+ else:
208
+ height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
209
+ if "width" in fixed_parameters:
210
+ width = fixed_parameters["width"]
211
+ else:
212
+ width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
213
+ num_images = st.number_input("Number of images", value=2)
214
+ use_fixed_seed = st.checkbox("Use fixed seed", value=False)
215
+ if use_fixed_seed:
216
+ seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
217
+
218
+ # Other fixed parameters
219
+ denoising_strength = 1.0
220
+ repetition = 1
221
+
222
+
223
+ # Show input image
224
+ with column_input:
225
+ with st.expander("Input image (Optional)", expanded=True):
226
+ with st.container(border=True):
227
+ column_white_board, column_upload_image = st.columns([1, 2])
228
+ with column_white_board:
229
+ create_white_board = st.button("Create white board")
230
+ delete_input_image = st.button("Delete input image")
231
+ with column_upload_image:
232
+ upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
233
+
234
+ if upload_image is not None:
235
+ st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
236
+ elif create_white_board:
237
+ st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
238
+ else:
239
+ use_output_image_as_input()
240
+
241
+ if delete_input_image and "input_image" in st.session_state:
242
+ del st.session_state.input_image
243
+ if delete_input_image and "upload_image" in st.session_state:
244
+ del st.session_state.upload_image
245
+
246
+ input_image = st.session_state.get("input_image", None)
247
+ if input_image is not None:
248
+ with st.container(border=True):
249
+ column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
250
+ with column_drawing_mode:
251
+ drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
252
+ with column_color_1:
253
+ stroke_color = st.color_picker("Stroke color")
254
+ with column_color_2:
255
+ fill_color = st.color_picker("Fill color")
256
+ stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
257
+ with st.container(border=True):
258
+ denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
259
+ repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
260
+ with st.container(border=True):
261
+ input_width, input_height = input_image.size
262
+ canvas_result = st_canvas(
263
+ fill_color=fill_color,
264
+ stroke_width=stroke_width,
265
+ stroke_color=stroke_color,
266
+ background_color="rgba(255, 255, 255, 0)",
267
+ background_image=input_image,
268
+ update_streamlit=True,
269
+ height=int(512 / input_width * input_height),
270
+ width=512,
271
+ drawing_mode=drawing_mode,
272
+ key="canvas"
273
+ )
274
+
275
+ num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
276
+ local_prompts, masks, mask_scales = [], [], []
277
+ white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
278
+ painter_layers_json_data = []
279
+ for painter_tab_id in range(num_painter_layer):
280
+ with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
281
+ enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
282
+ local_prompt = st.text_area(f"Prompt {painter_tab_id}")
283
+ mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
284
+ stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
285
+ canvas_result_local = st_canvas(
286
+ fill_color="#000000",
287
+ stroke_width=stroke_width,
288
+ stroke_color="#000000",
289
+ background_color="rgba(255, 255, 255, 0)",
290
+ background_image=white_board,
291
+ update_streamlit=True,
292
+ height=512,
293
+ width=512,
294
+ drawing_mode="freedraw",
295
+ key=f"canvas_{painter_tab_id}"
296
+ )
297
+ if canvas_result_local.json_data is not None:
298
+ painter_layers_json_data.append(canvas_result_local.json_data.copy())
299
+ painter_layers_json_data[-1]["prompt"] = local_prompt
300
+ if enable_local_prompt:
301
+ local_prompts.append(local_prompt)
302
+ if canvas_result_local.image_data is not None:
303
+ mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
304
+ else:
305
+ mask = white_board
306
+ mask = Image.fromarray(255 - np.array(mask))
307
+ masks.append(mask)
308
+ mask_scales.append(mask_scale)
309
+ save_painter_layers = st.button("Save painter layers")
310
+ if save_painter_layers:
311
+ os.makedirs("data/painter_layers", exist_ok=True)
312
+ json_file_path = f"data/painter_layers/{time.time_ns()}.json"
313
+ with open(json_file_path, "w") as f:
314
+ json.dump(painter_layers_json_data, f, indent=4)
315
+ st.markdown(f"Painter layers are saved in {json_file_path}.")
316
+
317
+
318
+ with column_output:
319
+ run_button = st.button("Generate image", type="primary")
320
+ auto_update = st.checkbox("Auto update", value=False)
321
+ num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
322
+ image_columns = st.columns(num_image_columns)
323
+
324
+ # Run
325
+ if (run_button or auto_update) and model_path != "None":
326
+
327
+ if input_image is not None:
328
+ input_image = input_image.resize((width, height))
329
+ if canvas_result.image_data is not None:
330
+ input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
331
+
332
+ output_images = []
333
+ for image_id in range(num_images * repetition):
334
+ if use_fixed_seed:
335
+ torch.manual_seed(seed + image_id)
336
+ else:
337
+ torch.manual_seed(np.random.randint(0, 10**9))
338
+ if image_id >= num_images:
339
+ input_image = output_images[image_id - num_images]
340
+ with image_columns[image_id % num_image_columns]:
341
+ progress_bar_st = st.progress(0.0)
342
+ image = pipeline(
343
+ prompt, negative_prompt=negative_prompt,
344
+ local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
345
+ cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
346
+ height=height, width=width,
347
+ input_image=input_image, denoising_strength=denoising_strength,
348
+ progress_bar_st=progress_bar_st
349
+ )
350
+ output_images.append(image)
351
+ progress_bar_st.progress(1.0)
352
+ show_output_image(image)
353
+ st.session_state["output_images"] = output_images
354
+
355
+ elif "output_images" in st.session_state:
356
+ for image_id in range(len(st.session_state.output_images)):
357
+ with image_columns[image_id % num_image_columns]:
358
+ image = st.session_state.output_images[image_id]
359
+ progress_bar = st.progress(1.0)
360
+ show_output_image(image)
361
+ if "upload_image" in st.session_state and use_output_image_as_input(update=False):
362
+ st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")
apps/streamlit/pages/2_Video_Creator.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ st.set_page_config(layout="wide")
3
+ from diffsynth import SDVideoPipelineRunner
4
+ import os
5
+ import numpy as np
6
+
7
+
8
+ def load_model_list(folder):
9
+ file_list = os.listdir(folder)
10
+ file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
11
+ file_list = sorted(file_list)
12
+ return file_list
13
+
14
+
15
+ def match_processor_id(model_name, supported_processor_id_list):
16
+ sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
17
+ for processor_id in sorted_processor_id:
18
+ if processor_id in model_name:
19
+ return supported_processor_id_list.index(processor_id) + 1
20
+ return 0
21
+
22
+
23
+ config = {
24
+ "models": {
25
+ "model_list": [],
26
+ "textual_inversion_folder": "models/textual_inversion",
27
+ "device": "cuda",
28
+ "lora_alphas": [],
29
+ "controlnet_units": []
30
+ },
31
+ "data": {
32
+ "input_frames": None,
33
+ "controlnet_frames": [],
34
+ "output_folder": "output",
35
+ "fps": 60
36
+ },
37
+ "pipeline": {
38
+ "seed": 0,
39
+ "pipeline_inputs": {}
40
+ }
41
+ }
42
+
43
+
44
+ with st.expander("Model", expanded=True):
45
+ stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
46
+ if stable_diffusion_ckpt != "None":
47
+ config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
48
+ animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
49
+ if animatediff_ckpt != "None":
50
+ config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
51
+ column_lora, column_lora_alpha = st.columns([2, 1])
52
+ with column_lora:
53
+ sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
54
+ with column_lora_alpha:
55
+ lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
56
+ if sd_lora_ckpt != "None":
57
+ config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
58
+ config["models"]["lora_alphas"].append(lora_alpha)
59
+
60
+
61
+ with st.expander("Data", expanded=True):
62
+ with st.container(border=True):
63
+ input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
64
+ column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
65
+ with column_height:
66
+ height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
67
+ with column_width:
68
+ width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
69
+ with column_start_frame_index:
70
+ start_frame_id = st.number_input("Start Frame id", value=0)
71
+ with column_end_frame_index:
72
+ end_frame_id = st.number_input("End Frame id", value=16)
73
+ if input_video != "":
74
+ config["data"]["input_frames"] = {
75
+ "video_file": input_video,
76
+ "image_folder": None,
77
+ "height": height,
78
+ "width": width,
79
+ "start_frame_id": start_frame_id,
80
+ "end_frame_id": end_frame_id
81
+ }
82
+ with st.container(border=True):
83
+ output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
84
+ fps = st.number_input("FPS", value=60)
85
+ config["data"]["output_folder"] = output_video
86
+ config["data"]["fps"] = fps
87
+
88
+
89
+ with st.expander("ControlNet Units", expanded=True):
90
+ supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
91
+ controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
92
+ for controlnet_id in range(len(controlnet_units)):
93
+ with controlnet_units[controlnet_id]:
94
+ controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
95
+ key=f"controlnet_ckpt_{controlnet_id}")
96
+ processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
97
+ index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
98
+ disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
99
+ controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
100
+ disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
101
+ use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
102
+ disabled=controlnet_ckpt == "None",
103
+ key=f"use_input_video_as_controlnet_input_{controlnet_id}")
104
+ if not use_input_video_as_controlnet_input:
105
+ controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
106
+ disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
107
+ column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
108
+ with column_height:
109
+ height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
110
+ disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
111
+ with column_width:
112
+ width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
113
+ disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
114
+ with column_start_frame_index:
115
+ start_frame_id = st.number_input("Start Frame id", value=0,
116
+ disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
117
+ with column_end_frame_index:
118
+ end_frame_id = st.number_input("End Frame id", value=16,
119
+ disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
120
+ if input_video != "":
121
+ config["data"]["input_video"] = {
122
+ "video_file": input_video,
123
+ "image_folder": None,
124
+ "height": height,
125
+ "width": width,
126
+ "start_frame_id": start_frame_id,
127
+ "end_frame_id": end_frame_id
128
+ }
129
+ if controlnet_ckpt != "None":
130
+ config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
131
+ config["models"]["controlnet_units"].append({
132
+ "processor_id": processor_id,
133
+ "model_path": os.path.join("models/ControlNet", controlnet_ckpt),
134
+ "scale": controlnet_scale,
135
+ })
136
+ if use_input_video_as_controlnet_input:
137
+ config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
138
+ else:
139
+ config["data"]["controlnet_frames"].append({
140
+ "video_file": input_video,
141
+ "image_folder": None,
142
+ "height": height,
143
+ "width": width,
144
+ "start_frame_id": start_frame_id,
145
+ "end_frame_id": end_frame_id
146
+ })
147
+
148
+
149
+ with st.container(border=True):
150
+ with st.expander("Seed", expanded=True):
151
+ use_fixed_seed = st.checkbox("Use fixed seed", value=False)
152
+ if use_fixed_seed:
153
+ seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
154
+ else:
155
+ seed = np.random.randint(0, 10**9)
156
+ with st.expander("Textual Guidance", expanded=True):
157
+ prompt = st.text_area("Positive prompt")
158
+ negative_prompt = st.text_area("Negative prompt")
159
+ column_cfg_scale, column_clip_skip = st.columns(2)
160
+ with column_cfg_scale:
161
+ cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
162
+ with column_clip_skip:
163
+ clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
164
+ with st.expander("Denoising", expanded=True):
165
+ column_num_inference_steps, column_denoising_strength = st.columns(2)
166
+ with column_num_inference_steps:
167
+ num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
168
+ with column_denoising_strength:
169
+ denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
170
+ with st.expander("Efficiency", expanded=False):
171
+ animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
172
+ animatediff_stride = st.slider("Animatediff stride",
173
+ min_value=1,
174
+ max_value=max(2, animatediff_batch_size),
175
+ value=max(1, animatediff_batch_size // 2),
176
+ step=1)
177
+ unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
178
+ controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
179
+ cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
180
+ config["pipeline"]["seed"] = seed
181
+ config["pipeline"]["pipeline_inputs"] = {
182
+ "prompt": prompt,
183
+ "negative_prompt": negative_prompt,
184
+ "cfg_scale": cfg_scale,
185
+ "clip_skip": clip_skip,
186
+ "denoising_strength": denoising_strength,
187
+ "num_inference_steps": num_inference_steps,
188
+ "animatediff_batch_size": animatediff_batch_size,
189
+ "animatediff_stride": animatediff_stride,
190
+ "unet_batch_size": unet_batch_size,
191
+ "controlnet_batch_size": controlnet_batch_size,
192
+ "cross_frame_attention": cross_frame_attention,
193
+ }
194
+
195
+ run_button = st.button("☢️Run☢️", type="primary")
196
+ if run_button:
197
+ SDVideoPipelineRunner(in_streamlit=True).run(config)
deal1.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import sys
4
+ from datetime import timedelta
5
+
6
+ def extract_frames_per_second(video_path, output_dir, interval_seconds=3):
7
+ """
8
+ 从视频中每隔 interval_seconds 提取一帧并保存到指定目录(按时间定位,保证时间戳单调递增)。
9
+ :param video_path: 视频文件的路径
10
+ :param output_dir: 帧图片的保存目录
11
+ :param interval_seconds: 每隔多少秒保存一帧(默认3秒)
12
+ """
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ # 1. 如果输出目录已存在且非空则跳过,避免重复处理
15
+ if os.path.exists(output_dir) and os.listdir(output_dir):
16
+ print(f"输出目录已存在且非空,跳过提取:{os.path.abspath(output_dir)}")
17
+ return
18
+ os.makedirs(output_dir, exist_ok=True)
19
+ print(f"帧保存目录:{os.path.abspath(output_dir)}")
20
+
21
+ # 2. 打开视频文件
22
+ cap = cv2.VideoCapture(video_path)
23
+ if not cap.isOpened():
24
+ raise ValueError(f"无法打开视频文件:{video_path}")
25
+
26
+ # 3. 获取视频基本信息
27
+ fps = cap.get(cv2.CAP_PROP_FPS) # 视频帧率(每秒帧数)
28
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 视频总帧数
29
+ duration = total_frames / fps if fps > 0 else 0 # 视频总时长(秒)
30
+ print(f"视频信息:帧率={fps:.2f} FPS | 总帧数={total_frames} | 总时长={timedelta(seconds=duration)}")
31
+
32
+ if fps <= 0:
33
+ raise ValueError("无法获取视频帧率,视频文件可能损坏或格式不支持")
34
+
35
+ saved_count = 0 # 已保存的帧序号
36
+
37
+ try:
38
+ t = 0.0 # 当前时间点(秒)
39
+ # 使用按时间定位的方式读取帧,避免帧计数舍入或读取跳帧导致时间戳混乱
40
+ while t <= duration:
41
+ # 定位到指定毫秒位置(更可靠地获取指定时间的帧)
42
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ # 定位或读取失败,跳出循环
46
+ break
47
+
48
+ # 使用时间戳作为文件名的一部分,保证按时间顺序保存
49
+ frame_filename = f"{saved_count:06d}_{t:.2f}s.jpg"
50
+ frame_path = os.path.join(output_dir, frame_filename)
51
+
52
+ # 保存帧图片
53
+ cv2.imwrite(frame_path, frame)
54
+ saved_count += 1
55
+
56
+ # 打印进度(每10帧打印一次,避免刷屏)
57
+ if saved_count % 10 == 0:
58
+ progress = (t / duration) * 100 if duration > 0 else 0
59
+ print(f"进度:{progress:.1f}% | 已保存 {saved_count} 帧 | 时间:{t:.2f}s")
60
+
61
+ t += interval_seconds
62
+
63
+ except Exception as e:
64
+ raise RuntimeError(f"提取帧时发生错误:{str(e)}")
65
+ finally:
66
+ # 释放视频资源
67
+ cap.release()
68
+ cv2.destroyAllWindows()
69
+
70
+ # 打印最终结果
71
+ print(f"\n提取完成!共保存 {saved_count} 帧,保存路径:{os.path.abspath(output_dir)}")
72
+
73
+ output_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/no other choice","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/the roses","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/nouvelle","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/legs","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/frankenstein"]
74
+ input_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/어쩔수가없다 NO OTHER CHOICE, 2025.1080p.WEB-DL.H264.AAC.mp4","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/The.Roses.2025.2160p.WEB-DL.DDP5.1.Atmos.SDR.H265-AOC/The.Roses.2025.2160p.WEB-DL.DDP5.1.Atmos.SDR.H265-AOC.mkv","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/NOUVELLE.VAGUE.2025.2160p.NF.WEB-DL.DDP.5.1.H.265-CHDWEB[PianYuan]/NOUVELLE.VAGUE.2025.2160p.NF.WEB-DL.DDP.5.1.H.265-CHDWEB.mkv","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/If.I.Had.Legs.Id.Kick.You.2025.1080p.iT.WEB-DL.DDP5.1.Atmos.H264-BTM/If.I.Had.Legs.Id.Kick.You.2025.1080p.iT.WEB-DL[Ben The Men].mkv","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/films/Frankenstein.2025.1080p.NF.WEB-DL.DDP5.1.Atmos.H.264-FLUX/Frankenstein.2025.1080p.NF.WEB-DL.DDP5.1.Atmos.H.264-FLUX.mkv"]
75
+ if __name__ == "__main__":
76
+ # 执行帧提取(每3秒保存一帧)
77
+ for i in range(len(output_dirs)):
78
+ try:
79
+ extract_frames_per_second(video_path=input_dirs[i],output_dir=output_dirs[i],interval_seconds=3)
80
+ except Exception as e:
81
+ print(f"程序异常:{str(e)}", file=sys.stderr)
82
+ sys.exit(1)
deal2.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ from pathlib import Path
5
+ import imagehash
6
+ from PIL import Image
7
+
8
+ def calculate_phash(image_path):
9
+ """
10
+ 计算图片的感知哈希值(pHash)
11
+ :param image_path: 图片路径
12
+ :return: 感知哈希值(imagehash.ImageHash对象)
13
+ """
14
+ try:
15
+ # 用PIL读取图片(兼容更多格式),转为灰度图
16
+ img = Image.open(image_path).convert("L")
17
+ # 计算pHash,hash_size越小,哈希值越短,计算越快(默认8,生成64位哈希)
18
+ phash = imagehash.phash(img, hash_size=8)
19
+ return phash
20
+ except Exception as e:
21
+ print(f"计算哈希失败:{image_path},错误:{e}")
22
+ return None
23
+
24
+ def calculate_clarity(image_path):
25
+ """
26
+ 拉普拉斯方差法计算图片清晰度评分
27
+ :param image_path: 图片路径
28
+ :return: 清晰度评分(方差值),若读取失败返回0
29
+ """
30
+ img = cv2.imread(image_path)
31
+ if img is None:
32
+ return 0
33
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
34
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
35
+ clarity_score = np.var(laplacian)
36
+ return clarity_score
37
+
38
+ def calculate_hamming_distance(hash1, hash2):
39
+ """
40
+ 计算两个哈希值的汉明距离
41
+ :param hash1, hash2: 感知哈希值(imagehash.ImageHash对象)
42
+ :return: 汉明距离(越小越相似)
43
+ """
44
+ if hash1 is None or hash2 is None:
45
+ return float("inf") # 哈希计算失败,视为不相似
46
+ return hash1 - hash2
47
+
48
+ def process_duplicate_frames(input_dir, output_dir, similarity_threshold=10):
49
+ """
50
+ 处理视频截帧,去除相似帧并保留更清晰的图片
51
+ :param input_dir: 原始图片文件夹路径
52
+ :param output_dir: 去重后保存的文件夹路径
53
+ :param similarity_threshold: 相似度阈值(汉明距离≤该值为相似帧)
54
+ """
55
+ # 创建输出目录
56
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
57
+
58
+ # 1. 获取文件夹内所有图片,按文件名排序(保证视频截帧的顺序)
59
+ img_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
60
+ img_paths = [
61
+ os.path.join(input_dir, f) for f in os.listdir(input_dir)
62
+ if Path(f).suffix.lower() in img_extensions
63
+ ]
64
+ # 按文件名排序(关键:保证视频截帧的时间顺序)
65
+ img_paths.sort(key=lambda x: os.path.basename(x))
66
+
67
+ if len(img_paths) == 0:
68
+ print("文件夹中未找到图片!")
69
+ return
70
+
71
+ # 2. 初始化:保留第一张图片作为基准,遍历后续图片
72
+ saved_img_path = img_paths[0] # 已保存的基准图片路径
73
+ saved_phash = calculate_phash(saved_img_path) # 基准图片的哈希值
74
+ saved_clarity = calculate_clarity(saved_img_path) # 基准图片的清晰度
75
+
76
+ # 保存第一张图片
77
+ save_name = os.path.basename(saved_img_path)
78
+ cv2.imwrite(os.path.join(output_dir, save_name), cv2.imread(saved_img_path))
79
+ print(f"初始保存:{save_name},清晰度:{saved_clarity:.2f}")
80
+
81
+ # 3. 从第二张开始遍历,逐张对比
82
+ for current_img_path in img_paths[1:]:
83
+ current_name = os.path.basename(current_img_path)
84
+ current_phash = calculate_phash(current_img_path)
85
+ current_clarity = calculate_clarity(current_img_path)
86
+
87
+ # 计算与基准图片的汉明距离
88
+ hamming_dist = calculate_hamming_distance(saved_phash, current_phash)
89
+ print(f"\n对比:{saved_img_path.split('/')[-1]} vs {current_name}")
90
+ print(f"汉明距离:{hamming_dist},当前图片清晰度:{current_clarity:.2f}")
91
+
92
+ if hamming_dist <= similarity_threshold:
93
+ # 相似帧:保留清晰度更高的图片
94
+ if current_clarity > saved_clarity:
95
+ # 当前图片更清晰:删除原基准图片,保存当前图片作为新基准
96
+ os.remove(os.path.join(output_dir, os.path.basename(saved_img_path)))
97
+ cv2.imwrite(os.path.join(output_dir, current_name), cv2.imread(current_img_path))
98
+ print(f"替换:{current_name} 更清晰,已替换原基准图片")
99
+ # 更新基准信息
100
+ saved_img_path = current_img_path
101
+ saved_phash = current_phash
102
+ saved_clarity = current_clarity
103
+ else:
104
+ # 当前图片更模糊:跳过,保留原基准
105
+ print(f"跳过:{current_name} 模糊,保留原基准图片")
106
+ else:
107
+ # 非相似帧:保存当前图片,作为新的基准
108
+ cv2.imwrite(os.path.join(output_dir, current_name), cv2.imread(current_img_path))
109
+ print(f"保存:{current_name} 为新基准,与原基准非相似帧")
110
+ # 更新基准信息
111
+ saved_img_path = current_img_path
112
+ saved_phash = current_phash
113
+ saved_clarity = current_clarity
114
+
115
+ print(f"\n处理完成!去重后图片保存在:{output_dir}")
116
+ print(f"原始图片数量:{len(img_paths)},去重后数量:{len(os.listdir(output_dir))}")
117
+
118
+ # 主函数调用
119
+ if __name__ == "__main__":
120
+ input_dirs = ["/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/no other choice","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/the roses","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/nouvelle","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/legs","/fi-lib/workspace/sjx/DiffSynth-Studio/dataset/frankenstein"]
121
+ # 配置参数(根据实际情况修改)
122
+ for i in range(len(input_dirs)):
123
+
124
+ SIMILARITY_THRESHOLD = 10 # 相似性阈值(汉明距离),可调整
125
+
126
+ # 执行去重处理
127
+ process_duplicate_frames(input_dirs[i], f"{input_dirs[i]}_dedup", SIMILARITY_THRESHOLD)
diffsynth.egg-info/PKG-INFO ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: diffsynth
3
+ Version: 1.1.9
4
+ Summary: Enjoy the magic of Diffusion models!
5
+ Author: Artiprocher
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: Apache Software License
8
+ Classifier: Operating System :: OS Independent
9
+ Requires-Python: >=3.6
10
+ License-File: LICENSE
11
+ Requires-Dist: torch>=2.0.0
12
+ Requires-Dist: torchvision
13
+ Requires-Dist: transformers
14
+ Requires-Dist: imageio
15
+ Requires-Dist: imageio[ffmpeg]
16
+ Requires-Dist: safetensors
17
+ Requires-Dist: einops
18
+ Requires-Dist: sentencepiece
19
+ Requires-Dist: protobuf
20
+ Requires-Dist: modelscope
21
+ Requires-Dist: ftfy
22
+ Requires-Dist: pynvml
23
+ Requires-Dist: pandas
24
+ Requires-Dist: accelerate
25
+ Requires-Dist: peft
26
+ Dynamic: author
27
+ Dynamic: classifier
28
+ Dynamic: license-file
29
+ Dynamic: requires-dist
30
+ Dynamic: requires-python
31
+ Dynamic: summary
diffsynth.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ diffsynth/__init__.py
5
+ diffsynth.egg-info/PKG-INFO
6
+ diffsynth.egg-info/SOURCES.txt
7
+ diffsynth.egg-info/dependency_links.txt
8
+ diffsynth.egg-info/requires.txt
9
+ diffsynth.egg-info/top_level.txt
10
+ diffsynth/configs/__init__.py
11
+ diffsynth/configs/model_config.py
12
+ diffsynth/controlnets/__init__.py
13
+ diffsynth/controlnets/controlnet_unit.py
14
+ diffsynth/controlnets/processors.py
15
+ diffsynth/data/__init__.py
16
+ diffsynth/data/simple_text_image.py
17
+ diffsynth/data/video.py
18
+ diffsynth/distributed/__init__.py
19
+ diffsynth/distributed/xdit_context_parallel.py
20
+ diffsynth/extensions/__init__.py
21
+ diffsynth/extensions/ESRGAN/__init__.py
22
+ diffsynth/extensions/FastBlend/__init__.py
23
+ diffsynth/extensions/FastBlend/api.py
24
+ diffsynth/extensions/FastBlend/cupy_kernels.py
25
+ diffsynth/extensions/FastBlend/data.py
26
+ diffsynth/extensions/FastBlend/patch_match.py
27
+ diffsynth/extensions/FastBlend/runners/__init__.py
28
+ diffsynth/extensions/FastBlend/runners/accurate.py
29
+ diffsynth/extensions/FastBlend/runners/balanced.py
30
+ diffsynth/extensions/FastBlend/runners/fast.py
31
+ diffsynth/extensions/FastBlend/runners/interpolation.py
32
+ diffsynth/extensions/ImageQualityMetric/__init__.py
33
+ diffsynth/extensions/ImageQualityMetric/aesthetic.py
34
+ diffsynth/extensions/ImageQualityMetric/clip.py
35
+ diffsynth/extensions/ImageQualityMetric/config.py
36
+ diffsynth/extensions/ImageQualityMetric/hps.py
37
+ diffsynth/extensions/ImageQualityMetric/imagereward.py
38
+ diffsynth/extensions/ImageQualityMetric/mps.py
39
+ diffsynth/extensions/ImageQualityMetric/pickscore.py
40
+ diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
41
+ diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
42
+ diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
43
+ diffsynth/extensions/ImageQualityMetric/BLIP/med.py
44
+ diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
45
+ diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
46
+ diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
47
+ diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
48
+ diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
49
+ diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py
50
+ diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
51
+ diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
52
+ diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
53
+ diffsynth/extensions/ImageQualityMetric/open_clip/model.py
54
+ diffsynth/extensions/ImageQualityMetric/open_clip/modified_resnet.py
55
+ diffsynth/extensions/ImageQualityMetric/open_clip/openai.py
56
+ diffsynth/extensions/ImageQualityMetric/open_clip/pretrained.py
57
+ diffsynth/extensions/ImageQualityMetric/open_clip/push_to_hf_hub.py
58
+ diffsynth/extensions/ImageQualityMetric/open_clip/timm_model.py
59
+ diffsynth/extensions/ImageQualityMetric/open_clip/tokenizer.py
60
+ diffsynth/extensions/ImageQualityMetric/open_clip/transform.py
61
+ diffsynth/extensions/ImageQualityMetric/open_clip/transformer.py
62
+ diffsynth/extensions/ImageQualityMetric/open_clip/utils.py
63
+ diffsynth/extensions/ImageQualityMetric/open_clip/version.py
64
+ diffsynth/extensions/ImageQualityMetric/trainer/__init__.py
65
+ diffsynth/extensions/ImageQualityMetric/trainer/models/__init__.py
66
+ diffsynth/extensions/ImageQualityMetric/trainer/models/base_model.py
67
+ diffsynth/extensions/ImageQualityMetric/trainer/models/clip_model.py
68
+ diffsynth/extensions/ImageQualityMetric/trainer/models/cross_modeling.py
69
+ diffsynth/extensions/RIFE/__init__.py
70
+ diffsynth/lora/__init__.py
71
+ diffsynth/lora/flux_lora.py
72
+ diffsynth/models/__init__.py
73
+ diffsynth/models/attention.py
74
+ diffsynth/models/cog_dit.py
75
+ diffsynth/models/cog_vae.py
76
+ diffsynth/models/downloader.py
77
+ diffsynth/models/flux_controlnet.py
78
+ diffsynth/models/flux_dit.py
79
+ diffsynth/models/flux_infiniteyou.py
80
+ diffsynth/models/flux_ipadapter.py
81
+ diffsynth/models/flux_lora_encoder.py
82
+ diffsynth/models/flux_text_encoder.py
83
+ diffsynth/models/flux_vae.py
84
+ diffsynth/models/flux_value_control.py
85
+ diffsynth/models/hunyuan_dit.py
86
+ diffsynth/models/hunyuan_dit_text_encoder.py
87
+ diffsynth/models/hunyuan_video_dit.py
88
+ diffsynth/models/hunyuan_video_text_encoder.py
89
+ diffsynth/models/hunyuan_video_vae_decoder.py
90
+ diffsynth/models/hunyuan_video_vae_encoder.py
91
+ diffsynth/models/kolors_text_encoder.py
92
+ diffsynth/models/longcat_video_dit.py
93
+ diffsynth/models/lora.py
94
+ diffsynth/models/model_manager.py
95
+ diffsynth/models/nexus_gen.py
96
+ diffsynth/models/nexus_gen_ar_model.py
97
+ diffsynth/models/nexus_gen_projector.py
98
+ diffsynth/models/omnigen.py
99
+ diffsynth/models/qwen_image_controlnet.py
100
+ diffsynth/models/qwen_image_dit.py
101
+ diffsynth/models/qwen_image_text_encoder.py
102
+ diffsynth/models/qwen_image_vae.py
103
+ diffsynth/models/qwenvl.py
104
+ diffsynth/models/sd3_dit.py
105
+ diffsynth/models/sd3_text_encoder.py
106
+ diffsynth/models/sd3_vae_decoder.py
107
+ diffsynth/models/sd3_vae_encoder.py
108
+ diffsynth/models/sd_controlnet.py
109
+ diffsynth/models/sd_ipadapter.py
110
+ diffsynth/models/sd_motion.py
111
+ diffsynth/models/sd_text_encoder.py
112
+ diffsynth/models/sd_unet.py
113
+ diffsynth/models/sd_vae_decoder.py
114
+ diffsynth/models/sd_vae_encoder.py
115
+ diffsynth/models/sdxl_controlnet.py
116
+ diffsynth/models/sdxl_ipadapter.py
117
+ diffsynth/models/sdxl_motion.py
118
+ diffsynth/models/sdxl_text_encoder.py
119
+ diffsynth/models/sdxl_unet.py
120
+ diffsynth/models/sdxl_vae_decoder.py
121
+ diffsynth/models/sdxl_vae_encoder.py
122
+ diffsynth/models/step1x_connector.py
123
+ diffsynth/models/stepvideo_dit.py
124
+ diffsynth/models/stepvideo_text_encoder.py
125
+ diffsynth/models/stepvideo_vae.py
126
+ diffsynth/models/svd_image_encoder.py
127
+ diffsynth/models/svd_unet.py
128
+ diffsynth/models/svd_vae_decoder.py
129
+ diffsynth/models/svd_vae_encoder.py
130
+ diffsynth/models/tiler.py
131
+ diffsynth/models/utils.py
132
+ diffsynth/models/wan_video_animate_adapter.py
133
+ diffsynth/models/wan_video_camera_controller.py
134
+ diffsynth/models/wan_video_dit.py
135
+ diffsynth/models/wan_video_dit_s2v.py
136
+ diffsynth/models/wan_video_image_encoder.py
137
+ diffsynth/models/wan_video_mot.py
138
+ diffsynth/models/wan_video_motion_controller.py
139
+ diffsynth/models/wan_video_text_encoder.py
140
+ diffsynth/models/wan_video_vace.py
141
+ diffsynth/models/wan_video_vae.py
142
+ diffsynth/models/wav2vec.py
143
+ diffsynth/pipelines/__init__.py
144
+ diffsynth/pipelines/base.py
145
+ diffsynth/pipelines/cog_video.py
146
+ diffsynth/pipelines/dancer.py
147
+ diffsynth/pipelines/flux_image.py
148
+ diffsynth/pipelines/flux_image_new.py
149
+ diffsynth/pipelines/hunyuan_image.py
150
+ diffsynth/pipelines/hunyuan_video.py
151
+ diffsynth/pipelines/omnigen_image.py
152
+ diffsynth/pipelines/pipeline_runner.py
153
+ diffsynth/pipelines/qwen_image.py
154
+ diffsynth/pipelines/sd3_image.py
155
+ diffsynth/pipelines/sd_image.py
156
+ diffsynth/pipelines/sd_video.py
157
+ diffsynth/pipelines/sdxl_image.py
158
+ diffsynth/pipelines/sdxl_video.py
159
+ diffsynth/pipelines/step_video.py
160
+ diffsynth/pipelines/svd_video.py
161
+ diffsynth/pipelines/wan_video.py
162
+ diffsynth/pipelines/wan_video_new.py
163
+ diffsynth/processors/FastBlend.py
164
+ diffsynth/processors/PILEditor.py
165
+ diffsynth/processors/RIFE.py
166
+ diffsynth/processors/__init__.py
167
+ diffsynth/processors/base.py
168
+ diffsynth/processors/sequencial_processor.py
169
+ diffsynth/prompters/__init__.py
170
+ diffsynth/prompters/base_prompter.py
171
+ diffsynth/prompters/cog_prompter.py
172
+ diffsynth/prompters/flux_prompter.py
173
+ diffsynth/prompters/hunyuan_dit_prompter.py
174
+ diffsynth/prompters/hunyuan_video_prompter.py
175
+ diffsynth/prompters/kolors_prompter.py
176
+ diffsynth/prompters/omnigen_prompter.py
177
+ diffsynth/prompters/omost.py
178
+ diffsynth/prompters/prompt_refiners.py
179
+ diffsynth/prompters/sd3_prompter.py
180
+ diffsynth/prompters/sd_prompter.py
181
+ diffsynth/prompters/sdxl_prompter.py
182
+ diffsynth/prompters/stepvideo_prompter.py
183
+ diffsynth/prompters/wan_prompter.py
184
+ diffsynth/schedulers/__init__.py
185
+ diffsynth/schedulers/continuous_ode.py
186
+ diffsynth/schedulers/ddim.py
187
+ diffsynth/schedulers/flow_match.py
188
+ diffsynth/tokenizer_configs/__init__.py
189
+ diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json
190
+ diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json
191
+ diffsynth/tokenizer_configs/cog/tokenizer/spiece.model
192
+ diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json
193
+ diffsynth/tokenizer_configs/flux/tokenizer_1/merges.txt
194
+ diffsynth/tokenizer_configs/flux/tokenizer_1/special_tokens_map.json
195
+ diffsynth/tokenizer_configs/flux/tokenizer_1/tokenizer_config.json
196
+ diffsynth/tokenizer_configs/flux/tokenizer_1/vocab.json
197
+ diffsynth/tokenizer_configs/flux/tokenizer_2/special_tokens_map.json
198
+ diffsynth/tokenizer_configs/flux/tokenizer_2/spiece.model
199
+ diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer.json
200
+ diffsynth/tokenizer_configs/flux/tokenizer_2/tokenizer_config.json
201
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json
202
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json
203
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt
204
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt
205
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json
206
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
207
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model
208
+ diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
209
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
210
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/special_tokens_map.json
211
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/tokenizer_config.json
212
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
213
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/preprocessor_config.json
214
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/special_tokens_map.json
215
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
216
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer_config.json
217
+ diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
218
+ diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json
219
+ diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
220
+ diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt
221
+ diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json
222
+ diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json
223
+ diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json
224
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt
225
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json
226
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json
227
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json
228
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt
229
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json
230
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json
231
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json
232
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json
233
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model
234
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json
235
+ diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json
236
+ diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt
237
+ diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json
238
+ diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json
239
+ diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json
240
+ diffsynth/trainers/__init__.py
241
+ diffsynth/trainers/text_to_image.py
242
+ diffsynth/trainers/unified_dataset.py
243
+ diffsynth/trainers/utils.py
244
+ diffsynth/utils/__init__.py
245
+ diffsynth/vram_management/__init__.py
246
+ diffsynth/vram_management/gradient_checkpointing.py
247
+ diffsynth/vram_management/layers.py
diffsynth.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
diffsynth.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision
3
+ transformers
4
+ imageio
5
+ imageio[ffmpeg]
6
+ safetensors
7
+ einops
8
+ sentencepiece
9
+ protobuf
10
+ modelscope
11
+ ftfy
12
+ pynvml
13
+ pandas
14
+ accelerate
15
+ peft
diffsynth.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ diffsynth
diffsynth/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .prompters import *
4
+ from .schedulers import *
5
+ from .pipelines import *
6
+ from .controlnets import *
diffsynth/configs/__init__.py ADDED
File without changes
diffsynth/configs/model_config.py ADDED
@@ -0,0 +1,859 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.sd_text_encoder import SDTextEncoder
4
+ from ..models.sd_unet import SDUNet
5
+ from ..models.sd_vae_encoder import SDVAEEncoder
6
+ from ..models.sd_vae_decoder import SDVAEDecoder
7
+
8
+ from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
9
+ from ..models.sdxl_unet import SDXLUNet
10
+ from ..models.sdxl_vae_decoder import SDXLVAEDecoder
11
+ from ..models.sdxl_vae_encoder import SDXLVAEEncoder
12
+
13
+ from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
14
+ from ..models.sd3_dit import SD3DiT
15
+ from ..models.sd3_vae_decoder import SD3VAEDecoder
16
+ from ..models.sd3_vae_encoder import SD3VAEEncoder
17
+
18
+ from ..models.sd_controlnet import SDControlNet
19
+ from ..models.sdxl_controlnet import SDXLControlNetUnion
20
+
21
+ from ..models.sd_motion import SDMotionModel
22
+ from ..models.sdxl_motion import SDXLMotionModel
23
+
24
+ from ..models.svd_image_encoder import SVDImageEncoder
25
+ from ..models.svd_unet import SVDUNet
26
+ from ..models.svd_vae_decoder import SVDVAEDecoder
27
+ from ..models.svd_vae_encoder import SVDVAEEncoder
28
+
29
+ from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
30
+ from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
31
+
32
+ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
33
+ from ..models.hunyuan_dit import HunyuanDiT
34
+
35
+ from ..models.flux_dit import FluxDiT
36
+ from ..models.flux_text_encoder import FluxTextEncoder2
37
+ from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
38
+ from ..models.flux_controlnet import FluxControlNet
39
+ from ..models.flux_ipadapter import FluxIpAdapter
40
+ from ..models.flux_infiniteyou import InfiniteYouImageProjector
41
+
42
+ from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
43
+ from ..models.cog_dit import CogDiT
44
+
45
+ from ..models.omnigen import OmniGenTransformer
46
+
47
+ from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
48
+ from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
49
+
50
+ from ..extensions.RIFE import IFNet
51
+ from ..extensions.ESRGAN import RRDBNet
52
+
53
+ from ..models.hunyuan_video_dit import HunyuanVideoDiT
54
+
55
+ from ..models.stepvideo_vae import StepVideoVAE
56
+ from ..models.stepvideo_dit import StepVideoModel
57
+
58
+ from ..models.wan_video_dit import WanModel
59
+ from ..models.wan_video_dit_s2v import WanS2VModel
60
+ from ..models.wan_video_text_encoder import WanTextEncoder
61
+ from ..models.wan_video_image_encoder import WanImageEncoder
62
+ from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
63
+ from ..models.wan_video_motion_controller import WanMotionControllerModel
64
+ from ..models.wan_video_vace import VaceWanModel
65
+ from ..models.wav2vec import WanS2VAudioEncoder
66
+ from ..models.wan_video_animate_adapter import WanAnimateAdapter
67
+ from ..models.wan_video_mot import MotWanModel
68
+
69
+ from ..models.step1x_connector import Qwen2Connector
70
+
71
+ from ..models.flux_value_control import SingleValueEncoder
72
+
73
+ from ..lora.flux_lora import FluxLoraPatcher
74
+ from ..models.flux_lora_encoder import FluxLoRAEncoder
75
+
76
+ from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
77
+ from ..models.nexus_gen import NexusGenAutoregressiveModel
78
+
79
+ from ..models.qwen_image_dit import QwenImageDiT
80
+ from ..models.qwen_image_text_encoder import QwenImageTextEncoder
81
+ from ..models.qwen_image_vae import QwenImageVAE
82
+ from ..models.qwen_image_controlnet import QwenImageBlockWiseControlNet
83
+
84
+ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
85
+
86
+ model_loader_configs = [
87
+ # These configs are provided for detecting model type automatically.
88
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
89
+ (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
90
+ (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
91
+ (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
92
+ (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
93
+ (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
94
+ (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
95
+ (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
96
+ (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
97
+ (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
98
+ (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
99
+ (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
100
+ (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
101
+ (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
102
+ (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
103
+ (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
104
+ (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
105
+ (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
106
+ (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
107
+ (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
108
+ (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
109
+ (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
110
+ (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
111
+ (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
112
+ (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
113
+ (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
114
+ (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
115
+ (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
116
+ (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
117
+ (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
118
+ (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
119
+ (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
120
+ (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
121
+ (None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
122
+ (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
123
+ (None, "0629116fce1472503a66992f96f3eb1a", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
124
+ (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
125
+ (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
126
+ (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
127
+ (None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
128
+ (None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
129
+ (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
130
+ (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
131
+ (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
132
+ (None, "7f9583eb8ba86642abb9a21a4b2c9e16", ["flux_controlnet"], [FluxControlNet], "diffusers"),
133
+ (None, "43ad5aaa27dd4ee01b832ed16773fa52", ["flux_controlnet"], [FluxControlNet], "diffusers"),
134
+ (None, "c07c0f04f5ff55e86b4e937c7a40d481", ["infiniteyou_image_projector"], [InfiniteYouImageProjector], "diffusers"),
135
+ (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
136
+ (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
137
+ (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
138
+ (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
139
+ (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
140
+ (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
141
+ (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
142
+ (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
143
+ (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
144
+ (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
145
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
146
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
147
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
148
+ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
149
+ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
150
+ (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
151
+ (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
152
+ (None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
153
+ (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
154
+ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
155
+ (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
156
+ (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
157
+ (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
158
+ (None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
159
+ (None, "5ec04e02b42d2580483ad69f4e76346a", ["wan_video_dit"], [WanModel], "civitai"),
160
+ (None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
161
+ (None, "5f90e66a0672219f12d9a626c8c21f61", ["wan_video_dit", "wan_video_vap"], [WanModel,MotWanModel], "diffusers"),
162
+ (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
163
+ (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
164
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
165
+ (None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"),
166
+ (None, "8b27900f680d7251ce44e2dc8ae1ffef", ["wan_video_dit"], [LongCatVideoTransformer3DModel], "civitai"),
167
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
168
+ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
169
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
170
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
171
+ (None, "e1de6c02cdac79f8b739f4d3698cd216", ["wan_video_vae"], [WanVideoVAE38], "civitai"),
172
+ (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
173
+ (None, "d30fb9e02b1dbf4e509142f05cf7dd50", ["flux_dit", "step1x_connector"], [FluxDiT, Qwen2Connector], "civitai"),
174
+ (None, "30143afb2dea73d1ac580e0787628f8c", ["flux_lora_patcher"], [FluxLoraPatcher], "civitai"),
175
+ (None, "77c2e4dd2440269eb33bfaa0d004f6ab", ["flux_lora_encoder"], [FluxLoRAEncoder], "civitai"),
176
+ (None, "3e6c61b0f9471135fc9c6d6a98e98b6d", ["flux_dit", "nexus_gen_generation_adapter"], [FluxDiT, NexusGenAdapter], "civitai"),
177
+ (None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
178
+ (None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
179
+ (None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
180
+ (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
181
+ (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
182
+ (None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
183
+ (None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
184
+ (None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
185
+ (None, "31fa352acb8a1b1d33cd8764273d80a2", ["wan_video_dit", "wan_video_animate_adapter"], [WanModel, WanAnimateAdapter], "civitai"),
186
+ ]
187
+ huggingface_model_loader_configs = [
188
+ # These configs are provided for detecting model type automatically.
189
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
190
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
191
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
192
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
193
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
194
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
195
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
196
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
197
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
198
+ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
199
+ ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
200
+ ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
201
+ ("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
202
+ ]
203
+ patch_model_loader_configs = [
204
+ # These configs are provided for detecting model type automatically.
205
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
206
+ ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
207
+ ]
208
+
209
+ preset_models_on_huggingface = {
210
+ "HunyuanDiT": [
211
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
212
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
213
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
214
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
215
+ ],
216
+ "stable-video-diffusion-img2vid-xt": [
217
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
218
+ ],
219
+ "ExVideo-SVD-128f-v1": [
220
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
221
+ ],
222
+ # Stable Diffusion
223
+ "StableDiffusion_v15": [
224
+ ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
225
+ ],
226
+ "DreamShaper_8": [
227
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
228
+ ],
229
+ # Textual Inversion
230
+ "TextualInversion_VeryBadImageNegative_v1.3": [
231
+ ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
232
+ ],
233
+ # Stable Diffusion XL
234
+ "StableDiffusionXL_v1": [
235
+ ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
236
+ ],
237
+ "BluePencilXL_v200": [
238
+ ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
239
+ ],
240
+ "StableDiffusionXL_Turbo": [
241
+ ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
242
+ ],
243
+ # Stable Diffusion 3
244
+ "StableDiffusion3": [
245
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
246
+ ],
247
+ "StableDiffusion3_without_T5": [
248
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
249
+ ],
250
+ # ControlNet
251
+ "ControlNet_v11f1p_sd15_depth": [
252
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
253
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
254
+ ],
255
+ "ControlNet_v11p_sd15_softedge": [
256
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
257
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
258
+ ],
259
+ "ControlNet_v11f1e_sd15_tile": [
260
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
261
+ ],
262
+ "ControlNet_v11p_sd15_lineart": [
263
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
264
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
265
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
266
+ ],
267
+ "ControlNet_union_sdxl_promax": [
268
+ ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
269
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
270
+ ],
271
+ # AnimateDiff
272
+ "AnimateDiff_v2": [
273
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
274
+ ],
275
+ "AnimateDiff_xl_beta": [
276
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
277
+ ],
278
+
279
+ # Qwen Prompt
280
+ "QwenPrompt": [
281
+ ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
282
+ ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
283
+ ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
284
+ ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
285
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
286
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
287
+ ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
288
+ ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
289
+ ],
290
+ # Beautiful Prompt
291
+ "BeautifulPrompt": [
292
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
293
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
294
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
295
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
296
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
297
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
298
+ ],
299
+ # Omost prompt
300
+ "OmostPrompt":[
301
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
302
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
303
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
304
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
305
+ ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
306
+ ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
307
+ ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
308
+ ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
309
+ ],
310
+ # Translator
311
+ "opus-mt-zh-en": [
312
+ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
313
+ ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
314
+ ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
315
+ ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
316
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
317
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
318
+ ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
319
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
320
+ ],
321
+ # IP-Adapter
322
+ "IP-Adapter-SD": [
323
+ ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
324
+ ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
325
+ ],
326
+ "IP-Adapter-SDXL": [
327
+ ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
328
+ ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
329
+ ],
330
+ "SDXL-vae-fp16-fix": [
331
+ ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
332
+ ],
333
+ # Kolors
334
+ "Kolors": [
335
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
336
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
337
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
338
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
339
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
340
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
341
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
342
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
343
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
344
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
345
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
346
+ ],
347
+ # FLUX
348
+ "FLUX.1-dev": [
349
+ ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
350
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
351
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
352
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
353
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
354
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
355
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
356
+ ],
357
+ "InstantX/FLUX.1-dev-IP-Adapter": {
358
+ "file_list": [
359
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
360
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
361
+ ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
362
+ ],
363
+ "load_path": [
364
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
365
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
366
+ ],
367
+ },
368
+ # RIFE
369
+ "RIFE": [
370
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
371
+ ],
372
+ # CogVideo
373
+ "CogVideoX-5B": [
374
+ ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
375
+ ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
376
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
377
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
378
+ ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
379
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
380
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
381
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
382
+ ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
383
+ ],
384
+ # Stable Diffusion 3.5
385
+ "StableDiffusion3.5-large": [
386
+ ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
387
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
388
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
389
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
390
+ ],
391
+ }
392
+ preset_models_on_modelscope = {
393
+ # Hunyuan DiT
394
+ "HunyuanDiT": [
395
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
396
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
397
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
398
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
399
+ ],
400
+ # Stable Video Diffusion
401
+ "stable-video-diffusion-img2vid-xt": [
402
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
403
+ ],
404
+ # ExVideo
405
+ "ExVideo-SVD-128f-v1": [
406
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
407
+ ],
408
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
409
+ ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
410
+ ],
411
+ # Stable Diffusion
412
+ "StableDiffusion_v15": [
413
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
414
+ ],
415
+ "DreamShaper_8": [
416
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
417
+ ],
418
+ "AingDiffusion_v12": [
419
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
420
+ ],
421
+ "Flat2DAnimerge_v45Sharp": [
422
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
423
+ ],
424
+ # Textual Inversion
425
+ "TextualInversion_VeryBadImageNegative_v1.3": [
426
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
427
+ ],
428
+ # Stable Diffusion XL
429
+ "StableDiffusionXL_v1": [
430
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
431
+ ],
432
+ "BluePencilXL_v200": [
433
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
434
+ ],
435
+ "StableDiffusionXL_Turbo": [
436
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
437
+ ],
438
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
439
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
440
+ ],
441
+ # Stable Diffusion 3
442
+ "StableDiffusion3": [
443
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
444
+ ],
445
+ "StableDiffusion3_without_T5": [
446
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
447
+ ],
448
+ # ControlNet
449
+ "ControlNet_v11f1p_sd15_depth": [
450
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
451
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
452
+ ],
453
+ "ControlNet_v11p_sd15_softedge": [
454
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
455
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
456
+ ],
457
+ "ControlNet_v11f1e_sd15_tile": [
458
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
459
+ ],
460
+ "ControlNet_v11p_sd15_lineart": [
461
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
462
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
463
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
464
+ ],
465
+ "ControlNet_union_sdxl_promax": [
466
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
467
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
468
+ ],
469
+ "Annotators:Depth": [
470
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
471
+ ],
472
+ "Annotators:Softedge": [
473
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
474
+ ],
475
+ "Annotators:Lineart": [
476
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
477
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
478
+ ],
479
+ "Annotators:Normal": [
480
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
481
+ ],
482
+ "Annotators:Openpose": [
483
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
484
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
485
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
486
+ ],
487
+ # AnimateDiff
488
+ "AnimateDiff_v2": [
489
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
490
+ ],
491
+ "AnimateDiff_xl_beta": [
492
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
493
+ ],
494
+ # RIFE
495
+ "RIFE": [
496
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
497
+ ],
498
+ # Qwen Prompt
499
+ "QwenPrompt": {
500
+ "file_list": [
501
+ ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
502
+ ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
503
+ ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
504
+ ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
505
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
506
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
507
+ ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
508
+ ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
509
+ ],
510
+ "load_path": [
511
+ "models/QwenPrompt/qwen2-1.5b-instruct",
512
+ ],
513
+ },
514
+ # Beautiful Prompt
515
+ "BeautifulPrompt": {
516
+ "file_list": [
517
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
518
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
519
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
520
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
521
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
522
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
523
+ ],
524
+ "load_path": [
525
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
526
+ ],
527
+ },
528
+ # Omost prompt
529
+ "OmostPrompt": {
530
+ "file_list": [
531
+ ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
532
+ ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
533
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
534
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
535
+ ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
536
+ ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
537
+ ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
538
+ ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
539
+ ],
540
+ "load_path": [
541
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
542
+ ],
543
+ },
544
+ # Translator
545
+ "opus-mt-zh-en": {
546
+ "file_list": [
547
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
548
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
549
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
550
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
551
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
552
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
553
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
554
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
555
+ ],
556
+ "load_path": [
557
+ "models/translator/opus-mt-zh-en",
558
+ ],
559
+ },
560
+ # IP-Adapter
561
+ "IP-Adapter-SD": [
562
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
563
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
564
+ ],
565
+ "IP-Adapter-SDXL": [
566
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
567
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
568
+ ],
569
+ # Kolors
570
+ "Kolors": {
571
+ "file_list": [
572
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
573
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
574
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
575
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
576
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
577
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
578
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
579
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
580
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
581
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
582
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
583
+ ],
584
+ "load_path": [
585
+ "models/kolors/Kolors/text_encoder",
586
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
587
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
588
+ ],
589
+ },
590
+ "SDXL-vae-fp16-fix": [
591
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
592
+ ],
593
+ # FLUX
594
+ "FLUX.1-dev": {
595
+ "file_list": [
596
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
597
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
598
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
599
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
600
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
601
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
602
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
603
+ ],
604
+ "load_path": [
605
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
606
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
607
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
608
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
609
+ ],
610
+ },
611
+ "FLUX.1-schnell": {
612
+ "file_list": [
613
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
614
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
615
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
616
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
617
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
618
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
619
+ ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
620
+ ],
621
+ "load_path": [
622
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
623
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
624
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
625
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
626
+ ],
627
+ },
628
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
629
+ ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
630
+ ],
631
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
632
+ ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
633
+ ],
634
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
635
+ ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
636
+ ],
637
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
638
+ ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
639
+ ],
640
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
641
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
642
+ ],
643
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
644
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
645
+ ],
646
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
647
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
648
+ ],
649
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
650
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
651
+ ],
652
+ "InstantX/FLUX.1-dev-IP-Adapter": {
653
+ "file_list": [
654
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
655
+ ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
656
+ ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
657
+ ],
658
+ "load_path": [
659
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
660
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
661
+ ],
662
+ },
663
+ "InfiniteYou":{
664
+ "file_list":[
665
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
666
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
667
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
668
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
669
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
670
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
671
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
672
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
673
+ ],
674
+ "load_path":[
675
+ [
676
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
677
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
678
+ ],
679
+ "models/InfiniteYou/image_proj_model.bin",
680
+ ],
681
+ },
682
+ # ESRGAN
683
+ "ESRGAN_x4": [
684
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
685
+ ],
686
+ # RIFE
687
+ "RIFE": [
688
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
689
+ ],
690
+ # Omnigen
691
+ "OmniGen-v1": {
692
+ "file_list": [
693
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
694
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
695
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
696
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
697
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
698
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
699
+ ],
700
+ "load_path": [
701
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
702
+ "models/OmniGen/OmniGen-v1/model.safetensors",
703
+ ]
704
+ },
705
+ # CogVideo
706
+ "CogVideoX-5B": {
707
+ "file_list": [
708
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
709
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
710
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
711
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
712
+ ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
713
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
714
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
715
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
716
+ ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
717
+ ],
718
+ "load_path": [
719
+ "models/CogVideo/CogVideoX-5b/text_encoder",
720
+ "models/CogVideo/CogVideoX-5b/transformer",
721
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
722
+ ],
723
+ },
724
+ # Stable Diffusion 3.5
725
+ "StableDiffusion3.5-large": [
726
+ ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
727
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
728
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
729
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
730
+ ],
731
+ "StableDiffusion3.5-medium": [
732
+ ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
733
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
734
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
735
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
736
+ ],
737
+ "StableDiffusion3.5-large-turbo": [
738
+ ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
739
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
740
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
741
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
742
+ ],
743
+ "HunyuanVideo":{
744
+ "file_list": [
745
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
746
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
747
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
748
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
749
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
750
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
751
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
752
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
753
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
754
+ ],
755
+ "load_path": [
756
+ "models/HunyuanVideo/text_encoder/model.safetensors",
757
+ "models/HunyuanVideo/text_encoder_2",
758
+ "models/HunyuanVideo/vae/pytorch_model.pt",
759
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
760
+ ],
761
+ },
762
+ "HunyuanVideoI2V":{
763
+ "file_list": [
764
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
765
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
766
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
767
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
768
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
769
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
770
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
771
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
772
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
773
+ ],
774
+ "load_path": [
775
+ "models/HunyuanVideoI2V/text_encoder/model.safetensors",
776
+ "models/HunyuanVideoI2V/text_encoder_2",
777
+ "models/HunyuanVideoI2V/vae/pytorch_model.pt",
778
+ "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
779
+ ],
780
+ },
781
+ "HunyuanVideo-fp8":{
782
+ "file_list": [
783
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
784
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
785
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
786
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
787
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
788
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
789
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
790
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
791
+ ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
792
+ ],
793
+ "load_path": [
794
+ "models/HunyuanVideo/text_encoder/model.safetensors",
795
+ "models/HunyuanVideo/text_encoder_2",
796
+ "models/HunyuanVideo/vae/pytorch_model.pt",
797
+ "models/HunyuanVideo/transformers/model.fp8.safetensors"
798
+ ],
799
+ },
800
+ }
801
+ Preset_model_id: TypeAlias = Literal[
802
+ "HunyuanDiT",
803
+ "stable-video-diffusion-img2vid-xt",
804
+ "ExVideo-SVD-128f-v1",
805
+ "ExVideo-CogVideoX-LoRA-129f-v1",
806
+ "StableDiffusion_v15",
807
+ "DreamShaper_8",
808
+ "AingDiffusion_v12",
809
+ "Flat2DAnimerge_v45Sharp",
810
+ "TextualInversion_VeryBadImageNegative_v1.3",
811
+ "StableDiffusionXL_v1",
812
+ "BluePencilXL_v200",
813
+ "StableDiffusionXL_Turbo",
814
+ "ControlNet_v11f1p_sd15_depth",
815
+ "ControlNet_v11p_sd15_softedge",
816
+ "ControlNet_v11f1e_sd15_tile",
817
+ "ControlNet_v11p_sd15_lineart",
818
+ "AnimateDiff_v2",
819
+ "AnimateDiff_xl_beta",
820
+ "RIFE",
821
+ "BeautifulPrompt",
822
+ "opus-mt-zh-en",
823
+ "IP-Adapter-SD",
824
+ "IP-Adapter-SDXL",
825
+ "StableDiffusion3",
826
+ "StableDiffusion3_without_T5",
827
+ "Kolors",
828
+ "SDXL-vae-fp16-fix",
829
+ "ControlNet_union_sdxl_promax",
830
+ "FLUX.1-dev",
831
+ "FLUX.1-schnell",
832
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
833
+ "jasperai/Flux.1-dev-Controlnet-Depth",
834
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
835
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
836
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
837
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
838
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
839
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
840
+ "InstantX/FLUX.1-dev-IP-Adapter",
841
+ "InfiniteYou",
842
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
843
+ "QwenPrompt",
844
+ "OmostPrompt",
845
+ "ESRGAN_x4",
846
+ "RIFE",
847
+ "OmniGen-v1",
848
+ "CogVideoX-5B",
849
+ "Annotators:Depth",
850
+ "Annotators:Softedge",
851
+ "Annotators:Lineart",
852
+ "Annotators:Normal",
853
+ "Annotators:Openpose",
854
+ "StableDiffusion3.5-large",
855
+ "StableDiffusion3.5-medium",
856
+ "HunyuanVideo",
857
+ "HunyuanVideo-fp8",
858
+ "HunyuanVideoI2V",
859
+ ]
diffsynth/controlnets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
2
+ from .processors import Annotator
diffsynth/controlnets/controlnet_unit.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .processors import Processor_id
4
+
5
+
6
+ class ControlNetConfigUnit:
7
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
8
+ self.processor_id = processor_id
9
+ self.model_path = model_path
10
+ self.scale = scale
11
+ self.skip_processor = skip_processor
12
+
13
+
14
+ class ControlNetUnit:
15
+ def __init__(self, processor, model, scale=1.0):
16
+ self.processor = processor
17
+ self.model = model
18
+ self.scale = scale
19
+
20
+
21
+ class MultiControlNetManager:
22
+ def __init__(self, controlnet_units=[]):
23
+ self.processors = [unit.processor for unit in controlnet_units]
24
+ self.models = [unit.model for unit in controlnet_units]
25
+ self.scales = [unit.scale for unit in controlnet_units]
26
+
27
+ def cpu(self):
28
+ for model in self.models:
29
+ model.cpu()
30
+
31
+ def to(self, device):
32
+ for model in self.models:
33
+ model.to(device)
34
+ for processor in self.processors:
35
+ processor.to(device)
36
+
37
+ def process_image(self, image, processor_id=None):
38
+ if processor_id is None:
39
+ processed_image = [processor(image) for processor in self.processors]
40
+ else:
41
+ processed_image = [self.processors[processor_id](image)]
42
+ processed_image = torch.concat([
43
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
44
+ for image_ in processed_image
45
+ ], dim=0)
46
+ return processed_image
47
+
48
+ def __call__(
49
+ self,
50
+ sample, timestep, encoder_hidden_states, conditionings,
51
+ tiled=False, tile_size=64, tile_stride=32, **kwargs
52
+ ):
53
+ res_stack = None
54
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
55
+ res_stack_ = model(
56
+ sample, timestep, encoder_hidden_states, conditioning, **kwargs,
57
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
58
+ processor_id=processor.processor_id
59
+ )
60
+ res_stack_ = [res * scale for res in res_stack_]
61
+ if res_stack is None:
62
+ res_stack = res_stack_
63
+ else:
64
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
65
+ return res_stack
66
+
67
+
68
+ class FluxMultiControlNetManager(MultiControlNetManager):
69
+ def __init__(self, controlnet_units=[]):
70
+ super().__init__(controlnet_units=controlnet_units)
71
+
72
+ def process_image(self, image, processor_id=None):
73
+ if processor_id is None:
74
+ processed_image = [processor(image) for processor in self.processors]
75
+ else:
76
+ processed_image = [self.processors[processor_id](image)]
77
+ return processed_image
78
+
79
+ def __call__(self, conditionings, **kwargs):
80
+ res_stack, single_res_stack = None, None
81
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
82
+ res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
83
+ res_stack_ = [res * scale for res in res_stack_]
84
+ single_res_stack_ = [res * scale for res in single_res_stack_]
85
+ if res_stack is None:
86
+ res_stack = res_stack_
87
+ single_res_stack = single_res_stack_
88
+ else:
89
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
90
+ single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
91
+ return res_stack, single_res_stack
diffsynth/controlnets/processors.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+
4
+ Processor_id: TypeAlias = Literal[
5
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
6
+ ]
7
+
8
+ class Annotator:
9
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
10
+ if not skip_processor:
11
+ if processor_id == "canny":
12
+ from controlnet_aux.processor import CannyDetector
13
+ self.processor = CannyDetector()
14
+ elif processor_id == "depth":
15
+ from controlnet_aux.processor import MidasDetector
16
+ self.processor = MidasDetector.from_pretrained(model_path).to(device)
17
+ elif processor_id == "softedge":
18
+ from controlnet_aux.processor import HEDdetector
19
+ self.processor = HEDdetector.from_pretrained(model_path).to(device)
20
+ elif processor_id == "lineart":
21
+ from controlnet_aux.processor import LineartDetector
22
+ self.processor = LineartDetector.from_pretrained(model_path).to(device)
23
+ elif processor_id == "lineart_anime":
24
+ from controlnet_aux.processor import LineartAnimeDetector
25
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
26
+ elif processor_id == "openpose":
27
+ from controlnet_aux.processor import OpenposeDetector
28
+ self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
29
+ elif processor_id == "normal":
30
+ from controlnet_aux.processor import NormalBaeDetector
31
+ self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
32
+ elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
33
+ self.processor = None
34
+ else:
35
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
36
+ else:
37
+ self.processor = None
38
+
39
+ self.processor_id = processor_id
40
+ self.detect_resolution = detect_resolution
41
+
42
+ def to(self,device):
43
+ if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
44
+
45
+ self.processor.model.to(device)
46
+
47
+ def __call__(self, image, mask=None):
48
+ width, height = image.size
49
+ if self.processor_id == "openpose":
50
+ kwargs = {
51
+ "include_body": True,
52
+ "include_hand": True,
53
+ "include_face": True
54
+ }
55
+ else:
56
+ kwargs = {}
57
+ if self.processor is not None:
58
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
59
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
60
+ image = image.resize((width, height))
61
+ return image
62
+
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames, merge_video_audio, save_video_with_audio
diffsynth/data/simple_text_image.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, torchvision
2
+ from torchvision import transforms
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+
7
+
8
+ class TextImageDataset(torch.utils.data.Dataset):
9
+ def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
10
+ self.steps_per_epoch = steps_per_epoch
11
+ metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
12
+ self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
13
+ self.text = metadata["text"].to_list()
14
+ self.height = height
15
+ self.width = width
16
+ self.image_processor = transforms.Compose(
17
+ [
18
+ transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
19
+ transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.5], [0.5]),
22
+ ]
23
+ )
24
+
25
+
26
+ def __getitem__(self, index):
27
+ data_id = torch.randint(0, len(self.path), (1,))[0]
28
+ data_id = (data_id + index) % len(self.path) # For fixed seed.
29
+ text = self.text[data_id]
30
+ image = Image.open(self.path[data_id]).convert("RGB")
31
+ target_height, target_width = self.height, self.width
32
+ width, height = image.size
33
+ scale = max(target_width / width, target_height / height)
34
+ shape = [round(height*scale),round(width*scale)]
35
+ image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
36
+ image = self.image_processor(image)
37
+ return {"text": text, "image": image}
38
+
39
+
40
+ def __len__(self):
41
+ return self.steps_per_epoch
diffsynth/data/video.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ import subprocess
6
+ import shutil
7
+
8
+
9
+ class LowMemoryVideo:
10
+ def __init__(self, file_name):
11
+ self.reader = imageio.get_reader(file_name)
12
+
13
+ def __len__(self):
14
+ return self.reader.count_frames()
15
+
16
+ def __getitem__(self, item):
17
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
18
+
19
+ def __del__(self):
20
+ self.reader.close()
21
+
22
+
23
+ def split_file_name(file_name):
24
+ result = []
25
+ number = -1
26
+ for i in file_name:
27
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
28
+ if number == -1:
29
+ number = 0
30
+ number = number*10 + ord(i) - ord("0")
31
+ else:
32
+ if number != -1:
33
+ result.append(number)
34
+ number = -1
35
+ result.append(i)
36
+ if number != -1:
37
+ result.append(number)
38
+ result = tuple(result)
39
+ return result
40
+
41
+
42
+ def search_for_images(folder):
43
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
44
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
45
+ file_list = [i[1] for i in sorted(file_list)]
46
+ file_list = [os.path.join(folder, i) for i in file_list]
47
+ return file_list
48
+
49
+
50
+ class LowMemoryImageFolder:
51
+ def __init__(self, folder, file_list=None):
52
+ if file_list is None:
53
+ self.file_list = search_for_images(folder)
54
+ else:
55
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
56
+
57
+ def __len__(self):
58
+ return len(self.file_list)
59
+
60
+ def __getitem__(self, item):
61
+ return Image.open(self.file_list[item]).convert("RGB")
62
+
63
+ def __del__(self):
64
+ pass
65
+
66
+
67
+ def crop_and_resize(image, height, width):
68
+ image = np.array(image)
69
+ image_height, image_width, _ = image.shape
70
+ if image_height / image_width < height / width:
71
+ croped_width = int(image_height / height * width)
72
+ left = (image_width - croped_width) // 2
73
+ image = image[:, left: left+croped_width]
74
+ image = Image.fromarray(image).resize((width, height))
75
+ else:
76
+ croped_height = int(image_width / width * height)
77
+ left = (image_height - croped_height) // 2
78
+ image = image[left: left+croped_height, :]
79
+ image = Image.fromarray(image).resize((width, height))
80
+ return image
81
+
82
+
83
+ class VideoData:
84
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
85
+ if video_file is not None:
86
+ self.data_type = "video"
87
+ self.data = LowMemoryVideo(video_file, **kwargs)
88
+ elif image_folder is not None:
89
+ self.data_type = "images"
90
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
91
+ else:
92
+ raise ValueError("Cannot open video or image folder")
93
+ self.length = None
94
+ self.set_shape(height, width)
95
+
96
+ def raw_data(self):
97
+ frames = []
98
+ for i in range(self.__len__()):
99
+ frames.append(self.__getitem__(i))
100
+ return frames
101
+
102
+ def set_length(self, length):
103
+ self.length = length
104
+
105
+ def set_shape(self, height, width):
106
+ self.height = height
107
+ self.width = width
108
+
109
+ def __len__(self):
110
+ if self.length is None:
111
+ return len(self.data)
112
+ else:
113
+ return self.length
114
+
115
+ def shape(self):
116
+ if self.height is not None and self.width is not None:
117
+ return self.height, self.width
118
+ else:
119
+ height, width, _ = self.__getitem__(0).shape
120
+ return height, width
121
+
122
+ def __getitem__(self, item):
123
+ frame = self.data.__getitem__(item)
124
+ width, height = frame.size
125
+ if self.height is not None and self.width is not None:
126
+ if self.height != height or self.width != width:
127
+ frame = crop_and_resize(frame, self.height, self.width)
128
+ return frame
129
+
130
+ def __del__(self):
131
+ pass
132
+
133
+ def save_images(self, folder):
134
+ os.makedirs(folder, exist_ok=True)
135
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
136
+ frame = self.__getitem__(i)
137
+ frame.save(os.path.join(folder, f"{i}.png"))
138
+
139
+
140
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
141
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
142
+ for frame in tqdm(frames, desc="Saving video"):
143
+ frame = np.array(frame)
144
+ writer.append_data(frame)
145
+ writer.close()
146
+
147
+ def save_frames(frames, save_path):
148
+ os.makedirs(save_path, exist_ok=True)
149
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
150
+ frame.save(os.path.join(save_path, f"{i}.png"))
151
+
152
+
153
+ def merge_video_audio(video_path: str, audio_path: str):
154
+ # TODO: may need a in-python implementation to avoid subprocess dependency
155
+ """
156
+ Merge the video and audio into a new video, with the duration set to the shorter of the two,
157
+ and overwrite the original video file.
158
+
159
+ Parameters:
160
+ video_path (str): Path to the original video file
161
+ audio_path (str): Path to the audio file
162
+ """
163
+
164
+ # check
165
+ if not os.path.exists(video_path):
166
+ raise FileNotFoundError(f"video file {video_path} does not exist")
167
+ if not os.path.exists(audio_path):
168
+ raise FileNotFoundError(f"audio file {audio_path} does not exist")
169
+
170
+ base, ext = os.path.splitext(video_path)
171
+ temp_output = f"{base}_temp{ext}"
172
+
173
+ try:
174
+ # create ffmpeg command
175
+ command = [
176
+ 'ffmpeg',
177
+ '-y', # overwrite
178
+ '-i',
179
+ video_path,
180
+ '-i',
181
+ audio_path,
182
+ '-c:v',
183
+ 'copy', # copy video stream
184
+ '-c:a',
185
+ 'aac', # use AAC audio encoder
186
+ '-b:a',
187
+ '192k', # set audio bitrate (optional)
188
+ '-map',
189
+ '0:v:0', # select the first video stream
190
+ '-map',
191
+ '1:a:0', # select the first audio stream
192
+ '-shortest', # choose the shortest duration
193
+ temp_output
194
+ ]
195
+
196
+ # execute the command
197
+ result = subprocess.run(
198
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
199
+
200
+ # check result
201
+ if result.returncode != 0:
202
+ error_msg = f"FFmpeg execute failed: {result.stderr}"
203
+ print(error_msg)
204
+ raise RuntimeError(error_msg)
205
+
206
+ shutil.move(temp_output, video_path)
207
+ print(f"Merge completed, saved to {video_path}")
208
+
209
+ except Exception as e:
210
+ if os.path.exists(temp_output):
211
+ os.remove(temp_output)
212
+ print(f"merge_video_audio failed with error: {e}")
213
+
214
+
215
+ def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
216
+ save_video(frames, save_path, fps, quality, ffmpeg_params)
217
+ merge_video_audio(save_path, audio_path)
diffsynth/distributed/__init__.py ADDED
File without changes
diffsynth/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from einops import rearrange
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+
9
+ def sinusoidal_embedding_1d(dim, position):
10
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
11
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
12
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
13
+ return x.to(position.dtype)
14
+
15
+ def pad_freqs(original_tensor, target_len):
16
+ seq_len, s1, s2 = original_tensor.shape
17
+ pad_size = target_len - seq_len
18
+ padding_tensor = torch.ones(
19
+ pad_size,
20
+ s1,
21
+ s2,
22
+ dtype=original_tensor.dtype,
23
+ device=original_tensor.device)
24
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
25
+ return padded_tensor
26
+
27
+ def rope_apply(x, freqs, num_heads):
28
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
29
+ s_per_rank = x.shape[1]
30
+
31
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
32
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
33
+
34
+ sp_size = get_sequence_parallel_world_size()
35
+ sp_rank = get_sequence_parallel_rank()
36
+ freqs = pad_freqs(freqs, s_per_rank * sp_size)
37
+ freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
38
+
39
+ x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
40
+ return x_out.to(x.dtype)
41
+
42
+ def usp_dit_forward(self,
43
+ x: torch.Tensor,
44
+ timestep: torch.Tensor,
45
+ context: torch.Tensor,
46
+ clip_feature: Optional[torch.Tensor] = None,
47
+ y: Optional[torch.Tensor] = None,
48
+ use_gradient_checkpointing: bool = False,
49
+ use_gradient_checkpointing_offload: bool = False,
50
+ **kwargs,
51
+ ):
52
+ t = self.time_embedding(
53
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
54
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
55
+ context = self.text_embedding(context)
56
+
57
+ if self.has_image_input:
58
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
59
+ clip_embdding = self.img_emb(clip_feature)
60
+ context = torch.cat([clip_embdding, context], dim=1)
61
+
62
+ x, (f, h, w) = self.patchify(x)
63
+
64
+ freqs = torch.cat([
65
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
66
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
67
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
68
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
69
+
70
+ def create_custom_forward(module):
71
+ def custom_forward(*inputs):
72
+ return module(*inputs)
73
+ return custom_forward
74
+
75
+ # Context Parallel
76
+ chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
77
+ pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
78
+ chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
79
+ x = chunks[get_sequence_parallel_rank()]
80
+
81
+ for block in self.blocks:
82
+ if self.training and use_gradient_checkpointing:
83
+ if use_gradient_checkpointing_offload:
84
+ with torch.autograd.graph.save_on_cpu():
85
+ x = torch.utils.checkpoint.checkpoint(
86
+ create_custom_forward(block),
87
+ x, context, t_mod, freqs,
88
+ use_reentrant=False,
89
+ )
90
+ else:
91
+ x = torch.utils.checkpoint.checkpoint(
92
+ create_custom_forward(block),
93
+ x, context, t_mod, freqs,
94
+ use_reentrant=False,
95
+ )
96
+ else:
97
+ x = block(x, context, t_mod, freqs)
98
+
99
+ x = self.head(x, t)
100
+
101
+ # Context Parallel
102
+ x = get_sp_group().all_gather(x, dim=1)
103
+ x = x[:, :-pad_shape] if pad_shape > 0 else x
104
+
105
+ # unpatchify
106
+ x = self.unpatchify(x, (f, h, w))
107
+ return x
108
+
109
+
110
+ def usp_attn_forward(self, x, freqs):
111
+ q = self.norm_q(self.q(x))
112
+ k = self.norm_k(self.k(x))
113
+ v = self.v(x)
114
+
115
+ q = rope_apply(q, freqs, self.num_heads)
116
+ k = rope_apply(k, freqs, self.num_heads)
117
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
118
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
119
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
120
+
121
+ x = xFuserLongContextAttention()(
122
+ None,
123
+ query=q,
124
+ key=k,
125
+ value=v,
126
+ )
127
+ x = x.flatten(2)
128
+
129
+ del q, k, v
130
+ torch.cuda.empty_cache()
131
+ return self.o(x)
diffsynth/extensions/ESRGAN/__init__.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import repeat
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ class ResidualDenseBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_feat=64, num_grow_ch=32):
10
+ super(ResidualDenseBlock, self).__init__()
11
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ def forward(self, x):
19
+ x1 = self.lrelu(self.conv1(x))
20
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
+ return x5 * 0.2 + x
25
+
26
+
27
+ class RRDB(torch.nn.Module):
28
+
29
+ def __init__(self, num_feat, num_grow_ch=32):
30
+ super(RRDB, self).__init__()
31
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34
+
35
+ def forward(self, x):
36
+ out = self.rdb1(x)
37
+ out = self.rdb2(out)
38
+ out = self.rdb3(out)
39
+ return out * 0.2 + x
40
+
41
+
42
+ class RRDBNet(torch.nn.Module):
43
+
44
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
45
+ super(RRDBNet, self).__init__()
46
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49
+ # upsample
50
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55
+
56
+ def forward(self, x):
57
+ feat = x
58
+ feat = self.conv_first(feat)
59
+ body_feat = self.conv_body(self.body(feat))
60
+ feat = feat + body_feat
61
+ # upsample
62
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63
+ feat = self.lrelu(self.conv_up1(feat))
64
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65
+ feat = self.lrelu(self.conv_up2(feat))
66
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67
+ return out
68
+
69
+ @staticmethod
70
+ def state_dict_converter():
71
+ return RRDBNetStateDictConverter()
72
+
73
+
74
+ class RRDBNetStateDictConverter:
75
+ def __init__(self):
76
+ pass
77
+
78
+ def from_diffusers(self, state_dict):
79
+ return state_dict, {"upcast_to_float32": True}
80
+
81
+ def from_civitai(self, state_dict):
82
+ return state_dict, {"upcast_to_float32": True}
83
+
84
+
85
+ class ESRGAN(torch.nn.Module):
86
+ def __init__(self, model):
87
+ super().__init__()
88
+ self.model = model
89
+
90
+ @staticmethod
91
+ def from_model_manager(model_manager):
92
+ return ESRGAN(model_manager.fetch_model("esrgan"))
93
+
94
+ def process_image(self, image):
95
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
96
+ return image
97
+
98
+ def process_images(self, images):
99
+ images = [self.process_image(image) for image in images]
100
+ images = torch.stack(images)
101
+ return images
102
+
103
+ def decode_images(self, images):
104
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
105
+ images = [Image.fromarray(image) for image in images]
106
+ return images
107
+
108
+ @torch.no_grad()
109
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
110
+ if not isinstance(images, list):
111
+ images = [images]
112
+ is_single_image = True
113
+ else:
114
+ is_single_image = False
115
+
116
+ # Preprocess
117
+ input_tensor = self.process_images(images)
118
+
119
+ # Interpolate
120
+ output_tensor = []
121
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
122
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
123
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
124
+ batch_input_tensor = batch_input_tensor.to(
125
+ device=self.model.conv_first.weight.device,
126
+ dtype=self.model.conv_first.weight.dtype)
127
+ batch_output_tensor = self.model(batch_input_tensor)
128
+ output_tensor.append(batch_output_tensor.cpu())
129
+
130
+ # Output
131
+ output_tensor = torch.concat(output_tensor, dim=0)
132
+
133
+ # To images
134
+ output_images = self.decode_images(output_tensor)
135
+ if is_single_image:
136
+ output_images = output_images[0]
137
+ return output_images
diffsynth/extensions/FastBlend/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners.fast import TableManager, PyramidPatchMatcher
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cupy as cp
5
+
6
+
7
+ class FastBlendSmoother:
8
+ def __init__(self):
9
+ self.batch_size = 8
10
+ self.window_size = 64
11
+ self.ebsynth_config = {
12
+ "minimum_patch_size": 5,
13
+ "threads_per_block": 8,
14
+ "num_iter": 5,
15
+ "gpu_id": 0,
16
+ "guide_weight": 10.0,
17
+ "initialize": "identity",
18
+ "tracking_window_size": 0,
19
+ }
20
+
21
+ @staticmethod
22
+ def from_model_manager(model_manager):
23
+ # TODO: fetch GPU ID from model_manager
24
+ return FastBlendSmoother()
25
+
26
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27
+ frames_guide = [np.array(frame) for frame in frames_guide]
28
+ frames_style = [np.array(frame) for frame in frames_style]
29
+ table_manager = TableManager()
30
+ patch_match_engine = PyramidPatchMatcher(
31
+ image_height=frames_style[0].shape[0],
32
+ image_width=frames_style[0].shape[1],
33
+ channel=3,
34
+ **ebsynth_config
35
+ )
36
+ # left part
37
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
39
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40
+ # right part
41
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
43
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44
+ # merge
45
+ frames = []
46
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47
+ weight_m = -1
48
+ weight = weight_l + weight_m + weight_r
49
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50
+ frames.append(frame)
51
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52
+ return frames
53
+
54
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
55
+ frames = self.run(
56
+ original_frames, rendered_frames,
57
+ self.batch_size, self.window_size, self.ebsynth_config
58
+ )
59
+ mempool = cp.get_default_memory_pool()
60
+ pinned_mempool = cp.get_default_pinned_memory_pool()
61
+ mempool.free_all_blocks()
62
+ pinned_mempool.free_all_blocks()
63
+ return frames
diffsynth/extensions/FastBlend/api.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
2
+ from .data import VideoData, get_video_fps, save_video, search_for_images
3
+ import os
4
+ import gradio as gr
5
+
6
+
7
+ def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
8
+ frames_guide = VideoData(video_guide, video_guide_folder)
9
+ frames_style = VideoData(video_style, video_style_folder)
10
+ message = ""
11
+ if len(frames_guide) < len(frames_style):
12
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
13
+ frames_style.set_length(len(frames_guide))
14
+ elif len(frames_guide) > len(frames_style):
15
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
16
+ frames_guide.set_length(len(frames_style))
17
+ height_guide, width_guide = frames_guide.shape()
18
+ height_style, width_style = frames_style.shape()
19
+ if height_guide != height_style or width_guide != width_style:
20
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
21
+ frames_style.set_shape(height_guide, width_guide)
22
+ return frames_guide, frames_style, message
23
+
24
+
25
+ def smooth_video(
26
+ video_guide,
27
+ video_guide_folder,
28
+ video_style,
29
+ video_style_folder,
30
+ mode,
31
+ window_size,
32
+ batch_size,
33
+ tracking_window_size,
34
+ output_path,
35
+ fps,
36
+ minimum_patch_size,
37
+ num_iter,
38
+ guide_weight,
39
+ initialize,
40
+ progress = None,
41
+ ):
42
+ # input
43
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
44
+ if len(message) > 0:
45
+ print(message)
46
+ # output
47
+ if output_path == "":
48
+ if video_style is None:
49
+ output_path = os.path.join(video_style_folder, "output")
50
+ else:
51
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
52
+ os.makedirs(output_path, exist_ok=True)
53
+ print("No valid output_path. Your video will be saved here:", output_path)
54
+ elif not os.path.exists(output_path):
55
+ os.makedirs(output_path, exist_ok=True)
56
+ print("Your video will be saved here:", output_path)
57
+ frames_path = os.path.join(output_path, "frames")
58
+ video_path = os.path.join(output_path, "video.mp4")
59
+ os.makedirs(frames_path, exist_ok=True)
60
+ # process
61
+ if mode == "Fast" or mode == "Balanced":
62
+ tracking_window_size = 0
63
+ ebsynth_config = {
64
+ "minimum_patch_size": minimum_patch_size,
65
+ "threads_per_block": 8,
66
+ "num_iter": num_iter,
67
+ "gpu_id": 0,
68
+ "guide_weight": guide_weight,
69
+ "initialize": initialize,
70
+ "tracking_window_size": tracking_window_size,
71
+ }
72
+ if mode == "Fast":
73
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
74
+ elif mode == "Balanced":
75
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
76
+ elif mode == "Accurate":
77
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
78
+ # output
79
+ try:
80
+ fps = int(fps)
81
+ except:
82
+ fps = get_video_fps(video_style) if video_style is not None else 30
83
+ print("Fps:", fps)
84
+ print("Saving video...")
85
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
86
+ print("Success!")
87
+ print("Your frames are here:", frames_path)
88
+ print("Your video is here:", video_path)
89
+ return output_path, fps, video_path
90
+
91
+
92
+ class KeyFrameMatcher:
93
+ def __init__(self):
94
+ pass
95
+
96
+ def extract_number_from_filename(self, file_name):
97
+ result = []
98
+ number = -1
99
+ for i in file_name:
100
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
101
+ if number == -1:
102
+ number = 0
103
+ number = number*10 + ord(i) - ord("0")
104
+ else:
105
+ if number != -1:
106
+ result.append(number)
107
+ number = -1
108
+ if number != -1:
109
+ result.append(number)
110
+ result = tuple(result)
111
+ return result
112
+
113
+ def extract_number_from_filenames(self, file_names):
114
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
115
+ min_length = min(len(i) for i in numbers)
116
+ for i in range(min_length-1, -1, -1):
117
+ if len(set(number[i] for number in numbers))==len(file_names):
118
+ return [number[i] for number in numbers]
119
+ return list(range(len(file_names)))
120
+
121
+ def match_using_filename(self, file_names_a, file_names_b):
122
+ file_names_b_set = set(file_names_b)
123
+ matched_file_name = []
124
+ for file_name in file_names_a:
125
+ if file_name not in file_names_b_set:
126
+ matched_file_name.append(None)
127
+ else:
128
+ matched_file_name.append(file_name)
129
+ return matched_file_name
130
+
131
+ def match_using_numbers(self, file_names_a, file_names_b):
132
+ numbers_a = self.extract_number_from_filenames(file_names_a)
133
+ numbers_b = self.extract_number_from_filenames(file_names_b)
134
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
135
+ matched_file_name = []
136
+ for number in numbers_a:
137
+ if number in numbers_b_dict:
138
+ matched_file_name.append(numbers_b_dict[number])
139
+ else:
140
+ matched_file_name.append(None)
141
+ return matched_file_name
142
+
143
+ def match_filenames(self, file_names_a, file_names_b):
144
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
145
+ if sum([i is not None for i in matched_file_name]) > 0:
146
+ return matched_file_name
147
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
148
+ return matched_file_name
149
+
150
+
151
+ def detect_frames(frames_path, keyframes_path):
152
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
153
+ return "Please input the directory of guide video and rendered frames"
154
+ elif not os.path.exists(frames_path):
155
+ return "Please input the directory of guide video"
156
+ elif not os.path.exists(keyframes_path):
157
+ return "Please input the directory of rendered frames"
158
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
159
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
160
+ if len(frames)==0:
161
+ return f"No images detected in {frames_path}"
162
+ if len(keyframes)==0:
163
+ return f"No images detected in {keyframes_path}"
164
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
165
+ max_filename_length = max([len(i) for i in frames])
166
+ if sum([i is not None for i in matched_keyframes])==0:
167
+ message = ""
168
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
169
+ message += frame + " " * (max_filename_length - len(frame) + 1)
170
+ message += "--> No matched keyframes\n"
171
+ else:
172
+ message = ""
173
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
174
+ message += frame + " " * (max_filename_length - len(frame) + 1)
175
+ if matched_keyframe is None:
176
+ message += "--> [to be rendered]\n"
177
+ else:
178
+ message += f"--> {matched_keyframe}\n"
179
+ return message
180
+
181
+
182
+ def check_input_for_interpolating(frames_path, keyframes_path):
183
+ # search for images
184
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
185
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
186
+ # match frames
187
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
188
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
189
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
190
+ frames_guide = VideoData(None, frames_path)
191
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
192
+ # match shape
193
+ message = ""
194
+ height_guide, width_guide = frames_guide.shape()
195
+ height_style, width_style = frames_style.shape()
196
+ if height_guide != height_style or width_guide != width_style:
197
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
198
+ frames_style.set_shape(height_guide, width_guide)
199
+ return frames_guide, frames_style, index_style, message
200
+
201
+
202
+ def interpolate_video(
203
+ frames_path,
204
+ keyframes_path,
205
+ output_path,
206
+ fps,
207
+ batch_size,
208
+ tracking_window_size,
209
+ minimum_patch_size,
210
+ num_iter,
211
+ guide_weight,
212
+ initialize,
213
+ progress = None,
214
+ ):
215
+ # input
216
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
217
+ if len(message) > 0:
218
+ print(message)
219
+ # output
220
+ if output_path == "":
221
+ output_path = os.path.join(keyframes_path, "output")
222
+ os.makedirs(output_path, exist_ok=True)
223
+ print("No valid output_path. Your video will be saved here:", output_path)
224
+ elif not os.path.exists(output_path):
225
+ os.makedirs(output_path, exist_ok=True)
226
+ print("Your video will be saved here:", output_path)
227
+ output_frames_path = os.path.join(output_path, "frames")
228
+ output_video_path = os.path.join(output_path, "video.mp4")
229
+ os.makedirs(output_frames_path, exist_ok=True)
230
+ # process
231
+ ebsynth_config = {
232
+ "minimum_patch_size": minimum_patch_size,
233
+ "threads_per_block": 8,
234
+ "num_iter": num_iter,
235
+ "gpu_id": 0,
236
+ "guide_weight": guide_weight,
237
+ "initialize": initialize,
238
+ "tracking_window_size": tracking_window_size
239
+ }
240
+ if len(index_style)==1:
241
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
242
+ else:
243
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
244
+ try:
245
+ fps = int(fps)
246
+ except:
247
+ fps = 30
248
+ print("Fps:", fps)
249
+ print("Saving video...")
250
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
251
+ print("Success!")
252
+ print("Your frames are here:", output_frames_path)
253
+ print("Your video is here:", video_path)
254
+ return output_path, fps, video_path
255
+
256
+
257
+ def on_ui_tabs():
258
+ with gr.Blocks(analytics_enabled=False) as ui_component:
259
+ with gr.Tab("Blend"):
260
+ gr.Markdown("""
261
+ # Blend
262
+
263
+ Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
264
+ """)
265
+ with gr.Row():
266
+ with gr.Column():
267
+ with gr.Tab("Guide video"):
268
+ video_guide = gr.Video(label="Guide video")
269
+ with gr.Tab("Guide video (images format)"):
270
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
271
+ with gr.Column():
272
+ with gr.Tab("Style video"):
273
+ video_style = gr.Video(label="Style video")
274
+ with gr.Tab("Style video (images format)"):
275
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
276
+ with gr.Column():
277
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
278
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
279
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
280
+ btn = gr.Button(value="Blend")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ gr.Markdown("# Settings")
284
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
285
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
286
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
287
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
288
+ gr.Markdown("## Advanced Settings")
289
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
290
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
291
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
292
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
293
+ with gr.Column():
294
+ gr.Markdown("""
295
+ # Reference
296
+
297
+ * Output directory: the directory to save the video.
298
+ * Inference mode
299
+
300
+ |Mode|Time|Memory|Quality|Frame by frame output|Description|
301
+ |-|-|-|-|-|-|
302
+ |Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
303
+ |Balanced|■■|■|■■|Yes|Blend the frames naively.|
304
+ |Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
305
+
306
+ * Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
307
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
308
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
309
+ * Advanced settings
310
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
311
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
312
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
313
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
314
+ """)
315
+ btn.click(
316
+ smooth_video,
317
+ inputs=[
318
+ video_guide,
319
+ video_guide_folder,
320
+ video_style,
321
+ video_style_folder,
322
+ mode,
323
+ window_size,
324
+ batch_size,
325
+ tracking_window_size,
326
+ output_path,
327
+ fps,
328
+ minimum_patch_size,
329
+ num_iter,
330
+ guide_weight,
331
+ initialize
332
+ ],
333
+ outputs=[output_path, fps, video_output]
334
+ )
335
+ with gr.Tab("Interpolate"):
336
+ gr.Markdown("""
337
+ # Interpolate
338
+
339
+ Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
340
+ """)
341
+ with gr.Row():
342
+ with gr.Column():
343
+ with gr.Row():
344
+ with gr.Column():
345
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
346
+ with gr.Column():
347
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
348
+ with gr.Row():
349
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
350
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
351
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
352
+ with gr.Column():
353
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
354
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
355
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
356
+ btn_ = gr.Button(value="Interpolate")
357
+ with gr.Row():
358
+ with gr.Column():
359
+ gr.Markdown("# Settings")
360
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
361
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
362
+ gr.Markdown("## Advanced Settings")
363
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
364
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
365
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
366
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
367
+ with gr.Column():
368
+ gr.Markdown("""
369
+ # Reference
370
+
371
+ * Output directory: the directory to save the video.
372
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
373
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
374
+ * Advanced settings
375
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
376
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
377
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
378
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
379
+ """)
380
+ btn_.click(
381
+ interpolate_video,
382
+ inputs=[
383
+ video_guide_folder_,
384
+ rendered_keyframes_,
385
+ output_path_,
386
+ fps_,
387
+ batch_size_,
388
+ tracking_window_size_,
389
+ minimum_patch_size_,
390
+ num_iter_,
391
+ guide_weight_,
392
+ initialize_,
393
+ ],
394
+ outputs=[output_path_, fps_, video_output_]
395
+ )
396
+
397
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diffsynth/extensions/FastBlend/cupy_kernels.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cupy as cp
2
+
3
+ remapping_kernel = cp.RawKernel(r'''
4
+ extern "C" __global__
5
+ void remap(
6
+ const int height,
7
+ const int width,
8
+ const int channel,
9
+ const int patch_size,
10
+ const int pad_size,
11
+ const float* source_style,
12
+ const int* nnf,
13
+ float* target_style
14
+ ) {
15
+ const int r = (patch_size - 1) / 2;
16
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
17
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
18
+ if (x >= height or y >= width) return;
19
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21
+ const int min_px = x < r ? -x : -r;
22
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
23
+ const int min_py = y < r ? -y : -r;
24
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
25
+ int num = 0;
26
+ for (int px = min_px; px <= max_px; px++){
27
+ for (int py = min_py; py <= max_py; py++){
28
+ const int nid = (x + px) * width + y + py;
29
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33
+ num++;
34
+ for (int c = 0; c < channel; c++){
35
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36
+ }
37
+ }
38
+ }
39
+ for (int c = 0; c < channel; c++){
40
+ target_style[z + pid * channel + c] /= num;
41
+ }
42
+ }
43
+ ''', 'remap')
44
+
45
+
46
+ patch_error_kernel = cp.RawKernel(r'''
47
+ extern "C" __global__
48
+ void patch_error(
49
+ const int height,
50
+ const int width,
51
+ const int channel,
52
+ const int patch_size,
53
+ const int pad_size,
54
+ const float* source,
55
+ const int* nnf,
56
+ const float* target,
57
+ float* error
58
+ ) {
59
+ const int r = (patch_size - 1) / 2;
60
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
61
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
62
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63
+ if (x >= height or y >= width) return;
64
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66
+ float e = 0;
67
+ for (int px = -r; px <= r; px++){
68
+ for (int py = -r; py <= r; py++){
69
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71
+ for (int c = 0; c < channel; c++){
72
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73
+ e += diff * diff;
74
+ }
75
+ }
76
+ }
77
+ error[blockIdx.z * height * width + x * width + y] = e;
78
+ }
79
+ ''', 'patch_error')
80
+
81
+
82
+ pairwise_patch_error_kernel = cp.RawKernel(r'''
83
+ extern "C" __global__
84
+ void pairwise_patch_error(
85
+ const int height,
86
+ const int width,
87
+ const int channel,
88
+ const int patch_size,
89
+ const int pad_size,
90
+ const float* source_a,
91
+ const int* nnf_a,
92
+ const float* source_b,
93
+ const int* nnf_b,
94
+ float* error
95
+ ) {
96
+ const int r = (patch_size - 1) / 2;
97
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
98
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
99
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100
+ if (x >= height or y >= width) return;
101
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102
+ const int x_a = nnf_a[z_nnf + 0];
103
+ const int y_a = nnf_a[z_nnf + 1];
104
+ const int x_b = nnf_b[z_nnf + 0];
105
+ const int y_b = nnf_b[z_nnf + 1];
106
+ float e = 0;
107
+ for (int px = -r; px <= r; px++){
108
+ for (int py = -r; py <= r; py++){
109
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111
+ for (int c = 0; c < channel; c++){
112
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113
+ e += diff * diff;
114
+ }
115
+ }
116
+ }
117
+ error[blockIdx.z * height * width + x * width + y] = e;
118
+ }
119
+ ''', 'pairwise_patch_error')
diffsynth/extensions/FastBlend/data.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_video(file_name):
7
+ reader = imageio.get_reader(file_name)
8
+ video = []
9
+ for frame in reader:
10
+ frame = np.array(frame)
11
+ video.append(frame)
12
+ reader.close()
13
+ return video
14
+
15
+
16
+ def get_video_fps(file_name):
17
+ reader = imageio.get_reader(file_name)
18
+ fps = reader.get_meta_data()["fps"]
19
+ reader.close()
20
+ return fps
21
+
22
+
23
+ def save_video(frames_path, video_path, num_frames, fps):
24
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
+ for i in range(num_frames):
26
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
+ writer.append_data(frame)
28
+ writer.close()
29
+ return video_path
30
+
31
+
32
+ class LowMemoryVideo:
33
+ def __init__(self, file_name):
34
+ self.reader = imageio.get_reader(file_name)
35
+
36
+ def __len__(self):
37
+ return self.reader.count_frames()
38
+
39
+ def __getitem__(self, item):
40
+ return np.array(self.reader.get_data(item))
41
+
42
+ def __del__(self):
43
+ self.reader.close()
44
+
45
+
46
+ def split_file_name(file_name):
47
+ result = []
48
+ number = -1
49
+ for i in file_name:
50
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
+ if number == -1:
52
+ number = 0
53
+ number = number*10 + ord(i) - ord("0")
54
+ else:
55
+ if number != -1:
56
+ result.append(number)
57
+ number = -1
58
+ result.append(i)
59
+ if number != -1:
60
+ result.append(number)
61
+ result = tuple(result)
62
+ return result
63
+
64
+
65
+ def search_for_images(folder):
66
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
+ file_list = [i[1] for i in sorted(file_list)]
69
+ file_list = [os.path.join(folder, i) for i in file_list]
70
+ return file_list
71
+
72
+
73
+ def read_images(folder):
74
+ file_list = search_for_images(folder)
75
+ frames = [np.array(Image.open(i)) for i in file_list]
76
+ return frames
77
+
78
+
79
+ class LowMemoryImageFolder:
80
+ def __init__(self, folder, file_list=None):
81
+ if file_list is None:
82
+ self.file_list = search_for_images(folder)
83
+ else:
84
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
+
86
+ def __len__(self):
87
+ return len(self.file_list)
88
+
89
+ def __getitem__(self, item):
90
+ return np.array(Image.open(self.file_list[item]))
91
+
92
+ def __del__(self):
93
+ pass
94
+
95
+
96
+ class VideoData:
97
+ def __init__(self, video_file, image_folder, **kwargs):
98
+ if video_file is not None:
99
+ self.data_type = "video"
100
+ self.data = LowMemoryVideo(video_file, **kwargs)
101
+ elif image_folder is not None:
102
+ self.data_type = "images"
103
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
+ else:
105
+ raise ValueError("Cannot open video or image folder")
106
+ self.length = None
107
+ self.height = None
108
+ self.width = None
109
+
110
+ def raw_data(self):
111
+ frames = []
112
+ for i in range(self.__len__()):
113
+ frames.append(self.__getitem__(i))
114
+ return frames
115
+
116
+ def set_length(self, length):
117
+ self.length = length
118
+
119
+ def set_shape(self, height, width):
120
+ self.height = height
121
+ self.width = width
122
+
123
+ def __len__(self):
124
+ if self.length is None:
125
+ return len(self.data)
126
+ else:
127
+ return self.length
128
+
129
+ def shape(self):
130
+ if self.height is not None and self.width is not None:
131
+ return self.height, self.width
132
+ else:
133
+ height, width, _ = self.__getitem__(0).shape
134
+ return height, width
135
+
136
+ def __getitem__(self, item):
137
+ frame = self.data.__getitem__(item)
138
+ height, width, _ = frame.shape
139
+ if self.height is not None and self.width is not None:
140
+ if self.height != height or self.width != width:
141
+ frame = Image.fromarray(frame).resize((self.width, self.height))
142
+ frame = np.array(frame)
143
+ return frame
144
+
145
+ def __del__(self):
146
+ pass
diffsynth/extensions/FastBlend/patch_match.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
+ import numpy as np
3
+ import cupy as cp
4
+ import cv2
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ class PatchMatcher:
9
+ def __init__(
10
+ self, height, width, channel, minimum_patch_size,
11
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
12
+ random_search_steps=3, random_search_range=4,
13
+ use_mean_target_style=False, use_pairwise_patch_error=False,
14
+ tracking_window_size=0
15
+ ):
16
+ self.height = height
17
+ self.width = width
18
+ self.channel = channel
19
+ self.minimum_patch_size = minimum_patch_size
20
+ self.threads_per_block = threads_per_block
21
+ self.num_iter = num_iter
22
+ self.gpu_id = gpu_id
23
+ self.guide_weight = guide_weight
24
+ self.random_search_steps = random_search_steps
25
+ self.random_search_range = random_search_range
26
+ self.use_mean_target_style = use_mean_target_style
27
+ self.use_pairwise_patch_error = use_pairwise_patch_error
28
+ self.tracking_window_size = tracking_window_size
29
+
30
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
31
+ self.pad_size = self.patch_size_list[0] // 2
32
+ self.grid = (
33
+ (height + threads_per_block - 1) // threads_per_block,
34
+ (width + threads_per_block - 1) // threads_per_block
35
+ )
36
+ self.block = (threads_per_block, threads_per_block)
37
+
38
+ def pad_image(self, image):
39
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
40
+
41
+ def unpad_image(self, image):
42
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
43
+
44
+ def apply_nnf_to_image(self, nnf, source):
45
+ batch_size = source.shape[0]
46
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
47
+ remapping_kernel(
48
+ self.grid + (batch_size,),
49
+ self.block,
50
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
51
+ )
52
+ return target
53
+
54
+ def get_patch_error(self, source, nnf, target):
55
+ batch_size = source.shape[0]
56
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
57
+ patch_error_kernel(
58
+ self.grid + (batch_size,),
59
+ self.block,
60
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
61
+ )
62
+ return error
63
+
64
+ def get_pairwise_patch_error(self, source, nnf):
65
+ batch_size = source.shape[0]//2
66
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
67
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
68
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
69
+ pairwise_patch_error_kernel(
70
+ self.grid + (batch_size,),
71
+ self.block,
72
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
73
+ )
74
+ error = error.repeat(2, axis=0)
75
+ return error
76
+
77
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
78
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
79
+ if self.use_mean_target_style:
80
+ target_style = self.apply_nnf_to_image(nnf, source_style)
81
+ target_style = target_style.mean(axis=0, keepdims=True)
82
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
83
+ if self.use_pairwise_patch_error:
84
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
85
+ else:
86
+ error_style = self.get_patch_error(source_style, nnf, target_style)
87
+ error = error_guide * self.guide_weight + error_style
88
+ return error
89
+
90
+ def clamp_bound(self, nnf):
91
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
92
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
93
+ return nnf
94
+
95
+ def random_step(self, nnf, r):
96
+ batch_size = nnf.shape[0]
97
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
98
+ upd_nnf = self.clamp_bound(nnf + step)
99
+ return upd_nnf
100
+
101
+ def neighboor_step(self, nnf, d):
102
+ if d==0:
103
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
104
+ upd_nnf[:, :, :, 0] += 1
105
+ elif d==1:
106
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
107
+ upd_nnf[:, :, :, 1] += 1
108
+ elif d==2:
109
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
110
+ upd_nnf[:, :, :, 0] -= 1
111
+ elif d==3:
112
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
113
+ upd_nnf[:, :, :, 1] -= 1
114
+ upd_nnf = self.clamp_bound(upd_nnf)
115
+ return upd_nnf
116
+
117
+ def shift_nnf(self, nnf, d):
118
+ if d>0:
119
+ d = min(nnf.shape[0], d)
120
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
121
+ else:
122
+ d = max(-nnf.shape[0], d)
123
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
124
+ return upd_nnf
125
+
126
+ def track_step(self, nnf, d):
127
+ if self.use_pairwise_patch_error:
128
+ upd_nnf = cp.zeros_like(nnf)
129
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
130
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
131
+ else:
132
+ upd_nnf = self.shift_nnf(nnf, d)
133
+ return upd_nnf
134
+
135
+ def C(self, n, m):
136
+ # not used
137
+ c = 1
138
+ for i in range(1, n+1):
139
+ c *= i
140
+ for i in range(1, m+1):
141
+ c //= i
142
+ for i in range(1, n-m+1):
143
+ c //= i
144
+ return c
145
+
146
+ def bezier_step(self, nnf, r):
147
+ # not used
148
+ n = r * 2 - 1
149
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
150
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
151
+ if d>0:
152
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
153
+ elif d<0:
154
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
155
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
156
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
157
+ return upd_nnf
158
+
159
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
160
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
161
+ upd_idx = (upd_err < err)
162
+ nnf[upd_idx] = upd_nnf[upd_idx]
163
+ err[upd_idx] = upd_err[upd_idx]
164
+ return nnf, err
165
+
166
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
167
+ for d in cp.random.permutation(4):
168
+ upd_nnf = self.neighboor_step(nnf, d)
169
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
170
+ return nnf, err
171
+
172
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
173
+ for i in range(self.random_search_steps):
174
+ upd_nnf = self.random_step(nnf, self.random_search_range)
175
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
176
+ return nnf, err
177
+
178
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
179
+ for d in range(1, self.tracking_window_size + 1):
180
+ upd_nnf = self.track_step(nnf, d)
181
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
182
+ upd_nnf = self.track_step(nnf, -d)
183
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
184
+ return nnf, err
185
+
186
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
187
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
188
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
189
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
190
+ return nnf, err
191
+
192
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
193
+ with cp.cuda.Device(self.gpu_id):
194
+ source_guide = self.pad_image(source_guide)
195
+ target_guide = self.pad_image(target_guide)
196
+ source_style = self.pad_image(source_style)
197
+ for it in range(self.num_iter):
198
+ self.patch_size = self.patch_size_list[it]
199
+ target_style = self.apply_nnf_to_image(nnf, source_style)
200
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
201
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
202
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
203
+ return nnf, target_style
204
+
205
+
206
+ class PyramidPatchMatcher:
207
+ def __init__(
208
+ self, image_height, image_width, channel, minimum_patch_size,
209
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
210
+ use_mean_target_style=False, use_pairwise_patch_error=False,
211
+ tracking_window_size=0,
212
+ initialize="identity"
213
+ ):
214
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
215
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
216
+ self.pyramid_heights = []
217
+ self.pyramid_widths = []
218
+ self.patch_matchers = []
219
+ self.minimum_patch_size = minimum_patch_size
220
+ self.num_iter = num_iter
221
+ self.gpu_id = gpu_id
222
+ self.initialize = initialize
223
+ for level in range(self.pyramid_level):
224
+ height = image_height//(2**(self.pyramid_level - 1 - level))
225
+ width = image_width//(2**(self.pyramid_level - 1 - level))
226
+ self.pyramid_heights.append(height)
227
+ self.pyramid_widths.append(width)
228
+ self.patch_matchers.append(PatchMatcher(
229
+ height, width, channel, minimum_patch_size=minimum_patch_size,
230
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
231
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
232
+ tracking_window_size=tracking_window_size
233
+ ))
234
+
235
+ def resample_image(self, images, level):
236
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
237
+ images_torch = torch.as_tensor(images, device='cuda', dtype=torch.float32)
238
+ images_torch = images_torch.permute(0, 3, 1, 2)
239
+ images_resample = F.interpolate(images_torch, size=(height, width), mode='area', align_corners=None)
240
+ images_resample = images_resample.permute(0, 2, 3, 1).contiguous()
241
+ return cp.asarray(images_resample)
242
+
243
+ def initialize_nnf(self, batch_size):
244
+ if self.initialize == "random":
245
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
246
+ nnf = cp.stack([
247
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
248
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
249
+ ], axis=3)
250
+ elif self.initialize == "identity":
251
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
252
+ nnf = cp.stack([
253
+ cp.repeat(cp.arange(height), width).reshape(height, width),
254
+ cp.tile(cp.arange(width), height).reshape(height, width)
255
+ ], axis=2)
256
+ nnf = cp.stack([nnf] * batch_size)
257
+ else:
258
+ raise NotImplementedError()
259
+ return nnf
260
+
261
+ def update_nnf(self, nnf, level):
262
+ # upscale
263
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
264
+ nnf[:, 1::2, :, 0] += 1
265
+ nnf[:, :, 1::2, 1] += 1
266
+ # check if scale is 2
267
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
268
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
269
+ nnf_torch = torch.as_tensor(nnf, device='cuda', dtype=torch.float32)
270
+ nnf_torch = nnf_torch.permute(0, 3, 1, 2)
271
+ nnf_resized = F.interpolate(nnf_torch, size=(height, width), mode='bilinear', align_corners=False)
272
+ nnf_resized = nnf_resized.permute(0, 2, 3, 1)
273
+ nnf = cp.asarray(nnf_resized).astype(cp.int32)
274
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
275
+ return nnf
276
+
277
+ def apply_nnf_to_image(self, nnf, image):
278
+ with cp.cuda.Device(self.gpu_id):
279
+ image = self.patch_matchers[-1].pad_image(image)
280
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
281
+ return image
282
+
283
+ def estimate_nnf(self, source_guide, target_guide, source_style):
284
+ with cp.cuda.Device(self.gpu_id):
285
+ if not isinstance(source_guide, cp.ndarray):
286
+ source_guide = cp.array(source_guide, dtype=cp.float32)
287
+ if not isinstance(target_guide, cp.ndarray):
288
+ target_guide = cp.array(target_guide, dtype=cp.float32)
289
+ if not isinstance(source_style, cp.ndarray):
290
+ source_style = cp.array(source_style, dtype=cp.float32)
291
+ for level in range(self.pyramid_level):
292
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
293
+ source_guide_ = self.resample_image(source_guide, level)
294
+ target_guide_ = self.resample_image(target_guide, level)
295
+ source_style_ = self.resample_image(source_style, level)
296
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
297
+ source_guide_, target_guide_, source_style_, nnf
298
+ )
299
+ return nnf.get(), target_style.get()
diffsynth/extensions/FastBlend/runners/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .accurate import AccurateModeRunner
2
+ from .fast import FastModeRunner
3
+ from .balanced import BalancedModeRunner
4
+ from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diffsynth/extensions/FastBlend/runners/accurate.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AccurateModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ use_mean_target_style=True,
18
+ **ebsynth_config
19
+ )
20
+ # run
21
+ n = len(frames_style)
22
+ for target in tqdm(range(n), desc=desc):
23
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
+ remapped_frames = []
25
+ for i in range(l, r, batch_size):
26
+ j = min(i + batch_size, r)
27
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
+ target_guide = np.stack([frames_guide[target]] * (j - i))
29
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
30
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
+ remapped_frames.append(target_style)
32
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
+ frame = frame.clip(0, 255).astype("uint8")
34
+ if save_path is not None:
35
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/balanced.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class BalancedModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ **ebsynth_config
18
+ )
19
+ # tasks
20
+ n = len(frames_style)
21
+ tasks = []
22
+ for target in range(n):
23
+ for source in range(target - window_size, target + window_size + 1):
24
+ if source >= 0 and source < n and source != target:
25
+ tasks.append((source, target))
26
+ # run
27
+ frames = [(None, 1) for i in range(n)]
28
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
+ for (source, target), result in zip(tasks_batch, target_style):
35
+ frame, weight = frames[target]
36
+ if frame is None:
37
+ frame = frames_style[target]
38
+ frames[target] = (
39
+ frame * (weight / (weight + 1)) + result / (weight + 1),
40
+ weight + 1
41
+ )
42
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
+ frame = frame.clip(0, 255).astype("uint8")
44
+ if save_path is not None:
45
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
+ frames[target] = (None, 1)
diffsynth/extensions/FastBlend/runners/fast.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import functools, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TableManager:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def task_list(self, n):
13
+ tasks = []
14
+ max_level = 1
15
+ while (1<<max_level)<=n:
16
+ max_level += 1
17
+ for i in range(n):
18
+ j = i
19
+ for level in range(max_level):
20
+ if i&(1<<level):
21
+ continue
22
+ j |= 1<<level
23
+ if j>=n:
24
+ break
25
+ meta_data = {
26
+ "source": i,
27
+ "target": j,
28
+ "level": level + 1
29
+ }
30
+ tasks.append(meta_data)
31
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
+ return tasks
33
+
34
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
+ n = len(frames_guide)
36
+ tasks = self.task_list(n)
37
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
+ for task, result in zip(tasks_batch, target_style):
45
+ target, level = task["target"], task["level"]
46
+ if len(remapping_table[target])==level:
47
+ remapping_table[target].append((result, 1))
48
+ else:
49
+ frame, weight = remapping_table[target][level]
50
+ remapping_table[target][level] = (
51
+ frame * (weight / (weight + 1)) + result / (weight + 1),
52
+ weight + 1
53
+ )
54
+ return remapping_table
55
+
56
+ def remapping_table_to_blending_table(self, table):
57
+ for i in range(len(table)):
58
+ for j in range(1, len(table[i])):
59
+ frame_1, weight_1 = table[i][j-1]
60
+ frame_2, weight_2 = table[i][j]
61
+ frame = (frame_1 + frame_2) / 2
62
+ weight = weight_1 + weight_2
63
+ table[i][j] = (frame, weight)
64
+ return table
65
+
66
+ def tree_query(self, leftbound, rightbound):
67
+ node_list = []
68
+ node_index = rightbound
69
+ while node_index>=leftbound:
70
+ node_level = 0
71
+ while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
+ node_level += 1
73
+ node_list.append((node_index, node_level))
74
+ node_index -= 1<<node_level
75
+ return node_list
76
+
77
+ def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
+ n = len(blending_table)
79
+ tasks = []
80
+ frames_result = []
81
+ for target in range(n):
82
+ node_list = self.tree_query(max(target-window_size, 0), target)
83
+ for source, level in node_list:
84
+ if source!=target:
85
+ meta_data = {
86
+ "source": source,
87
+ "target": target,
88
+ "level": level
89
+ }
90
+ tasks.append(meta_data)
91
+ else:
92
+ frames_result.append(blending_table[target][level])
93
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
+ source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
+ for task, frame_2 in zip(tasks_batch, target_style):
100
+ source, target, level = task["source"], task["target"], task["level"]
101
+ frame_1, weight_1 = frames_result[target]
102
+ weight_2 = blending_table[source][level][1]
103
+ weight = weight_1 + weight_2
104
+ frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
+ frames_result[target] = (frame, weight)
106
+ return frames_result
107
+
108
+
109
+ class FastModeRunner:
110
+ def __init__(self):
111
+ pass
112
+
113
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
+ frames_guide = frames_guide.raw_data()
115
+ frames_style = frames_style.raw_data()
116
+ table_manager = TableManager()
117
+ patch_match_engine = PyramidPatchMatcher(
118
+ image_height=frames_style[0].shape[0],
119
+ image_width=frames_style[0].shape[1],
120
+ channel=3,
121
+ **ebsynth_config
122
+ )
123
+ # left part
124
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
126
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
+ # right part
128
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
130
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
+ # merge
132
+ frames = []
133
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
+ weight_m = -1
135
+ weight = weight_l + weight_m + weight_r
136
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
+ frames.append(frame)
138
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
+ if save_path is not None:
140
+ for target, frame in enumerate(frames):
141
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/interpolation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class InterpolationModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def get_index_dict(self, index_style):
13
+ index_dict = {}
14
+ for i, index in enumerate(index_style):
15
+ index_dict[index] = i
16
+ return index_dict
17
+
18
+ def get_weight(self, l, m, r):
19
+ weight_l, weight_r = abs(m - r), abs(m - l)
20
+ if weight_l + weight_r == 0:
21
+ weight_l, weight_r = 0.5, 0.5
22
+ else:
23
+ weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
+ return weight_l, weight_r
25
+
26
+ def get_task_group(self, index_style, n):
27
+ task_group = []
28
+ index_style = sorted(index_style)
29
+ # first frame
30
+ if index_style[0]>0:
31
+ tasks = []
32
+ for m in range(index_style[0]):
33
+ tasks.append((index_style[0], m, index_style[0]))
34
+ task_group.append(tasks)
35
+ # middle frames
36
+ for l, r in zip(index_style[:-1], index_style[1:]):
37
+ tasks = []
38
+ for m in range(l, r):
39
+ tasks.append((l, m, r))
40
+ task_group.append(tasks)
41
+ # last frame
42
+ tasks = []
43
+ for m in range(index_style[-1], n):
44
+ tasks.append((index_style[-1], m, index_style[-1]))
45
+ task_group.append(tasks)
46
+ return task_group
47
+
48
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
+ patch_match_engine = PyramidPatchMatcher(
50
+ image_height=frames_style[0].shape[0],
51
+ image_width=frames_style[0].shape[1],
52
+ channel=3,
53
+ use_mean_target_style=False,
54
+ use_pairwise_patch_error=True,
55
+ **ebsynth_config
56
+ )
57
+ # task
58
+ index_dict = self.get_index_dict(index_style)
59
+ task_group = self.get_task_group(index_style, len(frames_guide))
60
+ # run
61
+ for tasks in task_group:
62
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
+ source_guide, target_guide, source_style = [], [], []
66
+ for l, m, r in tasks_batch:
67
+ # l -> m
68
+ source_guide.append(frames_guide[l])
69
+ target_guide.append(frames_guide[m])
70
+ source_style.append(frames_style[index_dict[l]])
71
+ # r -> m
72
+ source_guide.append(frames_guide[r])
73
+ target_guide.append(frames_guide[m])
74
+ source_style.append(frames_style[index_dict[r]])
75
+ source_guide = np.stack(source_guide)
76
+ target_guide = np.stack(target_guide)
77
+ source_style = np.stack(source_style)
78
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
+ if save_path is not None:
80
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
+ weight_l, weight_r = self.get_weight(l, m, r)
82
+ frame = frame_l * weight_l + frame_r * weight_r
83
+ frame = frame.clip(0, 255).astype("uint8")
84
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
+
86
+
87
+ class InterpolationModeSingleFrameRunner:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
+ # check input
93
+ tracking_window_size = ebsynth_config["tracking_window_size"]
94
+ if tracking_window_size * 2 >= batch_size:
95
+ raise ValueError("batch_size should be larger than track_window_size * 2")
96
+ frame_style = frames_style[0]
97
+ frame_guide = frames_guide[index_style[0]]
98
+ patch_match_engine = PyramidPatchMatcher(
99
+ image_height=frame_style.shape[0],
100
+ image_width=frame_style.shape[1],
101
+ channel=3,
102
+ **ebsynth_config
103
+ )
104
+ # run
105
+ frame_id, n = 0, len(frames_guide)
106
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
+ if i + batch_size > n:
108
+ l, r = max(n - batch_size, 0), n
109
+ else:
110
+ l, r = i, i + batch_size
111
+ source_guide = np.stack([frame_guide] * (r-l))
112
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
+ source_style = np.stack([frame_style] * (r-l))
114
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
+ for i, frame in zip(range(l, r), target_style):
116
+ if i==frame_id:
117
+ frame = frame.clip(0, 255).astype("uint8")
118
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
+ frame_id += 1
120
+ if r < n and r-frame_id <= tracking_window_size:
121
+ break
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .blip_pretrain import *
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import torch
9
+ import os
10
+ from urllib.parse import urlparse
11
+ from timm.models.hub import download_cached_file
12
+ from transformers import BertTokenizer
13
+ from .vit import VisionTransformer, interpolate_pos_embed
14
+
15
+
16
+ def default_bert():
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
19
+ model_path = os.path.join(project_root, 'models', 'QualityMetric')
20
+ return os.path.join(model_path, "bert-base-uncased")
21
+
22
+
23
+ def init_tokenizer(bert_model_path):
24
+ tokenizer = BertTokenizer.from_pretrained(bert_model_path)
25
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
26
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
27
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
28
+ return tokenizer
29
+
30
+
31
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
32
+
33
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
34
+ if vit=='base':
35
+ vision_width = 768
36
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
37
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
38
+ drop_path_rate=0 or drop_path_rate
39
+ )
40
+ elif vit=='large':
41
+ vision_width = 1024
42
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
43
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
44
+ drop_path_rate=0.1 or drop_path_rate
45
+ )
46
+ return visual_encoder, vision_width
47
+
48
+
49
+ def is_url(url_or_filename):
50
+ parsed = urlparse(url_or_filename)
51
+ return parsed.scheme in ("http", "https")
52
+
53
+ def load_checkpoint(model,url_or_filename):
54
+ if is_url(url_or_filename):
55
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
56
+ checkpoint = torch.load(cached_file, map_location='cpu')
57
+ elif os.path.isfile(url_or_filename):
58
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
59
+ else:
60
+ raise RuntimeError('checkpoint url or path is invalid')
61
+
62
+ state_dict = checkpoint['model']
63
+
64
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
65
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
66
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
67
+ model.visual_encoder_m)
68
+ for key in model.state_dict().keys():
69
+ if key in state_dict.keys():
70
+ if state_dict[key].shape!=model.state_dict()[key].shape:
71
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
72
+ del state_dict[key]
73
+
74
+ msg = model.load_state_dict(state_dict,strict=False)
75
+ print('load checkpoint from %s'%url_or_filename)
76
+ return model,msg
77
+
diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ from torch import nn
9
+ import os
10
+ from .med import BertConfig, BertModel
11
+ from .blip import create_vit, init_tokenizer
12
+
13
+ class BLIP_Pretrain(nn.Module):
14
+ def __init__(self,
15
+ med_config = "med_config.json",
16
+ image_size = 224,
17
+ vit = 'base',
18
+ vit_grad_ckpt = False,
19
+ vit_ckpt_layer = 0,
20
+ embed_dim = 256,
21
+ queue_size = 57600,
22
+ momentum = 0.995,
23
+ bert_model_path = ""
24
+ ):
25
+ """
26
+ Args:
27
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
28
+ image_size (int): input image size
29
+ vit (str): model size of vision transformer
30
+ """
31
+ super().__init__()
32
+
33
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
34
+
35
+ self.tokenizer = init_tokenizer(bert_model_path)
36
+ encoder_config = BertConfig.from_json_file(med_config)
37
+ encoder_config.encoder_width = vision_width
38
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
39
+
40
+ text_width = self.text_encoder.config.hidden_size
41
+
42
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
43
+ self.text_proj = nn.Linear(text_width, embed_dim)
44
+
diffsynth/extensions/ImageQualityMetric/BLIP/med.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on huggingface code base
4
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
5
+ '''
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch import Tensor, device, nn
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.file_utils import (
18
+ ModelOutput,
19
+ )
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutputWithPastAndCrossAttentions,
22
+ BaseModelOutputWithPoolingAndCrossAttentions,
23
+ CausalLMOutputWithCrossAttentions,
24
+ MaskedLMOutput,
25
+ MultipleChoiceModelOutput,
26
+ NextSentencePredictorOutput,
27
+ QuestionAnsweringModelOutput,
28
+ SequenceClassifierOutput,
29
+ TokenClassifierOutput,
30
+ )
31
+ from transformers.modeling_utils import (
32
+ PreTrainedModel,
33
+ apply_chunking_to_forward,
34
+ find_pruneable_heads_and_indices,
35
+ prune_linear_layer,
36
+ )
37
+ from transformers.utils import logging
38
+ from transformers.models.bert.configuration_bert import BertConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class BertEmbeddings(nn.Module):
45
+ """Construct the embeddings from word and position embeddings."""
46
+
47
+ def __init__(self, config):
48
+ super().__init__()
49
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
50
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
51
+
52
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
53
+ # any TensorFlow checkpoint file
54
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
56
+
57
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
58
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
59
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
60
+
61
+ self.config = config
62
+
63
+ def forward(
64
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
65
+ ):
66
+ if input_ids is not None:
67
+ input_shape = input_ids.size()
68
+ else:
69
+ input_shape = inputs_embeds.size()[:-1]
70
+
71
+ seq_length = input_shape[1]
72
+
73
+ if position_ids is None:
74
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
75
+
76
+ if inputs_embeds is None:
77
+ inputs_embeds = self.word_embeddings(input_ids)
78
+
79
+ embeddings = inputs_embeds
80
+
81
+ if self.position_embedding_type == "absolute":
82
+ position_embeddings = self.position_embeddings(position_ids)
83
+ embeddings += position_embeddings
84
+ embeddings = self.LayerNorm(embeddings)
85
+ embeddings = self.dropout(embeddings)
86
+ return embeddings
87
+
88
+
89
+ class BertSelfAttention(nn.Module):
90
+ def __init__(self, config, is_cross_attention):
91
+ super().__init__()
92
+ self.config = config
93
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
94
+ raise ValueError(
95
+ "The hidden size (%d) is not a multiple of the number of attention "
96
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
97
+ )
98
+
99
+ self.num_attention_heads = config.num_attention_heads
100
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
101
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
102
+
103
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
104
+ if is_cross_attention:
105
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
106
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
107
+ else:
108
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
109
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
110
+
111
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
112
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
113
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
114
+ self.max_position_embeddings = config.max_position_embeddings
115
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
116
+ self.save_attention = False
117
+
118
+ def save_attn_gradients(self, attn_gradients):
119
+ self.attn_gradients = attn_gradients
120
+
121
+ def get_attn_gradients(self):
122
+ return self.attn_gradients
123
+
124
+ def save_attention_map(self, attention_map):
125
+ self.attention_map = attention_map
126
+
127
+ def get_attention_map(self):
128
+ return self.attention_map
129
+
130
+ def transpose_for_scores(self, x):
131
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
132
+ x = x.view(*new_x_shape)
133
+ return x.permute(0, 2, 1, 3)
134
+
135
+ def forward(
136
+ self,
137
+ hidden_states,
138
+ attention_mask=None,
139
+ head_mask=None,
140
+ encoder_hidden_states=None,
141
+ encoder_attention_mask=None,
142
+ past_key_value=None,
143
+ output_attentions=False,
144
+ ):
145
+ mixed_query_layer = self.query(hidden_states)
146
+
147
+ # If this is instantiated as a cross-attention module, the keys
148
+ # and values come from an encoder; the attention mask needs to be
149
+ # such that the encoder's padding tokens are not attended to.
150
+ is_cross_attention = encoder_hidden_states is not None
151
+
152
+ if is_cross_attention:
153
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
154
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
155
+ attention_mask = encoder_attention_mask
156
+ elif past_key_value is not None:
157
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
158
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
159
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
160
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
161
+ else:
162
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
163
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
164
+
165
+ query_layer = self.transpose_for_scores(mixed_query_layer)
166
+
167
+ past_key_value = (key_layer, value_layer)
168
+
169
+ # Take the dot product between "query" and "key" to get the raw attention scores.
170
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
171
+
172
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
173
+ seq_length = hidden_states.size()[1]
174
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
175
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
176
+ distance = position_ids_l - position_ids_r
177
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
178
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
179
+
180
+ if self.position_embedding_type == "relative_key":
181
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
182
+ attention_scores = attention_scores + relative_position_scores
183
+ elif self.position_embedding_type == "relative_key_query":
184
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
185
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
186
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
187
+
188
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
189
+ if attention_mask is not None:
190
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
191
+ attention_scores = attention_scores + attention_mask
192
+
193
+ # Normalize the attention scores to probabilities.
194
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
195
+
196
+ if is_cross_attention and self.save_attention:
197
+ self.save_attention_map(attention_probs)
198
+ attention_probs.register_hook(self.save_attn_gradients)
199
+
200
+ # This is actually dropping out entire tokens to attend to, which might
201
+ # seem a bit unusual, but is taken from the original Transformer paper.
202
+ attention_probs_dropped = self.dropout(attention_probs)
203
+
204
+ # Mask heads if we want to
205
+ if head_mask is not None:
206
+ attention_probs_dropped = attention_probs_dropped * head_mask
207
+
208
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
209
+
210
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
211
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
212
+ context_layer = context_layer.view(*new_context_layer_shape)
213
+
214
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
215
+
216
+ outputs = outputs + (past_key_value,)
217
+ return outputs
218
+
219
+
220
+ class BertSelfOutput(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
224
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
225
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
226
+
227
+ def forward(self, hidden_states, input_tensor):
228
+ hidden_states = self.dense(hidden_states)
229
+ hidden_states = self.dropout(hidden_states)
230
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
231
+ return hidden_states
232
+
233
+
234
+ class BertAttention(nn.Module):
235
+ def __init__(self, config, is_cross_attention=False):
236
+ super().__init__()
237
+ self.self = BertSelfAttention(config, is_cross_attention)
238
+ self.output = BertSelfOutput(config)
239
+ self.pruned_heads = set()
240
+
241
+ def prune_heads(self, heads):
242
+ if len(heads) == 0:
243
+ return
244
+ heads, index = find_pruneable_heads_and_indices(
245
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
246
+ )
247
+
248
+ # Prune linear layers
249
+ self.self.query = prune_linear_layer(self.self.query, index)
250
+ self.self.key = prune_linear_layer(self.self.key, index)
251
+ self.self.value = prune_linear_layer(self.self.value, index)
252
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
253
+
254
+ # Update hyper params and store pruned heads
255
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
256
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
257
+ self.pruned_heads = self.pruned_heads.union(heads)
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states,
262
+ attention_mask=None,
263
+ head_mask=None,
264
+ encoder_hidden_states=None,
265
+ encoder_attention_mask=None,
266
+ past_key_value=None,
267
+ output_attentions=False,
268
+ ):
269
+ self_outputs = self.self(
270
+ hidden_states,
271
+ attention_mask,
272
+ head_mask,
273
+ encoder_hidden_states,
274
+ encoder_attention_mask,
275
+ past_key_value,
276
+ output_attentions,
277
+ )
278
+ attention_output = self.output(self_outputs[0], hidden_states)
279
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
280
+ return outputs
281
+
282
+
283
+ class BertIntermediate(nn.Module):
284
+ def __init__(self, config):
285
+ super().__init__()
286
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
287
+ if isinstance(config.hidden_act, str):
288
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
289
+ else:
290
+ self.intermediate_act_fn = config.hidden_act
291
+
292
+ def forward(self, hidden_states):
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.intermediate_act_fn(hidden_states)
295
+ return hidden_states
296
+
297
+
298
+ class BertOutput(nn.Module):
299
+ def __init__(self, config):
300
+ super().__init__()
301
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
302
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
303
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
304
+
305
+ def forward(self, hidden_states, input_tensor):
306
+ hidden_states = self.dense(hidden_states)
307
+ hidden_states = self.dropout(hidden_states)
308
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
309
+ return hidden_states
310
+
311
+
312
+ class BertLayer(nn.Module):
313
+ def __init__(self, config, layer_num):
314
+ super().__init__()
315
+ self.config = config
316
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
317
+ self.seq_len_dim = 1
318
+ self.attention = BertAttention(config)
319
+ self.layer_num = layer_num
320
+ if self.config.add_cross_attention:
321
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
322
+ self.intermediate = BertIntermediate(config)
323
+ self.output = BertOutput(config)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states,
328
+ attention_mask=None,
329
+ head_mask=None,
330
+ encoder_hidden_states=None,
331
+ encoder_attention_mask=None,
332
+ past_key_value=None,
333
+ output_attentions=False,
334
+ mode=None,
335
+ ):
336
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
337
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
338
+ self_attention_outputs = self.attention(
339
+ hidden_states,
340
+ attention_mask,
341
+ head_mask,
342
+ output_attentions=output_attentions,
343
+ past_key_value=self_attn_past_key_value,
344
+ )
345
+ attention_output = self_attention_outputs[0]
346
+
347
+ outputs = self_attention_outputs[1:-1]
348
+ present_key_value = self_attention_outputs[-1]
349
+
350
+ if mode=='multimodal':
351
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
352
+
353
+ cross_attention_outputs = self.crossattention(
354
+ attention_output,
355
+ attention_mask,
356
+ head_mask,
357
+ encoder_hidden_states,
358
+ encoder_attention_mask,
359
+ output_attentions=output_attentions,
360
+ )
361
+ attention_output = cross_attention_outputs[0]
362
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
363
+ layer_output = apply_chunking_to_forward(
364
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
365
+ )
366
+ outputs = (layer_output,) + outputs
367
+
368
+ outputs = outputs + (present_key_value,)
369
+
370
+ return outputs
371
+
372
+ def feed_forward_chunk(self, attention_output):
373
+ intermediate_output = self.intermediate(attention_output)
374
+ layer_output = self.output(intermediate_output, attention_output)
375
+ return layer_output
376
+
377
+
378
+ class BertEncoder(nn.Module):
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.config = config
382
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
383
+ self.gradient_checkpointing = False
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states,
388
+ attention_mask=None,
389
+ head_mask=None,
390
+ encoder_hidden_states=None,
391
+ encoder_attention_mask=None,
392
+ past_key_values=None,
393
+ use_cache=None,
394
+ output_attentions=False,
395
+ output_hidden_states=False,
396
+ return_dict=True,
397
+ mode='multimodal',
398
+ ):
399
+ all_hidden_states = () if output_hidden_states else None
400
+ all_self_attentions = () if output_attentions else None
401
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
402
+
403
+ next_decoder_cache = () if use_cache else None
404
+
405
+ for i in range(self.config.num_hidden_layers):
406
+ layer_module = self.layer[i]
407
+ if output_hidden_states:
408
+ all_hidden_states = all_hidden_states + (hidden_states,)
409
+
410
+ layer_head_mask = head_mask[i] if head_mask is not None else None
411
+ past_key_value = past_key_values[i] if past_key_values is not None else None
412
+
413
+ if self.gradient_checkpointing and self.training:
414
+
415
+ if use_cache:
416
+ logger.warning(
417
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
418
+ )
419
+ use_cache = False
420
+
421
+ def create_custom_forward(module):
422
+ def custom_forward(*inputs):
423
+ return module(*inputs, past_key_value, output_attentions)
424
+
425
+ return custom_forward
426
+
427
+ layer_outputs = torch.utils.checkpoint.checkpoint(
428
+ create_custom_forward(layer_module),
429
+ hidden_states,
430
+ attention_mask,
431
+ layer_head_mask,
432
+ encoder_hidden_states,
433
+ encoder_attention_mask,
434
+ mode=mode,
435
+ )
436
+ else:
437
+ layer_outputs = layer_module(
438
+ hidden_states,
439
+ attention_mask,
440
+ layer_head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ past_key_value,
444
+ output_attentions,
445
+ mode=mode,
446
+ )
447
+
448
+ hidden_states = layer_outputs[0]
449
+ if use_cache:
450
+ next_decoder_cache += (layer_outputs[-1],)
451
+ if output_attentions:
452
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
453
+
454
+ if output_hidden_states:
455
+ all_hidden_states = all_hidden_states + (hidden_states,)
456
+
457
+ if not return_dict:
458
+ return tuple(
459
+ v
460
+ for v in [
461
+ hidden_states,
462
+ next_decoder_cache,
463
+ all_hidden_states,
464
+ all_self_attentions,
465
+ all_cross_attentions,
466
+ ]
467
+ if v is not None
468
+ )
469
+ return BaseModelOutputWithPastAndCrossAttentions(
470
+ last_hidden_state=hidden_states,
471
+ past_key_values=next_decoder_cache,
472
+ hidden_states=all_hidden_states,
473
+ attentions=all_self_attentions,
474
+ cross_attentions=all_cross_attentions,
475
+ )
476
+
477
+
478
+ class BertPooler(nn.Module):
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
482
+ self.activation = nn.Tanh()
483
+
484
+ def forward(self, hidden_states):
485
+ # We "pool" the model by simply taking the hidden state corresponding
486
+ # to the first token.
487
+ first_token_tensor = hidden_states[:, 0]
488
+ pooled_output = self.dense(first_token_tensor)
489
+ pooled_output = self.activation(pooled_output)
490
+ return pooled_output
491
+
492
+
493
+ class BertPredictionHeadTransform(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
497
+ if isinstance(config.hidden_act, str):
498
+ self.transform_act_fn = ACT2FN[config.hidden_act]
499
+ else:
500
+ self.transform_act_fn = config.hidden_act
501
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502
+
503
+ def forward(self, hidden_states):
504
+ hidden_states = self.dense(hidden_states)
505
+ hidden_states = self.transform_act_fn(hidden_states)
506
+ hidden_states = self.LayerNorm(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertLMPredictionHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.transform = BertPredictionHeadTransform(config)
514
+
515
+ # The output weights are the same as the input embeddings, but there is
516
+ # an output-only bias for each token.
517
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
518
+
519
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
520
+
521
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
522
+ self.decoder.bias = self.bias
523
+
524
+ def forward(self, hidden_states):
525
+ hidden_states = self.transform(hidden_states)
526
+ hidden_states = self.decoder(hidden_states)
527
+ return hidden_states
528
+
529
+
530
+ class BertOnlyMLMHead(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+
535
+ def forward(self, sequence_output):
536
+ prediction_scores = self.predictions(sequence_output)
537
+ return prediction_scores
538
+
539
+
540
+ class BertPreTrainedModel(PreTrainedModel):
541
+ """
542
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
543
+ models.
544
+ """
545
+
546
+ config_class = BertConfig
547
+ base_model_prefix = "bert"
548
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
549
+
550
+ def _init_weights(self, module):
551
+ """ Initialize the weights """
552
+ if isinstance(module, (nn.Linear, nn.Embedding)):
553
+ # Slightly different from the TF version which uses truncated_normal for initialization
554
+ # cf https://github.com/pytorch/pytorch/pull/5617
555
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
556
+ elif isinstance(module, nn.LayerNorm):
557
+ module.bias.data.zero_()
558
+ module.weight.data.fill_(1.0)
559
+ if isinstance(module, nn.Linear) and module.bias is not None:
560
+ module.bias.data.zero_()
561
+
562
+
563
+ class BertModel(BertPreTrainedModel):
564
+ """
565
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
566
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
567
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
568
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
569
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
570
+ input to the forward pass.
571
+ """
572
+
573
+ def __init__(self, config, add_pooling_layer=True):
574
+ super().__init__(config)
575
+ self.config = config
576
+
577
+ self.embeddings = BertEmbeddings(config)
578
+
579
+ self.encoder = BertEncoder(config)
580
+
581
+ self.pooler = BertPooler(config) if add_pooling_layer else None
582
+
583
+ self.init_weights()
584
+
585
+
586
+ def get_input_embeddings(self):
587
+ return self.embeddings.word_embeddings
588
+
589
+ def set_input_embeddings(self, value):
590
+ self.embeddings.word_embeddings = value
591
+
592
+ def _prune_heads(self, heads_to_prune):
593
+ """
594
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
595
+ class PreTrainedModel
596
+ """
597
+ for layer, heads in heads_to_prune.items():
598
+ self.encoder.layer[layer].attention.prune_heads(heads)
599
+
600
+
601
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
602
+ """
603
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
604
+
605
+ Arguments:
606
+ attention_mask (:obj:`torch.Tensor`):
607
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
608
+ input_shape (:obj:`Tuple[int]`):
609
+ The shape of the input to the model.
610
+ device: (:obj:`torch.device`):
611
+ The device of the input to the model.
612
+
613
+ Returns:
614
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
615
+ """
616
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
617
+ # ourselves in which case we just need to make it broadcastable to all heads.
618
+ if attention_mask.dim() == 3:
619
+ extended_attention_mask = attention_mask[:, None, :, :]
620
+ elif attention_mask.dim() == 2:
621
+ # Provided a padding mask of dimensions [batch_size, seq_length]
622
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
623
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
624
+ if is_decoder:
625
+ batch_size, seq_length = input_shape
626
+
627
+ seq_ids = torch.arange(seq_length, device=device)
628
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
629
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
630
+ # causal and attention masks must have same type with pytorch version < 1.3
631
+ causal_mask = causal_mask.to(attention_mask.dtype)
632
+
633
+ if causal_mask.shape[1] < attention_mask.shape[1]:
634
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
635
+ causal_mask = torch.cat(
636
+ [
637
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
638
+ causal_mask,
639
+ ],
640
+ axis=-1,
641
+ )
642
+
643
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
644
+ else:
645
+ extended_attention_mask = attention_mask[:, None, None, :]
646
+ else:
647
+ raise ValueError(
648
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
649
+ input_shape, attention_mask.shape
650
+ )
651
+ )
652
+
653
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654
+ # masked positions, this operation will create a tensor which is 0.0 for
655
+ # positions we want to attend and -10000.0 for masked positions.
656
+ # Since we are adding it to the raw scores before the softmax, this is
657
+ # effectively the same as removing these entirely.
658
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
659
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
660
+ return extended_attention_mask
661
+
662
+ def forward(
663
+ self,
664
+ input_ids=None,
665
+ attention_mask=None,
666
+ position_ids=None,
667
+ head_mask=None,
668
+ inputs_embeds=None,
669
+ encoder_embeds=None,
670
+ encoder_hidden_states=None,
671
+ encoder_attention_mask=None,
672
+ past_key_values=None,
673
+ use_cache=None,
674
+ output_attentions=None,
675
+ output_hidden_states=None,
676
+ return_dict=None,
677
+ is_decoder=False,
678
+ mode='multimodal',
679
+ ):
680
+ r"""
681
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
682
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
683
+ the model is configured as a decoder.
684
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
685
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
686
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
687
+ - 1 for tokens that are **not masked**,
688
+ - 0 for tokens that are **masked**.
689
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
690
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
691
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
692
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
693
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
694
+ use_cache (:obj:`bool`, `optional`):
695
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
696
+ decoding (see :obj:`past_key_values`).
697
+ """
698
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
699
+ output_hidden_states = (
700
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
701
+ )
702
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
703
+
704
+ if is_decoder:
705
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
706
+ else:
707
+ use_cache = False
708
+
709
+ if input_ids is not None and inputs_embeds is not None:
710
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
711
+ elif input_ids is not None:
712
+ input_shape = input_ids.size()
713
+ batch_size, seq_length = input_shape
714
+ device = input_ids.device
715
+ elif inputs_embeds is not None:
716
+ input_shape = inputs_embeds.size()[:-1]
717
+ batch_size, seq_length = input_shape
718
+ device = inputs_embeds.device
719
+ elif encoder_embeds is not None:
720
+ input_shape = encoder_embeds.size()[:-1]
721
+ batch_size, seq_length = input_shape
722
+ device = encoder_embeds.device
723
+ else:
724
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
725
+
726
+ # past_key_values_length
727
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
728
+
729
+ if attention_mask is None:
730
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
731
+
732
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
733
+ # ourselves in which case we just need to make it broadcastable to all heads.
734
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
735
+ device, is_decoder)
736
+
737
+ # If a 2D or 3D attention mask is provided for the cross-attention
738
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
739
+ if encoder_hidden_states is not None:
740
+ if type(encoder_hidden_states) == list:
741
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
742
+ else:
743
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
744
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
745
+
746
+ if type(encoder_attention_mask) == list:
747
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
748
+ elif encoder_attention_mask is None:
749
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
750
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
751
+ else:
752
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
753
+ else:
754
+ encoder_extended_attention_mask = None
755
+
756
+ # Prepare head mask if needed
757
+ # 1.0 in head_mask indicate we keep the head
758
+ # attention_probs has shape bsz x n_heads x N x N
759
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
760
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
761
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
762
+
763
+ if encoder_embeds is None:
764
+ embedding_output = self.embeddings(
765
+ input_ids=input_ids,
766
+ position_ids=position_ids,
767
+ inputs_embeds=inputs_embeds,
768
+ past_key_values_length=past_key_values_length,
769
+ )
770
+ else:
771
+ embedding_output = encoder_embeds
772
+
773
+ encoder_outputs = self.encoder(
774
+ embedding_output,
775
+ attention_mask=extended_attention_mask,
776
+ head_mask=head_mask,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ encoder_attention_mask=encoder_extended_attention_mask,
779
+ past_key_values=past_key_values,
780
+ use_cache=use_cache,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ return_dict=return_dict,
784
+ mode=mode,
785
+ )
786
+ sequence_output = encoder_outputs[0]
787
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
788
+
789
+ if not return_dict:
790
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
791
+
792
+ return BaseModelOutputWithPoolingAndCrossAttentions(
793
+ last_hidden_state=sequence_output,
794
+ pooler_output=pooled_output,
795
+ past_key_values=encoder_outputs.past_key_values,
796
+ hidden_states=encoder_outputs.hidden_states,
797
+ attentions=encoder_outputs.attentions,
798
+ cross_attentions=encoder_outputs.cross_attentions,
799
+ )
800
+
801
+
802
+
803
+ class BertLMHeadModel(BertPreTrainedModel):
804
+
805
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
806
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
807
+
808
+ def __init__(self, config):
809
+ super().__init__(config)
810
+
811
+ self.bert = BertModel(config, add_pooling_layer=False)
812
+ self.cls = BertOnlyMLMHead(config)
813
+
814
+ self.init_weights()
815
+
816
+ def get_output_embeddings(self):
817
+ return self.cls.predictions.decoder
818
+
819
+ def set_output_embeddings(self, new_embeddings):
820
+ self.cls.predictions.decoder = new_embeddings
821
+
822
+ def forward(
823
+ self,
824
+ input_ids=None,
825
+ attention_mask=None,
826
+ position_ids=None,
827
+ head_mask=None,
828
+ inputs_embeds=None,
829
+ encoder_hidden_states=None,
830
+ encoder_attention_mask=None,
831
+ labels=None,
832
+ past_key_values=None,
833
+ use_cache=None,
834
+ output_attentions=None,
835
+ output_hidden_states=None,
836
+ return_dict=None,
837
+ return_logits=False,
838
+ is_decoder=True,
839
+ reduction='mean',
840
+ mode='multimodal',
841
+ ):
842
+ r"""
843
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
844
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
845
+ the model is configured as a decoder.
846
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
847
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
848
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
849
+ - 1 for tokens that are **not masked**,
850
+ - 0 for tokens that are **masked**.
851
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
852
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
853
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
854
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
855
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
856
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
857
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
858
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
859
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
860
+ use_cache (:obj:`bool`, `optional`):
861
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
862
+ decoding (see :obj:`past_key_values`).
863
+ Returns:
864
+ Example::
865
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
866
+ >>> import torch
867
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
868
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
869
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
870
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
871
+ >>> outputs = model(**inputs)
872
+ >>> prediction_logits = outputs.logits
873
+ """
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+ if labels is not None:
876
+ use_cache = False
877
+
878
+ outputs = self.bert(
879
+ input_ids,
880
+ attention_mask=attention_mask,
881
+ position_ids=position_ids,
882
+ head_mask=head_mask,
883
+ inputs_embeds=inputs_embeds,
884
+ encoder_hidden_states=encoder_hidden_states,
885
+ encoder_attention_mask=encoder_attention_mask,
886
+ past_key_values=past_key_values,
887
+ use_cache=use_cache,
888
+ output_attentions=output_attentions,
889
+ output_hidden_states=output_hidden_states,
890
+ return_dict=return_dict,
891
+ is_decoder=is_decoder,
892
+ mode=mode,
893
+ )
894
+
895
+ sequence_output = outputs[0]
896
+ prediction_scores = self.cls(sequence_output)
897
+
898
+ if return_logits:
899
+ return prediction_scores[:, :-1, :].contiguous()
900
+
901
+ lm_loss = None
902
+ if labels is not None:
903
+ # we are doing next-token prediction; shift prediction scores and input ids by one
904
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
905
+ labels = labels[:, 1:].contiguous()
906
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
907
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
908
+ if reduction=='none':
909
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
910
+
911
+ if not return_dict:
912
+ output = (prediction_scores,) + outputs[2:]
913
+ return ((lm_loss,) + output) if lm_loss is not None else output
914
+
915
+ return CausalLMOutputWithCrossAttentions(
916
+ loss=lm_loss,
917
+ logits=prediction_scores,
918
+ past_key_values=outputs.past_key_values,
919
+ hidden_states=outputs.hidden_states,
920
+ attentions=outputs.attentions,
921
+ cross_attentions=outputs.cross_attentions,
922
+ )
923
+
924
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
925
+ input_shape = input_ids.shape
926
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
927
+ if attention_mask is None:
928
+ attention_mask = input_ids.new_ones(input_shape)
929
+
930
+ # cut decoder_input_ids if past is used
931
+ if past is not None:
932
+ input_ids = input_ids[:, -1:]
933
+
934
+ return {
935
+ "input_ids": input_ids,
936
+ "attention_mask": attention_mask,
937
+ "past_key_values": past,
938
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
939
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
940
+ "is_decoder": True,
941
+ }
942
+
943
+ def _reorder_cache(self, past, beam_idx):
944
+ reordered_past = ()
945
+ for layer_past in past:
946
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
947
+ return reordered_past
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on timm code base
4
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ '''
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from functools import partial
11
+
12
+ from timm.models.vision_transformer import _cfg, PatchEmbed
13
+ from timm.models.registry import register_model
14
+ from timm.models.layers import trunc_normal_, DropPath
15
+ from timm.models.helpers import named_apply, adapt_input_conv
16
+
17
+ # from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
18
+
19
+ class Mlp(nn.Module):
20
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
21
+ """
22
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.fc1 = nn.Linear(in_features, hidden_features)
27
+ self.act = act_layer()
28
+ self.fc2 = nn.Linear(hidden_features, out_features)
29
+ self.drop = nn.Dropout(drop)
30
+
31
+ def forward(self, x):
32
+ x = self.fc1(x)
33
+ x = self.act(x)
34
+ x = self.drop(x)
35
+ x = self.fc2(x)
36
+ x = self.drop(x)
37
+ return x
38
+
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
42
+ super().__init__()
43
+ self.num_heads = num_heads
44
+ head_dim = dim // num_heads
45
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
46
+ self.scale = qk_scale or head_dim ** -0.5
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
+ self.attn_drop = nn.Dropout(attn_drop)
49
+ self.proj = nn.Linear(dim, dim)
50
+ self.proj_drop = nn.Dropout(proj_drop)
51
+ self.attn_gradients = None
52
+ self.attention_map = None
53
+
54
+ def save_attn_gradients(self, attn_gradients):
55
+ self.attn_gradients = attn_gradients
56
+
57
+ def get_attn_gradients(self):
58
+ return self.attn_gradients
59
+
60
+ def save_attention_map(self, attention_map):
61
+ self.attention_map = attention_map
62
+
63
+ def get_attention_map(self):
64
+ return self.attention_map
65
+
66
+ def forward(self, x, register_hook=False):
67
+ B, N, C = x.shape
68
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
70
+
71
+ attn = (q @ k.transpose(-2, -1)) * self.scale
72
+ attn = attn.softmax(dim=-1)
73
+ attn = self.attn_drop(attn)
74
+
75
+ if register_hook:
76
+ self.save_attention_map(attn)
77
+ attn.register_hook(self.save_attn_gradients)
78
+
79
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
80
+ x = self.proj(x)
81
+ x = self.proj_drop(x)
82
+ return x
83
+
84
+
85
+ class Block(nn.Module):
86
+
87
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
88
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
93
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ self.norm2 = norm_layer(dim)
96
+ mlp_hidden_dim = int(dim * mlp_ratio)
97
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
+
99
+ # if use_grad_checkpointing:
100
+ # self.attn = checkpoint_wrapper(self.attn)
101
+ # self.mlp = checkpoint_wrapper(self.mlp)
102
+
103
+ def forward(self, x, register_hook=False):
104
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
105
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
106
+ return x
107
+
108
+
109
+ class VisionTransformer(nn.Module):
110
+ """ Vision Transformer
111
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
112
+ https://arxiv.org/abs/2010.11929
113
+ """
114
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
115
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
116
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
117
+ use_grad_checkpointing=False, ckpt_layer=0):
118
+ """
119
+ Args:
120
+ img_size (int, tuple): input image size
121
+ patch_size (int, tuple): patch size
122
+ in_chans (int): number of input channels
123
+ num_classes (int): number of classes for classification head
124
+ embed_dim (int): embedding dimension
125
+ depth (int): depth of transformer
126
+ num_heads (int): number of attention heads
127
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
128
+ qkv_bias (bool): enable bias for qkv if True
129
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
130
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
131
+ drop_rate (float): dropout rate
132
+ attn_drop_rate (float): attention dropout rate
133
+ drop_path_rate (float): stochastic depth rate
134
+ norm_layer: (nn.Module): normalization layer
135
+ """
136
+ super().__init__()
137
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
138
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
139
+
140
+ self.patch_embed = PatchEmbed(
141
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
142
+
143
+ num_patches = self.patch_embed.num_patches
144
+
145
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147
+ self.pos_drop = nn.Dropout(p=drop_rate)
148
+
149
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
150
+ self.blocks = nn.ModuleList([
151
+ Block(
152
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
153
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
154
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
155
+ )
156
+ for i in range(depth)])
157
+ self.norm = norm_layer(embed_dim)
158
+
159
+ trunc_normal_(self.pos_embed, std=.02)
160
+ trunc_normal_(self.cls_token, std=.02)
161
+ self.apply(self._init_weights)
162
+
163
+ def _init_weights(self, m):
164
+ if isinstance(m, nn.Linear):
165
+ trunc_normal_(m.weight, std=.02)
166
+ if isinstance(m, nn.Linear) and m.bias is not None:
167
+ nn.init.constant_(m.bias, 0)
168
+ elif isinstance(m, nn.LayerNorm):
169
+ nn.init.constant_(m.bias, 0)
170
+ nn.init.constant_(m.weight, 1.0)
171
+
172
+ @torch.jit.ignore
173
+ def no_weight_decay(self):
174
+ return {'pos_embed', 'cls_token'}
175
+
176
+ def forward(self, x, register_blk=-1):
177
+ B = x.shape[0]
178
+ x = self.patch_embed(x)
179
+
180
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
181
+ x = torch.cat((cls_tokens, x), dim=1)
182
+
183
+ x = x + self.pos_embed[:,:x.size(1),:]
184
+ x = self.pos_drop(x)
185
+
186
+ for i,blk in enumerate(self.blocks):
187
+ x = blk(x, register_blk==i)
188
+ x = self.norm(x)
189
+
190
+ return x
191
+
192
+ @torch.jit.ignore()
193
+ def load_pretrained(self, checkpoint_path, prefix=''):
194
+ _load_weights(self, checkpoint_path, prefix)
195
+
196
+
197
+ @torch.no_grad()
198
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
199
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
200
+ """
201
+ import numpy as np
202
+
203
+ def _n2p(w, t=True):
204
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
205
+ w = w.flatten()
206
+ if t:
207
+ if w.ndim == 4:
208
+ w = w.transpose([3, 2, 0, 1])
209
+ elif w.ndim == 3:
210
+ w = w.transpose([2, 0, 1])
211
+ elif w.ndim == 2:
212
+ w = w.transpose([1, 0])
213
+ return torch.from_numpy(w)
214
+
215
+ w = np.load(checkpoint_path)
216
+ if not prefix and 'opt/target/embedding/kernel' in w:
217
+ prefix = 'opt/target/'
218
+
219
+ if hasattr(model.patch_embed, 'backbone'):
220
+ # hybrid
221
+ backbone = model.patch_embed.backbone
222
+ stem_only = not hasattr(backbone, 'stem')
223
+ stem = backbone if stem_only else backbone.stem
224
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
225
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
226
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
227
+ if not stem_only:
228
+ for i, stage in enumerate(backbone.stages):
229
+ for j, block in enumerate(stage.blocks):
230
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
231
+ for r in range(3):
232
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
233
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
234
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
235
+ if block.downsample is not None:
236
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
237
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
238
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
239
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
240
+ else:
241
+ embed_conv_w = adapt_input_conv(
242
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
243
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
244
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
245
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
246
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
247
+ if pos_embed_w.shape != model.pos_embed.shape:
248
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
249
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
250
+ model.pos_embed.copy_(pos_embed_w)
251
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
252
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
253
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
254
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
255
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
256
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
257
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
258
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
259
+ for i, block in enumerate(model.blocks.children()):
260
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
261
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
262
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
263
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
264
+ block.attn.qkv.weight.copy_(torch.cat([
265
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
266
+ block.attn.qkv.bias.copy_(torch.cat([
267
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
268
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
269
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
270
+ for r in range(2):
271
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
272
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
273
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
274
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
275
+
276
+
277
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
278
+ # interpolate position embedding
279
+ embedding_size = pos_embed_checkpoint.shape[-1]
280
+ num_patches = visual_encoder.patch_embed.num_patches
281
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
282
+ # height (== width) for the checkpoint position embedding
283
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
284
+ # height (== width) for the new position embedding
285
+ new_size = int(num_patches ** 0.5)
286
+
287
+ if orig_size!=new_size:
288
+ # class_token and dist_token are kept unchanged
289
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
290
+ # only the position tokens are interpolated
291
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
292
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
293
+ pos_tokens = torch.nn.functional.interpolate(
294
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
295
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
296
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
297
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
298
+
299
+ return new_pos_embed
300
+ else:
301
+ return pos_embed_checkpoint