Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| from torch import nn | |
| import k_diffusion | |
| from k_diffusion import sampling, utils | |
| class VDenoiser(nn.Module): | |
| """A v-diffusion-pytorch model wrapper for k-diffusion.""" | |
| def __init__(self, inner_model): | |
| super().__init__() | |
| self.inner_model = inner_model | |
| self.sigma_data = 1. | |
| def get_scalings(self, sigma): | |
| c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) | |
| c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 | |
| c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 | |
| return c_skip, c_out, c_in | |
| def sigma_to_t(self, sigma): | |
| return sigma.atan() / math.pi * 2 | |
| def t_to_sigma(self, t): | |
| return (t * math.pi / 2).tan() | |
| def loss(self, input, noise, sigma, **kwargs): | |
| c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] | |
| noised_input = input + noise * utils.append_dims(sigma, input.ndim) | |
| model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) | |
| target = (input - c_skip * noised_input) / c_out | |
| return (model_output - target).pow(2).flatten(1).mean(1) | |
| def forward(self, input, sigma, **kwargs): | |
| c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] | |
| return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip | |
| class DiscreteSchedule(nn.Module): | |
| """A mapping between continuous noise levels (sigmas) and a list of discrete noise | |
| levels.""" | |
| def __init__(self, sigmas, quantize): | |
| super().__init__() | |
| self.register_buffer('sigmas', sigmas) | |
| self.register_buffer('log_sigmas', sigmas.log()) | |
| self.quantize = quantize | |
| def sigma_min(self): | |
| return self.sigmas[0] | |
| def sigma_max(self): | |
| return self.sigmas[-1] | |
| def get_sigmas(self, n=None): | |
| if n is None: | |
| return sampling.append_zero(self.sigmas.flip(0)) | |
| t_max = len(self.sigmas) - 1 | |
| t = torch.linspace(t_max, 0, n, device=self.sigmas.device) | |
| return sampling.append_zero(self.t_to_sigma(t)) | |
| def sigma_to_t(self, sigma, quantize=None): | |
| quantize = self.quantize if quantize is None else quantize | |
| log_sigma = sigma.log() | |
| dists = log_sigma - self.log_sigmas[:, None] | |
| if quantize: | |
| return dists.abs().argmin(dim=0).view(sigma.shape) | |
| low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) | |
| high_idx = low_idx + 1 | |
| low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] | |
| w = (low - log_sigma) / (low - high) | |
| w = w.clamp(0, 1) | |
| t = (1 - w) * low_idx + w * high_idx | |
| return t.view(sigma.shape) | |
| def t_to_sigma(self, t): | |
| t = t.float() | |
| low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | |
| log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] | |
| return log_sigma.exp() | |
| class DiscreteEpsDDPMDenoiser(DiscreteSchedule): | |
| """A wrapper for discrete schedule DDPM models that output eps (the predicted | |
| noise).""" | |
| def __init__(self, model, alphas_cumprod, quantize): | |
| super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) | |
| self.inner_model = model | |
| self.sigma_data = 1. | |
| def get_scalings(self, sigma): | |
| c_out = -sigma | |
| c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 | |
| return c_out, c_in | |
| def get_eps(self, *args, **kwargs): | |
| return self.inner_model(*args, **kwargs) | |
| def loss(self, input, noise, sigma, **kwargs): | |
| c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] | |
| noised_input = input + noise * utils.append_dims(sigma, input.ndim) | |
| eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) | |
| return (eps - noise).pow(2).flatten(1).mean(1) | |
| def forward(self, input, sigma, **kwargs): | |
| c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] | |
| eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) | |
| # !!! fix for special models (controlnet, inpaint, depth, ..) | |
| input = input[:, :eps.shape[1],...] | |
| return input + eps * c_out | |
| class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): | |
| """A wrapper for OpenAI diffusion models.""" | |
| def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): | |
| alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) | |
| super().__init__(model, alphas_cumprod, quantize=quantize) | |
| self.has_learned_sigmas = has_learned_sigmas | |
| def get_eps(self, *args, **kwargs): | |
| model_output = self.inner_model(*args, **kwargs) | |
| if self.has_learned_sigmas: | |
| return model_output.chunk(2, dim=1)[0] | |
| return model_output | |
| class CompVisDenoiser(DiscreteEpsDDPMDenoiser): | |
| """A wrapper for CompVis diffusion models.""" | |
| def __init__(self, model, quantize=False, device='cpu'): | |
| super().__init__(model, model.alphas_cumprod, quantize=quantize) | |
| def get_eps(self, *args, **kwargs): | |
| return self.inner_model.apply_model(*args, **kwargs) | |
| class DiscreteVDDPMDenoiser(DiscreteSchedule): | |
| """A wrapper for discrete schedule DDPM models that output v.""" | |
| def __init__(self, model, alphas_cumprod, quantize): | |
| super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) | |
| self.inner_model = model | |
| self.sigma_data = 1. | |
| def get_scalings(self, sigma): | |
| c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) | |
| c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 | |
| c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 | |
| return c_skip, c_out, c_in | |
| def get_v(self, *args, **kwargs): | |
| return self.inner_model(*args, **kwargs) | |
| def loss(self, input, noise, sigma, **kwargs): | |
| c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] | |
| noised_input = input + noise * utils.append_dims(sigma, input.ndim) | |
| model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) | |
| target = (input - c_skip * noised_input) / c_out | |
| return (model_output - target).pow(2).flatten(1).mean(1) | |
| def forward(self, input, sigma, **kwargs): | |
| c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] | |
| vout = self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out | |
| # !!! fix for special models (controlnet, upscale, ..) | |
| input = input[:, :vout.shape[1],...] | |
| return vout + input * c_skip | |
| #return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip | |
| class CompVisVDenoiser(DiscreteVDDPMDenoiser): | |
| """A wrapper for CompVis diffusion models that output v.""" | |
| def __init__(self, model, quantize=False, device='cpu'): | |
| super().__init__(model, model.alphas_cumprod, quantize=quantize) | |
| def get_v(self, x, t, cond, **kwargs): | |
| return self.inner_model.apply_model(x, t, cond) |