harsh99 commited on
Commit
16d759f
·
1 Parent(s): 70511fe
Files changed (7) hide show
  1. .gitignore +2 -0
  2. README.md +251 -1
  3. diffusion.py +23 -32
  4. model.py +6 -2
  5. model_converter.py +80 -81
  6. pipeline.py +36 -8
  7. test.ipynb +0 -0
.gitignore CHANGED
@@ -1,4 +1,6 @@
1
  *inkpunk-diffusion-v1.ckpt
 
 
2
 
3
  # Byte-compiled / optimized / DLL files
4
  __pycache__/
 
1
  *inkpunk-diffusion-v1.ckpt
2
+ *sd-v1-5-inpainting.ckpt
3
+ *zalando-hd-resized.zip
4
 
5
  # Byte-compiled / optimized / DLL files
6
  __pycache__/
README.md CHANGED
@@ -1,4 +1,254 @@
1
  # stable-diffusion
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  <!-- 1. Download `vocab.json` and `merges.txt` from https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main/tokenizer and save them in the `data` folder
4
- 2. Download `inkpunk-diffusion-v1.ckpt` from https://huggingface.co/Envvi/Inkpunk-Diffusion/tree/main and save it in the `data` folder -->
 
 
 
 
 
1
  # stable-diffusion
2
 
