MAPSS-measures / metrics.py
AIvry's picture
Upload 11 files
b759ccc verified
import math
import numpy as np
import torch
from scipy.special import gammaincc
from scipy.stats import gamma
from config import COV_TOL
from utils import get_gpu_count, mahalanobis_torch, safe_cov_torch
def pm_tail_gamma(d_out_sq, sq_dists):
"""
Computes the PM measure based on the Gamma fit.
:param d_out_sq: squared mahalanobis distance from the output to its cluster on the manifold.
:param sq_dists: squared mahalanobis distance of all distortions in the cluster to their cluster on the manifold.
:return: PM score.
"""
mu = sq_dists.mean().item()
var = sq_dists.var(unbiased=True).item()
if var == 0.0:
return 1.0
k = (mu**2) / var
theta = var / mu
return float(1.0 - gamma.cdf(d_out_sq, a=k, scale=theta))
def pm_tail_rank(d_out_sq, sq_dists):
"""
A depracted method to compute the PM measure based on the ranking method of distances.
"""
rank = int((sq_dists < d_out_sq).sum().item())
n = sq_dists.numel()
return 1.0 - (rank + 0.5) / (n + 1.0)
def diffusion_map_torch(
X_np,
labels_by_mix,
*,
cutoff=0.99,
tol=1e-3,
diffusion_time=1,
alpha=0.0,
eig_solver="lobpcg",
k=None,
device=None,
return_eigs=False,
return_complement=False,
return_cval=False,
):
"""
Compute diffusion maps from a high dimensional set of points.
:param X_np: high dimensional input.
:param labels_by_mix: used to keep track of each source's coordinates on the manifold.
:param cutoff: the desired ratio between sum of kept and sum of all eigenvalues.
:param tol: deprecated since we do not use the "lobpcg" solver.
:param diffusion_time: number of steps taken on the probability transition matrix.
:param alpha: normalization factor in [0, 1].
:param eig_solver: "lobpcg" or "full".
:param k: pre-defined truncation dimension.
:param device: "cpu" or "cuda".
:param return_eigs: return eigenvalues and eigenvectors.
:param return_complement: return complementary coordinates, not just kept coordinates.
:param return_cval: calculate and return the psi_2 norm of the coordinates.
:return:
"""
device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
X = torch.as_tensor(X_np, dtype=torch.float32, device=device)
N = X.shape[0]
if device != "cpu" and torch.cuda.is_available():
stream = torch.cuda.Stream(device=device)
ctx_dev = torch.cuda.device(device)
ctx_stream = torch.cuda.stream(stream)
else:
from contextlib import nullcontext
stream = None
ctx_dev = nullcontext()
ctx_stream = nullcontext()
with ctx_dev:
with ctx_stream:
if N > 1000:
chunk = min(500, N)
D2 = torch.zeros(N, N, device=device)
for i in range(0, N, chunk):
ei = min(i + chunk, N)
for j in range(0, N, chunk):
ej = min(j + chunk, N)
D2[i:ei, j:ej] = torch.cdist(X[i:ei], X[j:ej]).pow_(2)
else:
D2 = torch.cdist(X, X).pow_(2)
i, j = torch.triu_indices(
N, N, offset=1, device=None if device == "cpu" else device
)
eps = torch.median(D2[i, j])
K = torch.exp(-D2 / (2 * eps))
d = K.sum(dim=1)
if alpha != 0.0:
d_alpha_inv = d.pow(-alpha)
K *= d_alpha_inv[:, None] * d_alpha_inv[None, :]
d = K.sum(dim=1)
D_half_inv = torch.diag(torch.rsqrt(d))
K_sym = D_half_inv @ K @ D_half_inv
if eig_solver == "lobpcg":
m = k if k is not None else min(N - 1, 50)
init = torch.randn(N, m, device=device)
vals, vecs = torch.lobpcg(
K_sym, k=m, X=init, niter=200, tol=tol, largest=True
)
elif eig_solver == "full":
vals, vecs = torch.linalg.eigh(K_sym)
vals, vecs = vals.flip(0), vecs.flip(1)
if k is not None:
vecs = vecs[:, : k + 1]
vals = vals[: k + 1]
else:
raise ValueError(f"Unknown eig_solver '{eig_solver}'")
psi = vecs[:, 1:]
lam = vals[1:]
cum = torch.cumsum(lam, dim=0)
L = int((cum / cum[-1] < cutoff).sum().item()) + 1
lam_pow = lam.pow(diffusion_time)
psi_all = psi * lam_pow
Psi = psi_all[:, :L]
Psi_rest = psi_all[:, L:]
if return_cval:
indices_with_out = [
ii for ii, name in enumerate(labels_by_mix) if "out" in name
]
valid_idx = torch.tensor(
[ii for ii in range(N) if ii not in indices_with_out], device=device
)
pi_min = d[valid_idx].min() / d[valid_idx].sum()
c_val = lam_pow[0] * pi_min.rsqrt() / math.log(2.0)
if stream is not None:
stream.synchronize()
if return_complement and return_eigs and return_cval:
return (
Psi.cpu().numpy(),
Psi_rest.cpu().numpy(),
lam.cpu().numpy(),
float(c_val),
)
if return_complement and return_eigs:
return Psi.cpu().numpy(), Psi_rest.cpu().numpy(), lam.cpu().numpy()
if return_complement:
return Psi.cpu().numpy(), Psi_rest.cpu().numpy()
if return_eigs:
return Psi.cpu().numpy(), lam.cpu().numpy()
return Psi.cpu().numpy()
def compute_ps(coords, labels, max_gpus=None):
"""
Computes the PS measure.
:param coords: coordinates on the manifold.
:param labels: assign source index per coordinate.
:param max_gpus: maximal number of GPUs to use.
:return: the PS measure.
"""
ngpu = get_gpu_count(max_gpus)
if ngpu == 0:
coords_t = torch.tensor(coords)
spks_here = sorted({l.split("-")[0] for l in labels})
out = {}
for s in spks_here:
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
out_i = labels.index(f"{s}-out")
ref_is = [i for i in idxs if i != out_i]
mu = coords_t[ref_is].mean(0)
cov = safe_cov_torch(coords_t[ref_is])
inv = torch.linalg.inv(cov)
A = mahalanobis_torch(coords_t[out_i], mu, inv)
B_list = []
for o in spks_here:
if o == s:
continue
o_idxs = [
i
for i, l in enumerate(labels)
if l.startswith(o) and not l.endswith("-out")
]
mu_o = coords_t[o_idxs].mean(0)
inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
B_min = torch.min(torch.stack(B_list)) if B_list else torch.tensor(0.0)
out[s] = (1 - A / (A + B_min + 1e-6)).item()
return out
device = min(ngpu - 1, 1)
device_str = f"cuda:{device}"
coords_t = torch.tensor(coords, device=device_str)
spks_here = sorted({l.split("-")[0] for l in labels})
out = {}
stream = torch.cuda.Stream(device=device_str)
with torch.cuda.device(device):
with torch.cuda.stream(stream):
for s in spks_here:
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
out_i = labels.index(f"{s}-out")
ref_is = [i for i in idxs if i != out_i]
mu = coords_t[ref_is].mean(0)
cov = safe_cov_torch(coords_t[ref_is])
inv = torch.linalg.inv(cov)
A = mahalanobis_torch(coords_t[out_i], mu, inv)
B_list = []
for o in spks_here:
if o == s:
continue
o_idxs = [
i
for i, l in enumerate(labels)
if l.startswith(o) and not l.endswith("-out")
]
mu_o = coords_t[o_idxs].mean(0)
inv_o = torch.linalg.inv(safe_cov_torch(coords_t[o_idxs]))
B_list.append(mahalanobis_torch(coords_t[out_i], mu_o, inv_o))
B_min = (
torch.min(torch.stack(B_list))
if B_list
else torch.tensor(0.0, device=device_str)
)
out[s] = (1 - A / (A + B_min + 1e-6)).item()
stream.synchronize()
return out
def compute_pm(coords, labels, pm_method, max_gpus=None):
"""
Computes the PM measure.
:param coords: coordinates on the manifold.
:param labels: assign source index per coordinate.
:param pm_method: "rank" or "gamma".
:param max_gpus: maximal number of GPUs to use.
:return: the PS measure.
"""
ngpu = get_gpu_count(max_gpus)
if ngpu == 0:
coords_t = torch.tensor(coords)
spks_here = sorted({l.split("-")[0] for l in labels})
out = {}
for s in spks_here:
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
ref_i = labels.index(f"{s}-ref")
out_i = labels.index(f"{s}-out")
d_idx = [i for i in idxs if i not in {ref_i, out_i}]
if len(d_idx) < 2:
out[s] = 0.0
continue
ref_v = coords_t[ref_i]
dist = coords_t[d_idx] - ref_v
N, D = dist.shape
cov = dist.T @ dist / (N - 1)
if torch.linalg.matrix_rank(cov) < D:
cov += torch.eye(D) * COV_TOL
inv = torch.linalg.inv(cov)
sq_dists = torch.stack(
[mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
)
d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
pm_score = (
pm_tail_rank(d_out_sq, sq_dists)
if pm_method == "rank"
else pm_tail_gamma(d_out_sq, sq_dists)
)
out[s] = float(np.clip(pm_score, 0.0, 1.0))
return out
device = min(ngpu - 1, 1)
device_str = f"cuda:{device}"
coords_t = torch.tensor(coords, device=device_str)
spks_here = sorted({l.split("-")[0] for l in labels})
out = {}
stream = torch.cuda.Stream(device=device_str)
with torch.cuda.device(device):
with torch.cuda.stream(stream):
for s in spks_here:
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
ref_i = labels.index(f"{s}-ref")
out_i = labels.index(f"{s}-out")
d_idx = [i for i in idxs if i not in {ref_i, out_i}]
if len(d_idx) < 2:
out[s] = 0.0
continue
ref_v = coords_t[ref_i]
dist = coords_t[d_idx] - ref_v
N, D = dist.shape
cov = dist.T @ dist / (N - 1)
if torch.linalg.matrix_rank(cov) < D:
cov += torch.eye(D, device=device_str) * COV_TOL
inv = torch.linalg.inv(cov)
sq_dists = torch.stack(
[mahalanobis_torch(coords_t[i], ref_v, inv) ** 2 for i in d_idx]
)
d_out_sq = float(mahalanobis_torch(coords_t[out_i], ref_v, inv) ** 2)
pm_score = (
pm_tail_rank(d_out_sq, sq_dists)
if pm_method == "rank"
else pm_tail_gamma(d_out_sq, sq_dists)
)
out[s] = float(np.clip(pm_score, 0.0, 1.0))
stream.synchronize()
return out
def pm_ci_components_full(
coords_d, coords_rest, eigvals, labels, *, delta=0.05, K=1.0, C1=1.0, C2=1.0
):
"""
Computes the error radius and tail bounds for the PM measure.
:param coords_d: Retained diffusion maps coordinates.
:param coords_rest: Complement diffusion maps coordinates.
:param eigvals: Eigenvalues of the diffusion maps.
:param labels: Assign source index per coordinate
:param delta: 1-\delta is the confidence score.
:param K: Absolute constant.
:param C1: Absolute constant.
:param C2: Absolute constant.
:return: error radius and tail bounds for the PM measure.
"""
_EPS = 1e-12
def _safe_x(a, theta):
return a / max(theta, _EPS)
D = coords_d.shape[1]
m = coords_rest.shape[1]
if m == 0:
z = {s: 0.0 for s in {l.split("-")[0] for l in labels}}
return z.copy(), z.copy()
X_d = torch.tensor(
coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
)
X_c = torch.tensor(
coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
)
spk_ids = sorted({l.split("-")[0] for l in labels})
bias_ci = {}
prob_ci = {}
for s in spk_ids:
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
ref_i = labels.index(f"{s}-ref")
out_i = labels.index(f"{s}-out")
dist_is = [i for i in idxs if i not in {ref_i, out_i}]
n_p = len(dist_is)
if n_p < 2:
bias_ci[s] = 0.0
prob_ci[s] = 0.0
continue
ref_d = X_d[ref_i]
ref_c = X_c[ref_i]
D_mat = X_d[dist_is] - ref_d
C_mat = X_c[dist_is] - ref_c
Sigma_d = safe_cov_torch(D_mat)
Sigma_c = safe_cov_torch(C_mat)
C_dc = D_mat.T @ C_mat / (n_p - 1)
inv_Sigma_d = torch.linalg.inv(Sigma_d)
S_i = (
Sigma_c
- C_dc.T @ inv_Sigma_d @ C_dc
+ torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
)
S_inv = torch.linalg.inv(S_i)
diff_out_d = X_d[out_i] - ref_d
diff_out_c = X_c[out_i] - ref_c
r_out = diff_out_c - C_dc.T @ inv_Sigma_d @ diff_out_d
delta_Gi_a = float(r_out @ S_inv @ r_out)
r_list = []
for p in dist_is:
d_p = X_d[p] - ref_d
c_p = X_c[p] - ref_c
r_p = c_p - C_dc.T @ inv_Sigma_d @ d_p
r_list.append(r_p)
R_p = torch.stack(r_list, dim=0)
delta_Gi_p = torch.sum(R_p @ S_inv * R_p, dim=1)
delta_Gi_mu_max = float(delta_Gi_p.max())
mah_sq = torch.stack(
[(X_d[i] - ref_d) @ inv_Sigma_d @ (X_d[i] - ref_d) for i in dist_is]
)
mu_g = float(mah_sq.mean())
sigma2_g = float(mah_sq.var(unbiased=True) + 1e-12)
sigma_g = math.sqrt(sigma2_g)
full_sq = mah_sq + delta_Gi_p
mu_full = float(full_sq.mean())
sigma2_full = float(full_sq.var(unbiased=True) + 1e-12)
if sigma2_g == 0.0:
delta_Gi_k = delta_Gi_theta = 0.0
else:
factor = delta_Gi_mu_max * n_p / (n_p - 1)
delta_Gi_k = 1.0 * factor * (mu_full + mu_g) / sigma2_g
delta_Gi_theta = 1.0 * factor * (sigma2_full + sigma2_g) / (mu_g**2 + 1e-9)
k_d = (mu_g**2) / max(sigma2_g, 1e-12)
theta_d = sigma2_g / max(mu_g, 1e-12)
a_d = float(diff_out_d @ inv_Sigma_d @ diff_out_d)
pm_center = gammaincc(k_d, _safe_x(a_d, theta_d))
corner_vals = []
for s_k in (-1, 1):
for s_theta in (-1, 1):
for s_a in (-1, 1):
k_c = max(k_d + s_k * delta_Gi_k, 1e-6)
theta_c = max(theta_d + s_theta * delta_Gi_theta, 1e-6)
a_c = max(a_d + s_a * delta_Gi_a, 1e-8)
corner_vals.append(gammaincc(k_c, _safe_x(a_c, theta_c)))
bias_ci[s] = max(abs(v - pm_center) for v in corner_vals)
R_sq = float(mah_sq.max()) + 1e-12
log_term = math.log(6.0 / delta)
eps_mu = math.sqrt(2 * sigma2_g * log_term / n_p) + 3 * R_sq * log_term / n_p
eps_sigma = (
math.sqrt(2 * R_sq**2 * log_term / n_p) + 3 * R_sq**2 * log_term / n_p
)
g1_x = 2.0 * mu_g / (sigma2_g + 1e-9)
g1_y = -2.0 * mu_g**2 / (sigma_g**3 + 1e-9)
g2_x = -sigma2_g / (mu_g**2 + 1e-9)
g2_y = 2.0 * sigma_g / (mu_g + 1e-9)
delta_k = min(abs(g1_x) * eps_mu + abs(g1_y) * eps_sigma, 0.5 * k_d)
delta_theta = min(abs(g2_x) * eps_mu + abs(g2_y) * eps_sigma, 0.5 * theta_d)
delta_a = min(R_sq * math.sqrt(2 * log_term / n_p), 0.5 * a_d + 1e-12)
pm_corners = []
for s_k in (-1, 1):
for s_theta in (-1, 1):
for s_a in (-1, 1):
k_c = k_d + s_k * delta_k
theta_c = theta_d + s_theta * delta_theta
a_c = max(a_d + s_a * delta_a, 1e-8)
pm_corners.append(gammaincc(k_c, _safe_x(a_c, theta_c)))
prob_ci[s] = max(abs(pm - pm_center) for pm in pm_corners)
return bias_ci, prob_ci
def ps_ci_components_full(coords_d, coords_rest, eigvals, labels, *, delta=0.05):
"""
Computes the error radius and tail bounds for the PS measure.
:param coords_d: Retained diffusion maps coordinates.
:param coords_rest: Complement diffusion maps coordinates.
:param eigvals: Eigenvalues of the diffusion maps.
:param labels: Assign source index per coordinate
:param delta: 1-\delta is the confidence score.
:return: error radius and tail bounds for the PS measure.
"""
def _mean_dev(lam_max, delta, n_eff):
return math.sqrt(2 * lam_max * math.log(2 / delta) / n_eff)
def _rel_cov_dev(lam_max, trace, delta, n_eff, C=1.0):
r = trace / lam_max
abs_dev = (
C * lam_max * (math.sqrt(r / n_eff) + (r + math.log(2 / delta)) / n_eff)
)
return abs_dev / lam_max
def _maha_eps_m(a_hat, lam_min, lam_max, mean_dev, rel_cov_dev):
term1 = 2 * math.sqrt(a_hat) * mean_dev * math.sqrt(lam_max / lam_min)
term2 = a_hat * rel_cov_dev
return term1 + term2
D = coords_d.shape[1]
m = coords_rest.shape[1]
if m == 0:
z = {s: 0.0 for s in set(l.split("-")[0] for l in labels)}
return z.copy(), z.copy()
X_d = torch.tensor(
coords_d, device="cuda:0" if torch.cuda.is_available() else "cpu"
)
X_c = torch.tensor(
coords_rest, device="cuda:0" if torch.cuda.is_available() else "cpu"
)
spk_ids = sorted({l.split("-")[0] for l in labels})
bias = {}
prob = {}
for s in spk_ids:
idxs = [i for i, l in enumerate(labels) if l.startswith(s)]
out_i = labels.index(f"{s}-out")
ref_is = [i for i in idxs if i != out_i]
mu_d = X_d[ref_is].mean(0)
mu_c = X_c[ref_is].mean(0)
Sigma_d = safe_cov_torch(X_d[ref_is])
Sigma_c = safe_cov_torch(X_c[ref_is])
C_dc = (X_d[ref_is] - mu_d).T @ (X_c[ref_is] - mu_c) / (len(ref_is) - 1)
inv_Sd = torch.linalg.inv(Sigma_d)
lam_min = torch.linalg.eigvalsh(Sigma_d).min().clamp_min(1e-9).item()
lam_max = torch.linalg.eigvalsh(Sigma_d).max()
trace = torch.trace(Sigma_d).item()
diff_d = X_d[out_i] - mu_d
diff_c = X_c[out_i] - mu_c
A_d = float(mahalanobis_torch(X_d[out_i], mu_d, inv_Sd))
r_i = diff_c - C_dc.T @ inv_Sd @ diff_d
S_i = (
Sigma_c
- C_dc.T @ inv_Sd @ C_dc
+ torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
)
term_i = math.sqrt(float(r_i @ torch.linalg.solve(S_i, r_i)))
B_d, term_j = float("inf"), 0.0
Sig_o = None
for o in spk_ids:
if o == s:
continue
o_idxs = [
i
for i, l in enumerate(labels)
if l.startswith(o) and not l.endswith("-out")
]
muo_d = X_d[o_idxs].mean(0)
muo_c = X_c[o_idxs].mean(0)
Sig_o_tmp = safe_cov_torch(X_d[o_idxs])
inv_So = torch.linalg.inv(Sig_o_tmp)
this_B = float(mahalanobis_torch(X_d[out_i], muo_d, inv_So))
if this_B < B_d:
B_d = this_B
Sig_o = Sig_o_tmp
diff_do = X_d[out_i] - muo_d
diff_co = X_c[out_i] - muo_c
C_oc = (
(X_d[o_idxs] - muo_d).T @ (X_c[o_idxs] - muo_c) / (len(o_idxs) - 1)
)
r_j = diff_co - C_oc.T @ inv_So @ diff_do
S_j = (
safe_cov_torch(X_c[o_idxs])
- C_oc.T @ inv_So @ C_oc
+ torch.eye(X_c.shape[1], device=X_c.device) * 1e-9
)
term_j = math.sqrt(float(r_j @ torch.linalg.solve(S_j, r_j)))
denom = A_d + B_d
bias[s] = (B_d * term_i + A_d * term_j) / (denom**2)
if Sig_o is not None:
lam_min_o = torch.linalg.eigvalsh(Sig_o).min().clamp_min(1e-9).item()
lam_max_o = torch.linalg.eigvalsh(Sig_o).max().item()
trace_o = torch.trace(Sig_o).item()
n_eff = max(int(0.7 * len(ref_is)), 3)
RIDGE = 0.05
lam_min_eff = max(lam_min, RIDGE * lam_max.item())
lam_min_o_eff = max(lam_min_o, RIDGE * lam_max_o)
eps_i_sg = _maha_eps_m(
A_d,
lam_min_eff,
lam_max.item(),
_mean_dev(lam_max.item(), delta / 2, n_eff),
_rel_cov_dev(lam_max.item(), trace, delta / 2, n_eff),
)
eps_j_sg = _maha_eps_m(
B_d,
lam_min_o_eff,
lam_max_o,
_mean_dev(lam_max_o, delta / 2, n_eff),
_rel_cov_dev(lam_max_o, trace_o, delta / 2, n_eff),
)
grad_l2 = math.hypot(A_d, B_d) / (A_d + B_d) ** 2
ps_radius = grad_l2 * math.hypot(eps_i_sg, eps_j_sg)
prob[s] = min(1.0, ps_radius)
else:
prob[s] = 0.0
return bias, prob