Spaces:
Running
on
Zero
Running
on
Zero
| # This file includes code derived from the SiT project (https://github.com/willisma/SiT), | |
| # which is licensed under the MIT License. | |
| # | |
| # MIT License | |
| # | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| import torch as th | |
| import numpy as np | |
| from functools import partial | |
| def expand_t_like_x(t, x): | |
| """Function to reshape time t to broadcastable dimension of x | |
| Args: | |
| t: [batch_dim,], time vector | |
| x: [batch_dim,...], data point | |
| """ | |
| dims = [1] * (len(x.size()) - 1) | |
| t = t.view(t.size(0), *dims) | |
| return t | |
| #################### Coupling Plans #################### | |
| class ICPlan: | |
| """Linear Coupling Plan""" | |
| def __init__(self, sigma=0.0): | |
| self.sigma = sigma | |
| def compute_alpha_t(self, t): | |
| """Compute the data coefficient along the path""" | |
| return t, 1 | |
| def compute_sigma_t(self, t): | |
| """Compute the noise coefficient along the path""" | |
| return 1 - t, -1 | |
| def compute_d_alpha_alpha_ratio_t(self, t): | |
| """Compute the ratio between d_alpha and alpha""" | |
| return 1 / t | |
| def compute_drift(self, x, t): | |
| """We always output sde according to score parametrization; """ | |
| t = expand_t_like_x(t, x) | |
| alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) | |
| sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
| drift = alpha_ratio * x | |
| diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t | |
| return -drift, diffusion | |
| def compute_diffusion(self, x, t, form="constant", norm=1.0): | |
| """Compute the diffusion term of the SDE | |
| Args: | |
| x: [batch_dim, ...], data point | |
| t: [batch_dim,], time vector | |
| form: str, form of the diffusion term | |
| norm: float, norm of the diffusion term | |
| """ | |
| t = expand_t_like_x(t, x) | |
| choices = { | |
| "constant": norm, | |
| "SBDM": norm * self.compute_drift(x, t)[1], | |
| "sigma": norm * self.compute_sigma_t(t)[0], | |
| "linear": norm * (1 - t), | |
| "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, | |
| "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, | |
| } | |
| try: | |
| diffusion = choices[form] | |
| except KeyError: | |
| raise NotImplementedError(f"Diffusion form {form} not implemented") | |
| return diffusion | |
| def get_score_from_velocity(self, velocity, x, t): | |
| """Wrapper function: transfrom velocity prediction model to score | |
| Args: | |
| velocity: [batch_dim, ...] shaped tensor; velocity model output | |
| x: [batch_dim, ...] shaped tensor; x_t data point | |
| t: [batch_dim,] time tensor | |
| """ | |
| t = expand_t_like_x(t, x) | |
| alpha_t, d_alpha_t = self.compute_alpha_t(t) | |
| sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
| mean = x | |
| reverse_alpha_ratio = alpha_t / d_alpha_t | |
| var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t | |
| score = (reverse_alpha_ratio * velocity - mean) / var | |
| return score | |
| def get_noise_from_velocity(self, velocity, x, t): | |
| """Wrapper function: transfrom velocity prediction model to denoiser | |
| Args: | |
| velocity: [batch_dim, ...] shaped tensor; velocity model output | |
| x: [batch_dim, ...] shaped tensor; x_t data point | |
| t: [batch_dim,] time tensor | |
| """ | |
| t = expand_t_like_x(t, x) | |
| alpha_t, d_alpha_t = self.compute_alpha_t(t) | |
| sigma_t, d_sigma_t = self.compute_sigma_t(t) | |
| mean = x | |
| reverse_alpha_ratio = alpha_t / d_alpha_t | |
| var = reverse_alpha_ratio * d_sigma_t - sigma_t | |
| noise = (reverse_alpha_ratio * velocity - mean) / var | |
| return noise | |
| def get_velocity_from_score(self, score, x, t): | |
| """Wrapper function: transfrom score prediction model to velocity | |
| Args: | |
| score: [batch_dim, ...] shaped tensor; score model output | |
| x: [batch_dim, ...] shaped tensor; x_t data point | |
| t: [batch_dim,] time tensor | |
| """ | |
| t = expand_t_like_x(t, x) | |
| drift, var = self.compute_drift(x, t) | |
| velocity = var * score - drift | |
| return velocity | |
| def compute_mu_t(self, t, x0, x1): | |
| """Compute the mean of time-dependent density p_t""" | |
| t = expand_t_like_x(t, x1) | |
| alpha_t, _ = self.compute_alpha_t(t) | |
| sigma_t, _ = self.compute_sigma_t(t) | |
| # t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1 | |
| return alpha_t * x1 + sigma_t * x0 | |
| def compute_xt(self, t, x0, x1): | |
| """Sample xt from time-dependent density p_t; rng is required""" | |
| xt = self.compute_mu_t(t, x0, x1) | |
| return xt | |
| def compute_ut(self, t, x0, x1, xt): | |
| """Compute the vector field corresponding to p_t""" | |
| t = expand_t_like_x(t, x1) | |
| _, d_alpha_t = self.compute_alpha_t(t) | |
| _, d_sigma_t = self.compute_sigma_t(t) | |
| return d_alpha_t * x1 + d_sigma_t * x0 | |
| def plan(self, t, x0, x1): | |
| xt = self.compute_xt(t, x0, x1) | |
| ut = self.compute_ut(t, x0, x1, xt) | |
| return t, xt, ut | |
| class VPCPlan(ICPlan): | |
| """class for VP path flow matching""" | |
| def __init__(self, sigma_min=0.1, sigma_max=20.0): | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \ | |
| (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min | |
| self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \ | |
| (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min | |
| def compute_alpha_t(self, t): | |
| """Compute coefficient of x1""" | |
| alpha_t = self.log_mean_coeff(t) | |
| alpha_t = th.exp(alpha_t) | |
| d_alpha_t = alpha_t * self.d_log_mean_coeff(t) | |
| return alpha_t, d_alpha_t | |
| def compute_sigma_t(self, t): | |
| """Compute coefficient of x0""" | |
| p_sigma_t = 2 * self.log_mean_coeff(t) | |
| sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) | |
| d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) | |
| return sigma_t, d_sigma_t | |
| def compute_d_alpha_alpha_ratio_t(self, t): | |
| """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" | |
| return self.d_log_mean_coeff(t) | |
| def compute_drift(self, x, t): | |
| """Compute the drift term of the SDE""" | |
| t = expand_t_like_x(t, x) | |
| beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) | |
| return -0.5 * beta_t * x, beta_t / 2 | |
| class GVPCPlan(ICPlan): | |
| def __init__(self, sigma=0.0): | |
| super().__init__(sigma) | |
| def compute_alpha_t(self, t): | |
| """Compute coefficient of x1""" | |
| alpha_t = th.sin(t * np.pi / 2) | |
| d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) | |
| return alpha_t, d_alpha_t | |
| def compute_sigma_t(self, t): | |
| """Compute coefficient of x0""" | |
| sigma_t = th.cos(t * np.pi / 2) | |
| d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) | |
| return sigma_t, d_sigma_t | |
| def compute_d_alpha_alpha_ratio_t(self, t): | |
| """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" | |
| return np.pi / (2 * th.tan(t * np.pi / 2)) | |