import math import numpy as np import torch from scipy.special import gammaincc from scipy.stats import gamma from config import COV_TOL from utils import get_gpu_count, mahalanobis_torch, safe_cov_torch def pm_tail_gamma(d_out_sq, sq_dists): """ Computes the PM measure based on the Gamma fit. :param d_out_sq: squared mahalanobis distance from the output to its cluster on the manifold. :param sq_dists: squared mahalanobis distance of all distortions in the cluster to their cluster on the manifold. :return: PM score. """ mu = sq_dists.mean().item() var = sq_dists.var(unbiased=True).item() if var == 0.0: return 1.0 k = (mu**2) / var theta = var / mu return float(1.0 - gamma.cdf(d_out_sq, a=k, scale=theta)) def pm_tail_rank(d_out_sq, sq_dists): """ A depracted method to compute the PM measure based on the ranking method of distances. """ rank = int((sq_dists < d_out_sq).sum().item()) n = sq_dists.numel() return 1.0 - (rank + 0.5) / (n + 1.0) def diffusion_map_torch( X_np, labels_by_mix, *, cutoff=0.99, tol=1e-3, diffusion_time=1, alpha=0.0, eig_solver="lobpcg", k=None, device=None, return_eigs=False, return_complement=False, return_cval=False, ): """ Compute diffusion maps from a high dimensional set of points. :param X_np: high dimensional input. :param labels_by_mix: used to keep track of each source's coordinates on the manifold. :param cutoff: the desired ratio between sum of kept and sum of all eigenvalues. :param tol: deprecated since we do not use the "lobpcg" solver. :param diffusion_time: number of steps taken on the probability transition matrix. :param alpha: normalization factor in [0, 1]. :param eig_solver: "lobpcg" or "full". :param k: pre-defined truncation dimension. :param device: "cpu" or "cuda". :param return_eigs: return eigenvalues and eigenvectors. :param return_complement: return complementary coordinates, not just kept coordinates. :param return_cval: calculate and return the psi_2 norm of the coordinates. :return: """ device = device or ("cuda:0" if torch.cuda.is_available() else "cpu") X = torch.as_tensor(X_np, dtype=torch.float32, device=device) N = X.shape[0] if device != "cpu" and torch.cuda.is_available(): stream = torch.cuda.Stream(device=device) ctx_dev = torch.cuda.device(device) ctx_stream = torch.cuda.stream(stream) else: from contextlib import nullcontext stream = None ctx_dev = nullcontext() ctx_stream = nullcontext() with ctx_dev: with ctx_stream: if N > 1000: chunk = min(500, N) D2 = torch.zeros(N, N, device=device) for i in range(0, N, chunk): ei = min(i + chunk, N) for j in range(0, N, chunk): ej = min(j + chunk, N) D2[i:ei, j:ej] = torch.cdist(X[i:ei], X[j:ej]).pow_(2) else: D2 = torch.cdist(X, X).pow_(2) i, j = torch.triu_indices( N, N, offset=1, device=None if device == "cpu" else device ) eps = torch.median(D2[i, j]) K = torch.exp(-D2 / (2 * eps)) d = K.sum(dim=1) if alpha != 0.0: d_alpha_inv = d.pow(-alpha) K *= d_alpha_inv[:, None] * d_alpha_inv[None, :] d = K.sum(dim=1) D_half_inv = torch.diag(torch.rsqrt(d)) K_sym = D_half_inv @ K @ D_half_inv if eig_solver == "lobpcg": m = k if k is not None else min(N - 1, 50) init = torch.randn(N, m, device=device) vals, vecs = torch.lobpcg( K_sym, k=m, X=init, niter=200, tol=tol, largest=True ) elif eig_solver == "full": vals, vecs = torch.linalg.eigh(K_sym) vals, vecs = vals.flip(0), vecs.flip(1) if k is not None: vecs = vecs[:, : k + 1] vals = vals[: k + 1] else: raise ValueError(f"Unknown eig_solver '{eig_solver}'") psi = vecs[:, 1:] lam = vals[1:] cum = torch.cumsum(lam, dim=0) L = int((cum / cum[-1] < cutoff).sum().item()) + 1 lam_pow = lam.pow(diffusion_time) psi_all = psi * lam_pow Psi = psi_all[:, :L] Psi_rest = psi_all[:, L:] if return_cval: indices_with_out = [ ii for ii, name in enumerate(labels_by_mix) if "out" in name ] valid_idx = torch.tensor( [ii for ii in range(N) if ii not in indices_with_out], device=device ) pi_min = d[valid_idx].min() / d[valid_idx].sum() c_val = lam_pow[0] * pi_min.rsqrt() / math.log(2.0) if stream is not None: stream.synchronize() if return_complement and return_eigs and return_cval: return ( Psi.cpu().numpy(), Psi_rest.cpu().numpy(), lam.cpu().numpy(), float(c_val), ) if return_complement and return_eigs: return Psi.cpu().numpy(), Psi_rest.cpu().numpy(), lam.cpu().numpy() if return_complement: return Psi.cpu().numpy(), Psi_rest.cpu().numpy() if return_eigs: return Psi.cpu().numpy(), lam.cpu().numpy() return Psi.cpu().numpy() def compute_ps(coords, labels, max_gpus=None): """ Computes the PS measure. :param coords: coordinates on the manifold. :param labels: assign source index per coordinate. :param max_gpus: maximal number of GPUs to use. :return: the PS measure. """ ngpu = get_gpu_count(max_gpus) if ngpu == 0: coords_t = torch.tensor(coords) spks_here = sorted({l.split("-")[0] for l in labels}) out = {} for s in spks_here: idxs = [i for i, l in enumerate(labels) if l.startswith(s)] out_i = labels.index(f"{s}-out") ref_is = [i for i in idxs if i != out_i] mu = coords_t[ref_is].mean(0) cov = safe_cov_torch(coords_t[ref_is]) inv = torch.linalg.inv(cov) A = mahalanobis_torch(coords_t[out_i], mu, inv) B_list = [] for o in spks_here: if o == s: continue o_idxs = [ i for i, l in enumerate(labels) if l.startswith(o) and not l.endswith("-out") ] mu_o = coords_t[o_idxs].mean(0) inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs])) B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o)) B_min = torch.min(torch.stack(B_list)) if B_list else torch.tensor(0.0) out[s] = (1 - A / (A + B_min + 1e-6)).item() return out device = min(ngpu - 1, 1) device_str = f"cuda:{device}" coords_t = torch.tensor(coords, device=device_str) spks_here = sorted({l.split("-")[0] for l in labels}) out = {} stream = torch.cuda.Stream(device=device_str) with torch.cuda.device(device): with torch.cuda.stream(stream): for s in spks_here: idxs = [i for i, l in enumerate(labels) if l.startswith(s)] out_i = labels.index(f"{s}-out") ref_is = [i for i in idxs if i != out_i] mu = coords_t[ref_is].mean(0) cov = safe_cov_torch(coords_t[ref_is]) inv = torch.linalg.inv(cov) A = mahalanobis_torch(coords_t[out_i], mu, inv) B_list = [] for o in spks_here: if o == s: continue o_idxs = [ i for i, l in enumerate(labels) if l.startswith(o) and not l.endswith("-out") ] mu_o = coords_t[o_idxs].mean(0) inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs])) B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o)) B_min = ( torch.min(torch.stack(B_list)) if B_list else torch.tensor(0.0, device=device_str) ) out[s] = (1 - A / (A + B_min + 1e-6)).item() stream.synchronize() return out def compute_pm(coords, labels, pm_method, max_gpus=None): """ Computes the PM measure. :param coords: coordinates on the manifold. :param labels: assign source index per coordinate. :param pm_method: "rank" or "gamma". :param max_gpus: maximal number of GPUs to use. :return: the PS measure. """ ngpu = get_gpu_count(max_gpus) if ngpu == 0: coords_t = torch.tensor(coords) spks_here = sorted({l.split("-")[0] for l in labels}) out = {} for s in spks_here: idxs = [i for i, l in enumerate(labels) if l.startswith(s)] ref_i = labels.index(f"{s}-ref") out_i = labels.index(f"{s}-out") d_idx = [i for i in idxs if i not in {ref_i, out_i}] if len(d_idx) < 2: out[s] = 0.0 continue ref_v = coords_t[ref_i] dist = coords_t[d_idx] - ref_v N, D = dist.shape cov = dist.T @ dist / (N - 1) if torch.linalg.matrix_rank(cov) < D: cov += torch.eye(D) * COV_TOL inv = torch.linalg.inv(cov) sq_dists = torch.stack( [mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx] ) d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2) pm_score = ( pm_tail_rank(d_out_sq, sq_dists) if pm_method == "rank" else pm_tail_gamma(d_out_sq, sq_dists) ) out[s] = float(np.clip(pm_score, 0.0, 1.0)) return out device = min(ngpu - 1, 1) device_str = f"cuda:{device}" coords_t = torch.tensor(coords, device=device_str) spks_here = sorted({l.split("-")[0] for l in labels}) out = {} stream = torch.cuda.Stream(device=device_str) with torch.cuda.device(device): with torch.cuda.stream(stream): for s in spks_here: idxs = [i for i, l in enumerate(labels) if l.startswith(s)] ref_i = labels.index(f"{s}-ref") out_i = labels.index(f"{s}-out") d_idx = [i for i in idxs if i not in {ref_i, out_i}] if len(d_idx) < 2: out[s] = 0.0 continue ref_v = coords_t[ref_i] dist = coords_t[d_idx] - ref_v N, D = dist.shape cov = dist.T @ dist / (N - 1) if torch.linalg.matrix_rank(cov) < D: cov += torch.eye(D, device=device_str) * COV_TOL inv = torch.linalg.inv(cov) sq_dists = torch.stack( [mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx] ) d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2) pm_score = ( pm_tail_rank(d_out_sq, sq_dists) if pm_method == "rank" else pm_tail_gamma(d_out_sq, sq_dists) ) out[s] = float(np.clip(pm_score, 0.0, 1.0)) stream.synchronize() return out def pm_ci_components_full( coords_d, coords_rest, eigvals, labels, *, delta=0.05, K=1.0, C1=1.0, C2=1.0 ): """ Computes the error radius and tail bounds for the PM measure. :param coords_d: Retained diffusion maps coordinates. :param coords_rest: Complement diffusion maps coordinates. :param eigvals: Eigenvalues of the diffusion maps. :param labels: Assign source index per coordinate :param delta: 1-\delta is the confidence score. :param K: Absolute constant. :param C1: Absolute constant. :param C2: Absolute constant. :return: error radius and tail bounds for the PM measure. """ _EPS = 1e-12 def _safe_x(a, theta): return a / max(theta, _EPS) D = coords_d.shape[1] m = coords_rest.shape[1] if m == 0: z = {s: 0.0 for s in {l.split("-")[0] for l in labels}} return z.copy(), z.copy() X_d = torch.tensor( coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu" ) X_c = torch.tensor( coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu" ) spk_ids = sorted({l.split("-")[0] for l in labels}) bias_ci = {} prob_ci = {} for s in spk_ids: idxs = [i for i, l in enumerate(labels) if l.startswith(s)] ref_i = labels.index(f"{s}-ref") out_i = labels.index(f"{s}-out") dist_is = [i for i in idxs if i not in {ref_i, out_i}] n_p = len(dist_is) if n_p < 2: bias_ci[s] = 0.0 prob_ci[s] = 0.0 continue ref_d = X_d[ref_i] ref_c = X_c[ref_i] D_mat = X_d[dist_is] - ref_d C_mat = X_c[dist_is] - ref_c Sigma_d = safe_cov_torch(D_mat) Sigma_c = safe_cov_torch(C_mat) C_dc = D_mat.T @ C_mat / (n_p - 1) inv_Sigma_d = torch.linalg.inv(Sigma_d) S_i = ( Sigma_c - C_dc.T @ inv_Sigma_d @ C_dc + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9 ) S_inv = torch.linalg.inv(S_i) diff_out_d = X_d[out_i] - ref_d diff_out_c = X_c[out_i] - ref_c r_out = diff_out_c - C_dc.T @ inv_Sigma_d @ diff_out_d delta_Gi_a = float(r_out @ S_inv @ r_out) r_list = [] for p in dist_is: d_p = X_d[p] - ref_d c_p = X_c[p] - ref_c r_p = c_p - C_dc.T @ inv_Sigma_d @ d_p r_list.append(r_p) R_p = torch.stack(r_list, dim=0) delta_Gi_p = torch.sum(R_p @ S_inv * R_p, dim=1) delta_Gi_mu_max = float(delta_Gi_p.max()) mah_sq = torch.stack( [(X_d[i] - ref_d) @ inv_Sigma_d @ (X_d[i] - ref_d) for i in dist_is] ) mu_g = float(mah_sq.mean()) sigma2_g = float(mah_sq.var(unbiased=True) + 1e-12) sigma_g = math.sqrt(sigma2_g) full_sq = mah_sq + delta_Gi_p mu_full = float(full_sq.mean()) sigma2_full = float(full_sq.var(unbiased=True) + 1e-12) if sigma2_g == 0.0: delta_Gi_k = delta_Gi_theta = 0.0 else: factor = delta_Gi_mu_max * n_p / (n_p - 1) delta_Gi_k = 1.0 * factor * (mu_full + mu_g) / sigma2_g delta_Gi_theta = 1.0 * factor * (sigma2_full + sigma2_g) / (mu_g**2 + 1e-9) k_d = (mu_g**2) / max(sigma2_g, 1e-12) theta_d = sigma2_g / max(mu_g, 1e-12) a_d = float(diff_out_d @ inv_Sigma_d @ diff_out_d) pm_center = gammaincc(k_d, _safe_x(a_d, theta_d)) corner_vals = [] for s_k in (-1, 1): for s_theta in (-1, 1): for s_a in (-1, 1): k_c = max(k_d + s_k * delta_Gi_k, 1e-6) theta_c = max(theta_d + s_theta * delta_Gi_theta, 1e-6) a_c = max(a_d + s_a * delta_Gi_a, 1e-8) corner_vals.append(gammaincc(k_c, _safe_x(a_c, theta_c))) bias_ci[s] = max(abs(v - pm_center) for v in corner_vals) R_sq = float(mah_sq.max()) + 1e-12 log_term = math.log(6.0 / delta) eps_mu = math.sqrt(2 * sigma2_g * log_term / n_p) + 3 * R_sq * log_term / n_p eps_sigma = ( math.sqrt(2 * R_sq**2 * log_term / n_p) + 3 * R_sq**2 * log_term / n_p ) g1_x = 2.0 * mu_g / (sigma2_g + 1e-9) g1_y = -2.0 * mu_g**2 / (sigma_g**3 + 1e-9) g2_x = -sigma2_g / (mu_g**2 + 1e-9) g2_y = 2.0 * sigma_g / (mu_g + 1e-9) delta_k = min(abs(g1_x) * eps_mu + abs(g1_y) * eps_sigma, 0.5 * k_d) delta_theta = min(abs(g2_x) * eps_mu + abs(g2_y) * eps_sigma, 0.5 * theta_d) delta_a = min(R_sq * math.sqrt(2 * log_term / n_p), 0.5 * a_d + 1e-12) pm_corners = [] for s_k in (-1, 1): for s_theta in (-1, 1): for s_a in (-1, 1): k_c = k_d + s_k * delta_k theta_c = theta_d + s_theta * delta_theta a_c = max(a_d + s_a * delta_a, 1e-8) pm_corners.append(gammaincc(k_c, _safe_x(a_c, theta_c))) prob_ci[s] = max(abs(pm - pm_center) for pm in pm_corners) return bias_ci, prob_ci def ps_ci_components_full(coords_d, coords_rest, eigvals, labels, *, delta=0.05): """ Computes the error radius and tail bounds for the PS measure. :param coords_d: Retained diffusion maps coordinates. :param coords_rest: Complement diffusion maps coordinates. :param eigvals: Eigenvalues of the diffusion maps. :param labels: Assign source index per coordinate :param delta: 1-\delta is the confidence score. :return: error radius and tail bounds for the PS measure. """ def _mean_dev(lam_max, delta, n_eff): return math.sqrt(2 * lam_max * math.log(2 / delta) / n_eff) def _rel_cov_dev(lam_max, trace, delta, n_eff, C=1.0): r = trace / lam_max abs_dev = ( C * lam_max * (math.sqrt(r / n_eff) + (r + math.log(2 / delta)) / n_eff) ) return abs_dev / lam_max def _maha_eps_m(a_hat, lam_min, lam_max, mean_dev, rel_cov_dev): term1 = 2 * math.sqrt(a_hat) * mean_dev * math.sqrt(lam_max / lam_min) term2 = a_hat * rel_cov_dev return term1 + term2 D = coords_d.shape[1] m = coords_rest.shape[1] if m == 0: z = {s: 0.0 for s in set(l.split("-")[0] for l in labels)} return z.copy(), z.copy() X_d = torch.tensor( coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu" ) X_c = torch.tensor( coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu" ) spk_ids = sorted({l.split("-")[0] for l in labels}) bias = {} prob = {} for s in spk_ids: idxs = [i for i, l in enumerate(labels) if l.startswith(s)] out_i = labels.index(f"{s}-out") ref_is = [i for i in idxs if i != out_i] mu_d = X_d[ref_is].mean(0) mu_c = X_c[ref_is].mean(0) Sigma_d = safe_cov_torch(X_d[ref_is]) Sigma_c = safe_cov_torch(X_c[ref_is]) C_dc = (X_d[ref_is] - mu_d).T @ (X_c[ref_is] - mu_c) / (len(ref_is) - 1) inv_Sd = torch.linalg.inv(Sigma_d) lam_min = torch.linalg.eigvalsh(Sigma_d).min().clamp_min(1e-9).item() lam_max = torch.linalg.eigvalsh(Sigma_d).max() trace = torch.trace(Sigma_d).item() diff_d = X_d[out_i] - mu_d diff_c = X_c[out_i] - mu_c A_d = float(mahalanobis_torch(X_d[out_i], mu_d, inv_Sd)) r_i = diff_c - C_dc.T @ inv_Sd @ diff_d S_i = ( Sigma_c - C_dc.T @ inv_Sd @ C_dc + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9 ) term_i = math.sqrt(float(r_i @ torch.linalg.solve(S_i, r_i))) B_d, term_j = float("inf"), 0.0 Sig_o = None for o in spk_ids: if o == s: continue o_idxs = [ i for i, l in enumerate(labels) if l.startswith(o) and not l.endswith("-out") ] muo_d = X_d[o_idxs].mean(0) muo_c = X_c[o_idxs].mean(0) Sig_o_tmp = safe_cov_torch(X_d[o_idxs]) inv_So = torch.linalg.inv(Sig_o_tmp) this_B = float(mahalanobis_torch(X_d[out_i], muo_d, inv_So)) if this_B < B_d: B_d = this_B Sig_o = Sig_o_tmp diff_do = X_d[out_i] - muo_d diff_co = X_c[out_i] - muo_c C_oc = ( (X_d[o_idxs] - muo_d).T @ (X_c[o_idxs] - muo_c) / (len(o_idxs) - 1) ) r_j = diff_co - C_oc.T @ inv_So @ diff_do S_j = ( safe_cov_torch(X_c[o_idxs]) - C_oc.T @ inv_So @ C_oc + torch.eye(X_c.shape[1], device=X_c.device) * 1e-9 ) term_j = math.sqrt(float(r_j @ torch.linalg.solve(S_j, r_j))) denom = A_d + B_d bias[s] = (B_d * term_i + A_d * term_j) / (denom**2) if Sig_o is not None: lam_min_o = torch.linalg.eigvalsh(Sig_o).min().clamp_min(1e-9).item() lam_max_o = torch.linalg.eigvalsh(Sig_o).max().item() trace_o = torch.trace(Sig_o).item() n_eff = max(int(0.7 * len(ref_is)), 3) RIDGE = 0.05 lam_min_eff = max(lam_min, RIDGE * lam_max.item()) lam_min_o_eff = max(lam_min_o, RIDGE * lam_max_o) eps_i_sg = _maha_eps_m( A_d, lam_min_eff, lam_max.item(), _mean_dev(lam_max.item(), delta / 2, n_eff), _rel_cov_dev(lam_max.item(), trace, delta / 2, n_eff), ) eps_j_sg = _maha_eps_m( B_d, lam_min_o_eff, lam_max_o, _mean_dev(lam_max_o, delta / 2, n_eff), _rel_cov_dev(lam_max_o, trace_o, delta / 2, n_eff), ) grad_l2 = math.hypot(A_d, B_d) / (A_d + B_d) ** 2 ps_radius = grad_l2 * math.hypot(eps_i_sg, eps_j_sg) prob[s] = min(1.0, ps_radius) else: prob[s] = 0.0 return bias, prob