harsh99 commited on
Commit
b993f12
·
0 Parent(s):

implementation of stable diffusion from scratch

Browse files
Files changed (17) hide show
  1. .gitignore +53 -0
  2. README.md +4 -0
  3. attention.py +117 -0
  4. clip.py +86 -0
  5. ddpm.py +119 -0
  6. decoder.py +134 -0
  7. diffusion.py +297 -0
  8. dog.jpg +0 -0
  9. encoder.py +91 -0
  10. interface.py +151 -0
  11. merges.txt +0 -0
  12. model.py +28 -0
  13. model_converter.py +0 -0
  14. pipeline.py +174 -0
  15. requirements.txt +115 -0
  16. test.ipynb +0 -0
  17. vocab.json +0 -0
.gitignore ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *inkpunk-diffusion-v1.ckpt
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ **/__pycache__/
6
+ *.py[cod]
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ bin/
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # Installer logs
27
+ pip-log.txt
28
+ pip-delete-this-directory.txt
29
+
30
+ # Unit test / coverage reports
31
+ .tox/
32
+ .coverage
33
+ .cache
34
+ nosetests.xml
35
+ coverage.xml
36
+
37
+ # Translations
38
+ *.mo
39
+
40
+ # Mr Developer
41
+ .mr.developer.cfg
42
+ .project
43
+ .pydevproject
44
+
45
+ # Rope
46
+ .ropeproject
47
+
48
+ # Django stuff:
49
+ *.log
50
+ *.pot
51
+
52
+ # Sphinx documentation
53
+ docs/_build/
README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # stable-diffusion
2
+
3
+ <!-- 1. Download `vocab.json` and `merges.txt` from https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main/tokenizer and save them in the `data` folder
4
+ 2. Download `inkpunk-diffusion-v1.ckpt` from https://huggingface.co/Envvi/Inkpunk-Diffusion/tree/main and save it in the `data` folder -->
attention.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ import math
5
+
6
+ class SelfAttention(nn.Module):
7
+ def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
8
+ super().__init__()
9
+ # This combines the Wq, Wk and Wv matrices into one matrix
10
+ self.in_proj=nn.Linear(d_embed, 3*d_embed, bias=in_proj_bias)
11
+
12
+ # This one represents the Wo matrix
13
+ self.out_proj=nn.Linear(d_embed, d_embed, bias=out_proj_bias)
14
+
15
+ self.n_heads=n_heads
16
+ self.d_head=d_embed // self.n_heads
17
+
18
+
19
+ def forward(self, x, causal_mask=False):
20
+ # x: (batch_size, seq_len, dim)
21
+ input_shape = x.shape
22
+ batch_size, sequence_length, d_embed = input_shape
23
+ interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
24
+
25
+ # Apply the in_proj to get the queries, keys, and values all at once
26
+ # (batch_size, seq_len, dim) -> (batch_size, seq_len, 3 * dim)
27
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
28
+
29
+ # Reshape to (batch_size, seq_len, n_heads, d_head)
30
+ q = q.view(interim_shape)
31
+ k = k.view(interim_shape)
32
+ v = v.view(interim_shape)
33
+
34
+ # Transpose for attention dot product: (batch_size, n_heads, seq_len, d_head)
35
+ q = q.transpose(1, 2)
36
+ k = k.transpose(1, 2)
37
+ v = v.transpose(1, 2)
38
+
39
+ # (batch_size, n_heads, seq_len, d_head) @ (batch_size, n_heads, d_head, seq_len) -> (batch_size, n_heads, seq_len, seq_len)
40
+ attention_weights = q @ k.transpose(-1, -2)
41
+
42
+ # Scaling by sqrt(d_head)
43
+ attention_weights = attention_weights / math.sqrt(self.d_head)
44
+
45
+ # Causal mask to prevent attending to future tokens
46
+ if causal_mask:
47
+ mask = torch.ones_like(attention_weights, dtype=torch.bool).triu(1)
48
+ attention_weights.masked_fill_(mask, -torch.inf)
49
+
50
+ # Apply softmax to get attention probabilities
51
+ attention_weights = F.softmax(attention_weights, dim=-1)
52
+
53
+ # Apply attention weights: (batch_size, n_heads, seq_len, seq_len) @ (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
54
+ output = attention_weights @ v
55
+
56
+ # Transpose back: (batch_size, seq_len, n_heads, d_head)
57
+ output = output.transpose(1, 2)
58
+
59
+ # Reshape back to (batch_size, seq_len, dim)
60
+ output = output.reshape(input_shape)
61
+
62
+ # Apply output projection
63
+ output = self.out_proj(output)
64
+
65
+ return output
66
+ class CrossAttention(nn.Module):
67
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
68
+ super().__init__()
69
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
70
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
71
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
72
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
73
+ self.n_heads = n_heads
74
+ self.d_head = d_embed // n_heads
75
+
76
+ def forward(self, x, y):
77
+ # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
78
+ # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
79
+
80
+ input_shape = x.shape
81
+ batch_size, sequence_length, d_embed = input_shape
82
+ # Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
83
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
84
+
85
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
86
+ q = self.q_proj(x)
87
+ # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
88
+ k = self.k_proj(y)
89
+ v = self.v_proj(y)
90
+
91
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
92
+ q = q.view(interim_shape).transpose(1, 2)
93
+ # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
94
+ k = k.view(interim_shape).transpose(1, 2)
95
+ v = v.view(interim_shape).transpose(1, 2)
96
+
97
+ # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
98
+ weight = q @ k.transpose(-1, -2)
99
+
100
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
101
+ weight /= math.sqrt(self.d_head)
102
+ weight = F.softmax(weight, dim=-1)
103
+
104
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
105
+ output = weight @ v
106
+
107
+ # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
108
+ output = output.transpose(1, 2).contiguous()
109
+
110
+ # (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
111
+ output = output.view(input_shape)
112
+
113
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
114
+ output = self.out_proj(output)
115
+
116
+ # (Batch_Size, Seq_Len_Q, Dim_Q)
117
+ return output
clip.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from attention import SelfAttention
6
+
7
+ class CLIPEmbedding(nn.Module):
8
+ def __init__(self, n_vocab, n_embed, n_token):
9
+ super().__init__()
10
+ self.token_embedding=nn.Embedding(n_vocab, n_embed)
11
+ self.position_embedding=nn.Parameter(torch.zeros((n_token, n_embed)))
12
+
13
+ def forward(self, tokens: torch.Tensor):
14
+ x=self.token_embedding(tokens)
15
+ x+=self.position_embedding
16
+
17
+ return x
18
+
19
+ class CLIPLayer(nn.Module):
20
+ def __init__(self, n_head, n_embed):
21
+ super().__init__()
22
+ self.layernorm_1=nn.LayerNorm(n_embed)
23
+ self.attention=SelfAttention(n_head, n_embed)
24
+ self.layernorm_2=nn.LayerNorm(n_embed)
25
+
26
+ self.linear_1=nn.Linear(n_embed, 4*n_embed)
27
+ self.linear_2=nn.Linear(4*n_embed, n_embed)
28
+
29
+ def forward(self, x):
30
+ residue=x
31
+
32
+ x=self.layernorm_1(x)
33
+
34
+ x=self.attention(x, causal_mask=True)
35
+
36
+ x+=residue
37
+
38
+ residue=x
39
+
40
+ x=self.layernorm_2(x)
41
+
42
+ x=self.linear_1(x)
43
+
44
+ x=x*torch.sigmoid(1.702*x)
45
+
46
+ x=self.linear_2(x)
47
+ x+=residue
48
+
49
+ return x
50
+
51
+ class CLIP(nn.Module):
52
+ def __init__(self):
53
+ super().__init__()
54
+ self.embedding=CLIPEmbedding(49408, 768, 77)
55
+ self.layers=nn.ModuleList([
56
+ CLIPLayer(12, 768) for i in range(12)
57
+ ])
58
+
59
+ self.layernorm=nn.LayerNorm(768)
60
+
61
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
62
+ tokens=tokens.type(torch.long)
63
+
64
+ state=self.embedding(tokens)
65
+
66
+ for layer in self.layers:
67
+ state=layer(state)
68
+
69
+ output=self.layernorm(state)
70
+
71
+ return output
72
+
73
+ if __name__ == "__main__":
74
+ dummy_tokens = torch.randint(0, 49408, (1, 77)) # (Batch_Size, Seq_Len)
75
+
76
+ # Instantiate the model
77
+ model = CLIP()
78
+
79
+ # Forward pass
80
+ with torch.no_grad(): # no need to track gradients for testing
81
+ output = model(dummy_tokens)
82
+
83
+ # Print the output shape
84
+ # Output shape: torch.Size([1, 77, 768])
85
+ print("Output shape:", output.shape)
86
+
ddpm.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class DDPMSampler:
5
+
6
+ def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
7
+ # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
8
+ # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
9
+ self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
10
+ self.alphas = 1.0 - self.betas
11
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
12
+ self.one = torch.tensor(1.0)
13
+
14
+ self.generator = generator
15
+
16
+ self.num_train_timesteps = num_training_steps
17
+ self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
18
+
19
+ def set_inference_timesteps(self, num_inference_steps=50):
20
+ self.num_inference_steps = num_inference_steps
21
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
22
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
23
+ self.timesteps = torch.from_numpy(timesteps)
24
+
25
+ def _get_previous_timestep(self, timestep: int) -> int:
26
+ prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
27
+ return prev_t
28
+
29
+ def _get_variance(self, timestep: int) -> torch.Tensor:
30
+ prev_t = self._get_previous_timestep(timestep)
31
+
32
+ alpha_prod_t = self.alphas_cumprod[timestep]
33
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
34
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
35
+
36
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
37
+ # and sample from it to get previous sample
38
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
39
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
40
+
41
+ # we always take the log of variance, so clamp it to ensure it's not 0
42
+ variance = torch.clamp(variance, min=1e-20)
43
+
44
+ return variance
45
+
46
+ def set_strength(self, strength=1):
47
+ """
48
+ Set how much noise to add to the input image.
49
+ More noise (strength ~ 1) means that the output will be further from the input image.
50
+ Less noise (strength ~ 0) means that the output will be closer to the input image.
51
+ """
52
+ # start_step is the number of noise levels to skip
53
+ start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
54
+ self.timesteps = self.timesteps[start_step:]
55
+ self.start_step = start_step
56
+
57
+ def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
58
+ t = timestep
59
+ prev_t = self._get_previous_timestep(t)
60
+
61
+ # 1. compute alphas, betas
62
+ alpha_prod_t = self.alphas_cumprod[t]
63
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
64
+ beta_prod_t = 1 - alpha_prod_t
65
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
66
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev
67
+ current_beta_t = 1 - current_alpha_t
68
+
69
+ # 2. compute predicted original sample from predicted noise also called
70
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
71
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
72
+
73
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
74
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
75
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
76
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
77
+
78
+ # 5. Compute predicted previous sample µ_t
79
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
80
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
81
+
82
+ # 6. Add noise
83
+ variance = 0
84
+ if t > 0:
85
+ device = model_output.device
86
+ noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
87
+ # Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
88
+ variance = (self._get_variance(t) ** 0.5) * noise
89
+
90
+ # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
91
+ # the variable "variance" is already multiplied by the noise N(0, 1)
92
+ pred_prev_sample = pred_prev_sample + variance
93
+
94
+ return pred_prev_sample
95
+
96
+ def add_noise(
97
+ self,
98
+ original_samples: torch.FloatTensor,
99
+ timesteps: torch.IntTensor,
100
+ ) -> torch.FloatTensor:
101
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
102
+ timesteps = timesteps.to(original_samples.device)
103
+
104
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
105
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
106
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
107
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
108
+
109
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
110
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
111
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
112
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
113
+
114
+ # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
115
+ # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
116
+ # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
117
+ noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
118
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
119
+ return noisy_samples
decoder.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from attention import SelfAttention
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class VAE_ResidualBlock(nn.Module):
8
+ def __init__(self, in_channels, out_channels):
9
+ super().__init__()
10
+ self.grpnorm_1=nn.GroupNorm(32, in_channels)
11
+ self.conv_1=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
12
+
13
+ self.grpnorm_2=nn.GroupNorm(32, out_channels)
14
+ self.conv_2=nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
15
+
16
+ if in_channels == out_channels:
17
+ self.residual_layer=nn.Identity()
18
+ else:
19
+ self.residual_layer=nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
20
+ def forward(self, x):
21
+ residue=x
22
+
23
+ x=self.grpnorm_1(x)
24
+ x=F.silu(x)
25
+
26
+ x=self.conv_1(x)
27
+
28
+ x=self.grpnorm_2(x)
29
+ x=F.silu(x)
30
+
31
+ x=self.conv_2(x)
32
+
33
+ return x+self.residual_layer(residue)
34
+
35
+ class VAE_AttentionBlock(nn.Module):
36
+ def __init__(self, channels):
37
+ super().__init__()
38
+ self.grpnorm=nn.GroupNorm(32, channels)
39
+ self.attention=SelfAttention(1, channels)
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ # x: (Batch_Size, Features, Height, Width)
43
+ residue=x
44
+
45
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
46
+ x=self.grpnorm(x)
47
+ n, c, h, w=x.shape
48
+
49
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
50
+ x=x.view((n,c,h*w))
51
+
52
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
53
+ x=x.transpose(-1, -2)
54
+
55
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
56
+ x=self.attention(x)
57
+
58
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
59
+ x=x.transpose(-1, -2)
60
+
61
+ # (Batch_Size, Features, Height , Width)
62
+ x=x.view((n, c, h, w))
63
+
64
+ x+=residue
65
+
66
+ return x
67
+
68
+ class VAE_Decoder(nn.Sequential):
69
+ def __init__(self):
70
+ super().__init__(
71
+ nn.Conv2d(4, 4, kernel_size=1, padding=0),
72
+
73
+ nn.Conv2d(4, 512, kernel_size=3, padding=1),
74
+
75
+ VAE_ResidualBlock(512, 512),
76
+
77
+ VAE_AttentionBlock(512),
78
+
79
+ VAE_ResidualBlock(512, 512),
80
+ VAE_ResidualBlock(512, 512),
81
+ VAE_ResidualBlock(512, 512),
82
+ VAE_ResidualBlock(512, 512),
83
+
84
+ nn.Upsample(scale_factor=2),
85
+
86
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
87
+
88
+ VAE_ResidualBlock(512, 512),
89
+ VAE_ResidualBlock(512, 512),
90
+ VAE_ResidualBlock(512, 512),
91
+
92
+ nn.Upsample(scale_factor=2),
93
+
94
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
95
+
96
+ VAE_ResidualBlock(512, 256),
97
+ VAE_ResidualBlock(256, 256),
98
+ VAE_ResidualBlock(256, 256),
99
+
100
+ nn.Upsample(scale_factor=2),
101
+
102
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
103
+
104
+ VAE_ResidualBlock(256, 128),
105
+ VAE_ResidualBlock(128, 128),
106
+ VAE_ResidualBlock(128, 128),
107
+
108
+ nn.GroupNorm(32, 128),
109
+
110
+ nn.SiLU(),
111
+
112
+ nn.Conv2d(128, 3, kernel_size=3, padding=1),
113
+ )
114
+
115
+ def forward(self, x):
116
+ x/=0.18215
117
+
118
+ for module in self:
119
+ x=module(x)
120
+
121
+ return x
122
+
123
+ if __name__ == "__main__":
124
+ model = VAE_Decoder()
125
+ model.eval()
126
+
127
+ # Create a dummy input tensor: (batch_size=1, channels=4, height=16, width=16)
128
+ x = torch.randn(1, 4, 8, 8)
129
+
130
+ with torch.no_grad():
131
+ output = model(x)
132
+
133
+ print("Input shape :", x.shape)
134
+ print("Output shape:", output.shape)
diffusion.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from attention import SelfAttention, CrossAttention
5
+
6
+ class TimeEmbedding(nn.Module):
7
+ def __init__(self, n_embed):
8
+ super().__init__()
9
+ self.linear_1=nn.Linear(n_embed, 4*n_embed)
10
+ self.linear_2=nn.Linear(4*n_embed, 4*n_embed)
11
+
12
+ def forward(self, x):
13
+ x=self.linear_1(x)
14
+ x=F.silu(x)
15
+ x=self.linear_2(x)
16
+ return x
17
+
18
+ class UNET_ResidualBlock(nn.Module):
19
+ def __init__(self, in_channels, out_channels, n_time=1280):
20
+ super().__init__()
21
+ self.grpnorm_feature=nn.GroupNorm(32, in_channels)
22
+ self.conv_feature=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
23
+ self.linear_time=nn.Linear(n_time, out_channels)
24
+
25
+ self.grpnorm_merged=nn.GroupNorm(32, out_channels)
26
+ self.conv_merged=nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
27
+
28
+ if in_channels==out_channels:
29
+ self.residual_layer=nn.Identity()
30
+ else:
31
+ self.residual_layer=nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
32
+
33
+ def forward(self, feature, time):
34
+ residue=feature
35
+
36
+ feature=self.grpnorm_feature(feature)
37
+ feature=F.silu(feature)
38
+
39
+ feature=self.conv_feature(feature)
40
+
41
+ time=F.silu(time)
42
+ time=self.linear_time(time)
43
+
44
+ merged=feature+time.unsqueeze(-1).unsqueeze(-1)
45
+
46
+ merged=self.grpnorm_merged(merged)
47
+ merged=F.silu(merged)
48
+ merged=self.conv_merged(merged)
49
+
50
+ return merged + self.residual_layer(residue)
51
+
52
+ class UNET_AttentionBlock(nn.Module):
53
+ def __init__(self, n_head, n_embed, d_context=768):
54
+ super().__init__()
55
+
56
+ channels=n_head*n_embed
57
+
58
+ self.grpnorm=nn.GroupNorm(32, channels, eps=1e-6)
59
+ self.conv_input=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
60
+
61
+ self.layernorm_1=nn.LayerNorm(channels)
62
+ self.attention_1=SelfAttention(n_head, channels, in_proj_bias=False)
63
+
64
+ self.layernorm_2=nn.LayerNorm(channels)
65
+ self.attention_2=CrossAttention(n_head, channels, d_context, in_proj_bias=False)
66
+
67
+ self.layernorm_3=nn.LayerNorm(channels)
68
+
69
+ self.linear_geglu_1=nn.Linear(channels, 4*channels*2)
70
+ self.linear_geglu_2=nn.Linear(4*channels, channels)
71
+
72
+ self.conv_output=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
73
+
74
+ def forward(self, x, context):
75
+ residue_long=x
76
+
77
+ x=self.grpnorm(x)
78
+ x=self.conv_input(x)
79
+
80
+ n, c, h, w=x.shape
81
+
82
+ x=x.view((n,c,h*w))
83
+
84
+ x=x.transpose(-1, -2)
85
+ residue_short=x
86
+
87
+ x=self.layernorm_1(x)
88
+ x=self.attention_1(x)
89
+
90
+ x+=residue_short
91
+
92
+ residue_short=x
93
+
94
+ x=self.layernorm_2(x)
95
+ x=self.attention_2(x, context)
96
+
97
+ x+=residue_short
98
+
99
+ residue_short=x
100
+
101
+ x=self.layernorm_3(x)
102
+ x, gate=self.linear_geglu_1(x).chunk(2, dim=-1)
103
+
104
+ x=x*F.gelu(gate)
105
+
106
+ x=self.linear_geglu_2(x)
107
+
108
+ x+=residue_short
109
+ x=x.transpose(-1, -2)
110
+
111
+ x=x.view((n, c, h, w))
112
+
113
+ return self.conv_output(x)+residue_long
114
+
115
+ class Upsample(nn.Module):
116
+ def __init__(self, channels):
117
+ super().__init__()
118
+ self.conv=nn.Conv2d(channels, channels, kernel_size=3, padding=1)
119
+
120
+ def forward(self, x):
121
+ x=F.interpolate(x, scale_factor=2, mode='nearest')
122
+ return self.conv(x)
123
+
124
+ # passing arguments to the parent class nn.Sequential, not to your SwitchSequential class directly — because you did not override the __init__ method in SwitchSequential
125
+ class SwitchSequential(nn.Sequential):
126
+ def forward(self, x, context, time):
127
+ for layer in self:
128
+ if isinstance(layer, UNET_AttentionBlock):
129
+ x=layer(x, context)
130
+ elif isinstance(layer, UNET_ResidualBlock):
131
+ x=layer(x, time)
132
+ else:
133
+ x=layer(x)
134
+ return x
135
+
136
+ class UNET(nn.Module):
137
+ def __init__(self):
138
+ super().__init__()
139
+ self.encoders=nn.ModuleList([
140
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
141
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
142
+
143
+ # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
144
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
145
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
146
+
147
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
148
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
149
+
150
+ # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
151
+ SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
152
+
153
+ # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
154
+ SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
155
+
156
+ # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
157
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
158
+
159
+ # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
160
+ SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
161
+
162
+ # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
163
+ SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
164
+
165
+ # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
166
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
167
+
168
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
169
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
170
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
171
+
172
+ ])
173
+
174
+ self.bottleneck = SwitchSequential(
175
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
176
+ UNET_ResidualBlock(1280, 1280),
177
+ UNET_AttentionBlock(8, 160),
178
+ UNET_ResidualBlock(1280, 1280),
179
+ )
180
+
181
+ self.decoders = nn.ModuleList([
182
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
183
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
184
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
185
+
186
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
187
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
188
+
189
+ # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
190
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
191
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
192
+
193
+ # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
194
+ SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
195
+
196
+ # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
197
+ SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
198
+
199
+ # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
200
+ SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
201
+
202
+ # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
203
+ SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
204
+
205
+ # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
206
+ SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
207
+
208
+ # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
209
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
210
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
211
+ ])
212
+
213
+ def forward(self, x, context, time):
214
+ # x: (Batch_Size, 4, Height / 8, Width / 8)
215
+ # context: (Batch_Size, Seq_Len, Dim)
216
+ # time: (1, 1280)
217
+
218
+ skip_connections = []
219
+ for layers in self.encoders:
220
+ x = layers(x, context, time)
221
+ skip_connections.append(x)
222
+
223
+ x = self.bottleneck(x, context, time)
224
+
225
+ for layers in self.decoders:
226
+ # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
227
+ x = torch.cat((x, skip_connections.pop()), dim=1)
228
+ x = layers(x, context, time)
229
+
230
+ return x
231
+
232
+
233
+ class UNET_OutputLayer(nn.Module):
234
+ def __init__(self, in_channels, out_channels):
235
+ super().__init__()
236
+ self.grpnorm = nn.GroupNorm(32, in_channels)
237
+
238
+ self.conv=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
239
+
240
+ def forward(self, x):
241
+ x=self.grpnorm(x)
242
+ x=F.silu(x)
243
+
244
+ x=self.conv(x)
245
+ return x
246
+
247
+ class Diffusion(nn.Module):
248
+ def __init__(self):
249
+ super().__init__()
250
+ self.time_embedding=TimeEmbedding(320)
251
+ self.unet=UNET()
252
+ self.final=UNET_OutputLayer(320, 4)
253
+
254
+ def forward(self, latent, context, time):
255
+ time=self.time_embedding(time)
256
+
257
+ output=self.unet(latent, context, time)
258
+
259
+ output=self.final(output)
260
+
261
+ return output
262
+
263
+ if __name__ == "__main__":
264
+ # Dummy inputs
265
+ batch_size = 10
266
+ height = 64
267
+ width = 64
268
+ in_channels = 4
269
+ context_dim = 768
270
+ seq_len = 77
271
+
272
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
273
+
274
+ # Create model and move to device
275
+ model = Diffusion().to(device)
276
+
277
+ # Random input tensor with 4 channels
278
+ x = torch.randn(batch_size, in_channels, height, width).to(device)
279
+
280
+ print('Input shape to UNET: ', x.shape)
281
+
282
+ # Time embedding (e.g., timestep from a diffusion schedule)
283
+ t = torch.randn(batch_size, 320).to(device)
284
+
285
+ print('Time Embedding shape to UNET: ',t.shape)
286
+
287
+ # Context for cross attention (e.g., text embedding from CLIP or transformer)
288
+ context = torch.randn(batch_size, seq_len, context_dim).to(device)
289
+
290
+ print('context shape to UNET: ', context.shape)
291
+
292
+ # Forward pass
293
+ with torch.no_grad():
294
+ output = model(x, context, t)
295
+ print(output)
296
+
297
+ print("Output shape of UNET:", output.shape)
dog.jpg ADDED
encoder.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
+
6
+ class VAE_Encoder(nn.Sequential):
7
+ def __init__(self):
8
+ super().__init__(
9
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
10
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
11
+
12
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
13
+ VAE_ResidualBlock(128, 128),
14
+ VAE_ResidualBlock(128, 128),
15
+
16
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height/2 , Width/2)
17
+ nn.Conv2d(128, 128, kernel_size=3,stride=2, padding=0),
18
+
19
+ # (Batch_Size, 128, Height/2 , Width/2) -> (Batch_Size, 256, Height/2 , Width/2)
20
+ VAE_ResidualBlock(128, 256),
21
+ # (Batch_Size, 256, Height/2 , Width/2) -> (Batch_Size, 256, Height/2 , Width/2)
22
+ VAE_ResidualBlock(256, 256),
23
+
24
+ # (Batch_Size, 256, Height/2 , Width/2) -> (Batch_Size, 256, Height/4 , Width/4)
25
+ nn.Conv2d(256, 256, kernel_size=3,stride=2, padding=0),
26
+
27
+ # (Batch_Size, 256, Height/4 , Width/4) -> (Batch_Size, 512, Height/4 , Width/4)
28
+ VAE_ResidualBlock(256, 512),
29
+ # (Batch_Size, 512, Height/4 , Width/4) -> (Batch_Size, 512, Height/4 , Width/4)
30
+ VAE_ResidualBlock(512, 512),
31
+
32
+ # (Batch_Size, 512, Height/4 , Width/4) -> (Batch_Size, 512, Height/8 , Width/8)
33
+ nn.Conv2d(512, 512, kernel_size=3,stride=2, padding=0),
34
+
35
+ # (Batch_Size, 512, Height/8 , Width/8) -> (Batch_Size, 512, Height/8 , Width/8)
36
+ VAE_ResidualBlock(512, 512),
37
+ VAE_ResidualBlock(512, 512),
38
+ VAE_ResidualBlock(512, 512),
39
+
40
+ # (Batch_Size, 512, Height/8 , Width/8) -> (Batch_Size, 512, Height/8 , Width/8)
41
+ VAE_AttentionBlock(512),
42
+
43
+ VAE_ResidualBlock(512, 512),
44
+
45
+ nn.GroupNorm(32, 512),
46
+
47
+ nn.SiLU(),
48
+
49
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
50
+
51
+ # (Batch_Size, 8, Height/8, Width/8) -> (Batch_Size, 8, Height/8, Width/8)
52
+ nn.Conv2d(8, 8, kernel_size=1, padding=0)
53
+ )
54
+
55
+ def forward(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
56
+ for module in self:
57
+ if getattr(module, 'stride', None) == (2, 2):
58
+ x=F.pad(x, (0,1,0,1))
59
+
60
+ x=module(x)
61
+
62
+ # (Batch_Size, 8, Height, Height/8, Width/8) -> two tensors of shape (Batch_Size, 4, Height/8, Width/8)
63
+ mean, log_var=torch.chunk(x, 2, dim=1)
64
+
65
+ log_var=torch.clamp(log_var, -30, 20)
66
+ var=log_var.exp()
67
+ stdev=var.sqrt()
68
+
69
+ Z=mean + stdev * noise
70
+
71
+ Z*=0.18215
72
+
73
+ # print('-'*100)
74
+ # print('Z shape: ', Z.shape)
75
+ # print('-'*100)
76
+
77
+ return Z
78
+
79
+ if __name__ == "__main__":
80
+ model = VAE_Encoder()
81
+ model.eval()
82
+
83
+ # Create a dummy input tensor: (batch_size=1, channels=3, height=64, width=64)
84
+ x = torch.randn(1, 3, 64, 64)
85
+ noise = torch.randn(1, 4, 8, 8) # Match the latent shape (Z)
86
+
87
+ with torch.no_grad():
88
+ output = model(x, noise)
89
+
90
+ print("Input shape :", x.shape)
91
+ print("Output shape:", output.shape)
interface.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import CLIPTokenizer
5
+
6
+ # Import your existing model and pipeline modules
7
+ import model
8
+ import pipeline
9
+
10
+ # Device Configuration
11
+ ALLOW_CUDA = True
12
+ ALLOW_MPS = False
13
+
14
+ def determine_device():
15
+ if torch.cuda.is_available() and ALLOW_CUDA:
16
+ return "cuda"
17
+ elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
18
+ return "mps"
19
+ return "cpu"
20
+
21
+ DEVICE = determine_device()
22
+ print(f"Using device: {DEVICE}")
23
+
24
+ # Load tokenizer and models
25
+ tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
26
+ model_file = "inkpunk-diffusion-v1.ckpt"
27
+ models = model.preload_models_from_standard_weights(model_file, DEVICE)
28
+ # models=None
29
+
30
+ def generate_image(
31
+ prompt,
32
+ uncond_prompt="",
33
+ do_cfg=True,
34
+ cfg_scale=8,
35
+ sampler="ddpm",
36
+ num_inference_steps=50,
37
+ seed=42,
38
+ input_image=None,
39
+ strength=1.0
40
+ ):
41
+ """
42
+ Generate an image using the Stable Diffusion pipeline
43
+
44
+ Args:
45
+ - prompt (str): Text description of the image to generate
46
+ - uncond_prompt (str, optional): Negative prompt to guide generation
47
+ - do_cfg (bool): Whether to use classifier-free guidance
48
+ - cfg_scale (float): Classifier-free guidance scale
49
+ - sampler (str): Sampling method
50
+ - num_inference_steps (int): Number of denoising steps
51
+ - seed (int): Random seed for reproducibility
52
+ - input_image (PIL.Image, optional): Input image for image-to-image generation
53
+ - strength (float): Strength of image transformation (0-1)
54
+
55
+ Returns:
56
+ - PIL.Image: Generated image
57
+ """
58
+ try:
59
+ # Ensure input_image is None if not provided
60
+ if input_image is None:
61
+ strength = 1.0
62
+
63
+ # Generate the image
64
+ output_image = pipeline.generate(
65
+ prompt=prompt,
66
+ uncond_prompt=uncond_prompt,
67
+ input_image=input_image,
68
+ strength=strength,
69
+ do_cfg=do_cfg,
70
+ cfg_scale=cfg_scale,
71
+ sampler_name=sampler,
72
+ n_inference_steps=num_inference_steps,
73
+ seed=seed,
74
+ models=models,
75
+ device=DEVICE,
76
+ idle_device="cuda",
77
+ tokenizer=tokenizer,
78
+ )
79
+
80
+ # Convert numpy array to PIL Image
81
+ return Image.fromarray(output_image)
82
+
83
+ except Exception as e:
84
+ print(f"Error generating image: {e}")
85
+ return None
86
+
87
+ def launch_gradio_interface():
88
+ """
89
+ Create and launch Gradio interface for Stable Diffusion
90
+ """
91
+ with gr.Blocks(title="Stable Diffusion Image Generator") as demo:
92
+ gr.Markdown("# 🎨 Stable Diffusion Image Generator")
93
+
94
+ with gr.Row():
95
+ with gr.Column():
96
+ # Text Inputs
97
+ prompt = gr.Textbox(label="Prompt",
98
+ placeholder="Describe the image you want to generate...")
99
+ uncond_prompt = gr.Textbox(label="Negative Prompt (Optional)",
100
+ placeholder="Describe what you don't want in the image...")
101
+
102
+ # Generation Parameters
103
+ with gr.Accordion("Advanced Settings", open=False):
104
+ do_cfg = gr.Checkbox(label="Use Classifier-Free Guidance", value=True)
105
+ cfg_scale = gr.Slider(minimum=1, maximum=14, value=8, label="CFG Scale")
106
+ sampler = gr.Dropdown(
107
+ choices=["ddpm", "ddim", "pndm"], # Add more samplers if available
108
+ value="ddpm",
109
+ label="Sampling Method"
110
+ )
111
+ num_inference_steps = gr.Slider(
112
+ minimum=10,
113
+ maximum=100,
114
+ value=50,
115
+ label="Number of Inference Steps"
116
+ )
117
+ seed = gr.Number(value=42, label="Random Seed")
118
+
119
+ # Image-to-Image Section
120
+ with gr.Accordion("Image-to-Image", open=False):
121
+ input_image = gr.Image(type="pil", label="Input Image (Optional)")
122
+ strength = gr.Slider(
123
+ minimum=0,
124
+ maximum=1,
125
+ value=0.8,
126
+ label="Image Transformation Strength"
127
+ )
128
+
129
+ # Generate Button
130
+ generate_btn = gr.Button("Generate Image", variant="primary")
131
+
132
+ with gr.Row():
133
+ # Output Image
134
+ output_image = gr.Image(label="Generated Image")
135
+
136
+ # Connect Button to Generation Function
137
+ generate_btn.click(
138
+ fn=generate_image,
139
+ inputs=[
140
+ prompt, uncond_prompt, do_cfg, cfg_scale,
141
+ sampler, num_inference_steps, seed,
142
+ input_image, strength
143
+ ],
144
+ outputs=output_image
145
+ )
146
+
147
+ # Launch the interface
148
+ demo.launch(server_name="0.0.0.0", server_port=7860)
149
+
150
+ if __name__ == "__main__":
151
+ launch_gradio_interface()
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clip import CLIP
2
+ from encoder import VAE_Encoder
3
+ from decoder import VAE_Decoder
4
+ from diffusion import Diffusion
5
+
6
+ import model_converter
7
+
8
+ def preload_models_from_standard_weights(ckpt_path, device):
9
+ state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
10
+
11
+ encoder=VAE_Encoder().to(device)
12
+ encoder.load_state_dict(state_dict['encoder'], strict=True)
13
+
14
+ decoder=VAE_Decoder().to(device)
15
+ decoder.load_state_dict(state_dict['decoder'], strict=True)
16
+
17
+ diffusion=Diffusion().to(device)
18
+ diffusion.load_state_dict(state_dict['diffusion'], strict=True)
19
+
20
+ clip=CLIP().to(device)
21
+ clip.load_state_dict(state_dict['clip'], strict=True)
22
+
23
+ return {
24
+ 'clip': clip,
25
+ 'encoder': encoder,
26
+ 'decoder': decoder,
27
+ 'diffusion': diffusion,
28
+ }
model_converter.py ADDED
The diff for this file is too large to render. See raw diff
 
pipeline.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from ddpm import DDPMSampler
5
+
6
+ WIDTH = 512
7
+ HEIGHT = 512
8
+ LATENTS_WIDTH = WIDTH // 8
9
+ LATENTS_HEIGHT = HEIGHT // 8
10
+
11
+ def generate(
12
+ prompt,
13
+ uncond_prompt=None,
14
+ input_image=None,
15
+ strength=0.8,
16
+ do_cfg=True,
17
+ cfg_scale=7.5,
18
+ sampler_name="ddpm",
19
+ n_inference_steps=50,
20
+ models={},
21
+ seed=None,
22
+ device=None,
23
+ idle_device=None,
24
+ tokenizer=None,
25
+ ):
26
+ with torch.no_grad():
27
+ if not 0 < strength <= 1:
28
+ raise ValueError("strength must be between 0 and 1")
29
+
30
+ if idle_device:
31
+ to_idle = lambda x: x.to(idle_device)
32
+ else:
33
+ to_idle = lambda x: x
34
+
35
+ # Initialize random number generator according to the seed specified
36
+ generator = torch.Generator(device=device)
37
+ if seed is None:
38
+ generator.seed()
39
+ else:
40
+ generator.manual_seed(seed)
41
+
42
+ clip = models["clip"]
43
+ clip.to(device)
44
+
45
+ if do_cfg:
46
+ # Convert into a list of length Seq_Len=77
47
+ cond_tokens = tokenizer.batch_encode_plus(
48
+ [prompt], padding="max_length", max_length=77
49
+ ).input_ids
50
+ # (Batch_Size, Seq_Len)
51
+ cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
52
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
53
+ cond_context = clip(cond_tokens)
54
+ # Convert into a list of length Seq_Len=77
55
+ uncond_tokens = tokenizer.batch_encode_plus(
56
+ [uncond_prompt], padding="max_length", max_length=77
57
+ ).input_ids
58
+ # (Batch_Size, Seq_Len)
59
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
60
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
61
+ uncond_context = clip(uncond_tokens)
62
+ # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
63
+ context = torch.cat([cond_context, uncond_context])
64
+ else:
65
+ # Convert into a list of length Seq_Len=77
66
+ tokens = tokenizer.batch_encode_plus(
67
+ [prompt], padding="max_length", max_length=77
68
+ ).input_ids
69
+ # (Batch_Size, Seq_Len)
70
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device)
71
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
72
+ context = clip(tokens)
73
+ to_idle(clip)
74
+
75
+ if sampler_name == "ddpm":
76
+ sampler = DDPMSampler(generator)
77
+ sampler.set_inference_timesteps(n_inference_steps)
78
+ else:
79
+ raise ValueError("Unknown sampler value %s. ")
80
+
81
+ latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
82
+
83
+ if input_image:
84
+ encoder = models["encoder"]
85
+ encoder.to(device)
86
+
87
+ input_image_tensor = input_image.resize((WIDTH, HEIGHT))
88
+
89
+ # (Height, Width, Channel)
90
+ input_image_tensor = np.array(input_image_tensor)
91
+
92
+ # (Height, Width, Channel) -> (Height, Width, Channel)
93
+ input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
94
+ input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
95
+
96
+ # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
97
+ input_image_tensor = input_image_tensor.unsqueeze(0)
98
+
99
+ # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
100
+ input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
101
+
102
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
103
+ encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
104
+ latents = encoder(input_image_tensor, encoder_noise)
105
+
106
+ # Add noise to the latents (the encoded input image)
107
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
108
+ sampler.set_strength(strength=strength)
109
+ latents = sampler.add_noise(latents, sampler.timesteps[0])
110
+ to_idle(encoder)
111
+ else:
112
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
113
+ latents = torch.randn(latents_shape, generator=generator, device=device)
114
+
115
+ diffusion = models["diffusion"]
116
+ diffusion.to(device)
117
+
118
+ timesteps = tqdm(sampler.timesteps)
119
+ for i, timestep in enumerate(timesteps):
120
+ # (1, 320)
121
+ time_embedding = get_time_embedding(timestep).to(device)
122
+
123
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
124
+ model_input = latents
125
+
126
+ if do_cfg:
127
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
128
+ model_input = model_input.repeat(2, 1, 1, 1)
129
+
130
+ # model_output is the predicted noise
131
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
132
+
133
+ model_output = diffusion(model_input, context, time_embedding)
134
+
135
+ if do_cfg:
136
+ output_cond, output_uncond = model_output.chunk(2)
137
+ model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
138
+
139
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
140
+ latents = sampler.step(timestep, latents, model_output)
141
+
142
+ to_idle(diffusion)
143
+
144
+ decoder = models["decoder"]
145
+ decoder.to(device)
146
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
147
+ images = decoder(latents)
148
+
149
+ to_idle(decoder)
150
+
151
+ images = rescale(images, (-1, 1), (0, 255), clamp=True)
152
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
153
+ images = images.permute(0, 2, 3, 1)
154
+ images = images.to("cpu", torch.uint8).numpy()
155
+ return images[0]
156
+
157
+ def rescale(x, old_range, new_range, clamp=False):
158
+ old_min, old_max = old_range
159
+ new_min, new_max = new_range
160
+ x -= old_min
161
+ x *= (new_max - new_min) / (old_max - old_min)
162
+ x += new_min
163
+ if clamp:
164
+ x = x.clamp(new_min, new_max)
165
+ return x
166
+
167
+ def get_time_embedding(timestep):
168
+ # Shape: (160,)
169
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
170
+ # Shape: (1, 160)
171
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
172
+ # Shape: (1, 160 * 2)
173
+ return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
174
+
requirements.txt ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.11.18
3
+ aiosignal==1.3.2
4
+ annotated-types==0.7.0
5
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
6
+ attrs==25.3.0
7
+ certifi==2025.4.26
8
+ charset-normalizer==3.4.2
9
+ click==8.2.0
10
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
11
+ contourpy==1.3.2
12
+ cycler==0.12.1
13
+ datasets==3.6.0
14
+ debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321233760/work
15
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
16
+ dill==0.3.8
17
+ docker-pycreds==0.4.0
18
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
19
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
20
+ filelock==3.18.0
21
+ fonttools==4.58.0
22
+ frozenlist==1.6.0
23
+ fsspec==2025.3.0
24
+ gitdb==4.0.12
25
+ GitPython==3.1.44
26
+ hf-xet==1.1.0
27
+ huggingface-hub==0.31.1
28
+ idna==3.10
29
+ importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work
30
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
31
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1745672166/work
32
+ ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
33
+ ipywidgets==8.1.7
34
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
35
+ Jinja2==3.1.6
36
+ jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
37
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
38
+ jupyterlab_widgets==3.0.15
39
+ kiwisolver==1.4.8
40
+ lightning==2.5.1.post0
41
+ lightning-utilities==0.14.3
42
+ MarkupSafe==3.0.2
43
+ matplotlib==3.10.3
44
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
45
+ mpmath==1.3.0
46
+ multidict==6.4.3
47
+ multiprocess==0.70.16
48
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
49
+ networkx==3.4.2
50
+ numpy==2.2.5
51
+ nvidia-cublas-cu12==12.6.4.1
52
+ nvidia-cuda-cupti-cu12==12.6.80
53
+ nvidia-cuda-nvrtc-cu12==12.6.77
54
+ nvidia-cuda-runtime-cu12==12.6.77
55
+ nvidia-cudnn-cu12==9.5.1.17
56
+ nvidia-cufft-cu12==11.3.0.4
57
+ nvidia-cufile-cu12==1.11.1.6
58
+ nvidia-curand-cu12==10.3.7.77
59
+ nvidia-cusolver-cu12==11.7.1.2
60
+ nvidia-cusparse-cu12==12.5.4.2
61
+ nvidia-cusparselt-cu12==0.6.3
62
+ nvidia-nccl-cu12==2.26.2
63
+ nvidia-nvjitlink-cu12==12.6.85
64
+ nvidia-nvtx-cu12==12.6.77
65
+ packaging==24.2
66
+ pandas==2.2.3
67
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
68
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
69
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
70
+ pillow==11.2.1
71
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
72
+ prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
73
+ propcache==0.3.1
74
+ protobuf==6.30.2
75
+ psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663149797/work
76
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
77
+ pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
78
+ pyarrow==20.0.0
79
+ pydantic==2.11.4
80
+ pydantic_core==2.33.2
81
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work
82
+ pyparsing==3.2.3
83
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
84
+ pytorch-lightning==2.5.1.post0
85
+ pytz==2025.2
86
+ PyYAML==6.0.2
87
+ pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245578/work
88
+ regex==2024.11.6
89
+ requests==2.32.3
90
+ safetensors==0.5.3
91
+ sentry-sdk==2.27.0
92
+ setproctitle==1.3.6
93
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
94
+ smmap==5.0.2
95
+ stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
96
+ sympy==1.14.0
97
+ tokenizers==0.21.1
98
+ torch==2.7.0
99
+ torchmetrics==1.7.1
100
+ torchvision==0.22.0
101
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1732615904614/work
102
+ tqdm==4.67.1
103
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
104
+ transformers==4.51.3
105
+ triton==3.3.0
106
+ typing-inspection==0.4.0
107
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work
108
+ tzdata==2025.2
109
+ urllib3==2.4.0
110
+ wandb==0.19.11
111
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
112
+ widgetsnbextension==4.0.14
113
+ xxhash==3.5.0
114
+ yarl==1.20.0
115
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
vocab.json ADDED
The diff for this file is too large to render. See raw diff