itishalogicgo commited on
Commit
7e328d3
·
1 Parent(s): ef6f7fd
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. LICENSE +201 -0
  3. app.py +183 -0
  4. configs/hair_transfer.yaml +22 -0
  5. default_config.yaml +16 -0
  6. diffusers/.DS_Store +0 -0
  7. diffusers/__init__.py +734 -0
  8. diffusers/commands/__init__.py +27 -0
  9. diffusers/commands/diffusers_cli.py +43 -0
  10. diffusers/commands/env.py +84 -0
  11. diffusers/commands/fp16_safetensors.py +133 -0
  12. diffusers/configuration_utils.py +694 -0
  13. diffusers/dependency_versions_check.py +35 -0
  14. diffusers/dependency_versions_table.py +46 -0
  15. diffusers/experimental/README.md +5 -0
  16. diffusers/experimental/__init__.py +1 -0
  17. diffusers/experimental/rl/__init__.py +1 -0
  18. diffusers/experimental/rl/value_guided_sampling.py +154 -0
  19. diffusers/image_processor.py +476 -0
  20. diffusers/loaders.py +0 -0
  21. diffusers/models/README.md +3 -0
  22. diffusers/models/__init__.py +77 -0
  23. diffusers/models/__pycache__/__init__.cpython-310.pyc +0 -0
  24. diffusers/models/__pycache__/__init__.cpython-38.pyc +0 -0
  25. diffusers/models/__pycache__/__init__.cpython-39.pyc +0 -0
  26. diffusers/models/__pycache__/activations.cpython-310.pyc +0 -0
  27. diffusers/models/__pycache__/activations.cpython-38.pyc +0 -0
  28. diffusers/models/__pycache__/activations.cpython-39.pyc +0 -0
  29. diffusers/models/__pycache__/attention.cpython-310.pyc +0 -0
  30. diffusers/models/__pycache__/attention.cpython-38.pyc +0 -0
  31. diffusers/models/__pycache__/attention.cpython-39.pyc +0 -0
  32. diffusers/models/__pycache__/attention_processor.cpython-310.pyc +0 -0
  33. diffusers/models/__pycache__/attention_processor.cpython-38.pyc +0 -0
  34. diffusers/models/__pycache__/attention_processor.cpython-39.pyc +0 -0
  35. diffusers/models/__pycache__/autoencoder_asym_kl.cpython-310.pyc +0 -0
  36. diffusers/models/__pycache__/autoencoder_asym_kl.cpython-38.pyc +0 -0
  37. diffusers/models/__pycache__/autoencoder_asym_kl.cpython-39.pyc +0 -0
  38. diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc +0 -0
  39. diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc +0 -0
  40. diffusers/models/__pycache__/autoencoder_kl.cpython-39.pyc +0 -0
  41. diffusers/models/__pycache__/controlnet.cpython-310.pyc +0 -0
  42. diffusers/models/__pycache__/controlnet.cpython-38.pyc +0 -0
  43. diffusers/models/__pycache__/controlnet.cpython-39.pyc +0 -0
  44. diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc +0 -0
  45. diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc +0 -0
  46. diffusers/models/__pycache__/dual_transformer_2d.cpython-39.pyc +0 -0
  47. diffusers/models/__pycache__/embeddings.cpython-310.pyc +0 -0
  48. diffusers/models/__pycache__/embeddings.cpython-38.pyc +0 -0
  49. diffusers/models/__pycache__/embeddings.cpython-39.pyc +0 -0
  50. diffusers/models/__pycache__/lora.cpython-310.pyc +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ import os
8
+ import cv2
9
+ from diffusers import DDIMScheduler, UniPCMultistepScheduler
10
+ from diffusers.models import UNet2DConditionModel
11
+ from ref_encoder.latent_controlnet import ControlNetModel
12
+ from ref_encoder.adapter import *
13
+ from ref_encoder.reference_unet import ref_unet
14
+ from utils.pipeline import StableHairPipeline
15
+ from utils.pipeline_cn import StableDiffusionControlNetPipeline
16
+ from huggingface_hub import hf_hub_download
17
+
18
+
19
+ class StableHair:
20
+ def __init__(self, config="./configs/hair_transfer.yaml", device="cuda", weight_dtype=torch.float32) -> None:
21
+ print("Initializing Stable Hair Pipeline...")
22
+ self.config = OmegaConf.load(config)
23
+ self.device = device
24
+
25
+ # Hugging Face repo with weights
26
+ repo_id = "LogicGoInfotechSpaces/new_weights"
27
+
28
+ # Map config paths to Hugging Face repo structure
29
+ # Based on config: pretrained_folder: "./models/stage2"
30
+ # encoder_path: "pytorch_model.bin" -> stage2/pytorch_model.bin
31
+ # adapter_path: "pytorch_model_1.bin" -> stage2/pytorch_model_1.bin
32
+ # controlnet_path: "pytorch_model_2.bin" -> stage2/pytorch_model_2.bin
33
+ # bald_converter_path: "./models/stage1/pytorch_model.bin" -> stage1/pytorch_model.bin
34
+
35
+ # Download weights from Hugging Face
36
+ encoder_hf_path = hf_hub_download(repo_id=repo_id, filename="stage2/pytorch_model.bin")
37
+ adapter_hf_path = hf_hub_download(repo_id=repo_id, filename="stage2/pytorch_model_1.bin")
38
+ controlnet_hf_path = hf_hub_download(repo_id=repo_id, filename="stage2/pytorch_model_2.bin")
39
+ bald_converter_hf_path = hf_hub_download(repo_id=repo_id, filename="stage1/pytorch_model.bin")
40
+
41
+ ### Load vae controlnet
42
+ unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
43
+ controlnet = ControlNetModel.from_unet(unet).to(device)
44
+ _state_dict = torch.load(controlnet_hf_path, map_location="cpu")
45
+ controlnet.load_state_dict(_state_dict, strict=False)
46
+ controlnet.to(weight_dtype)
47
+
48
+ ### >>> create pipeline >>> ###
49
+ self.pipeline = StableHairPipeline.from_pretrained(
50
+ self.config.pretrained_model_path,
51
+ controlnet=controlnet,
52
+ safety_checker=None,
53
+ torch_dtype=weight_dtype,
54
+ ).to(device)
55
+ self.pipeline.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
56
+
57
+ ### load Hair encoder/adapter
58
+ self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
59
+ _state_dict = torch.load(encoder_hf_path, map_location="cpu")
60
+ self.hair_encoder.load_state_dict(_state_dict, strict=False)
61
+ self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float16, use_resampler=False)
62
+ _state_dict = torch.load(adapter_hf_path, map_location="cpu")
63
+ self.hair_adapter.load_state_dict(_state_dict, strict=False)
64
+
65
+ ### load bald converter
66
+ bald_converter = ControlNetModel.from_unet(unet).to(device)
67
+ _state_dict = torch.load(bald_converter_hf_path, map_location="cpu")
68
+ bald_converter.load_state_dict(_state_dict, strict=False)
69
+ bald_converter.to(dtype=weight_dtype)
70
+ del unet
71
+
72
+ ### create pipeline for hair removal
73
+ self.remove_hair_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
74
+ self.config.pretrained_model_path,
75
+ controlnet=bald_converter,
76
+ safety_checker=None,
77
+ torch_dtype=weight_dtype,
78
+ )
79
+ self.remove_hair_pipeline.scheduler = UniPCMultistepScheduler.from_config(self.remove_hair_pipeline.scheduler.config)
80
+ self.remove_hair_pipeline = self.remove_hair_pipeline.to(device)
81
+
82
+ ### move to fp16
83
+ self.hair_encoder.to(weight_dtype)
84
+ self.hair_adapter.to(weight_dtype)
85
+
86
+ print("Initialization Done!")
87
+
88
+ def Hair_Transfer(self, source_image, reference_image, random_seed, step, guidance_scale, scale, controlnet_conditioning_scale):
89
+ prompt = ""
90
+ n_prompt = ""
91
+ random_seed = int(random_seed)
92
+ step = int(step)
93
+ guidance_scale = float(guidance_scale)
94
+ scale = float(scale)
95
+ controlnet_conditioning_scale = float(controlnet_conditioning_scale)
96
+
97
+ # load imgs
98
+ H, W, C = source_image.shape
99
+
100
+ # generate images
101
+ set_scale(self.pipeline.unet, scale)
102
+ generator = torch.Generator(device="cuda")
103
+ generator.manual_seed(random_seed)
104
+ sample = self.pipeline(
105
+ prompt,
106
+ negative_prompt=n_prompt,
107
+ num_inference_steps=step,
108
+ guidance_scale=guidance_scale,
109
+ width=W,
110
+ height=H,
111
+ controlnet_condition=source_image,
112
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
113
+ generator=generator,
114
+ reference_encoder=self.hair_encoder,
115
+ ref_image=reference_image,
116
+ ).samples
117
+ return sample, source_image, reference_image
118
+
119
+ def get_bald(self, id_image, scale):
120
+ H, W = id_image.size
121
+ scale = float(scale)
122
+ image = self.remove_hair_pipeline(
123
+ prompt="",
124
+ negative_prompt="",
125
+ num_inference_steps=30,
126
+ guidance_scale=1.5,
127
+ width=W,
128
+ height=H,
129
+ image=id_image,
130
+ controlnet_conditioning_scale=scale,
131
+ generator=None,
132
+ ).images[0]
133
+
134
+ return image
135
+
136
+
137
+ model = StableHair(config="./configs/hair_transfer.yaml", weight_dtype=torch.float32)
138
+
139
+ # Define your ML model or function here
140
+ def model_call(id_image, ref_hair, converter_scale, scale, guidance_scale, controlnet_conditioning_scale):
141
+ # # Your ML logic goes here
142
+ id_image = Image.fromarray(id_image.astype('uint8'), 'RGB')
143
+ ref_hair = Image.fromarray(ref_hair.astype('uint8'), 'RGB')
144
+ id_image = id_image.resize((512, 512))
145
+ ref_hair = ref_hair.resize((512, 512))
146
+ id_image_bald = model.get_bald(id_image, converter_scale)
147
+
148
+ id_image_bald = np.array(id_image_bald)
149
+ ref_hair = np.array(ref_hair)
150
+
151
+ image, source_image, reference_image = model.Hair_Transfer(source_image=id_image_bald,
152
+ reference_image=ref_hair,
153
+ random_seed=-1,
154
+ step=30,
155
+ guidance_scale=guidance_scale,
156
+ scale=scale,
157
+ controlnet_conditioning_scale=controlnet_conditioning_scale
158
+ )
159
+
160
+ image = Image.fromarray((image * 255.).astype(np.uint8))
161
+ return id_image_bald, image
162
+
163
+ # Create a Gradio interface
164
+ iface = gr.Interface(
165
+ fn=model_call,
166
+ inputs=[
167
+ gr.Image(label="ID Image"),
168
+ gr.Image(label="Reference Hair"),
169
+ gr.Slider(minimum=0.5, maximum=1.5, value=1, label="Converter Scale"),
170
+ gr.Slider(minimum=0.0, maximum=3.0, value=1.0, label="Hair Encoder Scale"),
171
+ gr.Slider(minimum=1.1, maximum=3.0, value=1.5, label="CFG"),
172
+ gr.Slider(minimum=0.1, maximum=2.0, value=1, label="Latent IdentityNet Scale"),
173
+ ],
174
+ outputs=[
175
+ gr.Image(type="pil", label="Bald Result"),
176
+ gr.Image(type="pil", label="Transfer Result"),
177
+ ],
178
+ title="Hair Transfer Demo",
179
+ description="In general, aligned faces work well, but can also be used on non-aligned faces, and you need to resize to 512 * 512"
180
+ )
181
+
182
+ # Launch the Gradio interface
183
+ iface.queue().launch(server_name='0.0.0.0', server_port=7860, share=True)
configs/hair_transfer.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "runwayml/stable-diffusion-v1-5" # your sd path
2
+
3
+ pretrained_folder: "./models/stage2"
4
+ encoder_path: "pytorch_model.bin"
5
+ adapter_path: "pytorch_model_1.bin"
6
+ controlnet_path: "pytorch_model_2.bin"
7
+ bald_converter_path: "./models/stage1/pytorch_model.bin"
8
+
9
+ fusion_blocks: "full"
10
+
11
+ inference_kwargs:
12
+ source_image: "./test_imgs/ID/0.jpg"
13
+ reference_image: "./test_imgs/Ref/0.jpg"
14
+ random_seed: -1
15
+ step: 30
16
+ guidance_scale: 1.5
17
+ controlnet_conditioning_scale: 1
18
+ scale: 1
19
+ size: 512
20
+
21
+ output_path: "./output"
22
+ save_name: "0.jpg"
default_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ main_process_port: 17362
5
+ downcast_bf16: 'no'
6
+ gpu_ids: 0,1,2,3
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: fp16
10
+ num_machines: 1
11
+ num_processes: 4
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
diffusers/.DS_Store ADDED
Binary file (8.2 kB). View file
 
