Spaces:
Running
Running
| # train.py | |
| # coding: utf-8 | |
| import logging | |
| import os | |
| import shutil | |
| import datetime | |
| import whisper | |
| import toml | |
| # os.environ["HF_HOME"] = "models" | |
| from utils.config_loader import ConfigLoader | |
| from utils.logger_setup import setup_logger | |
| from utils.search_utils import greedy_search, exhaustive_search | |
| from training.train_utils import ( | |
| make_dataset_and_loader, | |
| train_once | |
| ) | |
| from data_loading.feature_extractor import PretrainedAudioEmbeddingExtractor, PretrainedTextEmbeddingExtractor | |
| def main(): | |
| # Грузим конфиг | |
| base_config = ConfigLoader("config.toml") | |
| model_name = base_config.model_name.replace("/", "_").replace(" ", "_").lower() | |
| timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
| results_dir = f"results_{model_name}_{timestamp}" | |
| os.makedirs(results_dir, exist_ok=True) | |
| epochlog_dir = os.path.join(results_dir, "metrics_by_epoch") | |
| os.makedirs(epochlog_dir, exist_ok=True) | |
| # Настраиваем logging | |
| log_file = os.path.join(results_dir, "session_log.txt") | |
| setup_logger(logging.DEBUG, log_file=log_file) | |
| # Грузим конфиг | |
| base_config.show_config() | |
| shutil.copy("config.toml", os.path.join(results_dir, "config_copy.toml")) | |
| # Файл, куда будет писать наш жадный поиск | |
| overrides_file = os.path.join(results_dir, "overrides.txt") | |
| csv_prefix = os.path.join(epochlog_dir, "metrics_epochlog") | |
| audio_feature_extractor= PretrainedAudioEmbeddingExtractor(base_config) | |
| text_feature_extractor = PretrainedTextEmbeddingExtractor(base_config) | |
| # Инициализируем Whisper-модель один раз | |
| logging.info(f"Инициализация Whisper: модель={base_config.whisper_model}, устройство={base_config.whisper_device}") | |
| whisper_model = whisper.load_model(base_config.whisper_model, device=base_config.whisper_device) | |
| # Делаем датасеты/лоадеры | |
| # Общий train_loader | |
| _, train_loader = make_dataset_and_loader(base_config, "train", audio_feature_extractor, text_feature_extractor, whisper_model) | |
| # Раздельные dev/test | |
| dev_loaders = [] | |
| test_loaders = [] | |
| for dataset_name in base_config.datasets: | |
| _, dev_loader = make_dataset_and_loader(base_config, "dev", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name) | |
| _, test_loader = make_dataset_and_loader(base_config, "test", audio_feature_extractor, text_feature_extractor, whisper_model, only_dataset=dataset_name) | |
| dev_loaders.append((dataset_name, dev_loader)) | |
| test_loaders.append((dataset_name, test_loader)) | |
| if base_config.prepare_only: | |
| logging.info("== Режим prepare_only: только подготовка данных, без обучения ==") | |
| return | |
| search_config = toml.load("search_params.toml") | |
| param_grid = dict(search_config["grid"]) | |
| default_values = dict(search_config["defaults"]) | |
| if base_config.search_type == "greedy": | |
| greedy_search( | |
| base_config = base_config, | |
| train_loader = train_loader, | |
| dev_loader = dev_loaders, | |
| test_loader = test_loaders, | |
| train_fn = train_once, | |
| overrides_file = overrides_file, | |
| param_grid = param_grid, | |
| default_values = default_values, | |
| csv_prefix = csv_prefix | |
| ) | |
| elif base_config.search_type == "exhaustive": | |
| exhaustive_search( | |
| base_config = base_config, | |
| train_loader = train_loader, | |
| dev_loader = dev_loaders, | |
| test_loader = test_loaders, | |
| train_fn = train_once, | |
| overrides_file = overrides_file, | |
| param_grid = param_grid, | |
| csv_prefix = csv_prefix | |
| ) | |
| elif base_config.search_type == "none": | |
| logging.info("== Режим одиночной тренировки (без поиска параметров) ==") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| csv_file_path = f"{csv_prefix}_single_{timestamp}.csv" | |
| train_once( | |
| config = base_config, | |
| train_loader = train_loader, | |
| dev_loaders = dev_loaders, | |
| test_loaders = test_loaders, | |
| metrics_csv_path = csv_file_path | |
| ) | |
| else: | |
| raise ValueError(f"⛔️ Неверное значение search_type в конфиге: '{base_config.search_type}'. Используй 'greedy', 'exhaustive' или 'none'.") | |
| if __name__ == "__main__": | |
| main() | |