Spaces:
Sleeping
Sleeping
| # data_loading/feature_extractor.py | |
| import torch | |
| import logging | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| AutoFeatureExtractor, | |
| AutoModel, | |
| AutoTokenizer, | |
| AutoModelForAudioClassification, | |
| Wav2Vec2Processor | |
| ) | |
| from data_loading.pretrained_extractors import EmotionModel, get_model_mamba, Mamba | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # DEVICE = torch.device('cpu') | |
| class PretrainedAudioEmbeddingExtractor: | |
| """ | |
| Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'), | |
| с учётом pooling, нормализации и т.д. | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Ожидается, что в config есть поля: | |
| - audio_model_name (str) : название модели (ExHuBERT и т.п.) | |
| - emb_device (str) : "cpu" или "cuda" | |
| - audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг) | |
| - emb_normalize (bool) : делать ли L2-нормализацию выхода | |
| - max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать) | |
| """ | |
| self.config = config | |
| self.device = config.emb_device | |
| self.model_name = config.audio_model_name | |
| self.pooling = config.audio_pooling # может быть None | |
| self.normalize_output = config.emb_normalize | |
| self.max_audio_frames = getattr(config, "max_audio_frames", 0) | |
| self.audio_classifier_checkpoint = config.audio_classifier_checkpoint | |
| # Инициализируем processor и audio_embedder | |
| self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) | |
| self.audio_embedder = EmotionModel.from_pretrained(self.model_name).to(self.device) | |
| # Загружаем модель | |
| self.classifier_model = self.load_classifier_model_from_checkpoint(self.audio_classifier_checkpoint) | |
| def extract(self, waveform: torch.Tensor, sample_rate=16000): | |
| """ | |
| Извлекает эмбеддинги из аудиоданных. | |
| :param waveform: Тензор формы (T). | |
| :param sample_rate: Частота дискретизации (int). | |
| :return: Тензоры: | |
| вернётся (B, classes), (B, sequence_length, hidden_dim). | |
| """ | |
| embeddings = self.process_audio(waveform, sample_rate) | |
| tensor_emb = torch.tensor(embeddings, dtype=torch.float32).to(self.device) | |
| lengths = [tensor_emb.shape[1]] | |
| with torch.no_grad(): | |
| logits, hidden = self.classifier_model(tensor_emb, lengths, with_features=True) | |
| # Если pooling=None => вернём (B, seq_len, hidden_dim) | |
| if hidden.dim() == 3: | |
| if self.pooling is None: | |
| emb = hidden | |
| else: | |
| if self.pooling == "mean": | |
| emb = hidden.mean(dim=1) | |
| elif self.pooling == "cls": | |
| emb = hidden[:, 0, :] | |
| elif self.pooling == "max": | |
| emb, _ = hidden.max(dim=1) | |
| elif self.pooling == "min": | |
| emb, _ = hidden.min(dim=1) | |
| elif self.pooling == "last": | |
| emb = hidden[:, -1, :] | |
| elif self.pooling == "sum": | |
| emb = hidden.sum(dim=1) | |
| else: | |
| emb = hidden.mean(dim=1) | |
| else: | |
| # На всякий случай, если получилось (B, hidden_dim) | |
| emb = hidden | |
| if self.normalize_output and emb.dim() == 2: | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return logits, emb | |
| def process_audio(self, signal: np.ndarray, sampling_rate: int) -> np.ndarray: | |
| inputs = self.processor(signal, sampling_rate=sampling_rate, return_tensors="pt", padding=True) | |
| input_values = inputs["input_values"].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.audio_embedder(input_values) | |
| embeddings = outputs | |
| return embeddings.detach().cpu().numpy() | |
| def load_classifier_model_from_checkpoint(self, checkpoint_path): | |
| if checkpoint_path == "best_audio_model.pt": | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| exp_params = checkpoint['exp_params'] | |
| classifier_model = get_model_mamba(exp_params).to(self.device) | |
| classifier_model.load_state_dict(checkpoint['model_state_dict']) | |
| elif checkpoint_path == "best_audio_model_2.pt": | |
| model_params = { | |
| "input_size": 1024, | |
| "d_model": 256, | |
| "num_layers": 2, | |
| "num_classes": 7, | |
| "dropout": 0.2 | |
| } | |
| classifier_model = get_model_mamba(model_params).to(self.device) | |
| classifier_model.load_state_dict(torch.load(checkpoint_path, map_location=self.device)) | |
| classifier_model.eval() | |
| return classifier_model | |
| class AudioEmbeddingExtractor: | |
| """ | |
| Извлекает эмбеддинги из аудио, используя модель (например 'amiriparian/ExHuBERT'), | |
| с учётом pooling, нормализации и т.д. | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Ожидается, что в config есть поля: | |
| - audio_model_name (str) : название модели (ExHuBERT и т.п.) | |
| - emb_device (str) : "cpu" или "cuda" | |
| - audio_pooling (str | None) : "mean", "cls", "max", "min", "last" или None (пропустить пуллинг) | |
| - emb_normalize (bool) : делать ли L2-нормализацию выхода | |
| - max_audio_frames (int) : ограничение длины по временной оси (если 0 - не ограничивать) | |
| """ | |
| self.config = config | |
| self.device = config.emb_device | |
| self.model_name = config.audio_model_name | |
| self.pooling = config.audio_pooling # может быть None | |
| self.normalize_output = config.emb_normalize | |
| # self.max_audio_frames = getattr(config, "max_audio_frames", 0) | |
| self.max_audio_frames = config.sample_rate * config.wav_length | |
| # Попробуем загрузить feature_extractor (не у всех моделей доступен) | |
| try: | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_name) | |
| logging.info(f"[Audio] Using AutoFeatureExtractor for '{self.model_name}'") | |
| except Exception as e: | |
| self.feature_extractor = None | |
| logging.warning(f"[Audio] No built-in FeatureExtractor found. Model={self.model_name}. Error: {e}") | |
| # Загружаем модель | |
| # Если у модели нет head-классификации, бывает достаточно AutoModel | |
| try: | |
| self.model = AutoModel.from_pretrained( | |
| self.model_name, | |
| output_hidden_states=True # чтобы точно был last_hidden_state | |
| ).to(self.device) | |
| logging.info(f"[Audio] Loaded AutoModel with output_hidden_states=True: {self.model_name}") | |
| except Exception as e: | |
| logging.warning(f"[Audio] Fallback to AudioClassification model. Reason: {e}") | |
| self.model = AutoModelForAudioClassification.from_pretrained( | |
| self.model_name, | |
| output_hidden_states=True | |
| ).to(self.device) | |
| def extract(self, waveform_batch: torch.Tensor, sample_rate=16000): | |
| """ | |
| Извлекает эмбеддинги из аудиоданных. | |
| :param waveform_batch: Тензор формы (B, T) или (B, 1, T). | |
| :param sample_rate: Частота дискретизации (int). | |
| :return: Тензор: | |
| - если pooling != None, будет (B, hidden_dim) | |
| - если pooling == None и last_hidden_state имел форму (B, seq_len, hidden_dim), | |
| вернётся (B, seq_len, hidden_dim). | |
| """ | |
| # Если пришло (B, 1, T), уберём ось "1" | |
| if waveform_batch.dim() == 3 and waveform_batch.shape[1] == 1: | |
| waveform_batch = waveform_batch.squeeze(1) # -> (B, T) | |
| # Усечение по времени, если нужно | |
| if self.max_audio_frames > 0 and waveform_batch.shape[1] > self.max_audio_frames: | |
| waveform_batch = waveform_batch[:, :self.max_audio_frames] | |
| # Если есть feature_extractor - используем | |
| if self.feature_extractor is not None: | |
| inputs = self.feature_extractor( | |
| waveform_batch, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=self.max_audio_frames if self.max_audio_frames > 0 else None | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| outputs = self.model(input_values=inputs["input_values"]) | |
| else: | |
| # Иначе подадим напрямую "input_values" на модель | |
| inputs = {"input_values": waveform_batch.to(self.device)} | |
| outputs = self.model(**inputs) | |
| # Теперь outputs может быть BaseModelOutput (с last_hidden_state, hidden_states, etc.) | |
| # Или SequenceClassifierOutput (с logits), если это модель-классификатор | |
| if hasattr(outputs, "last_hidden_state"): | |
| # (B, seq_len, hidden_dim) | |
| hidden = outputs.last_hidden_state | |
| # logging.debug(f"[Audio] last_hidden_state shape: {hidden.shape}") | |
| elif hasattr(outputs, "logits"): | |
| # logits: (B, num_labels) | |
| # Для пуллинга по "seq_len" притворимся, что seq_len=1 | |
| hidden = outputs.logits.unsqueeze(1) # (B,1,num_labels) | |
| logging.debug(f"[Audio] Found logits shape: {outputs.logits.shape} => hidden={hidden.shape}") | |
| else: | |
| # Модель может сразу возвращать тензор | |
| hidden = outputs | |
| # Если у нас 2D-тензор (B, hidden_dim), значит всё уже спулено | |
| if hidden.dim() == 2: | |
| emb = hidden | |
| elif hidden.dim() == 3: | |
| # (B, seq_len, hidden_dim) | |
| if self.pooling is None: | |
| # Возвращаем как есть | |
| emb = hidden | |
| else: | |
| # Выполним пуллинг | |
| if self.pooling == "mean": | |
| emb = hidden.mean(dim=1) | |
| elif self.pooling == "cls": | |
| emb = hidden[:, 0, :] # [B, hidden_dim] | |
| elif self.pooling == "max": | |
| emb, _ = hidden.max(dim=1) | |
| elif self.pooling == "min": | |
| emb, _ = hidden.min(dim=1) | |
| elif self.pooling == "last": | |
| emb = hidden[:, -1, :] | |
| else: | |
| emb = hidden.mean(dim=1) # на всякий случай fallback | |
| else: | |
| # На всякий: если ещё какая-то форма | |
| raise ValueError(f"[Audio] Unexpected hidden shape={hidden.shape}, pooling={self.pooling}") | |
| if self.normalize_output and emb.dim() == 2: | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return emb | |
| class TextEmbeddingExtractor: | |
| """ | |
| Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'), | |
| с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения. | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Параметры в config: | |
| - text_model_name (str) | |
| - emb_device (str) | |
| - text_pooling (str | None) | |
| - emb_normalize (bool) | |
| - max_tokens (int) | |
| """ | |
| self.config = config | |
| self.device = config.emb_device | |
| self.model_name = config.text_model_name | |
| self.pooling = config.text_pooling # может быть None | |
| self.normalize_output = config.emb_normalize | |
| self.max_tokens = config.max_tokens | |
| # trust_remote_code=True нужно для моделей вроде jina | |
| logging.info(f"[Text] Loading tokenizer for {self.model_name} with trust_remote_code=True") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ) | |
| logging.info(f"[Text] Loading model for {self.model_name} with trust_remote_code=True") | |
| self.model = AutoModel.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| output_hidden_states=True, # хотим иметь last_hidden_state | |
| force_download=False | |
| ).to(self.device) | |
| def extract(self, text_list): | |
| """ | |
| :param text_list: список строк (или одна строка) | |
| :return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None | |
| """ | |
| if isinstance(text_list, str): | |
| text_list = [text_list] | |
| inputs = self.tokenizer( | |
| text_list, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_tokens, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Обычно у AutoModel last_hidden_state.shape = (B, seq_len, hidden_dim) | |
| hidden = outputs.last_hidden_state | |
| # logging.debug(f"[Text] last_hidden_state shape: {hidden.shape}") | |
| # Если pooling=None => вернём (B, seq_len, hidden_dim) | |
| if hidden.dim() == 3: | |
| if self.pooling is None: | |
| emb = hidden | |
| else: | |
| if self.pooling == "mean": | |
| emb = hidden.mean(dim=1) | |
| elif self.pooling == "cls": | |
| emb = hidden[:, 0, :] | |
| elif self.pooling == "max": | |
| emb, _ = hidden.max(dim=1) | |
| elif self.pooling == "min": | |
| emb, _ = hidden.min(dim=1) | |
| elif self.pooling == "last": | |
| emb = hidden[:, -1, :] | |
| elif self.pooling == "sum": | |
| emb = hidden.sum(dim=1) | |
| else: | |
| emb = hidden.mean(dim=1) | |
| else: | |
| # На всякий случай, если получилось (B, hidden_dim) | |
| emb = hidden | |
| if self.normalize_output and emb.dim() == 2: | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return emb | |
| class PretrainedTextEmbeddingExtractor: | |
| """ | |
| Извлекает эмбеддинги из текста (например 'jinaai/jina-embeddings-v3'), | |
| с учётом pooling (None, mean, cls, и т.д.), нормализации и усечения. | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Параметры в config: | |
| - text_model_name (str) | |
| - emb_device (str) | |
| - text_pooling (str | None) | |
| - emb_normalize (bool) | |
| - max_tokens (int) | |
| """ | |
| self.config = config | |
| self.device = config.emb_device | |
| self.model_name = config.text_model_name | |
| self.pooling = config.text_pooling # может быть None | |
| self.normalize_output = config.emb_normalize | |
| self.max_tokens = config.max_tokens | |
| self.text_classifier_checkpoint = config.text_classifier_checkpoint | |
| self.model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name=self.model_name, max_tokens=self.max_tokens, pooling=None).to(self.device) | |
| checkpoint = torch.load(self.text_classifier_checkpoint, map_location=DEVICE) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.eval() | |
| def extract(self, text_list): | |
| """ | |
| :param text_list: список строк (или одна строка) | |
| :return: тензор (B, hidden_dim) или (B, seq_len, hidden_dim), если pooling=None | |
| """ | |
| if isinstance(text_list, str): | |
| text_list = [text_list] | |
| with torch.no_grad(): | |
| logits, hidden = self.model(text_list, with_features=True) | |
| if hidden.dim() == 3: | |
| if self.pooling is None: | |
| emb = hidden | |
| else: | |
| if self.pooling == "mean": | |
| emb = hidden.mean(dim=1) | |
| elif self.pooling == "cls": | |
| emb = hidden[:, 0, :] | |
| elif self.pooling == "max": | |
| emb, _ = hidden.max(dim=1) | |
| elif self.pooling == "min": | |
| emb, _ = hidden.min(dim=1) | |
| elif self.pooling == "last": | |
| emb = hidden[:, -1, :] | |
| elif self.pooling == "sum": | |
| emb = hidden.sum(dim=1) | |
| else: | |
| emb = hidden.mean(dim=1) | |
| else: | |
| # На всякий случай, если получилось (B, hidden_dim) | |
| emb = hidden | |
| if self.normalize_output and emb.dim() == 2: | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return logits, emb | |