diffusers/__init__.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.23.1"
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .utils import (
6
+ DIFFUSERS_SLOW_IMPORT,
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_flax_available,
10
+ is_k_diffusion_available,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_torch_available,
16
+ is_torchsde_available,
17
+ is_transformers_available,
18
+ )
19
+
20
+
21
+ # Lazy Import based on
22
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
23
+
24
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
25
+ # and is used to defer the actual importing for when the objects are requested.
26
+ # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
27
+
28
+ _import_structure = {
29
+ "configuration_utils": ["ConfigMixin"],
30
+ "models": [],
31
+ "pipelines": [],
32
+ "schedulers": [],
33
+ "utils": [
34
+ "OptionalDependencyNotAvailable",
35
+ "is_flax_available",
36
+ "is_inflect_available",
37
+ "is_invisible_watermark_available",
38
+ "is_k_diffusion_available",
39
+ "is_k_diffusion_version",
40
+ "is_librosa_available",
41
+ "is_note_seq_available",
42
+ "is_onnx_available",
43
+ "is_scipy_available",
44
+ "is_torch_available",
45
+ "is_torchsde_available",
46
+ "is_transformers_available",
47
+ "is_transformers_version",
48
+ "is_unidecode_available",
49
+ "logging",
50
+ ],
51
+ }
52
+
53
+ try:
54
+ if not is_onnx_available():
55
+ raise OptionalDependencyNotAvailable()
56
+ except OptionalDependencyNotAvailable:
57
+ from .utils import dummy_onnx_objects # noqa F403
58
+
59
+ _import_structure["utils.dummy_onnx_objects"] = [
60
+ name for name in dir(dummy_onnx_objects) if not name.startswith("_")
61
+ ]
62
+
63
+ else:
64
+ _import_structure["pipelines"].extend(["OnnxRuntimeModel"])
65
+
66
+ try:
67
+ if not is_torch_available():
68
+ raise OptionalDependencyNotAvailable()
69
+ except OptionalDependencyNotAvailable:
70
+ from .utils import dummy_pt_objects # noqa F403
71
+
72
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
73
+
74
+ else:
75
+ _import_structure["models"].extend(
76
+ [
77
+ "AsymmetricAutoencoderKL",
78
+ "AutoencoderKL",
79
+ "AutoencoderTiny",
80
+ "ConsistencyDecoderVAE",
81
+ "ControlNetModel",
82
+ "ModelMixin",
83
+ "MotionAdapter",
84
+ "MultiAdapter",
85
+ "PriorTransformer",
86
+ "T2IAdapter",
87
+ "T5FilmDecoder",
88
+ "Transformer2DModel",
89
+ "UNet1DModel",
90
+ "UNet2DConditionModel",
91
+ "UNet2DModel",
92
+ "UNet3DConditionModel",
93
+ "UNetMotionModel",
94
+ "VQModel",
95
+ ]
96
+ )
97
+ _import_structure["optimization"] = [
98
+ "get_constant_schedule",
99
+ "get_constant_schedule_with_warmup",
100
+ "get_cosine_schedule_with_warmup",
101
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
102
+ "get_linear_schedule_with_warmup",
103
+ "get_polynomial_decay_schedule_with_warmup",
104
+ "get_scheduler",
105
+ ]
106
+
107
+ _import_structure["pipelines"].extend(
108
+ [
109
+ "AudioPipelineOutput",
110
+ "AutoPipelineForImage2Image",
111
+ "AutoPipelineForInpainting",
112
+ "AutoPipelineForText2Image",
113
+ "ConsistencyModelPipeline",
114
+ "DanceDiffusionPipeline",
115
+ "DDIMPipeline",
116
+ "DDPMPipeline",
117
+ "DiffusionPipeline",
118
+ "DiTPipeline",
119
+ "ImagePipelineOutput",
120
+ "KarrasVePipeline",
121
+ "LDMPipeline",
122
+ "LDMSuperResolutionPipeline",
123
+ "PNDMPipeline",
124
+ "RePaintPipeline",
125
+ "ScoreSdeVePipeline",
126
+ ]
127
+ )
128
+ _import_structure["schedulers"].extend(
129
+ [
130
+ "CMStochasticIterativeScheduler",
131
+ "DDIMInverseScheduler",
132
+ "DDIMParallelScheduler",
133
+ "DDIMScheduler",
134
+ "DDPMParallelScheduler",
135
+ "DDPMScheduler",
136
+ "DDPMWuerstchenScheduler",
137
+ "DEISMultistepScheduler",
138
+ "DPMSolverMultistepInverseScheduler",
139
+ "DPMSolverMultistepScheduler",
140
+ "DPMSolverSinglestepScheduler",
141
+ "EulerAncestralDiscreteScheduler",
142
+ "EulerDiscreteScheduler",
143
+ "HeunDiscreteScheduler",
144
+ "IPNDMScheduler",
145
+ "KarrasVeScheduler",
146
+ "KDPM2AncestralDiscreteScheduler",
147
+ "KDPM2DiscreteScheduler",
148
+ "LCMScheduler",
149
+ "PNDMScheduler",
150
+ "RePaintScheduler",
151
+ "SchedulerMixin",
152
+ "ScoreSdeVeScheduler",
153
+ "UnCLIPScheduler",
154
+ "UniPCMultistepScheduler",
155
+ "VQDiffusionScheduler",
156
+ ]
157
+ )
158
+ _import_structure["training_utils"] = ["EMAModel"]
159
+
160
+ try:
161
+ if not (is_torch_available() and is_scipy_available()):
162
+ raise OptionalDependencyNotAvailable()
163
+ except OptionalDependencyNotAvailable:
164
+ from .utils import dummy_torch_and_scipy_objects # noqa F403
165
+
166
+ _import_structure["utils.dummy_torch_and_scipy_objects"] = [
167
+ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
168
+ ]
169
+
170
+ else:
171
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
172
+
173
+ try:
174
+ if not (is_torch_available() and is_torchsde_available()):
175
+ raise OptionalDependencyNotAvailable()
176
+ except OptionalDependencyNotAvailable:
177
+ from .utils import dummy_torch_and_torchsde_objects # noqa F403
178
+
179
+ _import_structure["utils.dummy_torch_and_torchsde_objects"] = [
180
+ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
181
+ ]
182
+
183
+ else:
184
+ _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
185
+
186
+ try:
187
+ if not (is_torch_available() and is_transformers_available()):
188
+ raise OptionalDependencyNotAvailable()
189
+ except OptionalDependencyNotAvailable:
190
+ from .utils import dummy_torch_and_transformers_objects # noqa F403
191
+
192
+ _import_structure["utils.dummy_torch_and_transformers_objects"] = [
193
+ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
194
+ ]
195
+
196
+ else:
197
+ _import_structure["pipelines"].extend(
198
+ [
199
+ "AltDiffusionImg2ImgPipeline",
200
+ "AltDiffusionPipeline",
201
+ "AnimateDiffPipeline",
202
+ "AudioLDM2Pipeline",
203
+ "AudioLDM2ProjectionModel",
204
+ "AudioLDM2UNet2DConditionModel",
205
+ "AudioLDMPipeline",
206
+ "BlipDiffusionControlNetPipeline",
207
+ "BlipDiffusionPipeline",
208
+ "CLIPImageProjection",
209
+ "CycleDiffusionPipeline",
210
+ "IFImg2ImgPipeline",
211
+ "IFImg2ImgSuperResolutionPipeline",
212
+ "IFInpaintingPipeline",
213
+ "IFInpaintingSuperResolutionPipeline",
214
+ "IFPipeline",
215
+ "IFSuperResolutionPipeline",
216
+ "ImageTextPipelineOutput",
217
+ "KandinskyCombinedPipeline",
218
+ "KandinskyImg2ImgCombinedPipeline",
219
+ "KandinskyImg2ImgPipeline",
220
+ "KandinskyInpaintCombinedPipeline",
221
+ "KandinskyInpaintPipeline",
222
+ "KandinskyPipeline",
223
+ "KandinskyPriorPipeline",
224
+ "KandinskyV22CombinedPipeline",
225
+ "KandinskyV22ControlnetImg2ImgPipeline",
226
+ "KandinskyV22ControlnetPipeline",
227
+ "KandinskyV22Img2ImgCombinedPipeline",
228
+ "KandinskyV22Img2ImgPipeline",
229
+ "KandinskyV22InpaintCombinedPipeline",
230
+ "KandinskyV22InpaintPipeline",
231
+ "KandinskyV22Pipeline",
232
+ "KandinskyV22PriorEmb2EmbPipeline",
233
+ "KandinskyV22PriorPipeline",
234
+ "LatentConsistencyModelImg2ImgPipeline",
235
+ "LatentConsistencyModelPipeline",
236
+ "LDMTextToImagePipeline",
237
+ "MusicLDMPipeline",
238
+ "PaintByExamplePipeline",
239
+ "PixArtAlphaPipeline",
240
+ "SemanticStableDiffusionPipeline",
241
+ "ShapEImg2ImgPipeline",
242
+ "ShapEPipeline",
243
+ "StableDiffusionAdapterPipeline",
244
+ "StableDiffusionAttendAndExcitePipeline",
245
+ "StableDiffusionControlNetImg2ImgPipeline",
246
+ "StableDiffusionControlNetInpaintPipeline",
247
+ "StableDiffusionControlNetPipeline",
248
+ "StableDiffusionDepth2ImgPipeline",
249
+ "StableDiffusionDiffEditPipeline",
250
+ "StableDiffusionGLIGENPipeline",
251
+ "StableDiffusionGLIGENTextImagePipeline",
252
+ "StableDiffusionImageVariationPipeline",
253
+ "StableDiffusionImg2ImgPipeline",
254
+ "StableDiffusionInpaintPipeline",
255
+ "StableDiffusionInpaintPipelineLegacy",
256
+ "StableDiffusionInstructPix2PixPipeline",
257
+ "StableDiffusionLatentUpscalePipeline",
258
+ "StableDiffusionLDM3DPipeline",
259
+ "StableDiffusionModelEditingPipeline",
260
+ "StableDiffusionPanoramaPipeline",
261
+ "StableDiffusionParadigmsPipeline",
262
+ "StableDiffusionPipeline",
263
+ "StableDiffusionPipelineSafe",
264
+ "StableDiffusionPix2PixZeroPipeline",
265
+ "StableDiffusionSAGPipeline",
266
+ "StableDiffusionUpscalePipeline",
267
+ "StableDiffusionXLAdapterPipeline",
268
+ "StableDiffusionXLControlNetImg2ImgPipeline",
269
+ "StableDiffusionXLControlNetInpaintPipeline",
270
+ "StableDiffusionXLControlNetPipeline",
271
+ "StableDiffusionXLImg2ImgPipeline",
272
+ "StableDiffusionXLInpaintPipeline",
273
+ "StableDiffusionXLInstructPix2PixPipeline",
274
+ "StableDiffusionXLPipeline",
275
+ "StableUnCLIPImg2ImgPipeline",
276
+ "StableUnCLIPPipeline",
277
+ "TextToVideoSDPipeline",
278
+ "TextToVideoZeroPipeline",
279
+ "UnCLIPImageVariationPipeline",
280
+ "UnCLIPPipeline",
281
+ "UniDiffuserModel",
282
+ "UniDiffuserPipeline",
283
+ "UniDiffuserTextDecoder",
284
+ "VersatileDiffusionDualGuidedPipeline",
285
+ "VersatileDiffusionImageVariationPipeline",
286
+ "VersatileDiffusionPipeline",
287
+ "VersatileDiffusionTextToImagePipeline",
288
+ "VideoToVideoSDPipeline",
289
+ "VQDiffusionPipeline",
290
+ "WuerstchenCombinedPipeline",
291
+ "WuerstchenDecoderPipeline",
292
+ "WuerstchenPriorPipeline",
293
+ ]
294
+ )
295
+
296
+ try:
297
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
298
+ raise OptionalDependencyNotAvailable()
299
+ except OptionalDependencyNotAvailable:
300
+ from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
301
+
302
+ _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
303
+ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
304
+ ]
305
+
306
+ else:
307
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
308
+
309
+ try:
310
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
311
+ raise OptionalDependencyNotAvailable()
312
+ except OptionalDependencyNotAvailable:
313
+ from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
314
+
315
+ _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
316
+ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
317
+ ]
318
+
319
+ else:
320
+ _import_structure["pipelines"].extend(
321
+ [
322
+ "OnnxStableDiffusionImg2ImgPipeline",
323
+ "OnnxStableDiffusionInpaintPipeline",
324
+ "OnnxStableDiffusionInpaintPipelineLegacy",
325
+ "OnnxStableDiffusionPipeline",
326
+ "OnnxStableDiffusionUpscalePipeline",
327
+ "StableDiffusionOnnxPipeline",
328
+ ]
329
+ )
330
+
331
+ try:
332
+ if not (is_torch_available() and is_librosa_available()):
333
+ raise OptionalDependencyNotAvailable()
334
+ except OptionalDependencyNotAvailable:
335
+ from .utils import dummy_torch_and_librosa_objects # noqa F403
336
+
337
+ _import_structure["utils.dummy_torch_and_librosa_objects"] = [
338
+ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
339
+ ]
340
+
341
+ else:
342
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
343
+
344
+ try:
345
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
346
+ raise OptionalDependencyNotAvailable()
347
+ except OptionalDependencyNotAvailable:
348
+ from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
349
+
350
+ _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
351
+ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
352
+ ]
353
+
354
+
355
+ else:
356
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
357
+
358
+ try:
359
+ if not is_flax_available():
360
+ raise OptionalDependencyNotAvailable()
361
+ except OptionalDependencyNotAvailable:
362
+ from .utils import dummy_flax_objects # noqa F403
363
+
364
+ _import_structure["utils.dummy_flax_objects"] = [
365
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
366
+ ]
367
+
368
+
369
+ else:
370
+ _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
371
+ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
372
+ _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
373
+ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
374
+ _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
375
+ _import_structure["schedulers"].extend(
376
+ [
377
+ "FlaxDDIMScheduler",
378
+ "FlaxDDPMScheduler",
379
+ "FlaxDPMSolverMultistepScheduler",
380
+ "FlaxEulerDiscreteScheduler",
381
+ "FlaxKarrasVeScheduler",
382
+ "FlaxLMSDiscreteScheduler",
383
+ "FlaxPNDMScheduler",
384
+ "FlaxSchedulerMixin",
385
+ "FlaxScoreSdeVeScheduler",
386
+ ]
387
+ )
388
+
389
+
390
+ try:
391
+ if not (is_flax_available() and is_transformers_available()):
392
+ raise OptionalDependencyNotAvailable()
393
+ except OptionalDependencyNotAvailable:
394
+ from .utils import dummy_flax_and_transformers_objects # noqa F403
395
+
396
+ _import_structure["utils.dummy_flax_and_transformers_objects"] = [
397
+ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
398
+ ]
399
+
400
+
401
+ else:
402
+ _import_structure["pipelines"].extend(
403
+ [
404
+ "FlaxStableDiffusionControlNetPipeline",
405
+ "FlaxStableDiffusionImg2ImgPipeline",
406
+ "FlaxStableDiffusionInpaintPipeline",
407
+ "FlaxStableDiffusionPipeline",
408
+ "FlaxStableDiffusionXLPipeline",
409
+ ]
410
+ )
411
+
412
+ try:
413
+ if not (is_note_seq_available()):
414
+ raise OptionalDependencyNotAvailable()
415
+ except OptionalDependencyNotAvailable:
416
+ from .utils import dummy_note_seq_objects # noqa F403
417
+
418
+ _import_structure["utils.dummy_note_seq_objects"] = [
419
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
420
+ ]
421
+
422
+
423
+ else:
424
+ _import_structure["pipelines"].extend(["MidiProcessor"])
425
+
426
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
427
+ from .configuration_utils import ConfigMixin
428
+
429
+ try:
430
+ if not is_onnx_available():
431
+ raise OptionalDependencyNotAvailable()
432
+ except OptionalDependencyNotAvailable:
433
+ from .utils.dummy_onnx_objects import * # noqa F403
434
+ else:
435
+ from .pipelines import OnnxRuntimeModel
436
+
437
+ try:
438
+ if not is_torch_available():
439
+ raise OptionalDependencyNotAvailable()
440
+ except OptionalDependencyNotAvailable:
441
+ from .utils.dummy_pt_objects import * # noqa F403
442
+ else:
443
+ from .models import (
444
+ AsymmetricAutoencoderKL,
445
+ AutoencoderKL,
446
+ AutoencoderTiny,
447
+ ConsistencyDecoderVAE,
448
+ ControlNetModel,
449
+ ModelMixin,
450
+ MotionAdapter,
451
+ MultiAdapter,
452
+ PriorTransformer,
453
+ T2IAdapter,
454
+ T5FilmDecoder,
455
+ Transformer2DModel,
456
+ UNet1DModel,
457
+ UNet2DConditionModel,
458
+ UNet2DModel,
459
+ UNet3DConditionModel,
460
+ UNetMotionModel,
461
+ VQModel,
462
+ )
463
+ from .optimization import (
464
+ get_constant_schedule,
465
+ get_constant_schedule_with_warmup,
466
+ get_cosine_schedule_with_warmup,
467
+ get_cosine_with_hard_restarts_schedule_with_warmup,
468
+ get_linear_schedule_with_warmup,
469
+ get_polynomial_decay_schedule_with_warmup,
470
+ get_scheduler,
471
+ )
472
+ from .pipelines import (
473
+ AudioPipelineOutput,
474
+ AutoPipelineForImage2Image,
475
+ AutoPipelineForInpainting,
476
+ AutoPipelineForText2Image,
477
+ BlipDiffusionControlNetPipeline,
478
+ BlipDiffusionPipeline,
479
+ CLIPImageProjection,
480
+ ConsistencyModelPipeline,
481
+ DanceDiffusionPipeline,
482
+ DDIMPipeline,
483
+ DDPMPipeline,
484
+ DiffusionPipeline,
485
+ DiTPipeline,
486
+ ImagePipelineOutput,
487
+ KarrasVePipeline,
488
+ LDMPipeline,
489
+ LDMSuperResolutionPipeline,
490
+ PNDMPipeline,
491
+ RePaintPipeline,
492
+ ScoreSdeVePipeline,
493
+ )
494
+ from .schedulers import (
495
+ CMStochasticIterativeScheduler,
496
+ DDIMInverseScheduler,
497
+ DDIMParallelScheduler,
498
+ DDIMScheduler,
499
+ DDPMParallelScheduler,
500
+ DDPMScheduler,
501
+ DDPMWuerstchenScheduler,
502
+ DEISMultistepScheduler,
503
+ DPMSolverMultistepInverseScheduler,
504
+ DPMSolverMultistepScheduler,
505
+ DPMSolverSinglestepScheduler,
506
+ EulerAncestralDiscreteScheduler,
507
+ EulerDiscreteScheduler,
508
+ HeunDiscreteScheduler,
509
+ IPNDMScheduler,
510
+ KarrasVeScheduler,
511
+ KDPM2AncestralDiscreteScheduler,
512
+ KDPM2DiscreteScheduler,
513
+ LCMScheduler,
514
+ PNDMScheduler,
515
+ RePaintScheduler,
516
+ SchedulerMixin,
517
+ ScoreSdeVeScheduler,
518
+ UnCLIPScheduler,
519
+ UniPCMultistepScheduler,
520
+ VQDiffusionScheduler,
521
+ )
522
+ from .training_utils import EMAModel
523
+
524
+ try:
525
+ if not (is_torch_available() and is_scipy_available()):
526
+ raise OptionalDependencyNotAvailable()
527
+ except OptionalDependencyNotAvailable:
528
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
529
+ else:
530
+ from .schedulers import LMSDiscreteScheduler
531
+
532
+ try:
533
+ if not (is_torch_available() and is_torchsde_available()):
534
+ raise OptionalDependencyNotAvailable()
535
+ except OptionalDependencyNotAvailable:
536
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
537
+ else:
538
+ from .schedulers import DPMSolverSDEScheduler
539
+
540
+ try:
541
+ if not (is_torch_available() and is_transformers_available()):
542
+ raise OptionalDependencyNotAvailable()
543
+ except OptionalDependencyNotAvailable:
544
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
545
+ else:
546
+ from .pipelines import (
547
+ AltDiffusionImg2ImgPipeline,
548
+ AltDiffusionPipeline,
549
+ AnimateDiffPipeline,
550
+ AudioLDM2Pipeline,
551
+ AudioLDM2ProjectionModel,
552
+ AudioLDM2UNet2DConditionModel,
553
+ AudioLDMPipeline,
554
+ CLIPImageProjection,
555
+ CycleDiffusionPipeline,
556
+ IFImg2ImgPipeline,
557
+ IFImg2ImgSuperResolutionPipeline,
558
+ IFInpaintingPipeline,
559
+ IFInpaintingSuperResolutionPipeline,
560
+ IFPipeline,
561
+ IFSuperResolutionPipeline,
562
+ ImageTextPipelineOutput,
563
+ KandinskyCombinedPipeline,
564
+ KandinskyImg2ImgCombinedPipeline,
565
+ KandinskyImg2ImgPipeline,
566
+ KandinskyInpaintCombinedPipeline,
567
+ KandinskyInpaintPipeline,
568
+ KandinskyPipeline,
569
+ KandinskyPriorPipeline,
570
+ KandinskyV22CombinedPipeline,
571
+ KandinskyV22ControlnetImg2ImgPipeline,
572
+ KandinskyV22ControlnetPipeline,
573
+ KandinskyV22Img2ImgCombinedPipeline,
574
+ KandinskyV22Img2ImgPipeline,
575
+ KandinskyV22InpaintCombinedPipeline,
576
+ KandinskyV22InpaintPipeline,
577
+ KandinskyV22Pipeline,
578
+ KandinskyV22PriorEmb2EmbPipeline,
579
+ KandinskyV22PriorPipeline,
580
+ LatentConsistencyModelImg2ImgPipeline,
581
+ LatentConsistencyModelPipeline,
582
+ LDMTextToImagePipeline,
583
+ MusicLDMPipeline,
584
+ PaintByExamplePipeline,
585
+ PixArtAlphaPipeline,
586
+ SemanticStableDiffusionPipeline,
587
+ ShapEImg2ImgPipeline,
588
+ ShapEPipeline,
589
+ StableDiffusionAdapterPipeline,
590
+ StableDiffusionAttendAndExcitePipeline,
591
+ StableDiffusionControlNetImg2ImgPipeline,
592
+ StableDiffusionControlNetInpaintPipeline,
593
+ StableDiffusionControlNetPipeline,
594
+ StableDiffusionDepth2ImgPipeline,
595
+ StableDiffusionDiffEditPipeline,
596
+ StableDiffusionGLIGENPipeline,
597
+ StableDiffusionGLIGENTextImagePipeline,
598
+ StableDiffusionImageVariationPipeline,
599
+ StableDiffusionImg2ImgPipeline,
600
+ StableDiffusionInpaintPipeline,
601
+ StableDiffusionInpaintPipelineLegacy,
602
+ StableDiffusionInstructPix2PixPipeline,
603
+ StableDiffusionLatentUpscalePipeline,
604
+ StableDiffusionLDM3DPipeline,
605
+ StableDiffusionModelEditingPipeline,
606
+ StableDiffusionPanoramaPipeline,
607
+ StableDiffusionParadigmsPipeline,
608
+ StableDiffusionPipeline,
609
+ StableDiffusionPipelineSafe,
610
+ StableDiffusionPix2PixZeroPipeline,
611
+ StableDiffusionSAGPipeline,
612
+ StableDiffusionUpscalePipeline,
613
+ StableDiffusionXLAdapterPipeline,
614
+ StableDiffusionXLControlNetImg2ImgPipeline,
615
+ StableDiffusionXLControlNetInpaintPipeline,
616
+ StableDiffusionXLControlNetPipeline,
617
+ StableDiffusionXLImg2ImgPipeline,
618
+ StableDiffusionXLInpaintPipeline,
619
+ StableDiffusionXLInstructPix2PixPipeline,
620
+ StableDiffusionXLPipeline,
621
+ StableUnCLIPImg2ImgPipeline,
622
+ StableUnCLIPPipeline,
623
+ TextToVideoSDPipeline,
624
+ TextToVideoZeroPipeline,
625
+ UnCLIPImageVariationPipeline,
626
+ UnCLIPPipeline,
627
+ UniDiffuserModel,
628
+ UniDiffuserPipeline,
629
+ UniDiffuserTextDecoder,
630
+ VersatileDiffusionDualGuidedPipeline,
631
+ VersatileDiffusionImageVariationPipeline,
632
+ VersatileDiffusionPipeline,
633
+ VersatileDiffusionTextToImagePipeline,
634
+ VideoToVideoSDPipeline,
635
+ VQDiffusionPipeline,
636
+ WuerstchenCombinedPipeline,
637
+ WuerstchenDecoderPipeline,
638
+ WuerstchenPriorPipeline,
639
+ )
640
+
641
+ try:
642
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
643
+ raise OptionalDependencyNotAvailable()
644
+ except OptionalDependencyNotAvailable:
645
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
646
+ else:
647
+ from .pipelines import StableDiffusionKDiffusionPipeline
648
+
649
+ try:
650
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
651
+ raise OptionalDependencyNotAvailable()
652
+ except OptionalDependencyNotAvailable:
653
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
654
+ else:
655
+ from .pipelines import (
656
+ OnnxStableDiffusionImg2ImgPipeline,
657
+ OnnxStableDiffusionInpaintPipeline,
658
+ OnnxStableDiffusionInpaintPipelineLegacy,
659
+ OnnxStableDiffusionPipeline,
660
+ OnnxStableDiffusionUpscalePipeline,
661
+ StableDiffusionOnnxPipeline,
662
+ )
663
+
664
+ try:
665
+ if not (is_torch_available() and is_librosa_available()):
666
+ raise OptionalDependencyNotAvailable()
667
+ except OptionalDependencyNotAvailable:
668
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
669
+ else:
670
+ from .pipelines import AudioDiffusionPipeline, Mel
671
+
672
+ try:
673
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
674
+ raise OptionalDependencyNotAvailable()
675
+ except OptionalDependencyNotAvailable:
676
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
677
+ else:
678
+ from .pipelines import SpectrogramDiffusionPipeline
679
+
680
+ try:
681
+ if not is_flax_available():
682
+ raise OptionalDependencyNotAvailable()
683
+ except OptionalDependencyNotAvailable:
684
+ from .utils.dummy_flax_objects import * # noqa F403
685
+ else:
686
+ from .models.controlnet_flax import FlaxControlNetModel
687
+ from .models.modeling_flax_utils import FlaxModelMixin
688
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
689
+ from .models.vae_flax import FlaxAutoencoderKL
690
+ from .pipelines import FlaxDiffusionPipeline
691
+ from .schedulers import (
692
+ FlaxDDIMScheduler,
693
+ FlaxDDPMScheduler,
694
+ FlaxDPMSolverMultistepScheduler,
695
+ FlaxEulerDiscreteScheduler,
696
+ FlaxKarrasVeScheduler,
697
+ FlaxLMSDiscreteScheduler,
698
+ FlaxPNDMScheduler,
699
+ FlaxSchedulerMixin,
700
+ FlaxScoreSdeVeScheduler,
701
+ )
702
+
703
+ try:
704
+ if not (is_flax_available() and is_transformers_available()):
705
+ raise OptionalDependencyNotAvailable()
706
+ except OptionalDependencyNotAvailable:
707
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
708
+ else:
709
+ from .pipelines import (
710
+ FlaxStableDiffusionControlNetPipeline,
711
+ FlaxStableDiffusionImg2ImgPipeline,
712
+ FlaxStableDiffusionInpaintPipeline,
713
+ FlaxStableDiffusionPipeline,
714
+ FlaxStableDiffusionXLPipeline,
715
+ )
716
+
717
+ try:
718
+ if not (is_note_seq_available()):
719
+ raise OptionalDependencyNotAvailable()
720
+ except OptionalDependencyNotAvailable:
721
+ from .utils.dummy_note_seq_objects import * # noqa F403
722
+ else:
723
+ from .pipelines import MidiProcessor
724
+
725
+ else:
726
+ import sys
727
+
728
+ sys.modules[__name__] = _LazyModule(
729
+ __name__,
730
+ globals()["__file__"],
731
+ _import_structure,
732
+ module_spec=__spec__,
733
+ extra_objects={"__version__": __version__},
734
+ )
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+ from .fp16_safetensors import FP16SafetensorsCommand
20
+
21
+
22
+ def main():
23
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
24
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
25
+
26
+ # Register commands
27
+ EnvironmentCommand.register_subcommand(commands_parser)
28
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
29
+
30
+ # Let's go
31
+ args = parser.parse_args()
32
+
33
+ if not hasattr(args, "func"):
34
+ parser.print_help()
35
+ exit(1)
36
+
37
+ # Run
38
+ service = args.func(args)
39
+ service.run()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available():
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ accelerate_version = "not installed"
53
+ if is_accelerate_available():
54
+ import accelerate
55
+
56
+ accelerate_version = accelerate.__version__
57
+
58
+ xformers_version = "not installed"
59
+ if is_xformers_available():
60
+ import xformers
61
+
62
+ xformers_version = xformers.__version__
63
+
64
+ info = {
65
+ "`diffusers` version": version,
66
+ "Platform": platform.platform(),
67
+ "Python version": platform.python_version(),
68
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
+ "Huggingface_hub version": hub_version,
70
+ "Transformers version": transformers_version,
71
+ "Accelerate version": accelerate_version,
72
+ "xFormers version": xformers_version,
73
+ "Using GPU in script?": "<fill in>",
74
+ "Using distributed or parallel set-up in script?": "<fill in>",
75
+ }
76
+
77
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
+ print(self.format_dict(info))
79
+
80
+ return info
81
+
82
+ @staticmethod
83
+ def format_dict(d):
84
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ from argparse import ArgumentParser, Namespace
23
+ from importlib import import_module
24
+
25
+ import huggingface_hub
26
+ import torch
27
+ from huggingface_hub import hf_hub_download
28
+ from packaging import version
29
+
30
+ from ..utils import logging
31
+ from . import BaseDiffusersCLICommand
32
+
33
+
34
+ def conversion_command_factory(args: Namespace):
35
+ return FP16SafetensorsCommand(
36
+ args.ckpt_id,
37
+ args.fp16,
38
+ args.use_safetensors,
39
+ args.use_auth_token,
40
+ )
41
+
42
+
43
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
44
+ @staticmethod
45
+ def register_subcommand(parser: ArgumentParser):
46
+ conversion_parser = parser.add_parser("fp16_safetensors")
47
+ conversion_parser.add_argument(
48
+ "--ckpt_id",
49
+ type=str,
50
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
51
+ )
52
+ conversion_parser.add_argument(
53
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
54
+ )
55
+ conversion_parser.add_argument(
56
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
57
+ )
58
+ conversion_parser.add_argument(
59
+ "--use_auth_token",
60
+ action="store_true",
61
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
62
+ )
63
+ conversion_parser.set_defaults(func=conversion_command_factory)
64
+
65
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
66
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
67
+ self.ckpt_id = ckpt_id
68
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
69
+ self.fp16 = fp16
70
+
71
+ self.use_safetensors = use_safetensors
72
+
73
+ if not self.use_safetensors and not self.fp16:
74
+ raise NotImplementedError(
75
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
76
+ )
77
+
78
+ self.use_auth_token = use_auth_token
79
+
80
+ def run(self):
81
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
82
+ raise ImportError(
83
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
84
+ " installation."
85
+ )
86
+ else:
87
+ from huggingface_hub import create_commit
88
+ from huggingface_hub._commit_api import CommitOperationAdd
89
+
90
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
91
+ with open(model_index, "r") as f:
92
+ pipeline_class_name = json.load(f)["_class_name"]
93
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
94
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
95
+
96
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
97
+ # here, but just to avoid any rough edge cases.
98
+ pipeline = pipeline_class.from_pretrained(
99
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
100
+ )
101
+ pipeline.save_pretrained(
102
+ self.local_ckpt_dir,
103
+ safe_serialization=True if self.use_safetensors else False,
104
+ variant="fp16" if self.fp16 else None,
105
+ )
106
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
107
+
108
+ # Fetch all the paths.
109
+ if self.fp16:
110
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
111
+ elif self.use_safetensors:
112
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
113
+
114
+ # Prepare for the PR.
115
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
116
+ operations = []
117
+ for path in modified_paths:
118
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
119
+
120
+ # Open the PR.
121
+ commit_description = (
122
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
123
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
124
+ )
125
+ hub_pr_url = create_commit(
126
+ repo_id=self.ckpt_id,
127
+ operations=operations,
128
+ commit_message=commit_message,
129
+ commit_description=commit_description,
130
+ repo_type="model",
131
+ create_pr=True,
132
+ ).pr_url
133
+ self.logger.info(f"PR created here: {hub_pr_url}.")
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from pathlib import PosixPath
26
+ from typing import Any, Dict, Tuple, Union
27
+
28
+ import numpy as np
29
+ from huggingface_hub import create_repo, hf_hub_download
30
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
+ from requests import HTTPError
32
+
33
+ from . import __version__
34
+ from .utils import (
35
+ DIFFUSERS_CACHE,
36
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37
+ DummyObject,
38
+ deprecate,
39
+ extract_commit_hash,
40
+ http_user_agent,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
48
+
49
+
50
+ class FrozenDict(OrderedDict):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+
54
+ for key, value in self.items():
55
+ setattr(self, key, value)
56
+
57
+ self.__frozen = True
58
+
59
+ def __delitem__(self, *args, **kwargs):
60
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
61
+
62
+ def setdefault(self, *args, **kwargs):
63
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
64
+
65
+ def pop(self, *args, **kwargs):
66
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
67
+
68
+ def update(self, *args, **kwargs):
69
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
70
+
71
+ def __setattr__(self, name, value):
72
+ if hasattr(self, "__frozen") and self.__frozen:
73
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
74
+ super().__setattr__(name, value)
75
+
76
+ def __setitem__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setitem__(name, value)
80
+
81
+
82
+ class ConfigMixin:
83
+ r"""
84
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
+ saving classes that inherit from [`ConfigMixin`].
87
+
88
+ Class attributes:
89
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
90
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
91
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
92
+ overridden by subclass).
93
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
94
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
95
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
96
+ subclass).
97
+ """
98
+ config_name = None
99
+ ignore_for_config = []
100
+ has_compatibles = False
101
+
102
+ _deprecated_kwargs = []
103
+
104
+ def register_to_config(self, **kwargs):
105
+ if self.config_name is None:
106
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
107
+ # Special case for `kwargs` used in deprecation warning added to schedulers
108
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
109
+ # or solve in a more general way.
110
+ kwargs.pop("kwargs", None)
111
+
112
+ if not hasattr(self, "_internal_dict"):
113
+ internal_dict = kwargs
114
+ else:
115
+ previous_dict = dict(self._internal_dict)
116
+ internal_dict = {**self._internal_dict, **kwargs}
117
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
118
+
119
+ self._internal_dict = FrozenDict(internal_dict)
120
+
121
+ def __getattr__(self, name: str) -> Any:
122
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124
+
125
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127
+ """
128
+
129
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130
+ is_attribute = name in self.__dict__
131
+
132
+ if is_in_config and not is_attribute:
133
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135
+ return self._internal_dict[name]
136
+
137
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138
+
139
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
140
+ """
141
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
142
+ [`~ConfigMixin.from_config`] class method.
143
+
144
+ Args:
145
+ save_directory (`str` or `os.PathLike`):
146
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
147
+ push_to_hub (`bool`, *optional*, defaults to `False`):
148
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
149
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
150
+ namespace).
151
+ kwargs (`Dict[str, Any]`, *optional*):
152
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
153
+ """
154
+ if os.path.isfile(save_directory):
155
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
156
+
157
+ os.makedirs(save_directory, exist_ok=True)
158
+
159
+ # If we save using the predefined names, we can load using `from_config`
160
+ output_config_file = os.path.join(save_directory, self.config_name)
161
+
162
+ self.to_json_file(output_config_file)
163
+ logger.info(f"Configuration saved in {output_config_file}")
164
+
165
+ if push_to_hub:
166
+ commit_message = kwargs.pop("commit_message", None)
167
+ private = kwargs.pop("private", False)
168
+ create_pr = kwargs.pop("create_pr", False)
169
+ token = kwargs.pop("token", None)
170
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
171
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
172
+
173
+ self._upload_folder(
174
+ save_directory,
175
+ repo_id,
176
+ token=token,
177
+ commit_message=commit_message,
178
+ create_pr=create_pr,
179
+ )
180
+
181
+ @classmethod
182
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
183
+ r"""
184
+ Instantiate a Python class from a config dictionary.
185
+
186
+ Parameters:
187
+ config (`Dict[str, Any]`):
188
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
189
+ files of compatible classes.
190
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
191
+ Whether kwargs that are not consumed by the Python class should be returned or not.
192
+ kwargs (remaining dictionary of keyword arguments, *optional*):
193
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
194
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
195
+ overwrite the same named arguments in `config`.
196
+
197
+ Returns:
198
+ [`ModelMixin`] or [`SchedulerMixin`]:
199
+ A model or scheduler object instantiated from a config dictionary.
200
+
201
+ Examples:
202
+
203
+ ```python
204
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
205
+
206
+ >>> # Download scheduler from huggingface.co and cache.
207
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
208
+
209
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
210
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
211
+
212
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
213
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
214
+ ```
215
+ """
216
+ # <===== TO BE REMOVED WITH DEPRECATION
217
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
218
+ if "pretrained_model_name_or_path" in kwargs:
219
+ config = kwargs.pop("pretrained_model_name_or_path")
220
+
221
+ if config is None:
222
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
223
+ # ======>
224
+
225
+ if not isinstance(config, dict):
226
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
227
+ if "Scheduler" in cls.__name__:
228
+ deprecation_message += (
229
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
230
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
231
+ " be removed in v1.0.0."
232
+ )
233
+ elif "Model" in cls.__name__:
234
+ deprecation_message += (
235
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
236
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
237
+ " instead. This functionality will be removed in v1.0.0."
238
+ )
239
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
240
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
241
+
242
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
243
+
244
+ # Allow dtype to be specified on initialization
245
+ if "dtype" in unused_kwargs:
246
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
247
+
248
+ # add possible deprecated kwargs
249
+ for deprecated_kwarg in cls._deprecated_kwargs:
250
+ if deprecated_kwarg in unused_kwargs:
251
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
252
+
253
+ # Return model and optionally state and/or unused_kwargs
254
+ model = cls(**init_dict)
255
+
256
+ # make sure to also save config parameters that might be used for compatible classes
257
+ model.register_to_config(**hidden_dict)
258
+
259
+ # add hidden kwargs of compatible classes to unused_kwargs
260
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
261
+
262
+ if return_unused_kwargs:
263
+ return (model, unused_kwargs)
264
+ else:
265
+ return model
266
+
267
+ @classmethod
268
+ def get_config_dict(cls, *args, **kwargs):
269
+ deprecation_message = (
270
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
271
+ " removed in version v1.0.0"
272
+ )
273
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
274
+ return cls.load_config(*args, **kwargs)
275
+
276
+ @classmethod
277
+ def load_config(
278
+ cls,
279
+ pretrained_model_name_or_path: Union[str, os.PathLike],
280
+ return_unused_kwargs=False,
281
+ return_commit_hash=False,
282
+ **kwargs,
283
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
284
+ r"""
285
+ Load a model or scheduler configuration.
286
+
287
+ Parameters:
288
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
289
+ Can be either:
290
+
291
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
292
+ the Hub.
293
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
294
+ [`~ConfigMixin.save_config`].
295
+
296
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
297
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
298
+ is not used.
299
+ force_download (`bool`, *optional*, defaults to `False`):
300
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
301
+ cached versions if they exist.
302
+ resume_download (`bool`, *optional*, defaults to `False`):
303
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
304
+ incompletely downloaded files are deleted.
305
+ proxies (`Dict[str, str]`, *optional*):
306
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
307
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
308
+ output_loading_info(`bool`, *optional*, defaults to `False`):
309
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
310
+ local_files_only (`bool`, *optional*, defaults to `False`):
311
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
312
+ won't be downloaded from the Hub.
313
+ use_auth_token (`str` or *bool*, *optional*):
314
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
315
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
316
+ revision (`str`, *optional*, defaults to `"main"`):
317
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
318
+ allowed by Git.
319
+ subfolder (`str`, *optional*, defaults to `""`):
320
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
321
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
322
+ Whether unused keyword arguments of the config are returned.
323
+ return_commit_hash (`bool`, *optional*, defaults to `False):
324
+ Whether the `commit_hash` of the loaded configuration are returned.
325
+
326
+ Returns:
327
+ `dict`:
328
+ A dictionary of all the parameters stored in a JSON configuration file.
329
+
330
+ """
331
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
332
+ force_download = kwargs.pop("force_download", False)
333
+ resume_download = kwargs.pop("resume_download", False)
334
+ proxies = kwargs.pop("proxies", None)
335
+ use_auth_token = kwargs.pop("use_auth_token", None)
336
+ local_files_only = kwargs.pop("local_files_only", False)
337
+ revision = kwargs.pop("revision", None)
338
+ _ = kwargs.pop("mirror", None)
339
+ subfolder = kwargs.pop("subfolder", None)
340
+ user_agent = kwargs.pop("user_agent", {})
341
+
342
+ user_agent = {**user_agent, "file_type": "config"}
343
+ user_agent = http_user_agent(user_agent)
344
+
345
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
346
+
347
+ if cls.config_name is None:
348
+ raise ValueError(
349
+ "`self.config_name` is not defined. Note that one should not load a config from "
350
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
351
+ )
352
+
353
+ if os.path.isfile(pretrained_model_name_or_path):
354
+ config_file = pretrained_model_name_or_path
355
+ elif os.path.isdir(pretrained_model_name_or_path):
356
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
357
+ # Load from a PyTorch checkpoint
358
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
359
+ elif subfolder is not None and os.path.isfile(
360
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
361
+ ):
362
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
363
+ else:
364
+ raise EnvironmentError(
365
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
366
+ )
367
+ else:
368
+ try:
369
+ # Load from URL or cache if already cached
370
+ config_file = hf_hub_download(
371
+ pretrained_model_name_or_path,
372
+ filename=cls.config_name,
373
+ cache_dir=cache_dir,
374
+ force_download=force_download,
375
+ proxies=proxies,
376
+ resume_download=resume_download,
377
+ local_files_only=local_files_only,
378
+ use_auth_token=use_auth_token,
379
+ user_agent=user_agent,
380
+ subfolder=subfolder,
381
+ revision=revision,
382
+ )
383
+ except RepositoryNotFoundError:
384
+ raise EnvironmentError(
385
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
386
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
387
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
388
+ " login`."
389
+ )
390
+ except RevisionNotFoundError:
391
+ raise EnvironmentError(
392
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
393
+ " this model name. Check the model page at"
394
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
395
+ )
396
+ except EntryNotFoundError:
397
+ raise EnvironmentError(
398
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
399
+ )
400
+ except HTTPError as err:
401
+ raise EnvironmentError(
402
+ "There was a specific connection error when trying to load"
403
+ f" {pretrained_model_name_or_path}:\n{err}"
404
+ )
405
+ except ValueError:
406
+ raise EnvironmentError(
407
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
408
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
409
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
410
+ " run the library in offline mode at"
411
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
412
+ )
413
+ except EnvironmentError:
414
+ raise EnvironmentError(
415
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
416
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
417
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
418
+ f"containing a {cls.config_name} file"
419
+ )
420
+
421
+ try:
422
+ # Load config dict
423
+ config_dict = cls._dict_from_json_file(config_file)
424
+
425
+ commit_hash = extract_commit_hash(config_file)
426
+ except (json.JSONDecodeError, UnicodeDecodeError):
427
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
428
+
429
+ if not (return_unused_kwargs or return_commit_hash):
430
+ return config_dict
431
+
432
+ outputs = (config_dict,)
433
+
434
+ if return_unused_kwargs:
435
+ outputs += (kwargs,)
436
+
437
+ if return_commit_hash:
438
+ outputs += (commit_hash,)
439
+
440
+ return outputs
441
+
442
+ @staticmethod
443
+ def _get_init_keys(cls):
444
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
445
+
446
+ @classmethod
447
+ def extract_init_dict(cls, config_dict, **kwargs):
448
+ # Skip keys that were not present in the original config, so default __init__ values were used
449
+ used_defaults = config_dict.get("_use_default_values", [])
450
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
451
+
452
+ # 0. Copy origin config dict
453
+ original_dict = dict(config_dict.items())
454
+
455
+ # 1. Retrieve expected config attributes from __init__ signature
456
+ expected_keys = cls._get_init_keys(cls)
457
+ expected_keys.remove("self")
458
+ # remove general kwargs if present in dict
459
+ if "kwargs" in expected_keys:
460
+ expected_keys.remove("kwargs")
461
+ # remove flax internal keys
462
+ if hasattr(cls, "_flax_internal_args"):
463
+ for arg in cls._flax_internal_args:
464
+ expected_keys.remove(arg)
465
+
466
+ # 2. Remove attributes that cannot be expected from expected config attributes
467
+ # remove keys to be ignored
468
+ if len(cls.ignore_for_config) > 0:
469
+ expected_keys = expected_keys - set(cls.ignore_for_config)
470
+
471
+ # load diffusers library to import compatible and original scheduler
472
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
473
+
474
+ if cls.has_compatibles:
475
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
476
+ else:
477
+ compatible_classes = []
478
+
479
+ expected_keys_comp_cls = set()
480
+ for c in compatible_classes:
481
+ expected_keys_c = cls._get_init_keys(c)
482
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
483
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
484
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
485
+
486
+ # remove attributes from orig class that cannot be expected
487
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
488
+ if (
489
+ isinstance(orig_cls_name, str)
490
+ and orig_cls_name != cls.__name__
491
+ and hasattr(diffusers_library, orig_cls_name)
492
+ ):
493
+ orig_cls = getattr(diffusers_library, orig_cls_name)
494
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
495
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
496
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
497
+ raise ValueError(
498
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
499
+ )
500
+
501
+ # remove private attributes
502
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
503
+
504
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
505
+ init_dict = {}
506
+ for key in expected_keys:
507
+ # if config param is passed to kwarg and is present in config dict
508
+ # it should overwrite existing config dict key
509
+ if key in kwargs and key in config_dict:
510
+ config_dict[key] = kwargs.pop(key)
511
+
512
+ if key in kwargs:
513
+ # overwrite key
514
+ init_dict[key] = kwargs.pop(key)
515
+ elif key in config_dict:
516
+ # use value from config dict
517
+ init_dict[key] = config_dict.pop(key)
518
+
519
+ # 4. Give nice warning if unexpected values have been passed
520
+ if len(config_dict) > 0:
521
+ logger.warning(
522
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
523
+ "but are not expected and will be ignored. Please verify your "
524
+ f"{cls.config_name} configuration file."
525
+ )
526
+
527
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
528
+ passed_keys = set(init_dict.keys())
529
+ if len(expected_keys - passed_keys) > 0:
530
+ logger.info(
531
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
532
+ )
533
+
534
+ # 6. Define unused keyword arguments
535
+ unused_kwargs = {**config_dict, **kwargs}
536
+
537
+ # 7. Define "hidden" config parameters that were saved for compatible classes
538
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
539
+
540
+ return init_dict, unused_kwargs, hidden_config_dict
541
+
542
+ @classmethod
543
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
544
+ with open(json_file, "r", encoding="utf-8") as reader:
545
+ text = reader.read()
546
+ return json.loads(text)
547
+
548
+ def __repr__(self):
549
+ return f"{self.__class__.__name__} {self.to_json_string()}"
550
+
551
+ @property
552
+ def config(self) -> Dict[str, Any]:
553
+ """
554
+ Returns the config of the class as a frozen dictionary
555
+
556
+ Returns:
557
+ `Dict[str, Any]`: Config of the class.
558
+ """
559
+ return self._internal_dict
560
+
561
+ def to_json_string(self) -> str:
562
+ """
563
+ Serializes the configuration instance to a JSON string.
564
+
565
+ Returns:
566
+ `str`:
567
+ String containing all the attributes that make up the configuration instance in JSON format.
568
+ """
569
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
570
+ config_dict["_class_name"] = self.__class__.__name__
571
+ config_dict["_diffusers_version"] = __version__
572
+
573
+ def to_json_saveable(value):
574
+ if isinstance(value, np.ndarray):
575
+ value = value.tolist()
576
+ elif isinstance(value, PosixPath):
577
+ value = str(value)
578
+ return value
579
+
580
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
581
+ # Don't save "_ignore_files" or "_use_default_values"
582
+ config_dict.pop("_ignore_files", None)
583
+ config_dict.pop("_use_default_values", None)
584
+
585
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
586
+
587
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
588
+ """
589
+ Save the configuration instance's parameters to a JSON file.
590
+
591
+ Args:
592
+ json_file_path (`str` or `os.PathLike`):
593
+ Path to the JSON file to save a configuration instance's parameters.
594
+ """
595
+ with open(json_file_path, "w", encoding="utf-8") as writer:
596
+ writer.write(self.to_json_string())
597
+
598
+
599
+ def register_to_config(init):
600
+ r"""
601
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
602
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
603
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
604
+
605
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
606
+ """
607
+
608
+ @functools.wraps(init)
609
+ def inner_init(self, *args, **kwargs):
610
+ # Ignore private kwargs in the init.
611
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
612
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
613
+ if not isinstance(self, ConfigMixin):
614
+ raise RuntimeError(
615
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
616
+ "not inherit from `ConfigMixin`."
617
+ )
618
+
619
+ ignore = getattr(self, "ignore_for_config", [])
620
+ # Get positional arguments aligned with kwargs
621
+ new_kwargs = {}
622
+ signature = inspect.signature(init)
623
+ parameters = {
624
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
625
+ }
626
+ for arg, name in zip(args, parameters.keys()):
627
+ new_kwargs[name] = arg
628
+
629
+ # Then add all kwargs
630
+ new_kwargs.update(
631
+ {
632
+ k: init_kwargs.get(k, default)
633
+ for k, default in parameters.items()
634
+ if k not in ignore and k not in new_kwargs
635
+ }
636
+ )
637
+
638
+ # Take note of the parameters that were not present in the loaded config
639
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
640
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
641
+
642
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
643
+ getattr(self, "register_to_config")(**new_kwargs)
644
+ init(self, *args, **init_kwargs)
645
+
646
+ return inner_init
647
+
648
+
649
+ def flax_register_to_config(cls):
650
+ original_init = cls.__init__
651
+
652
+ @functools.wraps(original_init)
653
+ def init(self, *args, **kwargs):
654
+ if not isinstance(self, ConfigMixin):
655
+ raise RuntimeError(
656
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
657
+ "not inherit from `ConfigMixin`."
658
+ )
659
+
660
+ # Ignore private kwargs in the init. Retrieve all passed attributes
661
+ init_kwargs = dict(kwargs.items())
662
+
663
+ # Retrieve default values
664
+ fields = dataclasses.fields(self)
665
+ default_kwargs = {}
666
+ for field in fields:
667
+ # ignore flax specific attributes
668
+ if field.name in self._flax_internal_args:
669
+ continue
670
+ if type(field.default) == dataclasses._MISSING_TYPE:
671
+ default_kwargs[field.name] = None
672
+ else:
673
+ default_kwargs[field.name] = getattr(self, field.name)
674
+
675
+ # Make sure init_kwargs override default kwargs
676
+ new_kwargs = {**default_kwargs, **init_kwargs}
677
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
678
+ if "dtype" in new_kwargs:
679
+ new_kwargs.pop("dtype")
680
+
681
+ # Get positional arguments aligned with kwargs
682
+ for i, arg in enumerate(args):
683
+ name = fields[i].name
684
+ new_kwargs[name] = arg
685
+
686
+ # Take note of the parameters that were not present in the loaded config
687
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
688
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
689
+
690
+ getattr(self, "register_to_config")(**new_kwargs)
691
+ original_init(self, *args, **kwargs)
692
+
693
+ cls.__init__ = init
694
+ return cls
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python requests filelock numpy".split()
27
+ for pkg in pkgs_to_check_at_runtime:
28
+ if pkg in deps:
29
+ require_version_core(deps[pkg])
30
+ else:
31
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
32
+
33
+
34
+ def dep_version_check(pkg, hint=None):
35
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "compel": "compel==0.1.8",
8
+ "black": "black~=23.1",
9
+ "datasets": "datasets",
10
+ "filelock": "filelock",
11
+ "flax": "flax>=0.4.1",
12
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
+ "huggingface-hub": "huggingface-hub>=0.13.2",
14
+ "requests-mock": "requests-mock==1.10.0",
15
+ "importlib_metadata": "importlib_metadata",
16
+ "invisible-watermark": "invisible-watermark>=0.2.0",
17
+ "isort": "isort>=5.5.4",
18
+ "jax": "jax>=0.4.1",
19
+ "jaxlib": "jaxlib>=0.4.1",
20
+ "Jinja2": "Jinja2",
21
+ "k-diffusion": "k-diffusion>=0.0.12",
22
+ "torchsde": "torchsde",
23
+ "note_seq": "note_seq",
24
+ "librosa": "librosa",
25
+ "numpy": "numpy",
26
+ "omegaconf": "omegaconf",
27
+ "parameterized": "parameterized",
28
+ "peft": "peft<=0.6.2",
29
+ "protobuf": "protobuf>=3.20.3,<4",
30
+ "pytest": "pytest",
31
+ "pytest-timeout": "pytest-timeout",
32
+ "pytest-xdist": "pytest-xdist",
33
+ "python": "python>=3.8.0",
34
+ "ruff": "ruff==0.0.280",
35
+ "safetensors": "safetensors>=0.3.1",
36
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
37
+ "scipy": "scipy",
38
+ "onnx": "onnx",
39
+ "regex": "regex!=2019.12.17",
40
+ "requests": "requests",
41
+ "tensorboard": "tensorboard",
42
+ "torch": "torch>=1.4",
43
+ "torchvision": "torchvision",
44
+ "transformers": "transformers>=4.25.1",
45
+ "urllib3": "urllib3<=2.0.0",
46
+ }
diffusers/experimental/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # 🧨 Diffusers Experimental
2
+
3
+ We are adding experimental code to support novel applications and usages of the Diffusers library.
4
+ Currently, the following experiments are supported:
5
+ * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+ self.value_function = value_function
53
+ self.unet = unet
54
+ self.scheduler = scheduler
55
+ self.env = env
56
+ self.data = env.get_dataset()
57
+ self.means = {}
58
+ for key in self.data.keys():
59
+ try:
60
+ self.means[key] = self.data[key].mean()
61
+ except: # noqa: E722
62
+ pass
63
+ self.stds = {}
64
+ for key in self.data.keys():
65
+ try:
66
+ self.stds[key] = self.data[key].std()
67
+ except: # noqa: E722
68
+ pass
69
+ self.state_dim = env.observation_space.shape[0]
70
+ self.action_dim = env.action_space.shape[0]
71
+
72
+ def normalize(self, x_in, key):
73
+ return (x_in - self.means[key]) / self.stds[key]
74
+
75
+ def de_normalize(self, x_in, key):
76
+ return x_in * self.stds[key] + self.means[key]
77
+
78
+ def to_torch(self, x_in):
79
+ if isinstance(x_in, dict):
80
+ return {k: self.to_torch(v) for k, v in x_in.items()}
81
+ elif torch.is_tensor(x_in):
82
+ return x_in.to(self.unet.device)
83
+ return torch.tensor(x_in, device=self.unet.device)
84
+
85
+ def reset_x0(self, x_in, cond, act_dim):
86
+ for key, val in cond.items():
87
+ x_in[:, key, act_dim:] = val.clone()
88
+ return x_in
89
+
90
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
91
+ batch_size = x.shape[0]
92
+ y = None
93
+ for i in tqdm.tqdm(self.scheduler.timesteps):
94
+ # create batch of timesteps to pass into model
95
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
96
+ for _ in range(n_guide_steps):
97
+ with torch.enable_grad():
98
+ x.requires_grad_()
99
+
100
+ # permute to match dimension for pre-trained models
101
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
102
+ grad = torch.autograd.grad([y.sum()], [x])[0]
103
+
104
+ posterior_variance = self.scheduler._get_variance(i)
105
+ model_std = torch.exp(0.5 * posterior_variance)
106
+ grad = model_std * grad
107
+
108
+ grad[timesteps < 2] = 0
109
+ x = x.detach()
110
+ x = x + scale * grad
111
+ x = self.reset_x0(x, conditions, self.action_dim)
112
+
113
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
114
+
115
+ # TODO: verify deprecation of this kwarg
116
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
117
+
118
+ # apply conditions to the trajectory (set the initial state)
119
+ x = self.reset_x0(x, conditions, self.action_dim)
120
+ x = self.to_torch(x)
121
+ return x, y
122
+
123
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
124
+ # normalize the observations and create batch dimension
125
+ obs = self.normalize(obs, "observations")
126
+ obs = obs[None].repeat(batch_size, axis=0)
127
+
128
+ conditions = {0: self.to_torch(obs)}
129
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
130
+
131
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
132
+ x1 = randn_tensor(shape, device=self.unet.device)
133
+ x = self.reset_x0(x1, conditions, self.action_dim)
134
+ x = self.to_torch(x)
135
+
136
+ # run the diffusion process
137
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
138
+
139
+ # sort output trajectories by value
140
+ sorted_idx = y.argsort(0, descending=True).squeeze()
141
+ sorted_values = x[sorted_idx]
142
+ actions = sorted_values[:, :, : self.action_dim]
143
+ actions = actions.detach().cpu().numpy()
144
+ denorm_actions = self.de_normalize(actions, key="actions")
145
+
146
+ # select the action with the highest value
147
+ if y is not None:
148
+ selected_index = 0
149
+ else:
150
+ # if we didn't run value guiding, select a random action
151
+ selected_index = np.random.randint(0, batch_size)
152
+
153
+ denorm_actions = denorm_actions[selected_index, 0]
154
+ return denorm_actions
diffusers/image_processor.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from .configuration_utils import ConfigMixin, register_to_config
24
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
+
26
+
27
+ PipelineImageInput = Union[
28
+ PIL.Image.Image,
29
+ np.ndarray,
30
+ torch.FloatTensor,
31
+ List[PIL.Image.Image],
32
+ List[np.ndarray],
33
+ List[torch.FloatTensor],
34
+ ]
35
+
36
+
37
+ class VaeImageProcessor(ConfigMixin):
38
+ """
39
+ Image processor for VAE.
40
+
41
+ Args:
42
+ do_resize (`bool`, *optional*, defaults to `True`):
43
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
44
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
45
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
46
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
47
+ resample (`str`, *optional*, defaults to `lanczos`):
48
+ Resampling filter to use when resizing the image.
49
+ do_normalize (`bool`, *optional*, defaults to `True`):
50
+ Whether to normalize the image to [-1,1].
51
+ do_binarize (`bool`, *optional*, defaults to `False`):
52
+ Whether to binarize the image to 0/1.
53
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
54
+ Whether to convert the images to RGB format.
55
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
56
+ Whether to convert the images to grayscale format.
57
+ """
58
+
59
+ config_name = CONFIG_NAME
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ do_resize: bool = True,
65
+ vae_scale_factor: int = 8,
66
+ resample: str = "lanczos",
67
+ do_normalize: bool = True,
68
+ do_binarize: bool = False,
69
+ do_convert_rgb: bool = False,
70
+ do_convert_grayscale: bool = False,
71
+ ):
72
+ super().__init__()
73
+ if do_convert_rgb and do_convert_grayscale:
74
+ raise ValueError(
75
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
76
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
77
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
78
+ )
79
+ self.config.do_convert_rgb = False
80
+
81
+ @staticmethod
82
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
83
+ """
84
+ Convert a numpy image or a batch of images to a PIL image.
85
+ """
86
+ if images.ndim == 3:
87
+ images = images[None, ...]
88
+ images = (images * 255).round().astype("uint8")
89
+ if images.shape[-1] == 1:
90
+ # special case for grayscale (single channel) images
91
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
92
+ else:
93
+ pil_images = [Image.fromarray(image) for image in images]
94
+
95
+ return pil_images
96
+
97
+ @staticmethod
98
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
99
+ """
100
+ Convert a PIL image or a list of PIL images to NumPy arrays.
101
+ """
102
+ if not isinstance(images, list):
103
+ images = [images]
104
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
105
+ images = np.stack(images, axis=0)
106
+
107
+ return images
108
+
109
+ @staticmethod
110
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
111
+ """
112
+ Convert a NumPy image to a PyTorch tensor.
113
+ """
114
+ if images.ndim == 3:
115
+ images = images[..., None]
116
+
117
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
118
+ return images
119
+
120
+ @staticmethod
121
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
122
+ """
123
+ Convert a PyTorch tensor to a NumPy image.
124
+ """
125
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
126
+ return images
127
+
128
+ @staticmethod
129
+ def normalize(images):
130
+ """
131
+ Normalize an image array to [-1,1].
132
+ """
133
+ return 2.0 * images - 1.0
134
+
135
+ @staticmethod
136
+ def denormalize(images):
137
+ """
138
+ Denormalize an image array to [0,1].
139
+ """
140
+ return (images / 2 + 0.5).clamp(0, 1)
141
+
142
+ @staticmethod
143
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
144
+ """
145
+ Converts a PIL image to RGB format.
146
+ """
147
+ image = image.convert("RGB")
148
+
149
+ return image
150
+
151
+ @staticmethod
152
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
153
+ """
154
+ Converts a PIL image to grayscale format.
155
+ """
156
+ image = image.convert("L")
157
+
158
+ return image
159
+
160
+ def get_default_height_width(
161
+ self,
162
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
163
+ height: Optional[int] = None,
164
+ width: Optional[int] = None,
165
+ ):
166
+ """
167
+ This function return the height and width that are downscaled to the next integer multiple of
168
+ `vae_scale_factor`.
169
+
170
+ Args:
171
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
172
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
173
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
174
+ have shape `[batch, channel, height, width]`.
175
+ height (`int`, *optional*, defaults to `None`):
176
+ The height in preprocessed image. If `None`, will use the height of `image` input.
177
+ width (`int`, *optional*`, defaults to `None`):
178
+ The width in preprocessed. If `None`, will use the width of the `image` input.
179
+ """
180
+
181
+ if height is None:
182
+ if isinstance(image, PIL.Image.Image):
183
+ height = image.height
184
+ elif isinstance(image, torch.Tensor):
185
+ height = image.shape[2]
186
+ else:
187
+ height = image.shape[1]
188
+
189
+ if width is None:
190
+ if isinstance(image, PIL.Image.Image):
191
+ width = image.width
192
+ elif isinstance(image, torch.Tensor):
193
+ width = image.shape[3]
194
+ else:
195
+ width = image.shape[2]
196
+
197
+ width, height = (
198
+ x - x % self.config.vae_scale_factor for x in (width, height)
199
+ ) # resize to integer multiple of vae_scale_factor
200
+
201
+ return height, width
202
+
203
+ def resize(
204
+ self,
205
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
206
+ height: Optional[int] = None,
207
+ width: Optional[int] = None,
208
+ ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
209
+ """
210
+ Resize image.
211
+ """
212
+ if isinstance(image, PIL.Image.Image):
213
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
214
+ elif isinstance(image, torch.Tensor):
215
+ image = torch.nn.functional.interpolate(
216
+ image,
217
+ size=(height, width),
218
+ )
219
+ elif isinstance(image, np.ndarray):
220
+ image = self.numpy_to_pt(image)
221
+ image = torch.nn.functional.interpolate(
222
+ image,
223
+ size=(height, width),
224
+ )
225
+ image = self.pt_to_numpy(image)
226
+ return image
227
+
228
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
229
+ """
230
+ create a face_hair_mask
231
+ """
232
+ image[image < 0.5] = 0
233
+ image[image >= 0.5] = 1
234
+ return image
235
+
236
+ def preprocess(
237
+ self,
238
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
239
+ height: Optional[int] = None,
240
+ width: Optional[int] = None,
241
+ ) -> torch.Tensor:
242
+ """
243
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
244
+ """
245
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
246
+
247
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
248
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
249
+ if isinstance(image, torch.Tensor):
250
+ # if image is a pytorch tensor could have 2 possible shapes:
251
+ # 1. batch x height x width: we should insert the channel dimension at position 1
252
+ # 2. channnel x height x width: we should insert batch dimension at position 0,
253
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
254
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
255
+ image = image.unsqueeze(1)
256
+ else:
257
+ # if it is a numpy array, it could have 2 possible shapes:
258
+ # 1. batch x height x width: insert channel dimension on last position
259
+ # 2. height x width x channel: insert batch dimension on first position
260
+ if image.shape[-1] == 1:
261
+ image = np.expand_dims(image, axis=0)
262
+ else:
263
+ image = np.expand_dims(image, axis=-1)
264
+
265
+ if isinstance(image, supported_formats):
266
+ image = [image]
267
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
268
+ raise ValueError(
269
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
270
+ )
271
+
272
+ if isinstance(image[0], PIL.Image.Image):
273
+ if self.config.do_convert_rgb:
274
+ image = [self.convert_to_rgb(i) for i in image]
275
+ elif self.config.do_convert_grayscale:
276
+ image = [self.convert_to_grayscale(i) for i in image]
277
+ if self.config.do_resize:
278
+ height, width = self.get_default_height_width(image[0], height, width)
279
+ image = [self.resize(i, height, width) for i in image]
280
+ image = self.pil_to_numpy(image) # to np
281
+ image = self.numpy_to_pt(image) # to pt
282
+
283
+ elif isinstance(image[0], np.ndarray):
284
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
285
+
286
+ image = self.numpy_to_pt(image)
287
+
288
+ height, width = self.get_default_height_width(image, height, width)
289
+ if self.config.do_resize:
290
+ image = self.resize(image, height, width)
291
+
292
+ elif isinstance(image[0], torch.Tensor):
293
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
294
+
295
+ if self.config.do_convert_grayscale and image.ndim == 3:
296
+ image = image.unsqueeze(1)
297
+
298
+ channel = image.shape[1]
299
+ # don't need any preprocess if the image is latents
300
+ if channel == 4:
301
+ return image
302
+
303
+ height, width = self.get_default_height_width(image, height, width)
304
+ if self.config.do_resize:
305
+ image = self.resize(image, height, width)
306
+
307
+ # expected range [0,1], normalize to [-1,1]
308
+ do_normalize = self.config.do_normalize
309
+ if image.min() < 0 and do_normalize:
310
+ warnings.warn(
311
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
312
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
313
+ FutureWarning,
314
+ )
315
+ do_normalize = False
316
+
317
+ if do_normalize:
318
+ image = self.normalize(image)
319
+
320
+ if self.config.do_binarize:
321
+ image = self.binarize(image)
322
+
323
+ return image
324
+
325
+ def postprocess(
326
+ self,
327
+ image: torch.FloatTensor,
328
+ output_type: str = "pil",
329
+ do_denormalize: Optional[List[bool]] = None,
330
+ ):
331
+ if not isinstance(image, torch.Tensor):
332
+ raise ValueError(
333
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
334
+ )
335
+ if output_type not in ["latent", "pt", "np", "pil"]:
336
+ deprecation_message = (
337
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
338
+ "`pil`, `np`, `pt`, `latent`"
339
+ )
340
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
341
+ output_type = "np"
342
+
343
+ if output_type == "latent":
344
+ return image
345
+
346
+ if do_denormalize is None:
347
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
348
+
349
+ image = torch.stack(
350
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
351
+ )
352
+
353
+ if output_type == "pt":
354
+ return image
355
+
356
+ image = self.pt_to_numpy(image)
357
+
358
+ if output_type == "np":
359
+ return image
360
+
361
+ if output_type == "pil":
362
+ return self.numpy_to_pil(image)
363
+
364
+
365
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
366
+ """
367
+ Image processor for VAE LDM3D.
368
+
369
+ Args:
370
+ do_resize (`bool`, *optional*, defaults to `True`):
371
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
372
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
373
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
374
+ resample (`str`, *optional*, defaults to `lanczos`):
375
+ Resampling filter to use when resizing the image.
376
+ do_normalize (`bool`, *optional*, defaults to `True`):
377
+ Whether to normalize the image to [-1,1].
378
+ """
379
+
380
+ config_name = CONFIG_NAME
381
+
382
+ @register_to_config
383
+ def __init__(
384
+ self,
385
+ do_resize: bool = True,
386
+ vae_scale_factor: int = 8,
387
+ resample: str = "lanczos",
388
+ do_normalize: bool = True,
389
+ ):
390
+ super().__init__()
391
+
392
+ @staticmethod
393
+ def numpy_to_pil(images):
394
+ """
395
+ Convert a NumPy image or a batch of images to a PIL image.
396
+ """
397
+ if images.ndim == 3:
398
+ images = images[None, ...]
399
+ images = (images * 255).round().astype("uint8")
400
+ if images.shape[-1] == 1:
401
+ # special case for grayscale (single channel) images
402
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
403
+ else:
404
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
405
+
406
+ return pil_images
407
+
408
+ @staticmethod
409
+ def rgblike_to_depthmap(image):
410
+ """
411
+ Args:
412
+ image: RGB-like depth image
413
+
414
+ Returns: depth map
415
+
416
+ """
417
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
418
+
419
+ def numpy_to_depth(self, images):
420
+ """
421
+ Convert a NumPy depth image or a batch of images to a PIL image.
422
+ """
423
+ if images.ndim == 3:
424
+ images = images[None, ...]
425
+ images_depth = images[:, :, :, 3:]
426
+ if images.shape[-1] == 6:
427
+ images_depth = (images_depth * 255).round().astype("uint8")
428
+ pil_images = [
429
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
430
+ ]
431
+ elif images.shape[-1] == 4:
432
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
433
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
434
+ else:
435
+ raise Exception("Not supported")
436
+
437
+ return pil_images
438
+
439
+ def postprocess(
440
+ self,
441
+ image: torch.FloatTensor,
442
+ output_type: str = "pil",
443
+ do_denormalize: Optional[List[bool]] = None,
444
+ ):
445
+ if not isinstance(image, torch.Tensor):
446
+ raise ValueError(
447
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
448
+ )
449
+ if output_type not in ["latent", "pt", "np", "pil"]:
450
+ deprecation_message = (
451
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
452
+ "`pil`, `np`, `pt`, `latent`"
453
+ )
454
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
455
+ output_type = "np"
456
+
457
+ if do_denormalize is None:
458
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
459
+
460
+ image = torch.stack(
461
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
462
+ )
463
+
464
+ image = self.pt_to_numpy(image)
465
+
466
+ if output_type == "np":
467
+ if image.shape[-1] == 6:
468
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
469
+ else:
470
+ image_depth = image[:, :, :, 3:]
471
+ return image[:, :, :, :3], image_depth
472
+
473
+ if output_type == "pil":
474
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
475
+ else:
476
+ raise Exception(f"This type {output_type} is not supported")
diffusers/loaders.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers/models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Models
2
+
3
+ For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
diffusers/models/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
18
+
19
+
20
+ _import_structure = {}
21
+
22
+ if is_torch_available():
23
+ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
24
+ _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
25
+ _import_structure["autoencoder_kl"] = ["AutoencoderKL"]
26
+ _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
27
+ _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
28
+ _import_structure["controlnet"] = ["ControlNetModel"]
29
+ _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
30
+ _import_structure["modeling_utils"] = ["ModelMixin"]
31
+ _import_structure["prior_transformer"] = ["PriorTransformer"]
32
+ _import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
33
+ _import_structure["transformer_2d"] = ["Transformer2DModel"]
34
+ _import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
35
+ _import_structure["unet_1d"] = ["UNet1DModel"]
36
+ _import_structure["unet_2d"] = ["UNet2DModel"]
37
+ _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
38
+ _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
39
+ _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
40
+ _import_structure["vq_model"] = ["VQModel"]
41
+
42
+ if is_flax_available():
43
+ _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
44
+ _import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
45
+ _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
46
+
47
+
48
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
49
+ if is_torch_available():
50
+ from .adapter import MultiAdapter, T2IAdapter
51
+ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
52
+ from .autoencoder_kl import AutoencoderKL
53
+ from .autoencoder_tiny import AutoencoderTiny
54
+ from .consistency_decoder_vae import ConsistencyDecoderVAE
55
+ from .controlnet import ControlNetModel
56
+ from .dual_transformer_2d import DualTransformer2DModel
57
+ from .modeling_utils import ModelMixin
58
+ from .prior_transformer import PriorTransformer
59
+ from .t5_film_transformer import T5FilmDecoder
60
+ from .transformer_2d import Transformer2DModel
61
+ from .transformer_temporal import TransformerTemporalModel
62
+ from .unet_1d import UNet1DModel
63
+ from .unet_2d import UNet2DModel
64
+ from .unet_2d_condition import UNet2DConditionModel
65
+ from .unet_3d_condition import UNet3DConditionModel
66
+ from .unet_motion_model import MotionAdapter, UNetMotionModel
67
+ from .vq_model import VQModel
68
+
69
+ if is_flax_available():
70
+ from .controlnet_flax import FlaxControlNetModel
71
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
72
+ from .vae_flax import FlaxAutoencoderKL
73
+
74
+ else:
75
+ import sys
76
+
77
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.24 kB). View file
 
