Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import torch | |
| from scepter.modules.model.registry import NOISE_SCHEDULERS | |
| from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler | |
| class LinearScheduler(BaseNoiseScheduler): | |
| para_dict = {} | |
| def init_params(self): | |
| super().init_params() | |
| self.beta_min = self.cfg.get('BETA_MIN', 0.00085) | |
| self.beta_max = self.cfg.get('BETA_MAX', 0.012) | |
| def betas_to_sigmas(self, betas): | |
| return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) | |
| def get_schedule(self): | |
| betas = torch.linspace(self.beta_min, | |
| self.beta_max, | |
| self.num_timesteps, | |
| dtype=torch.float32) | |
| sigmas = self.betas_to_sigmas(betas) | |
| self._sigmas = sigmas | |
| self._betas = betas | |
| self._alphas = torch.sqrt(1 - sigmas**2) | |
| self._timesteps = torch.arange(len(sigmas), dtype=torch.float32) |