Spaces:
Running
on
Zero
Running
on
Zero
| # MIT License | |
| # Copyright (c) 2023 Alexander Tong | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # Copyright (c) [2023] [Alexander Tong] | |
| # Copyright (c) [2025] [Ziyue Jiang] | |
| # SPDX-License-Identifier: MIT | |
| # This file has been modified by Ziyue Jiang on 2025/03/19 | |
| # Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE. | |
| # This modified file is released under the same license. | |
| import math | |
| import torch | |
| from typing import Union | |
| from torch.distributions import LogisticNormal | |
| class LogitNormalTrainingTimesteps: | |
| def __init__(self, T=1000.0, loc=0.0, scale=1.0): | |
| assert T > 0 | |
| self.T = T | |
| self.dist = LogisticNormal(loc, scale) | |
| def sample(self, size, device): | |
| t = self.dist.sample(size)[..., 0].to(device) | |
| return t | |
| def pad_t_like_x(t, x): | |
| """Function to reshape the time vector t by the number of dimensions of x. | |
| Parameters | |
| ---------- | |
| x : Tensor, shape (bs, *dim) | |
| represents the source minibatch | |
| t : FloatTensor, shape (bs) | |
| Returns | |
| ------- | |
| t : Tensor, shape (bs, number of x dimensions) | |
| Example | |
| ------- | |
| x: Tensor (bs, C, W, H) | |
| t: Vector (bs) | |
| pad_t_like_x(t, x): Tensor (bs, 1, 1, 1) | |
| """ | |
| if isinstance(t, (float, int)): | |
| return t | |
| return t.reshape(-1, *([1] * (x.dim() - 1))) | |
| class ConditionalFlowMatcher: | |
| """Base class for conditional flow matching methods. This class implements the independent | |
| conditional flow matching methods from [1] and serves as a parent class for all other flow | |
| matching methods. | |
| It implements: | |
| - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function | |
| - conditional flow matching ut(x1|x0) = x1 - x0 | |
| - score function $\nabla log p_t(x|x0, x1)$ | |
| """ | |
| def __init__(self, sigma: Union[float, int] = 0.0): | |
| r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. | |
| Parameters | |
| ---------- | |
| sigma : Union[float, int] | |
| """ | |
| self.sigma = sigma | |
| self.time_sampler = LogitNormalTrainingTimesteps() | |
| def compute_mu_t(self, x0, x1, t): | |
| """ | |
| Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. | |
| Parameters | |
| ---------- | |
| x0 : Tensor, shape (bs, *dim) | |
| represents the source minibatch | |
| x1 : Tensor, shape (bs, *dim) | |
| represents the target minibatch | |
| t : FloatTensor, shape (bs) | |
| Returns | |
| ------- | |
| mean mu_t: t * x1 + (1 - t) * x0 | |
| References | |
| ---------- | |
| [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
| """ | |
| t = pad_t_like_x(t, x0) | |
| return t * x1 + (1 - t) * x0 | |
| def compute_sigma_t(self, t): | |
| """ | |
| Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. | |
| Parameters | |
| ---------- | |
| t : FloatTensor, shape (bs) | |
| Returns | |
| ------- | |
| standard deviation sigma | |
| References | |
| ---------- | |
| [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
| """ | |
| del t | |
| return self.sigma | |
| def sample_xt(self, x0, x1, t, epsilon): | |
| """ | |
| Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. | |
| Parameters | |
| ---------- | |
| x0 : Tensor, shape (bs, *dim) | |
| represents the source minibatch | |
| x1 : Tensor, shape (bs, *dim) | |
| represents the target minibatch | |
| t : FloatTensor, shape (bs) | |
| epsilon : Tensor, shape (bs, *dim) | |
| noise sample from N(0, 1) | |
| Returns | |
| ------- | |
| xt : Tensor, shape (bs, *dim) | |
| References | |
| ---------- | |
| [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
| """ | |
| mu_t = self.compute_mu_t(x0, x1, t) | |
| sigma_t = self.compute_sigma_t(t) | |
| sigma_t = pad_t_like_x(sigma_t, x0) | |
| return mu_t + sigma_t * epsilon | |
| def compute_conditional_flow(self, x0, x1, t, xt): | |
| """ | |
| Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. | |
| Parameters | |
| ---------- | |
| x0 : Tensor, shape (bs, *dim) | |
| represents the source minibatch | |
| x1 : Tensor, shape (bs, *dim) | |
| represents the target minibatch | |
| t : FloatTensor, shape (bs) | |
| xt : Tensor, shape (bs, *dim) | |
| represents the samples drawn from probability path pt | |
| Returns | |
| ------- | |
| ut : conditional vector field ut(x1|x0) = x1 - x0 | |
| References | |
| ---------- | |
| [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
| """ | |
| del t, xt | |
| return x1 - x0 | |
| def sample_noise_like(self, x): | |
| return torch.randn_like(x) | |
| def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): | |
| """ | |
| Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) | |
| and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. | |
| Parameters | |
| ---------- | |
| x0 : Tensor, shape (bs, *dim) | |
| represents the source minibatch | |
| x1 : Tensor, shape (bs, *dim) | |
| represents the target minibatch | |
| (optionally) t : Tensor, shape (bs) | |
| represents the time levels | |
| if None, drawn from uniform [0,1] | |
| return_noise : bool | |
| return the noise sample epsilon | |
| Returns | |
| ------- | |
| t : FloatTensor, shape (bs) | |
| xt : Tensor, shape (bs, *dim) | |
| represents the samples drawn from probability path pt | |
| ut : conditional vector field ut(x1|x0) = x1 - x0 | |
| (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon | |
| References | |
| ---------- | |
| [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
| """ | |
| if t is None: | |
| # t = torch.rand(x0.shape[0]).type_as(x0) | |
| t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0) | |
| assert len(t) == x0.shape[0], "t has to have batch size dimension" | |
| eps = self.sample_noise_like(x0) | |
| xt = self.sample_xt(x0, x1, t, eps) | |
| ut = self.compute_conditional_flow(x0, x1, t, xt) | |
| if return_noise: | |
| return t, xt, ut, eps | |
| else: | |
| return t, xt, ut | |
| def compute_lambda(self, t): | |
| """Compute the lambda function, see Eq.(23) [3]. | |
| Parameters | |
| ---------- | |
| t : FloatTensor, shape (bs) | |
| Returns | |
| ------- | |
| lambda : score weighting function | |
| References | |
| ---------- | |
| [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al. | |
| """ | |
| sigma_t = self.compute_sigma_t(t) | |
| return 2 * sigma_t / (self.sigma**2 + 1e-8) | |
| class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher): | |
| """Albergo et al. 2023 trigonometric interpolants class. This class inherits the | |
| ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in | |
| order to compute [3]'s trigonometric interpolants. | |
| [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. | |
| """ | |
| def compute_mu_t(self, x0, x1, t): | |
| r"""Compute the mean of the probability path (Eq.5) from [3]. | |
| Parameters | |
| ---------- | |
| x0 : Tensor, shape (bs, *dim) | |
| represents the source minibatch | |
| x1 : Tensor, shape (bs, *dim) | |
| represents the target minibatch | |
| t : FloatTensor, shape (bs) | |
| Returns | |
| ------- | |
| mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1 | |
| References | |
| ---------- | |
| [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. | |
| """ | |
| t = pad_t_like_x(t, x0) | |
| return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1 | |
| def compute_conditional_flow(self, x0, x1, t, xt): | |
| r"""Compute the conditional vector field similar to [3]. | |
| ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0), | |
| see Eq.(21) [3]. | |
| Parameters | |
| ---------- | |
| x0 : Tensor, shape (bs, *dim) | |
| represents the source minibatch | |
| x1 : Tensor, shape (bs, *dim) | |
| represents the target minibatch | |
| t : FloatTensor, shape (bs) | |
| xt : Tensor, shape (bs, *dim) | |
| represents the samples drawn from probability path pt | |
| Returns | |
| ------- | |
| ut : conditional vector field | |
| ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0) | |
| References | |
| ---------- | |
| [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. | |
| """ | |
| del xt | |
| t = pad_t_like_x(t, x0) | |
| return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0) | |