make style
Browse files- modeling_glide.py +38 -17
modeling_glide.py
CHANGED
|
@@ -18,7 +18,14 @@ import numpy as np
|
|
| 18 |
import torch
|
| 19 |
|
| 20 |
import tqdm
|
| 21 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from transformers import GPT2Tokenizer
|
| 23 |
|
| 24 |
|
|
@@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline):
|
|
| 46 |
text_encoder: CLIPTextModel,
|
| 47 |
tokenizer: GPT2Tokenizer,
|
| 48 |
upscale_unet: GLIDESuperResUNetModel,
|
| 49 |
-
upscale_noise_scheduler: GlideDDIMScheduler
|
| 50 |
):
|
| 51 |
super().__init__()
|
| 52 |
self.register_modules(
|
| 53 |
-
text_unet=text_unet,
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
|
@@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline):
|
|
| 67 |
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 68 |
)
|
| 69 |
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
| 70 |
-
posterior_log_variance_clipped = _extract_into_tensor(
|
| 71 |
-
scheduler.posterior_log_variance_clipped, t, x_t.shape
|
| 72 |
-
)
|
| 73 |
assert (
|
| 74 |
posterior_mean.shape[0]
|
| 75 |
== posterior_variance.shape[0]
|
|
@@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline):
|
|
| 190 |
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
| 191 |
upsample_temp = 0.997
|
| 192 |
|
| 193 |
-
image =
|
| 194 |
-
(
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
num_timesteps = len(self.upscale_noise_scheduler)
|
| 198 |
-
for t in tqdm.tqdm(
|
|
|
|
|
|
|
| 199 |
# i) define coefficients for time step t
|
| 200 |
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
| 201 |
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
| 202 |
-
image_coeff = (
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
# ii) predict noise residual
|
| 208 |
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
|
@@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline):
|
|
| 216 |
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
| 217 |
|
| 218 |
# iv) sample variance
|
| 219 |
-
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
| 220 |
-
|
|
|
|
| 221 |
|
| 222 |
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
| 223 |
sampled_prev_image = prev_image + prev_variance
|
|
|
|
| 18 |
import torch
|
| 19 |
|
| 20 |
import tqdm
|
| 21 |
+
from diffusers import (
|
| 22 |
+
ClassifierFreeGuidanceScheduler,
|
| 23 |
+
CLIPTextModel,
|
| 24 |
+
DiffusionPipeline,
|
| 25 |
+
GlideDDIMScheduler,
|
| 26 |
+
GLIDESuperResUNetModel,
|
| 27 |
+
GLIDETextToImageUNetModel,
|
| 28 |
+
)
|
| 29 |
from transformers import GPT2Tokenizer
|
| 30 |
|
| 31 |
|
|
|
|
| 53 |
text_encoder: CLIPTextModel,
|
| 54 |
tokenizer: GPT2Tokenizer,
|
| 55 |
upscale_unet: GLIDESuperResUNetModel,
|
| 56 |
+
upscale_noise_scheduler: GlideDDIMScheduler,
|
| 57 |
):
|
| 58 |
super().__init__()
|
| 59 |
self.register_modules(
|
| 60 |
+
text_unet=text_unet,
|
| 61 |
+
text_noise_scheduler=text_noise_scheduler,
|
| 62 |
+
text_encoder=text_encoder,
|
| 63 |
+
tokenizer=tokenizer,
|
| 64 |
+
upscale_unet=upscale_unet,
|
| 65 |
+
upscale_noise_scheduler=upscale_noise_scheduler,
|
| 66 |
)
|
| 67 |
|
| 68 |
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
|
|
|
| 78 |
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 79 |
)
|
| 80 |
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
| 81 |
+
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
|
|
| 82 |
assert (
|
| 83 |
posterior_mean.shape[0]
|
| 84 |
== posterior_variance.shape[0]
|
|
|
|
| 199 |
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
| 200 |
upsample_temp = 0.997
|
| 201 |
|
| 202 |
+
image = (
|
| 203 |
+
self.upscale_noise_scheduler.sample_noise(
|
| 204 |
+
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
| 205 |
+
)
|
| 206 |
+
* upsample_temp
|
| 207 |
+
)
|
| 208 |
|
| 209 |
num_timesteps = len(self.upscale_noise_scheduler)
|
| 210 |
+
for t in tqdm.tqdm(
|
| 211 |
+
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
|
| 212 |
+
):
|
| 213 |
# i) define coefficients for time step t
|
| 214 |
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
| 215 |
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
| 216 |
+
image_coeff = (
|
| 217 |
+
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
| 218 |
+
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
|
| 219 |
+
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
| 220 |
+
)
|
| 221 |
+
clipped_coeff = (
|
| 222 |
+
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
| 223 |
+
* self.upscale_noise_scheduler.get_beta(t)
|
| 224 |
+
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
| 225 |
+
)
|
| 226 |
|
| 227 |
# ii) predict noise residual
|
| 228 |
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
|
|
|
| 236 |
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
| 237 |
|
| 238 |
# iv) sample variance
|
| 239 |
+
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
| 240 |
+
t, prev_image.shape, device=torch_device, generator=generator
|
| 241 |
+
)
|
| 242 |
|
| 243 |
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
| 244 |
sampled_prev_image = prev_image + prev_variance
|