MAPSS-measures / utils.py
AIvry's picture
Upload 11 files
b759ccc verified
import gc
import threading
import warnings
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
warnings.filterwarnings("ignore", message="Some weights of Wav2Vec2Model")
def get_gpu_count(max_gpus=None):
"""
Get the number of available GPUs.
:param max_gpus: maximal number of GPUs to utilize.
"""
ngpu = torch.cuda.device_count()
if max_gpus is not None:
ngpu = min(ngpu, max_gpus)
return ngpu
def clear_gpu_memory():
"""
Enhanced GPU memory clearing
"""
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
with torch.cuda.device(i):
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
def get_gpu_memory_info(verbose=False):
"""
Get GPU memory info.
:param verbose: if True, get info.
"""
if not verbose:
return
for i in range(torch.cuda.device_count()):
try:
free_b, total_b = torch.cuda.mem_get_info(i)
free_gb = free_b / 1024**3
total_gb = total_b / 1024**3
except Exception:
total_gb = torch.cuda.get_device_properties(i).total_memory / 1024**3
free_gb = total_gb - (torch.cuda.memory_reserved(i) / 1024**3)
mem_allocated = torch.cuda.memory_allocated(i) / 1024**3
print(f"GPU {i}: {mem_allocated:.2f}GB allocated, {free_gb:.2f}GB free / {total_gb:.2f}GB total")
class GPUWorkDistributor:
"""
Distribute GPU memory into multiple GPUs.
"""
def __init__(self, max_gpus=None):
ngpu = get_gpu_count(max_gpus)
self.gpu_locks = [threading.Lock() for _ in range(max(1, min(ngpu, 2)))]
self.gpu_load = [0 for _ in range(max(1, min(ngpu, 2)))]
self.ngpu = ngpu
def get_least_loaded_gpu(self):
return int(np.argmin(self.gpu_load))
def execute_on_gpu(self, func, *args, **kwargs):
if self.ngpu == 0:
kwargs.pop("device", None)
return func(*args, **kwargs)
gid = self.get_least_loaded_gpu()
with self.gpu_locks[gid]:
self.gpu_load[gid] += 1
try:
with torch.cuda.device(gid):
kwargs["device"] = f"cuda:{gid}"
result = func(*args, **kwargs)
torch.cuda.empty_cache()
return result
finally:
self.gpu_load[gid] -= 1
@dataclass
class Mixture:
mixture_id: str
refs: list[Path]
systems: dict[str, list[Path]]
speaker_ids: list[str]
def canonicalize_mixtures(mixtures, systems=None):
canon = []
if not mixtures:
return canon
def as_paths(seq):
return [p if isinstance(p, Path) else Path(str(p)) for p in seq]
def speaker_id_from_ref(ref_path, idx, mixture_id):
stem = (ref_path.stem or "").strip()
if not stem:
stem = f"spk{idx:02d}"
return f"{mixture_id}__{stem}"
if isinstance(mixtures[0], dict):
for m in mixtures:
mid = str(m.get("mixture_id") or m.get("id") or "").strip()
if not mid:
raise ValueError("Each mixture must include 'mixture_id'.")
refs = as_paths(m.get("references", []))
if not refs:
raise ValueError(f"Mixture {mid}: 'references' must be non-empty.")
sysmap = {}
if isinstance(m.get("systems"), dict):
for algo, outs in m["systems"].items():
sysmap[str(algo)] = as_paths(outs)
spk_ids = [speaker_id_from_ref(r, i, mid) for i, r in enumerate(refs)]
canon.append(Mixture(mid, refs, sysmap, spk_ids))
return canon
if isinstance(mixtures[0], list):
for i, group in enumerate(mixtures):
mid = f"mix_{i:03d}"
refs, spk_ids = [], []
for d in group:
if not isinstance(d, dict) or "ref" not in d or "id" not in d:
raise ValueError(
"Legacy mixtures expect dicts with 'id' and 'ref'."
)
rp = Path(d["ref"])
refs.append(rp)
spk_ids.append(f"{mid}__{str(d['id']).strip()}")
sysmap = {}
if systems:
for algo, per_mix in systems.items():
if mid in per_mix:
sysmap[algo] = as_paths(per_mix[mid])
canon.append(Mixture(mid, refs, sysmap, spk_ids))
return canon
raise ValueError("Unsupported 'mixtures' format.")
def safe_cov_torch(X):
"""
Compute the covariance matrix of X.
:param X: array to compute covariance matrix of.
:return: regularized covariance matrix.
"""
Xc = X - X.mean(dim=0, keepdim=True)
cov = Xc.T @ Xc / (Xc.shape[0] - 1)
if torch.linalg.matrix_rank(cov) < cov.shape[0]:
cov += torch.eye(cov.shape[0], device=cov.device) * 1e-6
return cov
def mahalanobis_torch(x, mu, inv):
"""
Compute the mahalanobis distance with x centered around mu with inverse covariance matrix inv.
:param x: point to calculates distance from.
:param mu: x is centered around mu.
:param inv: the inverse covariance matrix.
:return: Mahalanobis distance.
"""
diff = x - mu
diff_T = diff.transpose(-1, -2) if diff.ndim >= 2 else diff
return torch.sqrt(diff @ inv @ diff_T + 1e-6)