diffusers/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.23 kB). View file
 
diffusers/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.24 kB). View file
 
diffusers/models/__pycache__/activations.cpython-310.pyc ADDED
Binary file (4.2 kB). View file
 
diffusers/models/__pycache__/activations.cpython-38.pyc ADDED
Binary file (4.18 kB). View file
 
diffusers/models/__pycache__/activations.cpython-39.pyc ADDED
Binary file (4.19 kB). View file
 
diffusers/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
diffusers/models/__pycache__/attention.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
diffusers/models/__pycache__/attention.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
diffusers/models/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (48.6 kB). View file
 
diffusers/models/__pycache__/attention_processor.cpython-38.pyc ADDED
Binary file (48.8 kB). View file
 
diffusers/models/__pycache__/attention_processor.cpython-39.pyc ADDED
Binary file (48.8 kB). View file
 
diffusers/models/__pycache__/autoencoder_asym_kl.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
diffusers/models/__pycache__/autoencoder_asym_kl.cpython-38.pyc ADDED
Binary file (6.39 kB). View file
 
diffusers/models/__pycache__/autoencoder_asym_kl.cpython-39.pyc ADDED
Binary file (6.4 kB). View file
 
diffusers/models/__pycache__/autoencoder_kl.cpython-310.pyc ADDED
Binary file (17.4 kB). View file
 
diffusers/models/__pycache__/autoencoder_kl.cpython-38.pyc ADDED
Binary file (17.4 kB). View file
 
diffusers/models/__pycache__/autoencoder_kl.cpython-39.pyc ADDED
Binary file (17.4 kB). View file
 
diffusers/models/__pycache__/controlnet.cpython-310.pyc ADDED
Binary file (27.3 kB). View file
 
diffusers/models/__pycache__/controlnet.cpython-38.pyc ADDED
Binary file (26.8 kB). View file
 
diffusers/models/__pycache__/controlnet.cpython-39.pyc ADDED
Binary file (26.8 kB). View file
 
diffusers/models/__pycache__/dual_transformer_2d.cpython-310.pyc ADDED
Binary file (6.08 kB). View file
 
diffusers/models/__pycache__/dual_transformer_2d.cpython-38.pyc ADDED
Binary file (5.95 kB). View file
 
diffusers/models/__pycache__/dual_transformer_2d.cpython-39.pyc ADDED
Binary file (5.96 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (23.9 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-38.pyc ADDED
Binary file (23.7 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-39.pyc ADDED
Binary file (23.7 kB). View file
 
diffusers/models/__pycache__/lora.cpython-310.pyc ADDED
Binary file (9.27 kB). View file