Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import torch | |
| from scepter.modules.model.registry import DIFFUSION_SAMPLERS | |
| from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler | |
| from scepter.modules.model.diffusion.util import _i | |
| def _i(tensor, t, x): | |
| """ | |
| Index tensor using t and format the output according to x. | |
| """ | |
| shape = (x.size(0), ) + (1, ) * (x.ndim - 1) | |
| if isinstance(t, torch.Tensor): | |
| t = t.to(tensor.device) | |
| return tensor[t].view(shape).to(x.device) | |
| class DDIMSampler(BaseDiffusionSampler): | |
| def init_params(self): | |
| super().init_params() | |
| self.eta = self.cfg.get('ETA', 0.) | |
| self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE', | |
| 'trailing') | |
| def preprare_sampler(self, | |
| noise, | |
| steps=20, | |
| scheduler_ins=None, | |
| prediction_type='', | |
| sigmas=None, | |
| betas=None, | |
| alphas=None, | |
| callback_fn=None, | |
| **kwargs): | |
| output = super().preprare_sampler(noise, steps, scheduler_ins, | |
| prediction_type, sigmas, betas, | |
| alphas, callback_fn, **kwargs) | |
| sigmas = output.sigmas | |
| sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) | |
| sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5 | |
| sigmas_vp[sigmas == float('inf')] = 1. | |
| output.add_custom_field('sigmas_vp', sigmas_vp) | |
| return output | |
| def step(self, sampler_output): | |
| x_t = sampler_output.x_t | |
| step = sampler_output.step | |
| t = sampler_output.ts[step] | |
| sigmas_vp = sampler_output.sigmas_vp.to(x_t.device) | |
| alpha_init = _i(sampler_output.alphas_init, step, x_t[:1]) | |
| sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1]) | |
| x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_init) | |
| noise_factor = self.eta * (sigmas_vp[step + 1]**2 / | |
| sigmas_vp[step]**2 * | |
| (1 - (1 - sigmas_vp[step]**2) / | |
| (1 - sigmas_vp[step + 1]**2))) | |
| d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step] | |
| x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \ | |
| (sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d | |
| sampler_output.x_0 = x | |
| if sigmas_vp[step + 1] > 0: | |
| x += noise_factor * torch.randn_like(x) | |
| sampler_output.x_t = x | |
| sampler_output.step += 1 | |
| sampler_output.msg = f'step {step}' | |
| return sampler_output |