MAPSS-measures / models.py
AIvry's picture
Upload 12 files
1832e16 verified
raw
history blame
12.2 kB
import queue
import threading
import gc
import torch
import torch.nn.functional as F
from transformers import (
HubertModel,
Wav2Vec2FeatureExtractor,
Wav2Vec2Model,
WavLMModel,
ASTModel,
AutoFeatureExtractor,
)
from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
from utils import get_gpu_count
class BalancedDualGPUModel:
def __init__(self, model_name, layer, max_gpus=None):
self.layer = layer
self.models = []
self.extractors = []
self.devices = []
ngpu = get_gpu_count(max_gpus)
for gpu_id in range(min(ngpu, 2)):
device = f"cuda:{gpu_id}"
self.devices.append(device)
ckpt, cls, _ = get_model_config(layer)[model_name]
if cls is ASTModel:
extractor = AutoFeatureExtractor.from_pretrained(ckpt)
else:
extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
model = cls.from_pretrained(
ckpt,
output_hidden_states=True,
use_safetensors=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation=attn_impl
)
model.eval()
model = model.to(device)
for param in model.parameters():
param.requires_grad = False
self.extractors.append(extractor)
self.models.append(model)
self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
self.result_queue = queue.Queue()
self.workers = []
for i in range(len(self.devices)):
worker = threading.Thread(target=self._gpu_worker, args=(i,))
worker.daemon = True
worker.start()
self.workers.append(worker)
def _gpu_worker(self, gpu_id):
device = self.devices[gpu_id]
model = self.models[gpu_id]
extractor = self.extractors[gpu_id]
while True:
task = self.gpu_queues[gpu_id].get()
if task is None:
break
signals, masks, use_mlm, task_id = task
try:
inputs = extractor(
signals, sampling_rate=SR, return_tensors="pt", padding=True
)
input_values = inputs.input_values.to(device, non_blocking=True)
torch.cuda.empty_cache()
orig_mode = model.training
model.train() if use_mlm else model.eval()
with torch.no_grad():
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
hs = model(
input_values, output_hidden_states=True
).hidden_states[self.layer]
model.train(orig_mode)
B, T, D = hs.shape
keep = []
for b in range(B):
mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
keep.append(hs[b][mask_t].cpu())
# Aggressive cleanup
del hs, input_values, inputs
torch.cuda.empty_cache()
if keep:
L_max = max(x.shape[0] for x in keep)
keep_padded = [
F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
]
result = torch.stack(keep_padded, dim=0)
else:
result = torch.empty(0, 0, 0)
self.result_queue.put((task_id, result))
except Exception as e:
self.result_queue.put((task_id, e))
finally:
# Always clear cache after processing
torch.cuda.empty_cache()
def process_batch(self, signals, masks, use_mlm=False):
if not signals:
return torch.empty(0, 0, 0)
batch_size = len(signals)
split = (batch_size + len(self.devices) - 1) // len(self.devices)
results = {}
task_id = 0
for i in range(0, batch_size, split):
end = min(i + split, batch_size)
gpu_id = (i // split) % len(self.devices)
self.gpu_queues[gpu_id].put(
(signals[i:end], masks[i:end], use_mlm, task_id)
)
task_id += 1
for _ in range(task_id):
tid, result = self.result_queue.get()
if isinstance(result, Exception):
raise result
results[tid] = result
parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
def cleanup(self):
"""Explicit cleanup method"""
for q in self.gpu_queues:
q.put(None)
for w in self.workers:
w.join(timeout=5.0)
for model in self.models:
del model
for extractor in self.extractors:
del extractor
self.models.clear()
self.extractors.clear()
torch.cuda.empty_cache()
gc.collect()
def __del__(self):
self.cleanup()
# NO CACHE - we need to clean up models properly between runs
def get_model_config(layer):
return {
"raw": (None, None, None),
"wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
"wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
"hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
"wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
"wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
"hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
"wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
"ast": ("MIT/ast-finetuned-audioset-10-10-0.4593", ASTModel, layer),
}
# Store loaded models globally to properly manage them
_loaded_models = {}
def load_model(name, layer, max_gpus=None):
global _loaded_models
# Clean up any previously loaded models first
if _loaded_models:
for key, model_data in _loaded_models.items():
if isinstance(model_data, tuple) and len(model_data) == 2:
if isinstance(model_data[0], BalancedDualGPUModel):
model_data[0].cleanup()
elif isinstance(model_data[0], tuple):
# Single GPU model
_, model = model_data[0]
del model
_loaded_models.clear()
torch.cuda.empty_cache()
gc.collect()
if name.lower() in {"raw", "waveform"}:
return "raw", layer
ngpu = get_gpu_count(max_gpus)
if ngpu > 1:
model = BalancedDualGPUModel(name, layer, max_gpus)
_loaded_models[name] = (model, layer)
return model, layer
else:
ckpt, cls, layer_eff = get_model_config(layer)[name]
if cls is ASTModel:
extractor = AutoFeatureExtractor.from_pretrained(ckpt)
else:
extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
model = cls.from_pretrained(
ckpt,
output_hidden_states=True,
use_safetensors=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation=attn_impl
)
model.eval()
model = model.to(device)
for param in model.parameters():
param.requires_grad = False
model_tuple = ((extractor, model), layer_eff)
_loaded_models[name] = model_tuple
return (extractor, model), layer_eff
def cleanup_all_models():
"""Call this at the end of each experiment to ensure complete cleanup"""
global _loaded_models
if _loaded_models:
for key, model_data in _loaded_models.items():
if isinstance(model_data, tuple) and len(model_data) == 2:
if isinstance(model_data[0], BalancedDualGPUModel):
model_data[0].cleanup()
elif isinstance(model_data[0], tuple):
# Single GPU model
_, model = model_data[0]
del model
_loaded_models.clear()
torch.cuda.empty_cache()
gc.collect()
def embed_batch_raw(signals, masks_audio):
win = int(ENERGY_WIN_MS * SR / 1000)
hop = int(ENERGY_HOP_MS * SR / 1000)
reps, L_max = [], 0
for sig_np, mask_np in zip(signals, masks_audio):
x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
frames = x.unfold(0, win, hop)
mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
keep = frames[mask] if mask.any() else frames[:1]
reps.append(keep)
L_max = max(L_max, keep.size(0))
reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
return torch.stack(reps, dim=0)
def embed_batch_single_gpu(
signals, masks_audio, extractor, model, layer, use_mlm=False
):
if not signals:
return torch.empty(0, 0, 0)
device = next(model.parameters()).device
max_batch = 2
all_keeps = []
for i in range(0, len(signals), max_batch):
batch_signals = signals[i:i + max_batch]
batch_masks = masks_audio[i:i + max_batch]
inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(device, non_blocking=True)
orig_mode = model.training
model.train() if use_mlm else model.eval()
with torch.no_grad():
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
hs = model(input_values, output_hidden_states=True).hidden_states[layer]
model.train(orig_mode)
B, T, D = hs.shape
for b in range(B):
mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
all_keeps.append(hs[b][mask_t].cpu())
# Aggressive cleanup
del hs, input_values, inputs
torch.cuda.empty_cache()
if all_keeps:
L_max = max(x.shape[0] for x in all_keeps)
keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
result = torch.stack(keep_padded, dim=0)
# Clean up intermediate lists
del all_keeps, keep_padded
return result
else:
return torch.empty(0, 0, 0)
def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
if model_wrapper == "raw":
return embed_batch_raw(signals, masks_audio)
if isinstance(model_wrapper, BalancedDualGPUModel):
all_embeddings = []
batch_size = min(BATCH_SIZE, 2)
for i in range(0, len(signals), batch_size):
batch_emb = model_wrapper.process_batch(
signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
)
if batch_emb.numel() > 0:
all_embeddings.append(batch_emb)
# Clear cache after each batch
torch.cuda.empty_cache()
if all_embeddings:
result = torch.cat(all_embeddings, dim=0)
del all_embeddings
return result
else:
return torch.empty(0, 0, 0)
else:
extractor, model = model_wrapper
return embed_batch_single_gpu(
signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
)