3
+ # 🎨 Stable Diffusion & CatVTON Implementation
4
+
5
+ <div align="center">
6
+
7
+ ![Stable Diffusion](https://img.shields.io/badge/Stable%20Diffusion-From%20Scratch-blue?style=for-the-badge&logo=pytorch)
8
+ ![CatVTON](https://img.shields.io/badge/CatVTON-Virtual%20Try--On-purple?style=for-the-badge)
9
+ ![PyTorch](https://img.shields.io/badge/PyTorch-EE4C2C?style=for-the-badge&logo=pytorch&logoColor=white)
10
+ ![Python](https://img.shields.io/badge/Python-3.10.9-green?style=for-the-badge&logo=python&logoColor=white)
11
+
12
+ *A comprehensive implementation of Stable Diffusion from scratch with CatVTON virtual try-on capabilities*
13
+
14
+ </div>
15
+
16
+ ---
17
+
18
+ ## 📋 Table of Contents
19
+
20
+ - [🌟 Overview](#-overview)
21
+ - [🏗️ Project Structure](#️-project-structure)
22
+ - [🚀 Features](#-features)
23
+ - [⚙️ Setup & Installation](#️-setup--installation)
24
+ - [📥 Model Downloads](#-model-downloads)
25
+ - [🎯 CatVTON Integration](#-catvton-integration)
26
+ - [📚 References](#-references)
27
+ - [👤 Author](#-author)
28
+ - [📜 License](#-license)
29
+
30
+ ---
31
+
32
+ ## 🌟 Overview
33
+
34
+ This project implements **Stable Diffusion from scratch** using PyTorch, with an additional **CatVTON (Virtual Cloths Try-On)** model built on top of stable-diffusion. The implementation includes:
35
+
36
+ - ✨ Complete Stable Diffusion pipeline built from ground up **(Branch: Main)**
37
+ - 🎭 CatVTON model for virtual garment try-on **(Branch: CatVTON)**
38
+ - 🧠 Custom attention mechanisms and CLIP integration
39
+ - 🔄 DDPM (Denoising Diffusion Probabilistic Models) implementation
40
+ - 🖼️ Inpainting capabilities using pretrained weights
41
+
42
+ ---
43
+
44
+ ## 🏗️ Project Structure
45
+
46
+ ```
47
+ stable-diffusion/
48
+ ├── 📁 Core Components
49
+ │ ├── attention.py # Attention mechanisms
50
+ │ ├── clip.py # CLIP model implementation
51
+ │ ├── ddpm.py # DDPM sampler
52
+ │ ├── decoder.py # VAE decoder
53
+ │ ├── encoder.py # VAE encoder
54
+ │ ├── diffusion.py # Diffusion process
55
+ │ ├── model.py # Defining model & loading pre-trained weights
56
+ │ └── pipeline.py # Main pipeline
57
+
58
+ ├── 📁 Utilities & Interface
59
+ │ ├── interface.py # User interface
60
+ │ ├── model_converter.py # Model conversion utilities
61
+ │ └── requirements.txt # Dependencies
62
+
63
+ ├── 📁 Data & Models
64
+ │ ├── vocab.json # Tokenizer vocabulary
65
+ │ ├── merges.txt # BPE merges
66
+ │ ├── inkpunk-diffusion-v1.ckpt # Inkpunk model weights
67
+ │ └── sd-v1-5-inpainting.ckpt # Inpainting model weights
68
+
69
+ ├── 📁 Sample Data
70
+ │ ├── person.jpg # Person image for try-on
71
+ │ ├── garment.jpg # Garment image
72
+ │ ├── agnostic_mask.png # Segmentation mask
73
+ │ ├── dog.jpg # Test image
74
+ │ ├── image.png # Generated sample
75
+ │ └── zalando-hd-resized.zip # Dataset
76
+
77
+ └── 📁 Notebooks & Documentation
78
+ ├── test.ipynb # Testing notebook
79
+ └── README.md # This file
80
+ ```
81
+
82
+ ---
83
+
84
+ ## 🚀 Features
85
+
86
+ ### 🎨 Stable Diffusion Core
87
+ - **From-scratch implementation** of Stable Diffusion architecture
88
+ - **Custom CLIP** text encoder integration
89
+ - **VAE encoder/decoder** for latent space operations
90
+ - **DDPM sampling** with configurable steps
91
+ - **Attention mechanisms** optimized for diffusion
92
+
93
+ ### 👕 CatVTON Capabilities
94
+ - **Virtual garment try-on** using inpainting
95
+ - **Person-garment alignment** and fitting
96
+ - **Mask-based inpainting** for realistic results
97
+
98
+ ---
99
+
100
+ ## ⚙️ Setup & Installation
101
+
102
+ ### Prerequisites
103
+ - Python 3.10.9
104
+ - CUDA-compatible GPU (recommended)
105
+ - Git
106
+
107
+ ### 1. Clone Repository
108
+ ```bash
109
+ git clone https://github.com/Harsh-Kesharwani/stable-diffusion.git
110
+ cd stable-diffusion
111
+ git checkout CatVTON # Switch to CatVTON branch to use virtual-try-on model
112
+ ```
113
+
114
+ ### 2. Create Virtual Environment
115
+ ```bash
116
+ conda create -n stable-diffusion python=3.10.9
117
+ conda activate stable-diffusion
118
+ ```
119
+
120
+ ### 3. Install Dependencies
121
+ ```bash
122
+ pip install -r requirements.txt
123
+ ```
124
+
125
+ ### 4. Verify Installation
126
+ ```bash
127
+ python -c "import torch; print(f'PyTorch version: {torch.__version__}')"
128
+ python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
129
+ ```
130
+
131
+ ---
132
+
133
+ ## 📥 Model Downloads
134
+
135
+ ### Required Files
136
+
137
+ #### 1. Tokenizer Files
138
+ Download from [Stable Diffusion v1.4 Tokenizer](https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main/tokenizer):
139
+ - `vocab.json`
140
+ - `merges.txt`
141
+
142
+ #### 2. Model Checkpoints
143
+ - **Inkpunk Diffusion**: Download `inkpunk-diffusion-v1.ckpt` from [Envvi/Inkpunk-Diffusion](https://huggingface.co/Envvi/Inkpunk-Diffusion/tree/main)
144
+ - **Inpainting Model**: Download from [Stable Diffusion v1.5 Inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)
145
+
146
+ ### Download Script
147
+ ```bash
148
+ # Create data directory if it doesn't exist
149
+ mkdir -p data
150
+
151
+ # Download tokenizer files
152
+ wget -O vocab.json "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/tokenizer/vocab.json"
153
+ wget -O merges.txt "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/tokenizer/merges.txt"
154
+
155
+ # Note: Large model files need to be downloaded manually from HuggingFace
156
+ ```
157
+
158
+ ---
159
+
160
+ ### Interactive Interface
161
+ ```bash
162
+ python interface.py
163
+ ```
164
+
165
+ ---
166
+
167
+ ## 🎯 CatVTON Integration
168
+
169
+ The CatVTON model extends the base Stable Diffusion with specialized capabilities for virtual garment try-on:
170
+
171
+ ### Key Components
172
+ 1. **Inpainting Pipeline**: Uses `sd-v1-5-inpainting.ckpt` for mask-based generation
173
+ 2. **Garment Alignment**: Automatic alignment of garments to person pose
174
+ 3. **Mask Generation**: Automated or manual mask creation for try-on regions
175
+ ---
176
+
177
+ ## 📚 References
178
+
179
+ ### 📖 Implementation Guides
180
+ - [Implementing Stable Diffusion from Scratch - Medium](https://medium.com/@sayedebad.777/implementing-stable-diffusion-from-scratch-using-pytorch-f07d50efcd97)
181
+ - [Stable Diffusion Implementation - YouTube](https://www.youtube.com/watch?v=ZBKpAp_6TGI)
182
+
183
+ ### 🤗 HuggingFace Resources
184
+ - [Diffusers: Adapt a Model](https://huggingface.co/docs/diffusers/training/adapt_a_model)
185
+ - [Stable Diffusion v1.5 Inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)
186
+ - [CompVis Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4)
187
+ - [Inkpunk Diffusion](https://huggingface.co/Envvi/Inkpunk-Diffusion)
188
+
189
+ ### 📄 Academic Papers
190
+ - Stable Diffusion: High-Resolution Image Synthesis with Latent Diffusion Models
191
+ - DDPM: Denoising Diffusion Probabilistic Models
192
+ - CatVTON: Category-aware Virtual Try-On Network
193
+
194
+ ---
195
+
196
+ ## 👤 Author
197
+
198
+ <div align="center">
199
+
200
+ **Harsh Kesharwani**
201
+
202
+ [![GitHub](https://img.shields.io/badge/GitHub-100000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/Harsh-Kesharwani)
203
+ [![LinkedIn](https://img.shields.io/badge/LinkedIn-0A66C2?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/harsh-kesharwani/)
204
+ [![Email](https://img.shields.io/badge/Email-D14836?style=for-the-badge&logo=gmail&logoColor=white)](mailto:harshkesharwani777@gmail.com)
205
+
206
+ *Passionate about AI, Computer Vision, and Generative Models*
207
+
208
+ </div>
209
+
210
+ ---
211
+
212
+ ## 🤝 Contributing
213
+
214
+ Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
215
+
216
+ ### Development Setup
217
+ 1. Fork the repository
218
+ 2. Create your feature branch (`git checkout -b feature/amazing-feature`)
219
+ 3. Commit your changes (`git commit -m 'Add some amazing feature'`)
220
+ 4. Push to the branch (`git push origin feature/amazing-feature`)
221
+ 5. Open a Pull Request
222
+
223
+ ---
224
+
225
+ ## 📜 License
226
+
227
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
228
+
229
+ ---
230
+
231
+ ## 🙏 Acknowledgments
232
+
233
+ - CompVis team for the original Stable Diffusion implementation
234
+ - HuggingFace for providing pre-trained weights, dataset and references.
235
+ - The open-source community for various implementations and tutorials
236
+ - Zalando Research for the fashion dataset
237
+
238
+ ---
239
+
240
+ <div align="center">
241
+
242
+ **⭐ Star this repository if you find it helpful!**
243
+
244
+ *Built with ❤️ by [Harsh Kesharwani](https://www.linkedin.com/in/harsh-kesharwani/)*
245
+
246
+ </div>
247
+
248
+
249
  <!-- 1. Download `vocab.json` and `merges.txt` from https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main/tokenizer and save them in the `data` folder
250
+ 1. Download `inkpunk-diffusion-v1.ckpt` from https://huggingface.co/Envvi/Inkpunk-Diffusion/tree/main and save it in the `data` folder -->
251
+
252
+ <!-- IMPORTANT REFRRENCE
253
+ 3. https://huggingface.co/docs/diffusers/training/adapt_a_model -->
254
+ <!-- 4. https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting -->
diffusion.py CHANGED
@@ -134,11 +134,11 @@ class SwitchSequential(nn.Sequential):
134
  return x
135
 
136
  class UNET(nn.Module):
137
- def __init__(self):
138
  super().__init__()
139
  self.encoders=nn.ModuleList([
140
  # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
141
- SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
142
 
143
  # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
144
  SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
@@ -245,11 +245,11 @@ class UNET_OutputLayer(nn.Module):
245
  return x
246
 
247
  class Diffusion(nn.Module):
248
- def __init__(self):
249
  super().__init__()
250
  self.time_embedding=TimeEmbedding(320)
251
- self.unet=UNET()
252
- self.final=UNET_OutputLayer(320, 4)
253
 
254
  def forward(self, latent, time):
255
  time=self.time_embedding(time)
@@ -260,38 +260,29 @@ class Diffusion(nn.Module):
260
 
261
  return output
262
 
263
- if __name__ == "__main__":
264
- # Dummy inputs
265
- batch_size = 10
266
- height = 64
267
- width = 64
268
- in_channels = 4
269
- # context_dim = 768
270
- seq_len = 77
271
-
272
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
273
-
274
- # Create model and move to device
275
- model = Diffusion().to(device)
276
-
277
- # Random input tensor with 4 channels
278
- x = torch.randn(batch_size, in_channels, height, width).to(device)
279
 
280
- print('Input shape to UNET: ', x.shape)
281
-
282
- # Time embedding (e.g., timestep from a diffusion schedule)
283
- t = torch.randn(batch_size, 320).to(device)
284
 
285
- print('Time Embedding shape to UNET: ',t.shape)
 
 
 
 
 
286
 
287
- # Context for cross attention (e.g., text embedding from CLIP or transformer)
288
- # context = torch.randn(batch_size, seq_len, context_dim).to(device)
 
289
 
290
- # print('context shape to UNET: ', context.shape)
 
291
 
292
  # Forward pass
293
  with torch.no_grad():
294
- output = model(x, t)
295
- print(output)
296
 
297
- print("Output shape of UNET:", output.shape)
 
 
 
 
134
  return x
135
 
136
  class UNET(nn.Module):
137
+ def __init__(self, in_channels=4):
138
  super().__init__()
139
  self.encoders=nn.ModuleList([
140
  # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
141
+ SwitchSequential(nn.Conv2d(in_channels, 320, kernel_size=3, padding=1)),
142
 
143
  # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
144
  SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
 
245
  return x
246
 
247
  class Diffusion(nn.Module):
248
+ def __init__(self, in_channels=4, out_channels=4):
249
  super().__init__()
250
  self.time_embedding=TimeEmbedding(320)
251
+ self.unet=UNET(in_channels)
252
+ self.final=UNET_OutputLayer(320, out_channels)
253
 
254
  def forward(self, latent, time):
255
  time=self.time_embedding(time)
 
260
 
261
  return output
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ import torch
265
+ from torch import nn
 
 
266
 
267
+ if __name__ == "__main__":
268
+ # Input configurations
269
+ batch_size = 1
270
+ in_channels = 9 # Must match UNET input
271
+ height = 64 # Height / 8 = 64 → original H = 512
272
+ width = 64 # Width / 8 = 64 → original W = 512
273
 
274
+ # Create dummy inputs
275
+ latent = torch.randn(batch_size, in_channels, height, width) # [1, 4, 64, 64]
276
+ time = torch.randn(batch_size, 320) # [1, 320]
277
 
278
+ # Initialize model
279
+ model = Diffusion(in_channels=in_channels, out_channels=4)
280
 
281
  # Forward pass
282
  with torch.no_grad():
283
+ output = model(latent, time)
 
284
 
285
+ # Print input and output shape
286
+ print("Input latent shape:", latent.shape)
287
+ print("Time embedding shape:", time.shape)
288
+ print("Output shape:", output.shape)
model.py CHANGED
@@ -6,6 +6,10 @@ from diffusion import Diffusion
6
  import model_converter
7
 
8
  def preload_models_from_standard_weights(ckpt_path, device):
 
 
 
 
9
  state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
10
 
11
  encoder=VAE_Encoder().to(device)
@@ -14,8 +18,8 @@ def preload_models_from_standard_weights(ckpt_path, device):
14
  decoder=VAE_Decoder().to(device)
15
  decoder.load_state_dict(state_dict['decoder'], strict=True)
16
 
17
- diffusion=Diffusion().to(device)
18
- diffusion.load_state_dict(state_dict['diffusion'], strict=True)
19
 
20
  clip=CLIP().to(device)
21
  clip.load_state_dict(state_dict['clip'], strict=True)
 
6
  import model_converter
7
 
8
  def preload_models_from_standard_weights(ckpt_path, device):
9
+ # CatVTON parameters
10
+ in_channels = 9
11
+ out_channels = 4
12
+
13
  state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
14
 
15
  encoder=VAE_Encoder().to(device)
 
18
  decoder=VAE_Decoder().to(device)
19
  decoder.load_state_dict(state_dict['decoder'], strict=True)
20
 
21
+ diffusion=Diffusion(in_channels=in_channels, out_channels=out_channels).to(device)
22
+ diffusion.load_state_dict(state_dict['diffusion'], strict=False)
23
 
24
  clip=CLIP().to(device)
25
  clip.load_state_dict(state_dict['clip'], strict=True)
model_converter.py CHANGED
@@ -5,7 +5,6 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
5
  # original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
6
  original_model=torch.load(input_file, weights_only = False)["state_dict"]
7
 
8
-
9
  converted = {}
10
  converted['diffusion'] = {}
11
  converted['encoder'] = {}
@@ -38,11 +37,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
38
  converted['diffusion']['unet.encoders.1.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias']
39
  converted['diffusion']['unet.encoders.1.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight']
40
  converted['diffusion']['unet.encoders.1.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias']
41
- converted['diffusion']['unet.encoders.1.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight']
42
- converted['diffusion']['unet.encoders.1.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight']
43
- converted['diffusion']['unet.encoders.1.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight']
44
- converted['diffusion']['unet.encoders.1.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight']
45
- converted['diffusion']['unet.encoders.1.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias']
46
  converted['diffusion']['unet.encoders.1.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight']
47
  converted['diffusion']['unet.encoders.1.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias']
48
  converted['diffusion']['unet.encoders.1.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight']
@@ -71,11 +70,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
71
  converted['diffusion']['unet.encoders.2.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias']
72
  converted['diffusion']['unet.encoders.2.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight']
73
  converted['diffusion']['unet.encoders.2.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias']
74
- converted['diffusion']['unet.encoders.2.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight']
75
- converted['diffusion']['unet.encoders.2.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight']
76
- converted['diffusion']['unet.encoders.2.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight']
77
- converted['diffusion']['unet.encoders.2.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight']
78
- converted['diffusion']['unet.encoders.2.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias']
79
  converted['diffusion']['unet.encoders.2.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight']
80
  converted['diffusion']['unet.encoders.2.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias']
81
  converted['diffusion']['unet.encoders.2.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight']
@@ -108,11 +107,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
108
  converted['diffusion']['unet.encoders.4.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias']
109
  converted['diffusion']['unet.encoders.4.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight']
110
  converted['diffusion']['unet.encoders.4.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias']
111
- converted['diffusion']['unet.encoders.4.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight']
112
- converted['diffusion']['unet.encoders.4.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight']
113
- converted['diffusion']['unet.encoders.4.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight']
114
- converted['diffusion']['unet.encoders.4.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight']
115
- converted['diffusion']['unet.encoders.4.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias']
116
  converted['diffusion']['unet.encoders.4.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight']
117
  converted['diffusion']['unet.encoders.4.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias']
118
  converted['diffusion']['unet.encoders.4.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight']
@@ -141,11 +140,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
141
  converted['diffusion']['unet.encoders.5.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias']
142
  converted['diffusion']['unet.encoders.5.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight']
143
  converted['diffusion']['unet.encoders.5.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias']
144
- converted['diffusion']['unet.encoders.5.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight']
145
- converted['diffusion']['unet.encoders.5.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight']
146
- converted['diffusion']['unet.encoders.5.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight']
147
- converted['diffusion']['unet.encoders.5.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight']
148
- converted['diffusion']['unet.encoders.5.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias']
149
  converted['diffusion']['unet.encoders.5.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight']
150
  converted['diffusion']['unet.encoders.5.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias']
151
  converted['diffusion']['unet.encoders.5.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight']
@@ -178,11 +177,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
178
  converted['diffusion']['unet.encoders.7.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias']
179
  converted['diffusion']['unet.encoders.7.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight']
180
  converted['diffusion']['unet.encoders.7.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias']
181
- converted['diffusion']['unet.encoders.7.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight']
182
- converted['diffusion']['unet.encoders.7.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight']
183
- converted['diffusion']['unet.encoders.7.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight']
184
- converted['diffusion']['unet.encoders.7.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight']
185
- converted['diffusion']['unet.encoders.7.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias']
186
  converted['diffusion']['unet.encoders.7.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight']
187
  converted['diffusion']['unet.encoders.7.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias']
188
  converted['diffusion']['unet.encoders.7.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight']
@@ -211,11 +210,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
211
  converted['diffusion']['unet.encoders.8.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias']
212
  converted['diffusion']['unet.encoders.8.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight']
213
  converted['diffusion']['unet.encoders.8.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias']
214
- converted['diffusion']['unet.encoders.8.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight']
215
- converted['diffusion']['unet.encoders.8.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight']
216
- converted['diffusion']['unet.encoders.8.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight']
217
- converted['diffusion']['unet.encoders.8.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight']
218
- converted['diffusion']['unet.encoders.8.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias']
219
  converted['diffusion']['unet.encoders.8.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight']
220
  converted['diffusion']['unet.encoders.8.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias']
221
  converted['diffusion']['unet.encoders.8.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight']
@@ -266,11 +265,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
266
  converted['diffusion']['unet.bottleneck.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias']
267
  converted['diffusion']['unet.bottleneck.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight']
268
  converted['diffusion']['unet.bottleneck.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias']
269
- converted['diffusion']['unet.bottleneck.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight']
270
- converted['diffusion']['unet.bottleneck.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight']
271
- converted['diffusion']['unet.bottleneck.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight']
272
- converted['diffusion']['unet.bottleneck.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight']
273
- converted['diffusion']['unet.bottleneck.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias']
274
  converted['diffusion']['unet.bottleneck.1.layernorm_1.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight']
275
  converted['diffusion']['unet.bottleneck.1.layernorm_1.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias']
276
  converted['diffusion']['unet.bottleneck.1.layernorm_2.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight']
@@ -349,11 +348,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
349
  converted['diffusion']['unet.decoders.3.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias']
350
  converted['diffusion']['unet.decoders.3.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight']
351
  converted['diffusion']['unet.decoders.3.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias']
352
- converted['diffusion']['unet.decoders.3.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight']
353
- converted['diffusion']['unet.decoders.3.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight']
354
- converted['diffusion']['unet.decoders.3.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight']
355
- converted['diffusion']['unet.decoders.3.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight']
356
- converted['diffusion']['unet.decoders.3.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias']
357
  converted['diffusion']['unet.decoders.3.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight']
358
  converted['diffusion']['unet.decoders.3.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias']
359
  converted['diffusion']['unet.decoders.3.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight']
@@ -384,11 +383,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
384
  converted['diffusion']['unet.decoders.4.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias']
385
  converted['diffusion']['unet.decoders.4.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight']
386
  converted['diffusion']['unet.decoders.4.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias']
387
- converted['diffusion']['unet.decoders.4.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight']
388
- converted['diffusion']['unet.decoders.4.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight']
389
- converted['diffusion']['unet.decoders.4.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight']
390
- converted['diffusion']['unet.decoders.4.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight']
391
- converted['diffusion']['unet.decoders.4.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias']
392
  converted['diffusion']['unet.decoders.4.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight']
393
  converted['diffusion']['unet.decoders.4.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias']
394
  converted['diffusion']['unet.decoders.4.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight']
@@ -419,11 +418,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
419
  converted['diffusion']['unet.decoders.5.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias']
420
  converted['diffusion']['unet.decoders.5.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight']
421
  converted['diffusion']['unet.decoders.5.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias']
422
- converted['diffusion']['unet.decoders.5.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight']
423
- converted['diffusion']['unet.decoders.5.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight']
424
- converted['diffusion']['unet.decoders.5.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight']
425
- converted['diffusion']['unet.decoders.5.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight']
426
- converted['diffusion']['unet.decoders.5.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias']
427
  converted['diffusion']['unet.decoders.5.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight']
428
  converted['diffusion']['unet.decoders.5.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias']
429
  converted['diffusion']['unet.decoders.5.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight']
@@ -456,11 +455,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
456
  converted['diffusion']['unet.decoders.6.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias']
457
  converted['diffusion']['unet.decoders.6.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight']
458
  converted['diffusion']['unet.decoders.6.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias']
459
- converted['diffusion']['unet.decoders.6.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight']
460
- converted['diffusion']['unet.decoders.6.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight']
461
- converted['diffusion']['unet.decoders.6.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight']
462
- converted['diffusion']['unet.decoders.6.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight']
463
- converted['diffusion']['unet.decoders.6.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias']
464
  converted['diffusion']['unet.decoders.6.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight']
465
  converted['diffusion']['unet.decoders.6.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias']
466
  converted['diffusion']['unet.decoders.6.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight']
@@ -491,11 +490,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
491
  converted['diffusion']['unet.decoders.7.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias']
492
  converted['diffusion']['unet.decoders.7.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight']
493
  converted['diffusion']['unet.decoders.7.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias']
494
- converted['diffusion']['unet.decoders.7.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight']
495
- converted['diffusion']['unet.decoders.7.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight']
496
- converted['diffusion']['unet.decoders.7.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight']
497
- converted['diffusion']['unet.decoders.7.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight']
498
- converted['diffusion']['unet.decoders.7.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias']
499
  converted['diffusion']['unet.decoders.7.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight']
500
  converted['diffusion']['unet.decoders.7.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias']
501
  converted['diffusion']['unet.decoders.7.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight']
@@ -526,11 +525,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
526
  converted['diffusion']['unet.decoders.8.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias']
527
  converted['diffusion']['unet.decoders.8.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight']
528
  converted['diffusion']['unet.decoders.8.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias']
529
- converted['diffusion']['unet.decoders.8.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight']
530
- converted['diffusion']['unet.decoders.8.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight']
531
- converted['diffusion']['unet.decoders.8.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight']
532
- converted['diffusion']['unet.decoders.8.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight']
533
- converted['diffusion']['unet.decoders.8.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias']
534
  converted['diffusion']['unet.decoders.8.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight']
535
  converted['diffusion']['unet.decoders.8.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias']
536
  converted['diffusion']['unet.decoders.8.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight']
@@ -563,11 +562,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
563
  converted['diffusion']['unet.decoders.9.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias']
564
  converted['diffusion']['unet.decoders.9.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight']
565
  converted['diffusion']['unet.decoders.9.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias']
566
- converted['diffusion']['unet.decoders.9.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight']
567
- converted['diffusion']['unet.decoders.9.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight']
568
- converted['diffusion']['unet.decoders.9.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight']
569
- converted['diffusion']['unet.decoders.9.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight']
570
- converted['diffusion']['unet.decoders.9.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias']
571
  converted['diffusion']['unet.decoders.9.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight']
572
  converted['diffusion']['unet.decoders.9.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias']
573
  converted['diffusion']['unet.decoders.9.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight']
@@ -598,11 +597,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
598
  converted['diffusion']['unet.decoders.10.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias']
599
  converted['diffusion']['unet.decoders.10.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight']
600
  converted['diffusion']['unet.decoders.10.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias']
601
- converted['diffusion']['unet.decoders.10.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight']
602
- converted['diffusion']['unet.decoders.10.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight']
603
- converted['diffusion']['unet.decoders.10.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight']
604
- converted['diffusion']['unet.decoders.10.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight']
605
- converted['diffusion']['unet.decoders.10.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias']
606
  converted['diffusion']['unet.decoders.10.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight']
607
  converted['diffusion']['unet.decoders.10.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias']
608
  converted['diffusion']['unet.decoders.10.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight']
@@ -633,11 +632,11 @@ def load_from_standard_weights(input_file: str, device: str) -> dict[str, torch.
633
  converted['diffusion']['unet.decoders.11.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias']
634
  converted['diffusion']['unet.decoders.11.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight']
635
  converted['diffusion']['unet.decoders.11.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias']
636
- converted['diffusion']['unet.decoders.11.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight']
637
- converted['diffusion']['unet.decoders.11.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight']
638
- converted['diffusion']['unet.decoders.11.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight']
639
- converted['diffusion']['unet.decoders.11.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight']
640
- converted['diffusion']['unet.decoders.11.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias']
641
  converted['diffusion']['unet.decoders.11.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight']
642
  converted['diffusion']['unet.decoders.11.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias']
643
  converted['diffusion']['unet.decoders.11.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight']
 
5
  # original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
6
  original_model=torch.load(input_file, weights_only = False)["state_dict"]
7
 
 
8
  converted = {}
9
  converted['diffusion'] = {}
10
  converted['encoder'] = {}
 
37
  converted['diffusion']['unet.encoders.1.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias']
38
  converted['diffusion']['unet.encoders.1.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight']
39
  converted['diffusion']['unet.encoders.1.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias']
40
+ # converted['diffusion']['unet.encoders.1.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight']
41
+ # converted['diffusion']['unet.encoders.1.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight']
42
+ # converted['diffusion']['unet.encoders.1.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight']
43
+ # converted['diffusion']['unet.encoders.1.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight']
44
+ # converted['diffusion']['unet.encoders.1.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias']
45
  converted['diffusion']['unet.encoders.1.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight']
46
  converted['diffusion']['unet.encoders.1.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias']
47
  converted['diffusion']['unet.encoders.1.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight']
 
70
  converted['diffusion']['unet.encoders.2.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias']
71
  converted['diffusion']['unet.encoders.2.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight']
72
  converted['diffusion']['unet.encoders.2.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias']
73
+ # converted['diffusion']['unet.encoders.2.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight']
74
+ # converted['diffusion']['unet.encoders.2.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight']
75
+ # converted['diffusion']['unet.encoders.2.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight']
76
+ # converted['diffusion']['unet.encoders.2.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight']
77
+ # converted['diffusion']['unet.encoders.2.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias']
78
  converted['diffusion']['unet.encoders.2.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight']
79
  converted['diffusion']['unet.encoders.2.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias']
80
  converted['diffusion']['unet.encoders.2.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight']
 
107
  converted['diffusion']['unet.encoders.4.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias']
108
  converted['diffusion']['unet.encoders.4.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight']
109
  converted['diffusion']['unet.encoders.4.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias']
110
+ # converted['diffusion']['unet.encoders.4.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight']
111
+ # converted['diffusion']['unet.encoders.4.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight']
112
+ # converted['diffusion']['unet.encoders.4.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight']
113
+ # converted['diffusion']['unet.encoders.4.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight']
114
+ # converted['diffusion']['unet.encoders.4.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias']
115
  converted['diffusion']['unet.encoders.4.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight']
116
  converted['diffusion']['unet.encoders.4.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias']
117
  converted['diffusion']['unet.encoders.4.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight']
 
140
  converted['diffusion']['unet.encoders.5.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias']
141
  converted['diffusion']['unet.encoders.5.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight']
142
  converted['diffusion']['unet.encoders.5.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias']
143
+ # converted['diffusion']['unet.encoders.5.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight']
144
+ # converted['diffusion']['unet.encoders.5.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight']
145
+ # converted['diffusion']['unet.encoders.5.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight']
146
+ # converted['diffusion']['unet.encoders.5.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight']
147
+ # converted['diffusion']['unet.encoders.5.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias']
148
  converted['diffusion']['unet.encoders.5.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight']
149
  converted['diffusion']['unet.encoders.5.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias']
150
  converted['diffusion']['unet.encoders.5.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight']
 
177
  converted['diffusion']['unet.encoders.7.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias']
178
  converted['diffusion']['unet.encoders.7.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight']
179
  converted['diffusion']['unet.encoders.7.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias']
180
+ # converted['diffusion']['unet.encoders.7.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight']
181
+ # converted['diffusion']['unet.encoders.7.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight']
182
+ # converted['diffusion']['unet.encoders.7.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight']
183
+ # converted['diffusion']['unet.encoders.7.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight']
184
+ # converted['diffusion']['unet.encoders.7.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias']
185
  converted['diffusion']['unet.encoders.7.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight']
186
  converted['diffusion']['unet.encoders.7.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias']
187
  converted['diffusion']['unet.encoders.7.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight']
 
210
  converted['diffusion']['unet.encoders.8.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias']
211
  converted['diffusion']['unet.encoders.8.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight']
212
  converted['diffusion']['unet.encoders.8.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias']
213
+ # converted['diffusion']['unet.encoders.8.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight']
214
+ # converted['diffusion']['unet.encoders.8.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight']
215
+ # converted['diffusion']['unet.encoders.8.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight']
216
+ # converted['diffusion']['unet.encoders.8.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight']
217
+ # converted['diffusion']['unet.encoders.8.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias']
218
  converted['diffusion']['unet.encoders.8.1.layernorm_1.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight']
219
  converted['diffusion']['unet.encoders.8.1.layernorm_1.bias'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias']
220
  converted['diffusion']['unet.encoders.8.1.layernorm_2.weight'] = original_model['model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight']
 
265
  converted['diffusion']['unet.bottleneck.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias']
266
  converted['diffusion']['unet.bottleneck.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight']
267
  converted['diffusion']['unet.bottleneck.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias']
268
+ # converted['diffusion']['unet.bottleneck.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight']
269
+ # converted['diffusion']['unet.bottleneck.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight']
270
+ # converted['diffusion']['unet.bottleneck.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight']
271
+ # converted['diffusion']['unet.bottleneck.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight']
272
+ # converted['diffusion']['unet.bottleneck.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias']
273
  converted['diffusion']['unet.bottleneck.1.layernorm_1.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight']
274
  converted['diffusion']['unet.bottleneck.1.layernorm_1.bias'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias']
275
  converted['diffusion']['unet.bottleneck.1.layernorm_2.weight'] = original_model['model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight']
 
348
  converted['diffusion']['unet.decoders.3.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias']
349
  converted['diffusion']['unet.decoders.3.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight']
350
  converted['diffusion']['unet.decoders.3.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias']
351
+ # converted['diffusion']['unet.decoders.3.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight']
352
+ # converted['diffusion']['unet.decoders.3.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight']
353
+ # converted['diffusion']['unet.decoders.3.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight']
354
+ # converted['diffusion']['unet.decoders.3.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight']
355
+ # converted['diffusion']['unet.decoders.3.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias']
356
  converted['diffusion']['unet.decoders.3.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight']
357
  converted['diffusion']['unet.decoders.3.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias']
358
  converted['diffusion']['unet.decoders.3.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight']
 
383
  converted['diffusion']['unet.decoders.4.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias']
384
  converted['diffusion']['unet.decoders.4.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight']
385
  converted['diffusion']['unet.decoders.4.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias']
386
+ # converted['diffusion']['unet.decoders.4.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight']
387
+ # converted['diffusion']['unet.decoders.4.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight']
388
+ # converted['diffusion']['unet.decoders.4.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight']
389
+ # converted['diffusion']['unet.decoders.4.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight']
390
+ # converted['diffusion']['unet.decoders.4.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias']
391
  converted['diffusion']['unet.decoders.4.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight']
392
  converted['diffusion']['unet.decoders.4.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias']
393
  converted['diffusion']['unet.decoders.4.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight']
 
418
  converted['diffusion']['unet.decoders.5.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias']
419
  converted['diffusion']['unet.decoders.5.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight']
420
  converted['diffusion']['unet.decoders.5.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias']
421
+ # converted['diffusion']['unet.decoders.5.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight']
422
+ # converted['diffusion']['unet.decoders.5.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight']
423
+ # converted['diffusion']['unet.decoders.5.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight']
424
+ # converted['diffusion']['unet.decoders.5.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight']
425
+ # converted['diffusion']['unet.decoders.5.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias']
426
  converted['diffusion']['unet.decoders.5.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight']
427
  converted['diffusion']['unet.decoders.5.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias']
428
  converted['diffusion']['unet.decoders.5.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight']
 
455
  converted['diffusion']['unet.decoders.6.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias']
456
  converted['diffusion']['unet.decoders.6.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight']
457
  converted['diffusion']['unet.decoders.6.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias']
458
+ # converted['diffusion']['unet.decoders.6.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight']
459
+ # converted['diffusion']['unet.decoders.6.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight']
460
+ # converted['diffusion']['unet.decoders.6.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight']
461
+ # converted['diffusion']['unet.decoders.6.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight']
462
+ # converted['diffusion']['unet.decoders.6.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias']
463
  converted['diffusion']['unet.decoders.6.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight']
464
  converted['diffusion']['unet.decoders.6.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias']
465
  converted['diffusion']['unet.decoders.6.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight']
 
490
  converted['diffusion']['unet.decoders.7.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias']
491
  converted['diffusion']['unet.decoders.7.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight']
492
  converted['diffusion']['unet.decoders.7.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias']
493
+ # converted['diffusion']['unet.decoders.7.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight']
494
+ # converted['diffusion']['unet.decoders.7.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight']
495
+ # converted['diffusion']['unet.decoders.7.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight']
496
+ # converted['diffusion']['unet.decoders.7.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight']
497
+ # converted['diffusion']['unet.decoders.7.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias']
498
  converted['diffusion']['unet.decoders.7.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight']
499
  converted['diffusion']['unet.decoders.7.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias']
500
  converted['diffusion']['unet.decoders.7.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight']
 
525
  converted['diffusion']['unet.decoders.8.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias']
526
  converted['diffusion']['unet.decoders.8.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight']
527
  converted['diffusion']['unet.decoders.8.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias']
528
+ # converted['diffusion']['unet.decoders.8.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight']
529
+ # converted['diffusion']['unet.decoders.8.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight']
530
+ # converted['diffusion']['unet.decoders.8.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight']
531
+ # converted['diffusion']['unet.decoders.8.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight']
532
+ # converted['diffusion']['unet.decoders.8.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias']
533
  converted['diffusion']['unet.decoders.8.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight']
534
  converted['diffusion']['unet.decoders.8.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias']
535
  converted['diffusion']['unet.decoders.8.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight']
 
562
  converted['diffusion']['unet.decoders.9.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias']
563
  converted['diffusion']['unet.decoders.9.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight']
564
  converted['diffusion']['unet.decoders.9.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias']
565
+ # converted['diffusion']['unet.decoders.9.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight']
566
+ # converted['diffusion']['unet.decoders.9.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight']
567
+ # converted['diffusion']['unet.decoders.9.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight']
568
+ # converted['diffusion']['unet.decoders.9.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight']
569
+ # converted['diffusion']['unet.decoders.9.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias']
570
  converted['diffusion']['unet.decoders.9.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight']
571
  converted['diffusion']['unet.decoders.9.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias']
572
  converted['diffusion']['unet.decoders.9.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight']
 
597
  converted['diffusion']['unet.decoders.10.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias']
598
  converted['diffusion']['unet.decoders.10.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight']
599
  converted['diffusion']['unet.decoders.10.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias']
600
+ # converted['diffusion']['unet.decoders.10.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight']
601
+ # converted['diffusion']['unet.decoders.10.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight']
602
+ # converted['diffusion']['unet.decoders.10.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight']
603
+ # converted['diffusion']['unet.decoders.10.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight']
604
+ # converted['diffusion']['unet.decoders.10.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias']
605
  converted['diffusion']['unet.decoders.10.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight']
606
  converted['diffusion']['unet.decoders.10.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias']
607
  converted['diffusion']['unet.decoders.10.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight']
 
632
  converted['diffusion']['unet.decoders.11.1.linear_geglu_1.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias']
633
  converted['diffusion']['unet.decoders.11.1.linear_geglu_2.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight']
634
  converted['diffusion']['unet.decoders.11.1.linear_geglu_2.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias']
635
+ # converted['diffusion']['unet.decoders.11.1.attention_2.q_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight']
636
+ # converted['diffusion']['unet.decoders.11.1.attention_2.k_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight']
637
+ # converted['diffusion']['unet.decoders.11.1.attention_2.v_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight']
638
+ # converted['diffusion']['unet.decoders.11.1.attention_2.out_proj.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight']
639
+ # converted['diffusion']['unet.decoders.11.1.attention_2.out_proj.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias']
640
  converted['diffusion']['unet.decoders.11.1.layernorm_1.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight']
641
  converted['diffusion']['unet.decoders.11.1.layernorm_1.bias'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias']
642
  converted['diffusion']['unet.decoders.11.1.layernorm_2.weight'] = original_model['model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight']
pipeline.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from tqdm import tqdm
7
  from ddpm import DDPMSampler
8
  from PIL import Image
 
9
 
10
  WIDTH = 512
11
  HEIGHT = 512
@@ -225,15 +226,21 @@ def generate(
225
  else:
226
  generator.manual_seed(seed)
227
 
228
- concat_dim = -2 # FIXME: y axis concat
 
229
  # Prepare inputs to Tensor
230
  image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)
 
231
  image = prepare_image(image).to(device)
232
  condition_image = prepare_image(condition_image).to(device)
233
  mask = prepare_mask_image(mask).to(device)
 
 
234
  # Mask image
235
  masked_image = image * (mask < 0.5)
236
 
 
 
237
  # VAE encoding
238
  encoder = models.get('encoder', None)
239
  if encoder is None:
@@ -244,12 +251,18 @@ def generate(
244
  condition_latent = compute_vae_encodings(condition_image, encoder, device)
245
  to_idle(encoder)
246
 
 
 
247
  # Concatenate latents
248
  masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)
249
 
 
 
250
  mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
251
  del image, mask, condition_image
252
  mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)
 
 
253
 
254
  # Initialize latents
255
  latents = torch.randn(
@@ -258,6 +271,8 @@ def generate(
258
  device=masked_latent_concat.device,
259
  dtype=masked_latent_concat.dtype
260
  )
 
 
261
 
262
  # Prepare timesteps
263
  if sampler_name == "ddpm":
@@ -267,7 +282,7 @@ def generate(
267
  raise ValueError("Unknown sampler value %s. " % sampler_name)
268
 
269
  timesteps = sampler.timesteps
270
- latents = sampler.add_noise(latents, timesteps[0])
271
 
272
  # Classifier-Free Guidance
273
  do_classifier_free_guidance = guidance_scale > 1.0
@@ -280,6 +295,9 @@ def generate(
280
  )
281
  mask_latent_concat = torch.cat([mask_latent_concat] * 2)
282
 
 
 
 
283
  # Denoising loop - Fixed: removed self references and incorrect scheduler calls
284
  num_warmup_steps = 0 # For simple DDPM, no warmup needed
285
 
@@ -287,9 +305,13 @@ def generate(
287
  for i, t in enumerate(timesteps):
288
  # expand the latents if we are doing classifier free guidance
289
  non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
 
 
290
 
291
  # prepare the input for the inpainting model
292
  inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1)
 
 
293
 
294
  # predict the noise residual
295
  diffusion = models.get('diffusion', None)
@@ -299,7 +321,9 @@ def generate(
299
  diffusion.to(device)
300
 
301
  # Create time embedding for the current timestep
302
- time_embedding = get_time_embedding(t.item()).unsqueeze(0).to(device)
 
 
303
  if do_classifier_free_guidance:
304
  time_embedding = torch.cat([time_embedding] * 2)
305
 
@@ -334,6 +358,7 @@ def generate(
334
  decoder.to(device)
335
 
336
  image = decoder(latents.to(device))
 
337
  image = (image / 2 + 0.5).clamp(0, 1)
338
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
339
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -364,12 +389,12 @@ def get_time_embedding(timestep):
364
 
365
  if __name__ == "__main__":
366
  # Example usage
367
- image = Image.open("example_image.jpg").convert("RGB")
368
- condition_image = Image.open("example_condition_image.jpg").convert("RGB")
369
- mask = Image.open("example_mask.png").convert("L")
370
 
371
- # Resize images to the desired dimensions
372
- image, condition_image, mask = check_inputs(image, condition_image, mask, WIDTH, HEIGHT)
373
 
374
  # Generate image
375
  generated_image = generate(
@@ -380,6 +405,9 @@ if __name__ == "__main__":
380
  guidance_scale=2.5,
381
  width=WIDTH,
382
  height=HEIGHT,
 
 
 
383
  device="cuda" # or "cpu"
384
  )
385
 
 
6
  from tqdm import tqdm
7
  from ddpm import DDPMSampler
8
  from PIL import Image
9
+ import model
10
 
11
  WIDTH = 512
12
  HEIGHT = 512
 
226
  else:
227
  generator.manual_seed(seed)
228
 
229
+ concat_dim = -1 # FIXME: y axis concat
230
+
231
  # Prepare inputs to Tensor
232
  image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)
233
+ # print(f"Input image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}")
234
  image = prepare_image(image).to(device)
235
  condition_image = prepare_image(condition_image).to(device)
236
  mask = prepare_mask_image(mask).to(device)
237
+
238
+ print(f"Prepared image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}")
239
  # Mask image
240
  masked_image = image * (mask < 0.5)
241
 
242
+ print(f"Masked image shape: {masked_image.shape}")
243
+
244
  # VAE encoding
245
  encoder = models.get('encoder', None)
246
  if encoder is None:
 
251
  condition_latent = compute_vae_encodings(condition_image, encoder, device)
252
  to_idle(encoder)
253
 
254
+ print(f"Masked latent shape: {masked_latent.shape}, condition latent shape: {condition_latent.shape}")
255
+
256
  # Concatenate latents
257
  masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)
258
 
259
+ print(f"Masked Person latent + garment latent: {masked_latent_concat.shape}")
260
+
261
  mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
262
  del image, mask, condition_image
263
  mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)
264
+
265
+ print(f"Mask latent concat shape: {mask_latent_concat.shape}")
266
 
267
  # Initialize latents
268
  latents = torch.randn(
 
271
  device=masked_latent_concat.device,
272
  dtype=masked_latent_concat.dtype
273
  )
274
+
275
+ print(f"Latents shape: {latents.shape}")
276
 
277
  # Prepare timesteps
278
  if sampler_name == "ddpm":
 
282
  raise ValueError("Unknown sampler value %s. " % sampler_name)
283
 
284
  timesteps = sampler.timesteps
285
+ # latents = sampler.add_noise(latents, timesteps[0])
286
 
287
  # Classifier-Free Guidance
288
  do_classifier_free_guidance = guidance_scale > 1.0
 
295
  )
296
  mask_latent_concat = torch.cat([mask_latent_concat] * 2)
297
 
298
+ print(f"Masked latent concat for classifier-free guidance: {masked_latent_concat.shape}, mask latent concat: {mask_latent_concat.shape}")
299
+
300
+
301
  # Denoising loop - Fixed: removed self references and incorrect scheduler calls
302
  num_warmup_steps = 0 # For simple DDPM, no warmup needed
303
 
 
305
  for i, t in enumerate(timesteps):
306
  # expand the latents if we are doing classifier free guidance
307
  non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
308
+
309
+ # print(f"Non-inpainting latent model input shape: {non_inpainting_latent_model_input.shape}")
310
 
311
  # prepare the input for the inpainting model
312
  inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1)
313
+
314
+ # print(f"Inpainting latent model input shape: {inpainting_latent_model_input.shape}")
315
 
316
  # predict the noise residual
317
  diffusion = models.get('diffusion', None)
 
321
  diffusion.to(device)
322
 
323
  # Create time embedding for the current timestep
324
+ time_embedding = get_time_embedding(t.item()).to(device)
325
+ # print(f"Time embedding shape: {time_embedding.shape}")
326
+
327
  if do_classifier_free_guidance:
328
  time_embedding = torch.cat([time_embedding] * 2)
329
 
 
358
  decoder.to(device)
359
 
360
  image = decoder(latents.to(device))
361
+ # image = rescale(image, (-1, 1), (0, 255), clamp=True)
362
  image = (image / 2 + 0.5).clamp(0, 1)
363
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
364
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
 
389
 
390
  if __name__ == "__main__":
391
  # Example usage
392
+ image = Image.open("person.jpg").convert("RGB")
393
+ condition_image = Image.open("image.png").convert("RGB")
394
+ mask = Image.open("agnostic_mask.png").convert("L")
395
 
396
+ # Load models
397
+ models=model.preload_models_from_standard_weights("sd-v1-5-inpainting.ckpt", device="cuda")
398
 
399
  # Generate image
400
  generated_image = generate(
 
405
  guidance_scale=2.5,
406
  width=WIDTH,
407
  height=HEIGHT,
408
+ models=models,
409
+ sampler_name="ddpm",
410
+ seed=42,
411
  device="cuda" # or "cpu"
412
  )
413
 
test.ipynb CHANGED
The diff for this file is too large to render. See raw diff