Spaces:
Running
Running
| # coding: utf-8 | |
| # train_utils.py | |
| import os | |
| import torch | |
| import logging | |
| import random | |
| import datetime | |
| import numpy as np | |
| from tqdm import tqdm | |
| import csv | |
| from torch.utils.data import DataLoader, ConcatDataset | |
| from utils.losses import WeightedCrossEntropyLoss | |
| from utils.measures import uar, war, mf1, wf1 | |
| from models.models import BiFormer, BiGraphFormer, BiGatedGraphFormer | |
| from data_loading.dataset_multimodal import DatasetMultiModal | |
| from data_loading.feature_extractor import AudioEmbeddingExtractor, TextEmbeddingExtractor | |
| from sklearn.utils.class_weight import compute_class_weight | |
| def custom_collate_fn(batch): | |
| """Собирает список образцов в единый батч, отбрасывая None (невалидные).""" | |
| batch = [x for x in batch if x is not None] | |
| if not batch: | |
| return None | |
| audios = [b["audio"] for b in batch] | |
| audio_tensor = torch.stack(audios) | |
| labels = [b["label"] for b in batch] | |
| label_tensor = torch.stack(labels) | |
| texts = [b["text"] for b in batch] | |
| return { | |
| "audio": audio_tensor, | |
| "label": label_tensor, | |
| "text": texts | |
| } | |
| def get_class_weights_from_loader(train_loader, num_classes): | |
| """ | |
| Вычисляет веса классов из train_loader, устойчиво к отсутствующим классам. | |
| Если какой-либо класс отсутствует в выборке, ему будет присвоен вес 0.0. | |
| :param train_loader: DataLoader с one-hot метками | |
| :param num_classes: Общее количество классов | |
| :return: np.ndarray весов длины num_classes | |
| """ | |
| all_labels = [] | |
| for batch in train_loader: | |
| if batch is None: | |
| continue | |
| all_labels.extend(batch["label"].argmax(dim=1).tolist()) | |
| if not all_labels: | |
| raise ValueError("Нет ни одной метки в train_loader для вычисления весов классов.") | |
| present_classes = np.unique(all_labels) | |
| if len(present_classes) < num_classes: | |
| missing = set(range(num_classes)) - set(present_classes) | |
| logging.info(f"[!] Отсутствуют метки для классов: {sorted(missing)}") | |
| # Вычисляем веса только по тем классам, что есть | |
| weights_partial = compute_class_weight( | |
| class_weight="balanced", | |
| classes=present_classes, | |
| y=all_labels | |
| ) | |
| # Собираем полный вектор весов | |
| full_weights = np.zeros(num_classes, dtype=np.float32) | |
| for cls, w in zip(present_classes, weights_partial): | |
| full_weights[cls] = w | |
| return full_weights | |
| def make_dataset_and_loader(config, split: str, only_dataset: str = None): | |
| """ | |
| Универсальная функция: объединяет датасеты, или возвращает один при only_dataset. | |
| """ | |
| datasets = [] | |
| if not hasattr(config, "datasets") or not config.datasets: | |
| raise ValueError("⛔ В конфиге не указана секция [datasets].") | |
| for dataset_name, dataset_cfg in config.datasets.items(): | |
| if only_dataset and dataset_name != only_dataset: | |
| continue | |
| csv_path = dataset_cfg["csv_path"].format(base_dir=dataset_cfg["base_dir"], split=split) | |
| wav_dir = dataset_cfg["wav_dir"].format(base_dir=dataset_cfg["base_dir"], split=split) | |
| logging.info(f"[{dataset_name.upper()}] Split={split}: CSV={csv_path}, WAV_DIR={wav_dir}") | |
| dataset = DatasetMultiModal( | |
| csv_path = csv_path, | |
| wav_dir = wav_dir, | |
| emotion_columns = config.emotion_columns, | |
| split = split, | |
| sample_rate = config.sample_rate, | |
| wav_length = config.wav_length, | |
| whisper_model = config.whisper_model, | |
| text_column = config.text_column, | |
| use_whisper_for_nontrain_if_no_text = config.use_whisper_for_nontrain_if_no_text, | |
| whisper_device = config.whisper_device, | |
| subset_size = config.subset_size, | |
| merge_probability = config.merge_probability | |
| ) | |
| datasets.append(dataset) | |
| if not datasets: | |
| raise ValueError(f"⚠️ Для split='{split}' не найдено ни одного подходящего датасета.") | |
| # Объединяем только если их несколько | |
| full_dataset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets) | |
| loader = DataLoader( | |
| full_dataset, | |
| batch_size=config.batch_size, | |
| shuffle=(split == "train"), | |
| num_workers=config.num_workers, | |
| collate_fn=custom_collate_fn | |
| ) | |
| return full_dataset, loader | |
| def run_eval(model, loader, audio_extractor, text_extractor, criterion, device="cuda"): | |
| """ | |
| Оценка модели на loader'е. Возвращает (loss, uar, war, mf1, wf1). | |
| """ | |
| model.eval() | |
| total_loss = 0.0 | |
| total_preds = [] | |
| total_targets = [] | |
| total = 0 | |
| with torch.no_grad(): | |
| for batch in tqdm(loader): | |
| if batch is None: | |
| continue | |
| audio = batch["audio"].to(device) | |
| labels = batch["label"].to(device) | |
| texts = batch["text"] | |
| audio_emb = audio_extractor.extract(audio) | |
| text_emb = text_extractor.extract(texts) | |
| logits = model(audio_emb, text_emb) | |
| target = labels.argmax(dim=1) | |
| loss = criterion(logits, target) | |
| bs = audio.shape[0] | |
| total_loss += loss.item() * bs | |
| total += bs | |
| preds = logits.argmax(dim=1) | |
| total_preds.extend(preds.cpu().numpy().tolist()) | |
| total_targets.extend(target.cpu().numpy().tolist()) | |
| avg_loss = total_loss / total | |
| uar_m = uar(total_targets, total_preds) | |
| war_m = war(total_targets, total_preds) | |
| mf1_m = mf1(total_targets, total_preds) | |
| wf1_m = wf1(total_targets, total_preds) | |
| return avg_loss, uar_m, war_m, mf1_m, wf1_m | |
| def train_once(config, train_loader, dev_loaders, test_loaders, metrics_csv_path=None): | |
| """ | |
| Логика обучения (train/dev/test). | |
| Возвращает лучшую метрику на dev и словарь метрик. | |
| """ | |
| logging.info("== Запуск тренировки (train/dev/test) ==") | |
| csv_writer = None | |
| csv_file = None | |
| if metrics_csv_path: | |
| csv_file = open(metrics_csv_path, mode="w", newline="", encoding="utf-8") | |
| csv_writer = csv.writer(csv_file) | |
| csv_writer.writerow(["split", "epoch", "dataset", "loss", "uar", "war", "mf1", "wf1", "mean"]) | |
| # Seed | |
| if config.random_seed > 0: | |
| random.seed(config.random_seed) | |
| torch.manual_seed(config.random_seed) | |
| logging.info(f"== Фиксируем random seed: {config.random_seed}") | |
| else: | |
| logging.info("== Random seed не фиксирован (0).") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Экстракторы | |
| audio_extractor = AudioEmbeddingExtractor(config) | |
| text_extractor = TextEmbeddingExtractor(config) | |
| # Параметры | |
| hidden_dim = config.hidden_dim | |
| num_classes = len(config.emotion_columns) | |
| num_transformer_heads = config.num_transformer_heads | |
| num_graph_heads = config.num_graph_heads | |
| hidden_dim_gated = config.hidden_dim_gated | |
| mode = config.mode | |
| positional_encoding = config.positional_encoding | |
| dropout = config.dropout | |
| out_features = config.out_features | |
| lr = config.lr | |
| num_epochs = config.num_epochs | |
| tr_layer_number = config.tr_layer_number | |
| max_patience = config.max_patience | |
| dict_models = { | |
| 'BiFormer': BiFormer, | |
| 'BiGraphFormer': BiGraphFormer, | |
| 'BiGatedGraphFormer': BiGatedGraphFormer, | |
| # 'MultiModalTransformer_v5': MultiModalTransformer_v5, | |
| # 'MultiModalTransformer_v4': MultiModalTransformer_v4, | |
| # 'MultiModalTransformer_v3': MultiModalTransformer_v3 | |
| } | |
| model_cls = dict_models[config.model_name] | |
| model = model_cls( | |
| audio_dim = config.audio_embedding_dim, | |
| text_dim = config.text_embedding_dim, | |
| hidden_dim = hidden_dim, | |
| hidden_dim_gated = hidden_dim_gated, | |
| num_transformer_heads = num_transformer_heads, | |
| num_graph_heads = num_graph_heads, | |
| seg_len = config.max_tokens, | |
| mode = mode, | |
| dropout = dropout, | |
| positional_encoding = positional_encoding, | |
| out_features = out_features, | |
| tr_layer_number = tr_layer_number, | |
| device = device, | |
| num_classes = num_classes | |
| ).to(device) | |
| # Оптимизатор и лосс | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| class_weights = get_class_weights_from_loader(train_loader, num_classes) | |
| criterion = WeightedCrossEntropyLoss(class_weights) | |
| logging.info("Class weights: " + ", ".join(f"{name}={weight:.4f}" for name, weight in zip(config.emotion_columns, class_weights))) | |
| # LR Scheduler | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| mode="max", | |
| factor=0.5, | |
| patience=2, | |
| min_lr=1e-7 | |
| ) | |
| # Early stopping по dev | |
| best_dev_mean = float("-inf") | |
| best_dev_metrics = {} | |
| patience_counter = 0 | |
| for epoch in range(num_epochs): | |
| logging.info(f"\n=== Эпоха {epoch} ===") | |
| model.train() | |
| total_loss = 0.0 | |
| total_samples = 0 | |
| total_preds = [] | |
| total_targets = [] | |
| for batch in tqdm(train_loader): | |
| if batch is None: | |
| continue | |
| audio = batch["audio"].to(device) | |
| labels = batch["label"].to(device) | |
| texts = batch["text"] | |
| audio_emb = audio_extractor.extract(audio) | |
| text_emb = text_extractor.extract(texts) | |
| logits = model(audio_emb, text_emb) | |
| target = labels.argmax(dim=1) | |
| loss = criterion(logits, target) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| bs = audio.shape[0] | |
| total_loss += loss.item() * bs | |
| preds = logits.argmax(dim=1) | |
| total_preds.extend(preds.cpu().numpy().tolist()) | |
| total_targets.extend(target.cpu().numpy().tolist()) | |
| total_samples += bs | |
| train_loss = total_loss / total_samples | |
| uar_m = uar(total_targets, total_preds) | |
| war_m = war(total_targets, total_preds) | |
| mf1_m = mf1(total_targets, total_preds) | |
| wf1_m = wf1(total_targets, total_preds) | |
| mean_train = np.mean([uar_m, war_m, mf1_m, wf1_m]) | |
| logging.info( | |
| f"[TRAIN] Loss={train_loss:.4f}, UAR={uar_m:.4f}, WAR={war_m:.4f}, " | |
| f"MF1={mf1_m:.4f}, WF1={wf1_m:.4f}, MEAN={mean_train:.4f}" | |
| ) | |
| # --- DEV --- | |
| dev_means = [] | |
| dev_metrics_by_dataset = [] | |
| for name, loader in dev_loaders: | |
| d_loss, d_uar, d_war, d_mf1, d_wf1 = run_eval( | |
| model, loader, audio_extractor, text_extractor, criterion, device | |
| ) | |
| d_mean = np.mean([d_uar, d_war, d_mf1, d_wf1]) | |
| dev_means.append(d_mean) | |
| if csv_writer: | |
| csv_writer.writerow(["dev", epoch, name, d_loss, d_uar, d_war, d_mf1, d_wf1, d_mean]) | |
| logging.info( | |
| f"[DEV:{name}] Loss={d_loss:.4f}, UAR={d_uar:.4f}, WAR={d_war:.4f}, " | |
| f"MF1={d_mf1:.4f}, WF1={d_wf1:.4f}, MEAN={d_mean:.4f}" | |
| ) | |
| dev_metrics_by_dataset.append({ | |
| "name": name, | |
| "loss": d_loss, | |
| "uar": d_uar, | |
| "war": d_war, | |
| "mf1": d_mf1, | |
| "wf1": d_wf1, | |
| "mean": d_mean, | |
| }) | |
| mean_dev = np.mean(dev_means) | |
| scheduler.step(mean_dev) | |
| if mean_dev > best_dev_mean: | |
| best_dev_mean = mean_dev | |
| patience_counter = 0 | |
| best_dev_metrics = { | |
| "mean": mean_dev | |
| } | |
| best_dev_metrics["by_dataset"] = dev_metrics_by_dataset | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= max_patience: | |
| logging.info(f"Early stopping: {max_patience} эпох без улучшения.") | |
| break | |
| # --- TEST --- | |
| for name, loader in test_loaders: | |
| t_loss, t_uar, t_war, t_mf1, t_wf1 = run_eval( | |
| model, loader, audio_extractor, text_extractor, criterion, device | |
| ) | |
| t_mean = np.mean([t_uar, t_war, t_mf1, t_wf1]) | |
| logging.info( | |
| f"[TEST:{name}] Loss={t_loss:.4f}, UAR={t_uar:.4f}, WAR={t_war:.4f}, " | |
| f"MF1={t_mf1:.4f}, WF1={t_wf1:.4f}, MEAN={t_mean:.4f}" | |
| ) | |
| if csv_writer: | |
| csv_writer.writerow(["test", epoch, name, t_loss, t_uar, t_war, t_mf1, t_wf1, t_mean]) | |
| if csv_file: | |
| csv_file.close() | |
| logging.info("Тренировка завершена. Все split'ы обработаны!") | |
| return best_dev_mean, best_dev_metrics | |