# This file includes code derived from the SiT project (https://github.com/willisma/SiT), # which is licensed under the MIT License. # # MIT License # # Copyright (c) Meta Platforms, Inc. and affiliates. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import numpy as np import torch as th import torch.nn as nn from torchdiffeq import odeint from functools import partial from tqdm import tqdm class sde: """SDE solver class""" def __init__( self, drift, diffusion, *, t0, t1, num_steps, sampler_type, ): assert t0 < t1, "SDE sampler has to be in forward time" self.num_timesteps = num_steps self.t = th.linspace(t0, t1, num_steps) self.dt = self.t[1] - self.t[0] self.drift = drift self.diffusion = diffusion self.sampler_type = sampler_type def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): w_cur = th.randn(x.size()).to(x) t = th.ones(x.size(0)).to(x) * t dw = w_cur * th.sqrt(self.dt) drift = self.drift(x, t, model, **model_kwargs) diffusion = self.diffusion(x, t) mean_x = x + drift * self.dt x = mean_x + th.sqrt(2 * diffusion) * dw return x, mean_x def __Heun_step(self, x, _, t, model, **model_kwargs): w_cur = th.randn(x.size()).to(x) dw = w_cur * th.sqrt(self.dt) t_cur = th.ones(x.size(0)).to(x) * t diffusion = self.diffusion(x, t_cur) xhat = x + th.sqrt(2 * diffusion) * dw K1 = self.drift(xhat, t_cur, model, **model_kwargs) xp = xhat + self.dt * K1 K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step def __forward_fn(self): """TODO: generalize here by adding all private functions ending with steps to it""" sampler_dict = { "Euler": self.__Euler_Maruyama_step, "Heun": self.__Heun_step, } try: sampler = sampler_dict[self.sampler_type] except: raise NotImplementedError("Smapler type not implemented.") return sampler def sample(self, init, model, **model_kwargs): """forward loop of sde""" x = init mean_x = init samples = [] sampler = self.__forward_fn() for ti in self.t[:-1]: with th.no_grad(): x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) samples.append(x) return samples class ode: """ODE solver class""" def __init__( self, drift, *, t0, t1, sampler_type, num_steps, atol, rtol, ): assert t0 < t1, "ODE sampler has to be in forward time" self.drift = drift self.t = th.linspace(t0, t1, num_steps) self.atol = atol self.rtol = rtol self.sampler_type = sampler_type def sample(self, x, model, **model_kwargs): device = x[0].device if isinstance(x, tuple) else x.device def _fn(t, x): t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t model_output = self.drift(x, t, model, **model_kwargs) return model_output t = self.t.to(device) atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] samples = odeint( _fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol ) return samples