Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| from scipy.spatial.transform import Rotation | |
| MIN_EPS, MAX_EPS, N_EPS = 0.01, 2, 1000 | |
| X_N = 2000 | |
| """ | |
| Preprocessing for the SO(3) sampling and score computations, truncated infinite series are computed and then | |
| cached to memory, therefore the precomputation is only run the first time the repository is run on a machine | |
| """ | |
| omegas = np.linspace(0, np.pi, X_N + 1)[1:] | |
| def _compose(r1, r2): # R1 @ R2 but for Euler vecs | |
| return Rotation.from_matrix(Rotation.from_rotvec(r1).as_matrix() @ Rotation.from_rotvec(r2).as_matrix()).as_rotvec() | |
| def _expansion(omega, eps, L=2000): # the summation term only | |
| p = 0 | |
| for l in range(L): | |
| p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2) | |
| return p | |
| def _density(expansion, omega, marginal=True): # if marginal, density over [0, pi], else over SO(3) | |
| if marginal: | |
| return expansion * (1 - np.cos(omega)) / np.pi | |
| else: | |
| return expansion / 8 / np.pi ** 2 # the constant factor doesn't affect any actual calculations though | |
| def _score(exp, omega, eps, L=2000): # score of density over SO(3) | |
| dSigma = 0 | |
| for l in range(L): | |
| hi = np.sin(omega * (l + 1 / 2)) | |
| dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2)) | |
| lo = np.sin(omega / 2) | |
| dlo = 1 / 2 * np.cos(omega / 2) | |
| dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * (lo * dhi - hi * dlo) / lo ** 2 | |
| return dSigma / exp | |
| if os.path.exists('.so3_omegas_array2.npy'): | |
| _omegas_array = np.load('.so3_omegas_array2.npy') | |
| _cdf_vals = np.load('.so3_cdf_vals2.npy') | |
| _score_norms = np.load('.so3_score_norms2.npy') | |
| _exp_score_norms = np.load('.so3_exp_score_norms2.npy') | |
| else: | |
| _eps_array = 10 ** np.linspace(np.log10(MIN_EPS), np.log10(MAX_EPS), N_EPS) | |
| _omegas_array = np.linspace(0, np.pi, X_N + 1)[1:] | |
| _exp_vals = np.asarray([_expansion(_omegas_array, eps) for eps in _eps_array]) | |
| _pdf_vals = np.asarray([_density(_exp, _omegas_array, marginal=True) for _exp in _exp_vals]) | |
| _cdf_vals = np.asarray([_pdf.cumsum() / X_N * np.pi for _pdf in _pdf_vals]) | |
| _score_norms = np.asarray([_score(_exp_vals[i], _omegas_array, _eps_array[i]) for i in range(len(_eps_array))]) | |
| _exp_score_norms = np.sqrt(np.sum(_score_norms**2 * _pdf_vals, axis=1) / np.sum(_pdf_vals, axis=1) / np.pi) | |
| np.save('.so3_omegas_array2.npy', _omegas_array) | |
| np.save('.so3_cdf_vals2.npy', _cdf_vals) | |
| np.save('.so3_score_norms2.npy', _score_norms) | |
| np.save('.so3_exp_score_norms2.npy', _exp_score_norms) | |
| def sample(eps): | |
| eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS | |
| eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) | |
| x = np.random.rand() | |
| return np.interp(x, _cdf_vals[eps_idx], _omegas_array) | |
| def sample_vec(eps): | |
| x = np.random.randn(3) | |
| x /= np.linalg.norm(x) | |
| return x * sample(eps) | |
| def score_vec(eps, vec): | |
| eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS | |
| eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS - 1) | |
| om = np.linalg.norm(vec) | |
| return np.interp(om, _omegas_array, _score_norms[eps_idx]) * vec / om | |
| def score_norm(eps): | |
| eps = eps.numpy() | |
| eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS | |
| eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS-1) | |
| return torch.from_numpy(_exp_score_norms[eps_idx]).float() | |