Spaces:
Running
Running
Commit
·
b993f12
0
Parent(s):
implementation of stable diffusion from scratch
Browse files- .gitignore +53 -0
- README.md +4 -0
- attention.py +117 -0
- clip.py +86 -0
- ddpm.py +119 -0
- decoder.py +134 -0
- diffusion.py +297 -0
- dog.jpg +0 -0
- encoder.py +91 -0
- interface.py +151 -0
- merges.txt +0 -0
- model.py +28 -0
- model_converter.py +0 -0
- pipeline.py +174 -0
- requirements.txt +115 -0
- test.ipynb +0 -0
- vocab.json +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*inkpunk-diffusion-v1.ckpt
|
| 2 |
+
|
| 3 |
+
# Byte-compiled / optimized / DLL files
|
| 4 |
+
__pycache__/
|
| 5 |
+
**/__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
|
| 8 |
+
# C extensions
|
| 9 |
+
*.so
|
| 10 |
+
|
| 11 |
+
# Distribution / packaging
|
| 12 |
+
bin/
|
| 13 |
+
build/
|
| 14 |
+
develop-eggs/
|
| 15 |
+
dist/
|
| 16 |
+
eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
|
| 26 |
+
# Installer logs
|
| 27 |
+
pip-log.txt
|
| 28 |
+
pip-delete-this-directory.txt
|
| 29 |
+
|
| 30 |
+
# Unit test / coverage reports
|
| 31 |
+
.tox/
|
| 32 |
+
.coverage
|
| 33 |
+
.cache
|
| 34 |
+
nosetests.xml
|
| 35 |
+
coverage.xml
|
| 36 |
+
|
| 37 |
+
# Translations
|
| 38 |
+
*.mo
|
| 39 |
+
|
| 40 |
+
# Mr Developer
|
| 41 |
+
.mr.developer.cfg
|
| 42 |
+
.project
|
| 43 |
+
.pydevproject
|
| 44 |
+
|
| 45 |
+
# Rope
|
| 46 |
+
.ropeproject
|
| 47 |
+
|
| 48 |
+
# Django stuff:
|
| 49 |
+
*.log
|
| 50 |
+
*.pot
|
| 51 |
+
|
| 52 |
+
# Sphinx documentation
|
| 53 |
+
docs/_build/
|
README.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# stable-diffusion
|
| 2 |
+
|
| 3 |
+
<!-- 1. Download `vocab.json` and `merges.txt` from https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main/tokenizer and save them in the `data` folder
|
| 4 |
+
2. Download `inkpunk-diffusion-v1.ckpt` from https://huggingface.co/Envvi/Inkpunk-Diffusion/tree/main and save it in the `data` folder -->
|
attention.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
class SelfAttention(nn.Module):
|
| 7 |
+
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
|
| 8 |
+
super().__init__()
|
| 9 |
+
# This combines the Wq, Wk and Wv matrices into one matrix
|
| 10 |
+
self.in_proj=nn.Linear(d_embed, 3*d_embed, bias=in_proj_bias)
|
| 11 |
+
|
| 12 |
+
# This one represents the Wo matrix
|
| 13 |
+
self.out_proj=nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
| 14 |
+
|
| 15 |
+
self.n_heads=n_heads
|
| 16 |
+
self.d_head=d_embed // self.n_heads
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def forward(self, x, causal_mask=False):
|
| 20 |
+
# x: (batch_size, seq_len, dim)
|
| 21 |
+
input_shape = x.shape
|
| 22 |
+
batch_size, sequence_length, d_embed = input_shape
|
| 23 |
+
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
|
| 24 |
+
|
| 25 |
+
# Apply the in_proj to get the queries, keys, and values all at once
|
| 26 |
+
# (batch_size, seq_len, dim) -> (batch_size, seq_len, 3 * dim)
|
| 27 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 28 |
+
|
| 29 |
+
# Reshape to (batch_size, seq_len, n_heads, d_head)
|
| 30 |
+
q = q.view(interim_shape)
|
| 31 |
+
k = k.view(interim_shape)
|
| 32 |
+
v = v.view(interim_shape)
|
| 33 |
+
|
| 34 |
+
# Transpose for attention dot product: (batch_size, n_heads, seq_len, d_head)
|
| 35 |
+
q = q.transpose(1, 2)
|
| 36 |
+
k = k.transpose(1, 2)
|
| 37 |
+
v = v.transpose(1, 2)
|
| 38 |
+
|
| 39 |
+
# (batch_size, n_heads, seq_len, d_head) @ (batch_size, n_heads, d_head, seq_len) -> (batch_size, n_heads, seq_len, seq_len)
|
| 40 |
+
attention_weights = q @ k.transpose(-1, -2)
|
| 41 |
+
|
| 42 |
+
# Scaling by sqrt(d_head)
|
| 43 |
+
attention_weights = attention_weights / math.sqrt(self.d_head)
|
| 44 |
+
|
| 45 |
+
# Causal mask to prevent attending to future tokens
|
| 46 |
+
if causal_mask:
|
| 47 |
+
mask = torch.ones_like(attention_weights, dtype=torch.bool).triu(1)
|
| 48 |
+
attention_weights.masked_fill_(mask, -torch.inf)
|
| 49 |
+
|
| 50 |
+
# Apply softmax to get attention probabilities
|
| 51 |
+
attention_weights = F.softmax(attention_weights, dim=-1)
|
| 52 |
+
|
| 53 |
+
# Apply attention weights: (batch_size, n_heads, seq_len, seq_len) @ (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
|
| 54 |
+
output = attention_weights @ v
|
| 55 |
+
|
| 56 |
+
# Transpose back: (batch_size, seq_len, n_heads, d_head)
|
| 57 |
+
output = output.transpose(1, 2)
|
| 58 |
+
|
| 59 |
+
# Reshape back to (batch_size, seq_len, dim)
|
| 60 |
+
output = output.reshape(input_shape)
|
| 61 |
+
|
| 62 |
+
# Apply output projection
|
| 63 |
+
output = self.out_proj(output)
|
| 64 |
+
|
| 65 |
+
return output
|
| 66 |
+
class CrossAttention(nn.Module):
|
| 67 |
+
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
|
| 70 |
+
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
| 71 |
+
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
| 72 |
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
| 73 |
+
self.n_heads = n_heads
|
| 74 |
+
self.d_head = d_embed // n_heads
|
| 75 |
+
|
| 76 |
+
def forward(self, x, y):
|
| 77 |
+
# x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
|
| 78 |
+
# y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
|
| 79 |
+
|
| 80 |
+
input_shape = x.shape
|
| 81 |
+
batch_size, sequence_length, d_embed = input_shape
|
| 82 |
+
# Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
|
| 83 |
+
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
|
| 84 |
+
|
| 85 |
+
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
|
| 86 |
+
q = self.q_proj(x)
|
| 87 |
+
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
|
| 88 |
+
k = self.k_proj(y)
|
| 89 |
+
v = self.v_proj(y)
|
| 90 |
+
|
| 91 |
+
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
|
| 92 |
+
q = q.view(interim_shape).transpose(1, 2)
|
| 93 |
+
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
|
| 94 |
+
k = k.view(interim_shape).transpose(1, 2)
|
| 95 |
+
v = v.view(interim_shape).transpose(1, 2)
|
| 96 |
+
|
| 97 |
+
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
|
| 98 |
+
weight = q @ k.transpose(-1, -2)
|
| 99 |
+
|
| 100 |
+
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
|
| 101 |
+
weight /= math.sqrt(self.d_head)
|
| 102 |
+
weight = F.softmax(weight, dim=-1)
|
| 103 |
+
|
| 104 |
+
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
|
| 105 |
+
output = weight @ v
|
| 106 |
+
|
| 107 |
+
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
|
| 108 |
+
output = output.transpose(1, 2).contiguous()
|
| 109 |
+
|
| 110 |
+
# (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
|
| 111 |
+
output = output.view(input_shape)
|
| 112 |
+
|
| 113 |
+
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
|
| 114 |
+
output = self.out_proj(output)
|
| 115 |
+
|
| 116 |
+
# (Batch_Size, Seq_Len_Q, Dim_Q)
|
| 117 |
+
return output
|
clip.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from attention import SelfAttention
|
| 6 |
+
|
| 7 |
+
class CLIPEmbedding(nn.Module):
|
| 8 |
+
def __init__(self, n_vocab, n_embed, n_token):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.token_embedding=nn.Embedding(n_vocab, n_embed)
|
| 11 |
+
self.position_embedding=nn.Parameter(torch.zeros((n_token, n_embed)))
|
| 12 |
+
|
| 13 |
+
def forward(self, tokens: torch.Tensor):
|
| 14 |
+
x=self.token_embedding(tokens)
|
| 15 |
+
x+=self.position_embedding
|
| 16 |
+
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
class CLIPLayer(nn.Module):
|
| 20 |
+
def __init__(self, n_head, n_embed):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.layernorm_1=nn.LayerNorm(n_embed)
|
| 23 |
+
self.attention=SelfAttention(n_head, n_embed)
|
| 24 |
+
self.layernorm_2=nn.LayerNorm(n_embed)
|
| 25 |
+
|
| 26 |
+
self.linear_1=nn.Linear(n_embed, 4*n_embed)
|
| 27 |
+
self.linear_2=nn.Linear(4*n_embed, n_embed)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
residue=x
|
| 31 |
+
|
| 32 |
+
x=self.layernorm_1(x)
|
| 33 |
+
|
| 34 |
+
x=self.attention(x, causal_mask=True)
|
| 35 |
+
|
| 36 |
+
x+=residue
|
| 37 |
+
|
| 38 |
+
residue=x
|
| 39 |
+
|
| 40 |
+
x=self.layernorm_2(x)
|
| 41 |
+
|
| 42 |
+
x=self.linear_1(x)
|
| 43 |
+
|
| 44 |
+
x=x*torch.sigmoid(1.702*x)
|
| 45 |
+
|
| 46 |
+
x=self.linear_2(x)
|
| 47 |
+
x+=residue
|
| 48 |
+
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
class CLIP(nn.Module):
|
| 52 |
+
def __init__(self):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.embedding=CLIPEmbedding(49408, 768, 77)
|
| 55 |
+
self.layers=nn.ModuleList([
|
| 56 |
+
CLIPLayer(12, 768) for i in range(12)
|
| 57 |
+
])
|
| 58 |
+
|
| 59 |
+
self.layernorm=nn.LayerNorm(768)
|
| 60 |
+
|
| 61 |
+
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
| 62 |
+
tokens=tokens.type(torch.long)
|
| 63 |
+
|
| 64 |
+
state=self.embedding(tokens)
|
| 65 |
+
|
| 66 |
+
for layer in self.layers:
|
| 67 |
+
state=layer(state)
|
| 68 |
+
|
| 69 |
+
output=self.layernorm(state)
|
| 70 |
+
|
| 71 |
+
return output
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
dummy_tokens = torch.randint(0, 49408, (1, 77)) # (Batch_Size, Seq_Len)
|
| 75 |
+
|
| 76 |
+
# Instantiate the model
|
| 77 |
+
model = CLIP()
|
| 78 |
+
|
| 79 |
+
# Forward pass
|
| 80 |
+
with torch.no_grad(): # no need to track gradients for testing
|
| 81 |
+
output = model(dummy_tokens)
|
| 82 |
+
|
| 83 |
+
# Print the output shape
|
| 84 |
+
# Output shape: torch.Size([1, 77, 768])
|
| 85 |
+
print("Output shape:", output.shape)
|
| 86 |
+
|
ddpm.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class DDPMSampler:
|
| 5 |
+
|
| 6 |
+
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
|
| 7 |
+
# Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
|
| 8 |
+
# For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
|
| 9 |
+
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
|
| 10 |
+
self.alphas = 1.0 - self.betas
|
| 11 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 12 |
+
self.one = torch.tensor(1.0)
|
| 13 |
+
|
| 14 |
+
self.generator = generator
|
| 15 |
+
|
| 16 |
+
self.num_train_timesteps = num_training_steps
|
| 17 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
|
| 18 |
+
|
| 19 |
+
def set_inference_timesteps(self, num_inference_steps=50):
|
| 20 |
+
self.num_inference_steps = num_inference_steps
|
| 21 |
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
| 22 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
| 23 |
+
self.timesteps = torch.from_numpy(timesteps)
|
| 24 |
+
|
| 25 |
+
def _get_previous_timestep(self, timestep: int) -> int:
|
| 26 |
+
prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
|
| 27 |
+
return prev_t
|
| 28 |
+
|
| 29 |
+
def _get_variance(self, timestep: int) -> torch.Tensor:
|
| 30 |
+
prev_t = self._get_previous_timestep(timestep)
|
| 31 |
+
|
| 32 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 33 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
| 34 |
+
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
|
| 35 |
+
|
| 36 |
+
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
| 37 |
+
# and sample from it to get previous sample
|
| 38 |
+
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
| 39 |
+
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
|
| 40 |
+
|
| 41 |
+
# we always take the log of variance, so clamp it to ensure it's not 0
|
| 42 |
+
variance = torch.clamp(variance, min=1e-20)
|
| 43 |
+
|
| 44 |
+
return variance
|
| 45 |
+
|
| 46 |
+
def set_strength(self, strength=1):
|
| 47 |
+
"""
|
| 48 |
+
Set how much noise to add to the input image.
|
| 49 |
+
More noise (strength ~ 1) means that the output will be further from the input image.
|
| 50 |
+
Less noise (strength ~ 0) means that the output will be closer to the input image.
|
| 51 |
+
"""
|
| 52 |
+
# start_step is the number of noise levels to skip
|
| 53 |
+
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
|
| 54 |
+
self.timesteps = self.timesteps[start_step:]
|
| 55 |
+
self.start_step = start_step
|
| 56 |
+
|
| 57 |
+
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
|
| 58 |
+
t = timestep
|
| 59 |
+
prev_t = self._get_previous_timestep(t)
|
| 60 |
+
|
| 61 |
+
# 1. compute alphas, betas
|
| 62 |
+
alpha_prod_t = self.alphas_cumprod[t]
|
| 63 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
| 64 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 65 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 66 |
+
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
| 67 |
+
current_beta_t = 1 - current_alpha_t
|
| 68 |
+
|
| 69 |
+
# 2. compute predicted original sample from predicted noise also called
|
| 70 |
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
| 71 |
+
pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 72 |
+
|
| 73 |
+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
| 74 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
| 75 |
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
| 76 |
+
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
|
| 77 |
+
|
| 78 |
+
# 5. Compute predicted previous sample µ_t
|
| 79 |
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
| 80 |
+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
|
| 81 |
+
|
| 82 |
+
# 6. Add noise
|
| 83 |
+
variance = 0
|
| 84 |
+
if t > 0:
|
| 85 |
+
device = model_output.device
|
| 86 |
+
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
|
| 87 |
+
# Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
| 88 |
+
variance = (self._get_variance(t) ** 0.5) * noise
|
| 89 |
+
|
| 90 |
+
# sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
|
| 91 |
+
# the variable "variance" is already multiplied by the noise N(0, 1)
|
| 92 |
+
pred_prev_sample = pred_prev_sample + variance
|
| 93 |
+
|
| 94 |
+
return pred_prev_sample
|
| 95 |
+
|
| 96 |
+
def add_noise(
|
| 97 |
+
self,
|
| 98 |
+
original_samples: torch.FloatTensor,
|
| 99 |
+
timesteps: torch.IntTensor,
|
| 100 |
+
) -> torch.FloatTensor:
|
| 101 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 102 |
+
timesteps = timesteps.to(original_samples.device)
|
| 103 |
+
|
| 104 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 105 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 106 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
| 107 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 108 |
+
|
| 109 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 110 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 111 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
| 112 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 113 |
+
|
| 114 |
+
# Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
|
| 115 |
+
# Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
|
| 116 |
+
# here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
|
| 117 |
+
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
|
| 118 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 119 |
+
return noisy_samples
|
decoder.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from attention import SelfAttention
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VAE_ResidualBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, out_channels):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.grpnorm_1=nn.GroupNorm(32, in_channels)
|
| 11 |
+
self.conv_1=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 12 |
+
|
| 13 |
+
self.grpnorm_2=nn.GroupNorm(32, out_channels)
|
| 14 |
+
self.conv_2=nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
| 15 |
+
|
| 16 |
+
if in_channels == out_channels:
|
| 17 |
+
self.residual_layer=nn.Identity()
|
| 18 |
+
else:
|
| 19 |
+
self.residual_layer=nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
residue=x
|
| 22 |
+
|
| 23 |
+
x=self.grpnorm_1(x)
|
| 24 |
+
x=F.silu(x)
|
| 25 |
+
|
| 26 |
+
x=self.conv_1(x)
|
| 27 |
+
|
| 28 |
+
x=self.grpnorm_2(x)
|
| 29 |
+
x=F.silu(x)
|
| 30 |
+
|
| 31 |
+
x=self.conv_2(x)
|
| 32 |
+
|
| 33 |
+
return x+self.residual_layer(residue)
|
| 34 |
+
|
| 35 |
+
class VAE_AttentionBlock(nn.Module):
|
| 36 |
+
def __init__(self, channels):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.grpnorm=nn.GroupNorm(32, channels)
|
| 39 |
+
self.attention=SelfAttention(1, channels)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
# x: (Batch_Size, Features, Height, Width)
|
| 43 |
+
residue=x
|
| 44 |
+
|
| 45 |
+
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
|
| 46 |
+
x=self.grpnorm(x)
|
| 47 |
+
n, c, h, w=x.shape
|
| 48 |
+
|
| 49 |
+
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
|
| 50 |
+
x=x.view((n,c,h*w))
|
| 51 |
+
|
| 52 |
+
# (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
|
| 53 |
+
x=x.transpose(-1, -2)
|
| 54 |
+
|
| 55 |
+
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
|
| 56 |
+
x=self.attention(x)
|
| 57 |
+
|
| 58 |
+
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
|
| 59 |
+
x=x.transpose(-1, -2)
|
| 60 |
+
|
| 61 |
+
# (Batch_Size, Features, Height , Width)
|
| 62 |
+
x=x.view((n, c, h, w))
|
| 63 |
+
|
| 64 |
+
x+=residue
|
| 65 |
+
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
class VAE_Decoder(nn.Sequential):
|
| 69 |
+
def __init__(self):
|
| 70 |
+
super().__init__(
|
| 71 |
+
nn.Conv2d(4, 4, kernel_size=1, padding=0),
|
| 72 |
+
|
| 73 |
+
nn.Conv2d(4, 512, kernel_size=3, padding=1),
|
| 74 |
+
|
| 75 |
+
VAE_ResidualBlock(512, 512),
|
| 76 |
+
|
| 77 |
+
VAE_AttentionBlock(512),
|
| 78 |
+
|
| 79 |
+
VAE_ResidualBlock(512, 512),
|
| 80 |
+
VAE_ResidualBlock(512, 512),
|
| 81 |
+
VAE_ResidualBlock(512, 512),
|
| 82 |
+
VAE_ResidualBlock(512, 512),
|
| 83 |
+
|
| 84 |
+
nn.Upsample(scale_factor=2),
|
| 85 |
+
|
| 86 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
| 87 |
+
|
| 88 |
+
VAE_ResidualBlock(512, 512),
|
| 89 |
+
VAE_ResidualBlock(512, 512),
|
| 90 |
+
VAE_ResidualBlock(512, 512),
|
| 91 |
+
|
| 92 |
+
nn.Upsample(scale_factor=2),
|
| 93 |
+
|
| 94 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
| 95 |
+
|
| 96 |
+
VAE_ResidualBlock(512, 256),
|
| 97 |
+
VAE_ResidualBlock(256, 256),
|
| 98 |
+
VAE_ResidualBlock(256, 256),
|
| 99 |
+
|
| 100 |
+
nn.Upsample(scale_factor=2),
|
| 101 |
+
|
| 102 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
| 103 |
+
|
| 104 |
+
VAE_ResidualBlock(256, 128),
|
| 105 |
+
VAE_ResidualBlock(128, 128),
|
| 106 |
+
VAE_ResidualBlock(128, 128),
|
| 107 |
+
|
| 108 |
+
nn.GroupNorm(32, 128),
|
| 109 |
+
|
| 110 |
+
nn.SiLU(),
|
| 111 |
+
|
| 112 |
+
nn.Conv2d(128, 3, kernel_size=3, padding=1),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x/=0.18215
|
| 117 |
+
|
| 118 |
+
for module in self:
|
| 119 |
+
x=module(x)
|
| 120 |
+
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
model = VAE_Decoder()
|
| 125 |
+
model.eval()
|
| 126 |
+
|
| 127 |
+
# Create a dummy input tensor: (batch_size=1, channels=4, height=16, width=16)
|
| 128 |
+
x = torch.randn(1, 4, 8, 8)
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
output = model(x)
|
| 132 |
+
|
| 133 |
+
print("Input shape :", x.shape)
|
| 134 |
+
print("Output shape:", output.shape)
|
diffusion.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from attention import SelfAttention, CrossAttention
|
| 5 |
+
|
| 6 |
+
class TimeEmbedding(nn.Module):
|
| 7 |
+
def __init__(self, n_embed):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.linear_1=nn.Linear(n_embed, 4*n_embed)
|
| 10 |
+
self.linear_2=nn.Linear(4*n_embed, 4*n_embed)
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
x=self.linear_1(x)
|
| 14 |
+
x=F.silu(x)
|
| 15 |
+
x=self.linear_2(x)
|
| 16 |
+
return x
|
| 17 |
+
|
| 18 |
+
class UNET_ResidualBlock(nn.Module):
|
| 19 |
+
def __init__(self, in_channels, out_channels, n_time=1280):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.grpnorm_feature=nn.GroupNorm(32, in_channels)
|
| 22 |
+
self.conv_feature=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 23 |
+
self.linear_time=nn.Linear(n_time, out_channels)
|
| 24 |
+
|
| 25 |
+
self.grpnorm_merged=nn.GroupNorm(32, out_channels)
|
| 26 |
+
self.conv_merged=nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
| 27 |
+
|
| 28 |
+
if in_channels==out_channels:
|
| 29 |
+
self.residual_layer=nn.Identity()
|
| 30 |
+
else:
|
| 31 |
+
self.residual_layer=nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
|
| 32 |
+
|
| 33 |
+
def forward(self, feature, time):
|
| 34 |
+
residue=feature
|
| 35 |
+
|
| 36 |
+
feature=self.grpnorm_feature(feature)
|
| 37 |
+
feature=F.silu(feature)
|
| 38 |
+
|
| 39 |
+
feature=self.conv_feature(feature)
|
| 40 |
+
|
| 41 |
+
time=F.silu(time)
|
| 42 |
+
time=self.linear_time(time)
|
| 43 |
+
|
| 44 |
+
merged=feature+time.unsqueeze(-1).unsqueeze(-1)
|
| 45 |
+
|
| 46 |
+
merged=self.grpnorm_merged(merged)
|
| 47 |
+
merged=F.silu(merged)
|
| 48 |
+
merged=self.conv_merged(merged)
|
| 49 |
+
|
| 50 |
+
return merged + self.residual_layer(residue)
|
| 51 |
+
|
| 52 |
+
class UNET_AttentionBlock(nn.Module):
|
| 53 |
+
def __init__(self, n_head, n_embed, d_context=768):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
channels=n_head*n_embed
|
| 57 |
+
|
| 58 |
+
self.grpnorm=nn.GroupNorm(32, channels, eps=1e-6)
|
| 59 |
+
self.conv_input=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
| 60 |
+
|
| 61 |
+
self.layernorm_1=nn.LayerNorm(channels)
|
| 62 |
+
self.attention_1=SelfAttention(n_head, channels, in_proj_bias=False)
|
| 63 |
+
|
| 64 |
+
self.layernorm_2=nn.LayerNorm(channels)
|
| 65 |
+
self.attention_2=CrossAttention(n_head, channels, d_context, in_proj_bias=False)
|
| 66 |
+
|
| 67 |
+
self.layernorm_3=nn.LayerNorm(channels)
|
| 68 |
+
|
| 69 |
+
self.linear_geglu_1=nn.Linear(channels, 4*channels*2)
|
| 70 |
+
self.linear_geglu_2=nn.Linear(4*channels, channels)
|
| 71 |
+
|
| 72 |
+
self.conv_output=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
| 73 |
+
|
| 74 |
+
def forward(self, x, context):
|
| 75 |
+
residue_long=x
|
| 76 |
+
|
| 77 |
+
x=self.grpnorm(x)
|
| 78 |
+
x=self.conv_input(x)
|
| 79 |
+
|
| 80 |
+
n, c, h, w=x.shape
|
| 81 |
+
|
| 82 |
+
x=x.view((n,c,h*w))
|
| 83 |
+
|
| 84 |
+
x=x.transpose(-1, -2)
|
| 85 |
+
residue_short=x
|
| 86 |
+
|
| 87 |
+
x=self.layernorm_1(x)
|
| 88 |
+
x=self.attention_1(x)
|
| 89 |
+
|
| 90 |
+
x+=residue_short
|
| 91 |
+
|
| 92 |
+
residue_short=x
|
| 93 |
+
|
| 94 |
+
x=self.layernorm_2(x)
|
| 95 |
+
x=self.attention_2(x, context)
|
| 96 |
+
|
| 97 |
+
x+=residue_short
|
| 98 |
+
|
| 99 |
+
residue_short=x
|
| 100 |
+
|
| 101 |
+
x=self.layernorm_3(x)
|
| 102 |
+
x, gate=self.linear_geglu_1(x).chunk(2, dim=-1)
|
| 103 |
+
|
| 104 |
+
x=x*F.gelu(gate)
|
| 105 |
+
|
| 106 |
+
x=self.linear_geglu_2(x)
|
| 107 |
+
|
| 108 |
+
x+=residue_short
|
| 109 |
+
x=x.transpose(-1, -2)
|
| 110 |
+
|
| 111 |
+
x=x.view((n, c, h, w))
|
| 112 |
+
|
| 113 |
+
return self.conv_output(x)+residue_long
|
| 114 |
+
|
| 115 |
+
class Upsample(nn.Module):
|
| 116 |
+
def __init__(self, channels):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.conv=nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
x=F.interpolate(x, scale_factor=2, mode='nearest')
|
| 122 |
+
return self.conv(x)
|
| 123 |
+
|
| 124 |
+
# passing arguments to the parent class nn.Sequential, not to your SwitchSequential class directly — because you did not override the __init__ method in SwitchSequential
|
| 125 |
+
class SwitchSequential(nn.Sequential):
|
| 126 |
+
def forward(self, x, context, time):
|
| 127 |
+
for layer in self:
|
| 128 |
+
if isinstance(layer, UNET_AttentionBlock):
|
| 129 |
+
x=layer(x, context)
|
| 130 |
+
elif isinstance(layer, UNET_ResidualBlock):
|
| 131 |
+
x=layer(x, time)
|
| 132 |
+
else:
|
| 133 |
+
x=layer(x)
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
class UNET(nn.Module):
|
| 137 |
+
def __init__(self):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.encoders=nn.ModuleList([
|
| 140 |
+
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
|
| 141 |
+
SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
|
| 142 |
+
|
| 143 |
+
# (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
|
| 144 |
+
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
|
| 145 |
+
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
|
| 146 |
+
|
| 147 |
+
# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
|
| 148 |
+
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
|
| 149 |
+
|
| 150 |
+
# (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
|
| 151 |
+
SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
|
| 152 |
+
|
| 153 |
+
# (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
|
| 154 |
+
SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
|
| 155 |
+
|
| 156 |
+
# (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
|
| 157 |
+
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
|
| 158 |
+
|
| 159 |
+
# (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
|
| 160 |
+
SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
|
| 161 |
+
|
| 162 |
+
# (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
|
| 163 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
|
| 164 |
+
|
| 165 |
+
# (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
|
| 166 |
+
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
|
| 167 |
+
|
| 168 |
+
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
|
| 169 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
|
| 170 |
+
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
|
| 171 |
+
|
| 172 |
+
])
|
| 173 |
+
|
| 174 |
+
self.bottleneck = SwitchSequential(
|
| 175 |
+
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
|
| 176 |
+
UNET_ResidualBlock(1280, 1280),
|
| 177 |
+
UNET_AttentionBlock(8, 160),
|
| 178 |
+
UNET_ResidualBlock(1280, 1280),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self.decoders = nn.ModuleList([
|
| 182 |
+
# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
|
| 183 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
|
| 184 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
|
| 185 |
+
|
| 186 |
+
# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
|
| 187 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
|
| 188 |
+
|
| 189 |
+
# (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
|
| 190 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
|
| 191 |
+
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
|
| 192 |
+
|
| 193 |
+
# (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
|
| 194 |
+
SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
|
| 195 |
+
|
| 196 |
+
# (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
|
| 197 |
+
SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
|
| 198 |
+
|
| 199 |
+
# (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
|
| 200 |
+
SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
|
| 201 |
+
|
| 202 |
+
# (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
|
| 203 |
+
SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
|
| 204 |
+
|
| 205 |
+
# (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
|
| 206 |
+
SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
|
| 207 |
+
|
| 208 |
+
# (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
|
| 209 |
+
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
|
| 210 |
+
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
|
| 211 |
+
])
|
| 212 |
+
|
| 213 |
+
def forward(self, x, context, time):
|
| 214 |
+
# x: (Batch_Size, 4, Height / 8, Width / 8)
|
| 215 |
+
# context: (Batch_Size, Seq_Len, Dim)
|
| 216 |
+
# time: (1, 1280)
|
| 217 |
+
|
| 218 |
+
skip_connections = []
|
| 219 |
+
for layers in self.encoders:
|
| 220 |
+
x = layers(x, context, time)
|
| 221 |
+
skip_connections.append(x)
|
| 222 |
+
|
| 223 |
+
x = self.bottleneck(x, context, time)
|
| 224 |
+
|
| 225 |
+
for layers in self.decoders:
|
| 226 |
+
# Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
|
| 227 |
+
x = torch.cat((x, skip_connections.pop()), dim=1)
|
| 228 |
+
x = layers(x, context, time)
|
| 229 |
+
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class UNET_OutputLayer(nn.Module):
|
| 234 |
+
def __init__(self, in_channels, out_channels):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.grpnorm = nn.GroupNorm(32, in_channels)
|
| 237 |
+
|
| 238 |
+
self.conv=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 239 |
+
|
| 240 |
+
def forward(self, x):
|
| 241 |
+
x=self.grpnorm(x)
|
| 242 |
+
x=F.silu(x)
|
| 243 |
+
|
| 244 |
+
x=self.conv(x)
|
| 245 |
+
return x
|
| 246 |
+
|
| 247 |
+
class Diffusion(nn.Module):
|
| 248 |
+
def __init__(self):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.time_embedding=TimeEmbedding(320)
|
| 251 |
+
self.unet=UNET()
|
| 252 |
+
self.final=UNET_OutputLayer(320, 4)
|
| 253 |
+
|
| 254 |
+
def forward(self, latent, context, time):
|
| 255 |
+
time=self.time_embedding(time)
|
| 256 |
+
|
| 257 |
+
output=self.unet(latent, context, time)
|
| 258 |
+
|
| 259 |
+
output=self.final(output)
|
| 260 |
+
|
| 261 |
+
return output
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
# Dummy inputs
|
| 265 |
+
batch_size = 10
|
| 266 |
+
height = 64
|
| 267 |
+
width = 64
|
| 268 |
+
in_channels = 4
|
| 269 |
+
context_dim = 768
|
| 270 |
+
seq_len = 77
|
| 271 |
+
|
| 272 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 273 |
+
|
| 274 |
+
# Create model and move to device
|
| 275 |
+
model = Diffusion().to(device)
|
| 276 |
+
|
| 277 |
+
# Random input tensor with 4 channels
|
| 278 |
+
x = torch.randn(batch_size, in_channels, height, width).to(device)
|
| 279 |
+
|
| 280 |
+
print('Input shape to UNET: ', x.shape)
|
| 281 |
+
|
| 282 |
+
# Time embedding (e.g., timestep from a diffusion schedule)
|
| 283 |
+
t = torch.randn(batch_size, 320).to(device)
|
| 284 |
+
|
| 285 |
+
print('Time Embedding shape to UNET: ',t.shape)
|
| 286 |
+
|
| 287 |
+
# Context for cross attention (e.g., text embedding from CLIP or transformer)
|
| 288 |
+
context = torch.randn(batch_size, seq_len, context_dim).to(device)
|
| 289 |
+
|
| 290 |
+
print('context shape to UNET: ', context.shape)
|
| 291 |
+
|
| 292 |
+
# Forward pass
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
output = model(x, context, t)
|
| 295 |
+
print(output)
|
| 296 |
+
|
| 297 |
+
print("Output shape of UNET:", output.shape)
|
dog.jpg
ADDED
|
encoder.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from decoder import VAE_AttentionBlock, VAE_ResidualBlock
|
| 5 |
+
|
| 6 |
+
class VAE_Encoder(nn.Sequential):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super().__init__(
|
| 9 |
+
# (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
|
| 10 |
+
nn.Conv2d(3, 128, kernel_size=3, padding=1),
|
| 11 |
+
|
| 12 |
+
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
|
| 13 |
+
VAE_ResidualBlock(128, 128),
|
| 14 |
+
VAE_ResidualBlock(128, 128),
|
| 15 |
+
|
| 16 |
+
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height/2 , Width/2)
|
| 17 |
+
nn.Conv2d(128, 128, kernel_size=3,stride=2, padding=0),
|
| 18 |
+
|
| 19 |
+
# (Batch_Size, 128, Height/2 , Width/2) -> (Batch_Size, 256, Height/2 , Width/2)
|
| 20 |
+
VAE_ResidualBlock(128, 256),
|
| 21 |
+
# (Batch_Size, 256, Height/2 , Width/2) -> (Batch_Size, 256, Height/2 , Width/2)
|
| 22 |
+
VAE_ResidualBlock(256, 256),
|
| 23 |
+
|
| 24 |
+
# (Batch_Size, 256, Height/2 , Width/2) -> (Batch_Size, 256, Height/4 , Width/4)
|
| 25 |
+
nn.Conv2d(256, 256, kernel_size=3,stride=2, padding=0),
|
| 26 |
+
|
| 27 |
+
# (Batch_Size, 256, Height/4 , Width/4) -> (Batch_Size, 512, Height/4 , Width/4)
|
| 28 |
+
VAE_ResidualBlock(256, 512),
|
| 29 |
+
# (Batch_Size, 512, Height/4 , Width/4) -> (Batch_Size, 512, Height/4 , Width/4)
|
| 30 |
+
VAE_ResidualBlock(512, 512),
|
| 31 |
+
|
| 32 |
+
# (Batch_Size, 512, Height/4 , Width/4) -> (Batch_Size, 512, Height/8 , Width/8)
|
| 33 |
+
nn.Conv2d(512, 512, kernel_size=3,stride=2, padding=0),
|
| 34 |
+
|
| 35 |
+
# (Batch_Size, 512, Height/8 , Width/8) -> (Batch_Size, 512, Height/8 , Width/8)
|
| 36 |
+
VAE_ResidualBlock(512, 512),
|
| 37 |
+
VAE_ResidualBlock(512, 512),
|
| 38 |
+
VAE_ResidualBlock(512, 512),
|
| 39 |
+
|
| 40 |
+
# (Batch_Size, 512, Height/8 , Width/8) -> (Batch_Size, 512, Height/8 , Width/8)
|
| 41 |
+
VAE_AttentionBlock(512),
|
| 42 |
+
|
| 43 |
+
VAE_ResidualBlock(512, 512),
|
| 44 |
+
|
| 45 |
+
nn.GroupNorm(32, 512),
|
| 46 |
+
|
| 47 |
+
nn.SiLU(),
|
| 48 |
+
|
| 49 |
+
nn.Conv2d(512, 8, kernel_size=3, padding=1),
|
| 50 |
+
|
| 51 |
+
# (Batch_Size, 8, Height/8, Width/8) -> (Batch_Size, 8, Height/8, Width/8)
|
| 52 |
+
nn.Conv2d(8, 8, kernel_size=1, padding=0)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
for module in self:
|
| 57 |
+
if getattr(module, 'stride', None) == (2, 2):
|
| 58 |
+
x=F.pad(x, (0,1,0,1))
|
| 59 |
+
|
| 60 |
+
x=module(x)
|
| 61 |
+
|
| 62 |
+
# (Batch_Size, 8, Height, Height/8, Width/8) -> two tensors of shape (Batch_Size, 4, Height/8, Width/8)
|
| 63 |
+
mean, log_var=torch.chunk(x, 2, dim=1)
|
| 64 |
+
|
| 65 |
+
log_var=torch.clamp(log_var, -30, 20)
|
| 66 |
+
var=log_var.exp()
|
| 67 |
+
stdev=var.sqrt()
|
| 68 |
+
|
| 69 |
+
Z=mean + stdev * noise
|
| 70 |
+
|
| 71 |
+
Z*=0.18215
|
| 72 |
+
|
| 73 |
+
# print('-'*100)
|
| 74 |
+
# print('Z shape: ', Z.shape)
|
| 75 |
+
# print('-'*100)
|
| 76 |
+
|
| 77 |
+
return Z
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
model = VAE_Encoder()
|
| 81 |
+
model.eval()
|
| 82 |
+
|
| 83 |
+
# Create a dummy input tensor: (batch_size=1, channels=3, height=64, width=64)
|
| 84 |
+
x = torch.randn(1, 3, 64, 64)
|
| 85 |
+
noise = torch.randn(1, 4, 8, 8) # Match the latent shape (Z)
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
output = model(x, noise)
|
| 89 |
+
|
| 90 |
+
print("Input shape :", x.shape)
|
| 91 |
+
print("Output shape:", output.shape)
|
interface.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import CLIPTokenizer
|
| 5 |
+
|
| 6 |
+
# Import your existing model and pipeline modules
|
| 7 |
+
import model
|
| 8 |
+
import pipeline
|
| 9 |
+
|
| 10 |
+
# Device Configuration
|
| 11 |
+
ALLOW_CUDA = True
|
| 12 |
+
ALLOW_MPS = False
|
| 13 |
+
|
| 14 |
+
def determine_device():
|
| 15 |
+
if torch.cuda.is_available() and ALLOW_CUDA:
|
| 16 |
+
return "cuda"
|
| 17 |
+
elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
|
| 18 |
+
return "mps"
|
| 19 |
+
return "cpu"
|
| 20 |
+
|
| 21 |
+
DEVICE = determine_device()
|
| 22 |
+
print(f"Using device: {DEVICE}")
|
| 23 |
+
|
| 24 |
+
# Load tokenizer and models
|
| 25 |
+
tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
|
| 26 |
+
model_file = "inkpunk-diffusion-v1.ckpt"
|
| 27 |
+
models = model.preload_models_from_standard_weights(model_file, DEVICE)
|
| 28 |
+
# models=None
|
| 29 |
+
|
| 30 |
+
def generate_image(
|
| 31 |
+
prompt,
|
| 32 |
+
uncond_prompt="",
|
| 33 |
+
do_cfg=True,
|
| 34 |
+
cfg_scale=8,
|
| 35 |
+
sampler="ddpm",
|
| 36 |
+
num_inference_steps=50,
|
| 37 |
+
seed=42,
|
| 38 |
+
input_image=None,
|
| 39 |
+
strength=1.0
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Generate an image using the Stable Diffusion pipeline
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
- prompt (str): Text description of the image to generate
|
| 46 |
+
- uncond_prompt (str, optional): Negative prompt to guide generation
|
| 47 |
+
- do_cfg (bool): Whether to use classifier-free guidance
|
| 48 |
+
- cfg_scale (float): Classifier-free guidance scale
|
| 49 |
+
- sampler (str): Sampling method
|
| 50 |
+
- num_inference_steps (int): Number of denoising steps
|
| 51 |
+
- seed (int): Random seed for reproducibility
|
| 52 |
+
- input_image (PIL.Image, optional): Input image for image-to-image generation
|
| 53 |
+
- strength (float): Strength of image transformation (0-1)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
- PIL.Image: Generated image
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
# Ensure input_image is None if not provided
|
| 60 |
+
if input_image is None:
|
| 61 |
+
strength = 1.0
|
| 62 |
+
|
| 63 |
+
# Generate the image
|
| 64 |
+
output_image = pipeline.generate(
|
| 65 |
+
prompt=prompt,
|
| 66 |
+
uncond_prompt=uncond_prompt,
|
| 67 |
+
input_image=input_image,
|
| 68 |
+
strength=strength,
|
| 69 |
+
do_cfg=do_cfg,
|
| 70 |
+
cfg_scale=cfg_scale,
|
| 71 |
+
sampler_name=sampler,
|
| 72 |
+
n_inference_steps=num_inference_steps,
|
| 73 |
+
seed=seed,
|
| 74 |
+
models=models,
|
| 75 |
+
device=DEVICE,
|
| 76 |
+
idle_device="cuda",
|
| 77 |
+
tokenizer=tokenizer,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Convert numpy array to PIL Image
|
| 81 |
+
return Image.fromarray(output_image)
|
| 82 |
+
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error generating image: {e}")
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
def launch_gradio_interface():
|
| 88 |
+
"""
|
| 89 |
+
Create and launch Gradio interface for Stable Diffusion
|
| 90 |
+
"""
|
| 91 |
+
with gr.Blocks(title="Stable Diffusion Image Generator") as demo:
|
| 92 |
+
gr.Markdown("# 🎨 Stable Diffusion Image Generator")
|
| 93 |
+
|
| 94 |
+
with gr.Row():
|
| 95 |
+
with gr.Column():
|
| 96 |
+
# Text Inputs
|
| 97 |
+
prompt = gr.Textbox(label="Prompt",
|
| 98 |
+
placeholder="Describe the image you want to generate...")
|
| 99 |
+
uncond_prompt = gr.Textbox(label="Negative Prompt (Optional)",
|
| 100 |
+
placeholder="Describe what you don't want in the image...")
|
| 101 |
+
|
| 102 |
+
# Generation Parameters
|
| 103 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 104 |
+
do_cfg = gr.Checkbox(label="Use Classifier-Free Guidance", value=True)
|
| 105 |
+
cfg_scale = gr.Slider(minimum=1, maximum=14, value=8, label="CFG Scale")
|
| 106 |
+
sampler = gr.Dropdown(
|
| 107 |
+
choices=["ddpm", "ddim", "pndm"], # Add more samplers if available
|
| 108 |
+
value="ddpm",
|
| 109 |
+
label="Sampling Method"
|
| 110 |
+
)
|
| 111 |
+
num_inference_steps = gr.Slider(
|
| 112 |
+
minimum=10,
|
| 113 |
+
maximum=100,
|
| 114 |
+
value=50,
|
| 115 |
+
label="Number of Inference Steps"
|
| 116 |
+
)
|
| 117 |
+
seed = gr.Number(value=42, label="Random Seed")
|
| 118 |
+
|
| 119 |
+
# Image-to-Image Section
|
| 120 |
+
with gr.Accordion("Image-to-Image", open=False):
|
| 121 |
+
input_image = gr.Image(type="pil", label="Input Image (Optional)")
|
| 122 |
+
strength = gr.Slider(
|
| 123 |
+
minimum=0,
|
| 124 |
+
maximum=1,
|
| 125 |
+
value=0.8,
|
| 126 |
+
label="Image Transformation Strength"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Generate Button
|
| 130 |
+
generate_btn = gr.Button("Generate Image", variant="primary")
|
| 131 |
+
|
| 132 |
+
with gr.Row():
|
| 133 |
+
# Output Image
|
| 134 |
+
output_image = gr.Image(label="Generated Image")
|
| 135 |
+
|
| 136 |
+
# Connect Button to Generation Function
|
| 137 |
+
generate_btn.click(
|
| 138 |
+
fn=generate_image,
|
| 139 |
+
inputs=[
|
| 140 |
+
prompt, uncond_prompt, do_cfg, cfg_scale,
|
| 141 |
+
sampler, num_inference_steps, seed,
|
| 142 |
+
input_image, strength
|
| 143 |
+
],
|
| 144 |
+
outputs=output_image
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Launch the interface
|
| 148 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
launch_gradio_interface()
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from clip import CLIP
|
| 2 |
+
from encoder import VAE_Encoder
|
| 3 |
+
from decoder import VAE_Decoder
|
| 4 |
+
from diffusion import Diffusion
|
| 5 |
+
|
| 6 |
+
import model_converter
|
| 7 |
+
|
| 8 |
+
def preload_models_from_standard_weights(ckpt_path, device):
|
| 9 |
+
state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
|
| 10 |
+
|
| 11 |
+
encoder=VAE_Encoder().to(device)
|
| 12 |
+
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
| 13 |
+
|
| 14 |
+
decoder=VAE_Decoder().to(device)
|
| 15 |
+
decoder.load_state_dict(state_dict['decoder'], strict=True)
|
| 16 |
+
|
| 17 |
+
diffusion=Diffusion().to(device)
|
| 18 |
+
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
| 19 |
+
|
| 20 |
+
clip=CLIP().to(device)
|
| 21 |
+
clip.load_state_dict(state_dict['clip'], strict=True)
|
| 22 |
+
|
| 23 |
+
return {
|
| 24 |
+
'clip': clip,
|
| 25 |
+
'encoder': encoder,
|
| 26 |
+
'decoder': decoder,
|
| 27 |
+
'diffusion': diffusion,
|
| 28 |
+
}
|
model_converter.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pipeline.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from ddpm import DDPMSampler
|
| 5 |
+
|
| 6 |
+
WIDTH = 512
|
| 7 |
+
HEIGHT = 512
|
| 8 |
+
LATENTS_WIDTH = WIDTH // 8
|
| 9 |
+
LATENTS_HEIGHT = HEIGHT // 8
|
| 10 |
+
|
| 11 |
+
def generate(
|
| 12 |
+
prompt,
|
| 13 |
+
uncond_prompt=None,
|
| 14 |
+
input_image=None,
|
| 15 |
+
strength=0.8,
|
| 16 |
+
do_cfg=True,
|
| 17 |
+
cfg_scale=7.5,
|
| 18 |
+
sampler_name="ddpm",
|
| 19 |
+
n_inference_steps=50,
|
| 20 |
+
models={},
|
| 21 |
+
seed=None,
|
| 22 |
+
device=None,
|
| 23 |
+
idle_device=None,
|
| 24 |
+
tokenizer=None,
|
| 25 |
+
):
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
if not 0 < strength <= 1:
|
| 28 |
+
raise ValueError("strength must be between 0 and 1")
|
| 29 |
+
|
| 30 |
+
if idle_device:
|
| 31 |
+
to_idle = lambda x: x.to(idle_device)
|
| 32 |
+
else:
|
| 33 |
+
to_idle = lambda x: x
|
| 34 |
+
|
| 35 |
+
# Initialize random number generator according to the seed specified
|
| 36 |
+
generator = torch.Generator(device=device)
|
| 37 |
+
if seed is None:
|
| 38 |
+
generator.seed()
|
| 39 |
+
else:
|
| 40 |
+
generator.manual_seed(seed)
|
| 41 |
+
|
| 42 |
+
clip = models["clip"]
|
| 43 |
+
clip.to(device)
|
| 44 |
+
|
| 45 |
+
if do_cfg:
|
| 46 |
+
# Convert into a list of length Seq_Len=77
|
| 47 |
+
cond_tokens = tokenizer.batch_encode_plus(
|
| 48 |
+
[prompt], padding="max_length", max_length=77
|
| 49 |
+
).input_ids
|
| 50 |
+
# (Batch_Size, Seq_Len)
|
| 51 |
+
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
|
| 52 |
+
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
|
| 53 |
+
cond_context = clip(cond_tokens)
|
| 54 |
+
# Convert into a list of length Seq_Len=77
|
| 55 |
+
uncond_tokens = tokenizer.batch_encode_plus(
|
| 56 |
+
[uncond_prompt], padding="max_length", max_length=77
|
| 57 |
+
).input_ids
|
| 58 |
+
# (Batch_Size, Seq_Len)
|
| 59 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
|
| 60 |
+
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
|
| 61 |
+
uncond_context = clip(uncond_tokens)
|
| 62 |
+
# (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
|
| 63 |
+
context = torch.cat([cond_context, uncond_context])
|
| 64 |
+
else:
|
| 65 |
+
# Convert into a list of length Seq_Len=77
|
| 66 |
+
tokens = tokenizer.batch_encode_plus(
|
| 67 |
+
[prompt], padding="max_length", max_length=77
|
| 68 |
+
).input_ids
|
| 69 |
+
# (Batch_Size, Seq_Len)
|
| 70 |
+
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
|
| 71 |
+
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
|
| 72 |
+
context = clip(tokens)
|
| 73 |
+
to_idle(clip)
|
| 74 |
+
|
| 75 |
+
if sampler_name == "ddpm":
|
| 76 |
+
sampler = DDPMSampler(generator)
|
| 77 |
+
sampler.set_inference_timesteps(n_inference_steps)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError("Unknown sampler value %s. ")
|
| 80 |
+
|
| 81 |
+
latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
|
| 82 |
+
|
| 83 |
+
if input_image:
|
| 84 |
+
encoder = models["encoder"]
|
| 85 |
+
encoder.to(device)
|
| 86 |
+
|
| 87 |
+
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
|
| 88 |
+
|
| 89 |
+
# (Height, Width, Channel)
|
| 90 |
+
input_image_tensor = np.array(input_image_tensor)
|
| 91 |
+
|
| 92 |
+
# (Height, Width, Channel) -> (Height, Width, Channel)
|
| 93 |
+
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
|
| 94 |
+
input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
|
| 95 |
+
|
| 96 |
+
# (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
|
| 97 |
+
input_image_tensor = input_image_tensor.unsqueeze(0)
|
| 98 |
+
|
| 99 |
+
# (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
|
| 100 |
+
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
|
| 101 |
+
|
| 102 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width)
|
| 103 |
+
encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
|
| 104 |
+
latents = encoder(input_image_tensor, encoder_noise)
|
| 105 |
+
|
| 106 |
+
# Add noise to the latents (the encoded input image)
|
| 107 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width)
|
| 108 |
+
sampler.set_strength(strength=strength)
|
| 109 |
+
latents = sampler.add_noise(latents, sampler.timesteps[0])
|
| 110 |
+
to_idle(encoder)
|
| 111 |
+
else:
|
| 112 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width)
|
| 113 |
+
latents = torch.randn(latents_shape, generator=generator, device=device)
|
| 114 |
+
|
| 115 |
+
diffusion = models["diffusion"]
|
| 116 |
+
diffusion.to(device)
|
| 117 |
+
|
| 118 |
+
timesteps = tqdm(sampler.timesteps)
|
| 119 |
+
for i, timestep in enumerate(timesteps):
|
| 120 |
+
# (1, 320)
|
| 121 |
+
time_embedding = get_time_embedding(timestep).to(device)
|
| 122 |
+
|
| 123 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width)
|
| 124 |
+
model_input = latents
|
| 125 |
+
|
| 126 |
+
if do_cfg:
|
| 127 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
|
| 128 |
+
model_input = model_input.repeat(2, 1, 1, 1)
|
| 129 |
+
|
| 130 |
+
# model_output is the predicted noise
|
| 131 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
|
| 132 |
+
|
| 133 |
+
model_output = diffusion(model_input, context, time_embedding)
|
| 134 |
+
|
| 135 |
+
if do_cfg:
|
| 136 |
+
output_cond, output_uncond = model_output.chunk(2)
|
| 137 |
+
model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
|
| 138 |
+
|
| 139 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
|
| 140 |
+
latents = sampler.step(timestep, latents, model_output)
|
| 141 |
+
|
| 142 |
+
to_idle(diffusion)
|
| 143 |
+
|
| 144 |
+
decoder = models["decoder"]
|
| 145 |
+
decoder.to(device)
|
| 146 |
+
# (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
|
| 147 |
+
images = decoder(latents)
|
| 148 |
+
|
| 149 |
+
to_idle(decoder)
|
| 150 |
+
|
| 151 |
+
images = rescale(images, (-1, 1), (0, 255), clamp=True)
|
| 152 |
+
# (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
|
| 153 |
+
images = images.permute(0, 2, 3, 1)
|
| 154 |
+
images = images.to("cpu", torch.uint8).numpy()
|
| 155 |
+
return images[0]
|
| 156 |
+
|
| 157 |
+
def rescale(x, old_range, new_range, clamp=False):
|
| 158 |
+
old_min, old_max = old_range
|
| 159 |
+
new_min, new_max = new_range
|
| 160 |
+
x -= old_min
|
| 161 |
+
x *= (new_max - new_min) / (old_max - old_min)
|
| 162 |
+
x += new_min
|
| 163 |
+
if clamp:
|
| 164 |
+
x = x.clamp(new_min, new_max)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
def get_time_embedding(timestep):
|
| 168 |
+
# Shape: (160,)
|
| 169 |
+
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
|
| 170 |
+
# Shape: (1, 160)
|
| 171 |
+
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
| 172 |
+
# Shape: (1, 160 * 2)
|
| 173 |
+
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
| 174 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs==2.6.1
|
| 2 |
+
aiohttp==3.11.18
|
| 3 |
+
aiosignal==1.3.2
|
| 4 |
+
annotated-types==0.7.0
|
| 5 |
+
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
| 6 |
+
attrs==25.3.0
|
| 7 |
+
certifi==2025.4.26
|
| 8 |
+
charset-normalizer==3.4.2
|
| 9 |
+
click==8.2.0
|
| 10 |
+
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
|
| 11 |
+
contourpy==1.3.2
|
| 12 |
+
cycler==0.12.1
|
| 13 |
+
datasets==3.6.0
|
| 14 |
+
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321233760/work
|
| 15 |
+
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
| 16 |
+
dill==0.3.8
|
| 17 |
+
docker-pycreds==0.4.0
|
| 18 |
+
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
|
| 19 |
+
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
|
| 20 |
+
filelock==3.18.0
|
| 21 |
+
fonttools==4.58.0
|
| 22 |
+
frozenlist==1.6.0
|
| 23 |
+
fsspec==2025.3.0
|
| 24 |
+
gitdb==4.0.12
|
| 25 |
+
GitPython==3.1.44
|
| 26 |
+
hf-xet==1.1.0
|
| 27 |
+
huggingface-hub==0.31.1
|
| 28 |
+
idna==3.10
|
| 29 |
+
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1737420181517/work
|
| 30 |
+
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
|
| 31 |
+
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1745672166/work
|
| 32 |
+
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
|
| 33 |
+
ipywidgets==8.1.7
|
| 34 |
+
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
| 35 |
+
Jinja2==3.1.6
|
| 36 |
+
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
| 37 |
+
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
|
| 38 |
+
jupyterlab_widgets==3.0.15
|
| 39 |
+
kiwisolver==1.4.8
|
| 40 |
+
lightning==2.5.1.post0
|
| 41 |
+
lightning-utilities==0.14.3
|
| 42 |
+
MarkupSafe==3.0.2
|
| 43 |
+
matplotlib==3.10.3
|
| 44 |
+
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
| 45 |
+
mpmath==1.3.0
|
| 46 |
+
multidict==6.4.3
|
| 47 |
+
multiprocess==0.70.16
|
| 48 |
+
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
| 49 |
+
networkx==3.4.2
|
| 50 |
+
numpy==2.2.5
|
| 51 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 52 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 53 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 54 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 55 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 56 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 57 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 58 |
+
nvidia-curand-cu12==10.3.7.77
|
| 59 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 60 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 61 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 62 |
+
nvidia-nccl-cu12==2.26.2
|
| 63 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 64 |
+
nvidia-nvtx-cu12==12.6.77
|
| 65 |
+
packaging==24.2
|
| 66 |
+
pandas==2.2.3
|
| 67 |
+
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
| 68 |
+
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
|
| 69 |
+
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
| 70 |
+
pillow==11.2.1
|
| 71 |
+
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
|
| 72 |
+
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
|
| 73 |
+
propcache==0.3.1
|
| 74 |
+
protobuf==6.30.2
|
| 75 |
+
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663149797/work
|
| 76 |
+
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
|
| 77 |
+
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
| 78 |
+
pyarrow==20.0.0
|
| 79 |
+
pydantic==2.11.4
|
| 80 |
+
pydantic_core==2.33.2
|
| 81 |
+
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work
|
| 82 |
+
pyparsing==3.2.3
|
| 83 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1733215673016/work
|
| 84 |
+
pytorch-lightning==2.5.1.post0
|
| 85 |
+
pytz==2025.2
|
| 86 |
+
PyYAML==6.0.2
|
| 87 |
+
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1743831245578/work
|
| 88 |
+
regex==2024.11.6
|
| 89 |
+
requests==2.32.3
|
| 90 |
+
safetensors==0.5.3
|
| 91 |
+
sentry-sdk==2.27.0
|
| 92 |
+
setproctitle==1.3.6
|
| 93 |
+
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
|
| 94 |
+
smmap==5.0.2
|
| 95 |
+
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
| 96 |
+
sympy==1.14.0
|
| 97 |
+
tokenizers==0.21.1
|
| 98 |
+
torch==2.7.0
|
| 99 |
+
torchmetrics==1.7.1
|
| 100 |
+
torchvision==0.22.0
|
| 101 |
+
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1732615904614/work
|
| 102 |
+
tqdm==4.67.1
|
| 103 |
+
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
| 104 |
+
transformers==4.51.3
|
| 105 |
+
triton==3.3.0
|
| 106 |
+
typing-inspection==0.4.0
|
| 107 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work
|
| 108 |
+
tzdata==2025.2
|
| 109 |
+
urllib3==2.4.0
|
| 110 |
+
wandb==0.19.11
|
| 111 |
+
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
|
| 112 |
+
widgetsnbextension==4.0.14
|
| 113 |
+
xxhash==3.5.0
|
| 114 |
+
yarl==1.20.0
|
| 115 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
|
test.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|