Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| import copy | |
| import logging | |
| import numpy as np | |
| import datetime | |
| from itertools import product | |
| from typing import Any | |
| def format_result_box_dual(step_num, param_name, candidate, fixed_params, dev_metrics, test_metrics, is_best=False): | |
| title = f"Шаг {step_num}: {param_name} = {candidate}" | |
| fixed_lines = [f"{k} = {v}" for k, v in fixed_params.items()] | |
| def format_metrics_block(metrics, label): | |
| lines = [f" Результаты ({label.upper()}):"] | |
| for k in ["uar", "war", "mf1", "wf1", "loss", "mean"]: | |
| if k in metrics: | |
| val = metrics[k] | |
| line = f" {k.upper():12} = {val:.4f}" if isinstance(val, float) else f" {k.upper():12} = {val}" | |
| if is_best and label.lower() == "dev" and k.lower() == "mean": | |
| line += " ✅" | |
| lines.append(line) | |
| return lines | |
| content_lines = [title, " Фиксировано:"] | |
| content_lines += [f" {line}" for line in fixed_lines] | |
| # DEV блок | |
| content_lines += format_metrics_block(dev_metrics, "dev") | |
| content_lines.append("") | |
| # TEST блок | |
| content_lines += format_metrics_block(test_metrics, "test") | |
| # GAP | |
| if "mean" in dev_metrics and "mean" in test_metrics: | |
| gap_val = dev_metrics["mean"] - test_metrics["mean"] | |
| gap_str = f" GAP = {gap_val:+.4f}" | |
| content_lines.append(gap_str) | |
| max_width = max(len(line) for line in content_lines) | |
| border_top = "┌" + "─" * (max_width + 2) + "┐" | |
| border_bot = "└" + "─" * (max_width + 2) + "┘" | |
| box = [border_top] | |
| for line in content_lines: | |
| box.append(f"│ {line.ljust(max_width)} │") | |
| box.append(border_bot) | |
| return "\n".join(box) | |
| def greedy_search( | |
| base_config, | |
| train_loader, | |
| dev_loader, | |
| test_loader, | |
| train_fn, | |
| overrides_file: str, | |
| param_grid: dict[str, list], | |
| default_values: dict[str, Any], | |
| csv_prefix: str = None | |
| ): | |
| current_best_params = copy.deepcopy(default_values) | |
| all_param_names = list(param_grid.keys()) | |
| model_name = getattr(base_config, "model_name", "UNKNOWN_MODEL") | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write("=== Жадный (поэтапный) перебор гиперпараметров (Dev-based) ===\n") | |
| f.write(f"Модель: {model_name}\n") | |
| for i, param_name in enumerate(all_param_names): | |
| candidates = param_grid[param_name] | |
| tried_value = current_best_params[param_name] | |
| if i == 0: | |
| candidates_to_try = candidates | |
| else: | |
| candidates_to_try = [v for v in candidates if v != tried_value] | |
| best_val_for_param = tried_value | |
| best_metric_for_param = float("-inf") | |
| # Если не первый шаг — вставим текущую комбу | |
| if i != 0: | |
| config_default = copy.deepcopy(base_config) | |
| for k, v in current_best_params.items(): | |
| setattr(config_default, k, v) | |
| logging.info(f"[ШАГ {i+1}] {param_name} = {tried_value} (ранее проверенный)") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| csv_filename = f"{csv_prefix}_{model_name}_{param_name}_{tried_value}_{timestamp}.csv" if csv_prefix else None | |
| dev_mean_default, dev_metrics_default, test_metrics_default = train_fn( | |
| config_default, | |
| train_loader, | |
| dev_loader, | |
| test_loader, | |
| metrics_csv_path=csv_filename | |
| ) | |
| box_text = format_result_box_dual( | |
| step_num=i+1, | |
| param_name=param_name, | |
| candidate=tried_value, | |
| fixed_params={k: v for k, v in current_best_params.items() if k != param_name}, | |
| dev_metrics=dev_metrics_default, | |
| test_metrics=test_metrics_default, | |
| is_best=True | |
| ) | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write("\n" + box_text + "\n") | |
| _log_dataset_metrics(dev_metrics_default, overrides_file, label="dev") | |
| _log_dataset_metrics(test_metrics_default, overrides_file, label="test") | |
| best_metric_for_param = dev_mean_default | |
| for candidate in candidates_to_try: | |
| config = copy.deepcopy(base_config) | |
| for k, v in current_best_params.items(): | |
| setattr(config, k, v) | |
| setattr(config, param_name, candidate) | |
| logging.info(f"[ШАГ {i+1}] {param_name} = {candidate}, (остальные {current_best_params})") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| csv_filename = f"{csv_prefix}_{model_name}_{param_name}_{candidate}_{timestamp}.csv" if csv_prefix else None | |
| dev_mean, dev_metrics, test_metrics = train_fn( | |
| config, | |
| train_loader, | |
| dev_loader, | |
| test_loader, | |
| metrics_csv_path=csv_filename | |
| ) | |
| is_better = dev_mean > best_metric_for_param | |
| box_text = format_result_box_dual( | |
| step_num=i+1, | |
| param_name=param_name, | |
| candidate=candidate, | |
| fixed_params={k: v for k, v in current_best_params.items() if k != param_name}, | |
| dev_metrics=dev_metrics, | |
| test_metrics=test_metrics, | |
| is_best=is_better | |
| ) | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write("\n" + box_text + "\n") | |
| _log_dataset_metrics(dev_metrics, overrides_file, label="dev") | |
| _log_dataset_metrics(test_metrics, overrides_file, label="test") | |
| if is_better: | |
| best_val_for_param = candidate | |
| best_metric_for_param = dev_mean | |
| current_best_params[param_name] = best_val_for_param | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write(f"\n>> [Итог Шаг{i+1}]: Лучший {param_name}={best_val_for_param}, dev_mean={best_metric_for_param:.4f}\n") | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write("\n=== Итоговая комбинация (Dev-based) ===\n") | |
| for k, v in current_best_params.items(): | |
| f.write(f"{k} = {v}\n") | |
| logging.info("Готово! Лучшие параметры подобраны.") | |
| def exhaustive_search( | |
| base_config, | |
| train_loader, | |
| dev_loader, | |
| test_loader, | |
| train_fn, | |
| overrides_file: str, | |
| param_grid: dict[str, list], | |
| csv_prefix: str = None | |
| ): | |
| all_param_names = list(param_grid.keys()) | |
| model_name = getattr(base_config, "model_name", "UNKNOWN_MODEL") | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write("=== Полный перебор гиперпараметров (Dev-based) ===\n") | |
| f.write(f"Модель: {model_name}\n") | |
| best_config = None | |
| best_metric = float("-inf") | |
| best_metrics = {} | |
| combo_id = 0 | |
| for combo in product(*(param_grid[param] for param in all_param_names)): | |
| combo_id += 1 | |
| param_combo = dict(zip(all_param_names, combo)) | |
| config = copy.deepcopy(base_config) | |
| for k, v in param_combo.items(): | |
| setattr(config, k, v) | |
| logging.info(f"\n[Комбинация #{combo_id}] {param_combo}") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| csv_filename = f"{csv_prefix}_{model_name}_combo{combo_id}_{timestamp}.csv" if csv_prefix else None | |
| dev_mean, dev_metrics, test_metrics = train_fn( | |
| config, | |
| train_loader, | |
| dev_loader, | |
| test_loader, | |
| metrics_csv_path=csv_filename | |
| ) | |
| is_better = dev_mean > best_metric | |
| box_text = format_result_box_dual( | |
| step_num=combo_id, | |
| param_name=" + ".join(all_param_names), | |
| candidate=str(combo), | |
| fixed_params={}, | |
| dev_metrics=dev_metrics, | |
| test_metrics=test_metrics, | |
| is_best=is_better | |
| ) | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write("\n" + box_text + "\n") | |
| _log_dataset_metrics(dev_metrics, overrides_file, label="dev") | |
| _log_dataset_metrics(test_metrics, overrides_file, label="test") | |
| if is_better: | |
| best_metric = dev_mean | |
| best_config = param_combo | |
| best_metrics = dev_metrics | |
| with open(overrides_file, "a", encoding="utf-8") as f: | |
| f.write("\n=== Лучшая комбинация (Dev-based) ===\n") | |
| for k, v in best_config.items(): | |
| f.write(f"{k} = {v}\n") | |
| logging.info("Полный перебор завершён! Лучшие параметры выбраны.") | |
| return best_metric, best_config, best_metrics | |
| def _compute_combined_avg(dev_metrics): | |
| if "by_dataset" not in dev_metrics: | |
| return None | |
| values = [] | |
| for entry in dev_metrics["by_dataset"]: | |
| for key in ["uar", "war", "mf1", "wf1"]: | |
| if key in entry: | |
| values.append(entry[key]) | |
| return float(np.mean(values)) if values else None | |
| def _log_dataset_metrics(metrics, file_path, label="dev"): | |
| if "by_dataset" not in metrics: | |
| return | |
| with open(file_path, "a", encoding="utf-8") as f: | |
| f.write(f"\n>>> Подробные метрики по каждому датасету ({label}):\n") | |
| for ds in metrics["by_dataset"]: | |
| name = ds.get("name", "unknown") | |
| f.write(f" - {name}:\n") | |
| for k in ["loss", "uar", "war", "mf1", "wf1", "mean"]: | |
| if k in ds: | |
| f.write(f" {k.upper():4} = {ds[k]:.4f}\n") | |
| f.write(f"<<< Конец подробных метрик ({label})\n") | |