Spaces:
Sleeping
Sleeping
| #!/usr/bin/python | |
| # -*- coding: utf-8 -*- | |
| import os, sys | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| import itertools | |
| import evaluate | |
| import disrpt_eval_2025 | |
| #from .disrpt_eval_2025 import * | |
| # TODO : should be conditioned on the task or the metric indicated in the config file ?? | |
| def prepare_compute_metrics(LABEL_NAMES): | |
| ''' | |
| Return the method to be used in the trainer loop. | |
| For seg or conn, based on seqeval, and here ignore tokens with label | |
| -100 (okay ?) | |
| Parameters : | |
| ------------ | |
| LABEL_NAMES: Dict | |
| Needed only for BIO labels, convert to the right labels for seqeval | |
| task: str | |
| Should be either 'seg', 'conn', but could be expanded to other | |
| sequence / classif tasks | |
| Returns : | |
| --------- | |
| compute_metrics: function | |
| ''' | |
| def compute_metrics(eval_preds): | |
| nonlocal LABEL_NAMES | |
| # nonlocal task | |
| # Retrieve gold and predictions | |
| logits, labels = eval_preds | |
| predictions = np.argmax(logits, axis=-1) | |
| metric = evaluate.load("seqeval") | |
| # Remove ignored index (special tokens) and convert to labels | |
| true_labels = [[LABEL_NAMES[l] for l in label if l != -100] for label in labels] | |
| true_predictions = [ | |
| [LABEL_NAMES[p] for (p, l) in zip(prediction, label) if l != -100] | |
| for prediction, label in zip(predictions, labels) | |
| ] | |
| all_metrics = metric.compute(predictions=true_predictions, references=true_labels) | |
| print_metrics( all_metrics ) | |
| return { | |
| "precision": all_metrics["overall_precision"], | |
| "recall": all_metrics["overall_recall"], | |
| "f1": all_metrics["overall_f1"], | |
| "accuracy": all_metrics["overall_accuracy"], | |
| } | |
| return compute_metrics | |
| def print_metrics( all_metrics ): | |
| #print( all_metrics ) | |
| for p,v in all_metrics.items(): | |
| if '_' in p: | |
| print( p, v ) | |
| else: | |
| print( p+' = '+str(v)) | |
| def compute_metrics_dirspt( dataset_eval, pred_file, task='seg' ): | |
| print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file, | |
| pred_file ) | |
| if task == 'seg': | |
| #clean_pred_file(pred_file, os.path.basename(pred_file)+"cleaned.preds") | |
| my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg", | |
| dataset_eval.annotations_file, | |
| pred_file ) | |
| elif task == 'conn': | |
| my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn", | |
| dataset_eval.annotations_file, | |
| pred_file ) | |
| else: | |
| raise NotImplementedError | |
| my_eval.compute_scores() | |
| my_eval.print_results() | |
| def clean_pred_file(pred_path: str, out_path: str): | |
| c=0 | |
| with open(pred_path, "r", encoding="utf8") as fin, open(out_path, "w", encoding="utf8") as fout: | |
| for line in fin: | |
| if line.strip() == "" or line.startswith("#"): | |
| fout.write(line) | |
| continue | |
| fields = line.strip().split("\t") | |
| token = fields[1] | |
| if token.startswith("[LANG=") or token.startswith("[FRAME="): | |
| c+=1 | |
| continue # skip meta-tokens | |
| fout.write(line) | |
| print(f"we've cleaned {c} tokens") | |
| # ------------------------------------------------------------------------------------------------- | |
| # ------ UTILS FUNCTIONS | |
| # ------------------------------------------------------------------------------------------------- | |
| def read_config( config_file ): | |
| '''Read the config file for training''' | |
| f = open(config_file) | |
| config = json.load(f) | |
| if 'frozen' in config['trainer_config']: | |
| config['trainer_config']["frozen"] = update_frozen_set( config['trainer_config']["frozen"] ) | |
| return config | |
| def update_frozen_set( freeze ): | |
| # MAke a set from the list of frozen layers | |
| # [] --> nothing frozen | |
| # [3] --> only layer 3 frozen | |
| # [0,3] --> only layers 0 and 3 | |
| # [0-3, 12, 15] --> layers 0 to 3 included, + layers 12 and layers 15 | |
| frozen = set() | |
| for spec in freeze: | |
| if "-" in spec: # eg 1-9 | |
| b, e = spec.split("-") | |
| frozen = frozen | set(range(int(b),int(e)+1)) | |
| else: | |
| frozen.add(int(spec)) | |
| return frozen | |
| def print_config(config): | |
| '''Print info from config dictionary''' | |
| print('\n'.join([ '| '+k+": "+str(v) for (k,v) in config.items() ])) | |
| # ------------------------------------------------------------------------------------------------- | |
| def retrieve_files_dataset( input_path, list_dataset, mode='conllu', dset='train' ): | |
| if mode == 'conllu': | |
| pat = ".[cC][oO][nN][lL][lL][uU]" | |
| elif mode == 'tok': | |
| pat = ".[tT][oO][kK]" | |
| else: | |
| sys.exit('Unknown mode for file extension: '+mode) | |
| if len(list_dataset) == 0: | |
| return list(Path(input_path).rglob("*_"+dset+pat)) | |
| else: | |
| # files eng.pdtb.pdtb_train.conllu | |
| matched = [] | |
| for subdir in os.listdir( input_path ): | |
| if subdir in list_dataset: | |
| matched.extend( list(Path(os.path.join(input_path,subdir)).rglob("*_"+dset+pat)) ) | |
| return matched | |
| # ------------------------------------------------------------------------------------------------- | |
| # https://wandb.ai/site | |
| def init_wandb( config, model_checkpoint, annotations_file ): | |
| ''' | |
| Initialize a new WANDB project to keep track of the experiments. | |
| Parameters | |
| ---------- | |
| config : dict | |
| Allow to retrieve the name of the entity and project (from config file) | |
| model_checkpoint : | |
| Name of the PLM used | |
| annotations_file : str | |
| Path to the training file | |
| Returns | |
| ------- | |
| None | |
| ''' | |
| print("HERE WE INITIALIZE A WANDB PROJECT") | |
| import wandb | |
| proj_wandb = config["wandb"] | |
| ent_wandbd = config["wandb_ent"] | |
| # start a new wandb run to track this script | |
| # The project name must be set before initializing the trainer | |
| wandb.init( | |
| # set the wandb project where this run will be logged | |
| project=proj_wandb, | |
| entity=ent_wandbd, | |
| # track hyperparameters and run metadata | |
| config={ | |
| "model_checkpoint": model_checkpoint, | |
| "dataset": annotations_file, | |
| } | |
| ) | |
| wandb.define_metric("epoch") | |
| wandb.define_metric("epoch") | |
| wandb.define_metric("f1", step_metric="batch") | |
| wandb.define_metric("f1", step_metric="epoch") | |
| def set_name_output_dir( output_dir, config, corpus_name ): | |
| ''' | |
| Set the path name for the target directory used to store models. The name should contain | |
| info about the task, the PLM and the hyperparameter values. | |
| Parameters | |
| ---------- | |
| output_dir : str | |
| Path to the output directory provided by the user | |
| config: dict | |
| Information of configuration | |
| corpus_name: str | |
| Name of the corpus | |
| Returns | |
| ------- | |
| Str: Path to the output directory | |
| ''' | |
| # Retrieve decimal number for learning rate, to avoir scientific notation | |
| hyperparam = [ | |
| config['trainer_config']['batch_size'], | |
| np.format_float_positional(config['trainer_config']['learning_rate']) | |
| ] | |
| output_dir = os.path.join( output_dir, | |
| '_'.join( [ | |
| corpus_name, | |
| config["model_name"], | |
| config["task"], | |
| '_'.join([str(p) for p in hyperparam]) | |
| ] ) ) | |
| return output_dir |