#!/usr/bin/python # -*- coding: utf-8 -*- import os, sys import json import numpy as np from pathlib import Path import itertools import evaluate import disrpt_eval_2025 #from .disrpt_eval_2025 import * # TODO : should be conditioned on the task or the metric indicated in the config file ?? def prepare_compute_metrics(LABEL_NAMES): ''' Return the method to be used in the trainer loop. For seg or conn, based on seqeval, and here ignore tokens with label -100 (okay ?) Parameters : ------------ LABEL_NAMES: Dict Needed only for BIO labels, convert to the right labels for seqeval task: str Should be either 'seg', 'conn', but could be expanded to other sequence / classif tasks Returns : --------- compute_metrics: function ''' def compute_metrics(eval_preds): nonlocal LABEL_NAMES # nonlocal task # Retrieve gold and predictions logits, labels = eval_preds predictions = np.argmax(logits, axis=-1) metric = evaluate.load("seqeval") # Remove ignored index (special tokens) and convert to labels true_labels = [[LABEL_NAMES[l] for l in label if l != -100] for label in labels] true_predictions = [ [LABEL_NAMES[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] all_metrics = metric.compute(predictions=true_predictions, references=true_labels) print_metrics( all_metrics ) return { "precision": all_metrics["overall_precision"], "recall": all_metrics["overall_recall"], "f1": all_metrics["overall_f1"], "accuracy": all_metrics["overall_accuracy"], } return compute_metrics def print_metrics( all_metrics ): #print( all_metrics ) for p,v in all_metrics.items(): if '_' in p: print( p, v ) else: print( p+' = '+str(v)) def compute_metrics_dirspt( dataset_eval, pred_file, task='seg' ): print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file, pred_file ) if task == 'seg': #clean_pred_file(pred_file, os.path.basename(pred_file)+"cleaned.preds") my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg", dataset_eval.annotations_file, pred_file ) elif task == 'conn': my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn", dataset_eval.annotations_file, pred_file ) else: raise NotImplementedError my_eval.compute_scores() my_eval.print_results() def clean_pred_file(pred_path: str, out_path: str): c=0 with open(pred_path, "r", encoding="utf8") as fin, open(out_path, "w", encoding="utf8") as fout: for line in fin: if line.strip() == "" or line.startswith("#"): fout.write(line) continue fields = line.strip().split("\t") token = fields[1] if token.startswith("[LANG=") or token.startswith("[FRAME="): c+=1 continue # skip meta-tokens fout.write(line) print(f"we've cleaned {c} tokens") # ------------------------------------------------------------------------------------------------- # ------ UTILS FUNCTIONS # ------------------------------------------------------------------------------------------------- def read_config( config_file ): '''Read the config file for training''' f = open(config_file) config = json.load(f) if 'frozen' in config['trainer_config']: config['trainer_config']["frozen"] = update_frozen_set( config['trainer_config']["frozen"] ) return config def update_frozen_set( freeze ): # MAke a set from the list of frozen layers # [] --> nothing frozen # [3] --> only layer 3 frozen # [0,3] --> only layers 0 and 3 # [0-3, 12, 15] --> layers 0 to 3 included, + layers 12 and layers 15 frozen = set() for spec in freeze: if "-" in spec: # eg 1-9 b, e = spec.split("-") frozen = frozen | set(range(int(b),int(e)+1)) else: frozen.add(int(spec)) return frozen def print_config(config): '''Print info from config dictionary''' print('\n'.join([ '| '+k+": "+str(v) for (k,v) in config.items() ])) # ------------------------------------------------------------------------------------------------- def retrieve_files_dataset( input_path, list_dataset, mode='conllu', dset='train' ): if mode == 'conllu': pat = ".[cC][oO][nN][lL][lL][uU]" elif mode == 'tok': pat = ".[tT][oO][kK]" else: sys.exit('Unknown mode for file extension: '+mode) if len(list_dataset) == 0: return list(Path(input_path).rglob("*_"+dset+pat)) else: # files eng.pdtb.pdtb_train.conllu matched = [] for subdir in os.listdir( input_path ): if subdir in list_dataset: matched.extend( list(Path(os.path.join(input_path,subdir)).rglob("*_"+dset+pat)) ) return matched # ------------------------------------------------------------------------------------------------- # https://wandb.ai/site def init_wandb( config, model_checkpoint, annotations_file ): ''' Initialize a new WANDB project to keep track of the experiments. Parameters ---------- config : dict Allow to retrieve the name of the entity and project (from config file) model_checkpoint : Name of the PLM used annotations_file : str Path to the training file Returns ------- None ''' print("HERE WE INITIALIZE A WANDB PROJECT") import wandb proj_wandb = config["wandb"] ent_wandbd = config["wandb_ent"] # start a new wandb run to track this script # The project name must be set before initializing the trainer wandb.init( # set the wandb project where this run will be logged project=proj_wandb, entity=ent_wandbd, # track hyperparameters and run metadata config={ "model_checkpoint": model_checkpoint, "dataset": annotations_file, } ) wandb.define_metric("epoch") wandb.define_metric("epoch") wandb.define_metric("f1", step_metric="batch") wandb.define_metric("f1", step_metric="epoch") def set_name_output_dir( output_dir, config, corpus_name ): ''' Set the path name for the target directory used to store models. The name should contain info about the task, the PLM and the hyperparameter values. Parameters ---------- output_dir : str Path to the output directory provided by the user config: dict Information of configuration corpus_name: str Name of the corpus Returns ------- Str: Path to the output directory ''' # Retrieve decimal number for learning rate, to avoir scientific notation hyperparam = [ config['trainer_config']['batch_size'], np.format_float_positional(config['trainer_config']['learning_rate']) ] output_dir = os.path.join( output_dir, '_'.join( [ corpus_name, config["model_name"], config["task"], '_'.join([str(p) for p in hyperparam]) ] ) ) return output_dir