test_discut / utils.py
poyum's picture
gradio space init
f709e5e
#!/usr/bin/python
# -*- coding: utf-8 -*-
import os, sys
import json
import numpy as np
from pathlib import Path
import itertools
import evaluate
import disrpt_eval_2025
#from .disrpt_eval_2025 import *
# TODO : should be conditioned on the task or the metric indicated in the config file ??
def prepare_compute_metrics(LABEL_NAMES):
'''
Return the method to be used in the trainer loop.
For seg or conn, based on seqeval, and here ignore tokens with label
-100 (okay ?)
Parameters :
------------
LABEL_NAMES: Dict
Needed only for BIO labels, convert to the right labels for seqeval
task: str
Should be either 'seg', 'conn', but could be expanded to other
sequence / classif tasks
Returns :
---------
compute_metrics: function
'''
def compute_metrics(eval_preds):
nonlocal LABEL_NAMES
# nonlocal task
# Retrieve gold and predictions
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
metric = evaluate.load("seqeval")
# Remove ignored index (special tokens) and convert to labels
true_labels = [[LABEL_NAMES[l] for l in label if l != -100] for label in labels]
true_predictions = [
[LABEL_NAMES[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
print_metrics( all_metrics )
return {
"precision": all_metrics["overall_precision"],
"recall": all_metrics["overall_recall"],
"f1": all_metrics["overall_f1"],
"accuracy": all_metrics["overall_accuracy"],
}
return compute_metrics
def print_metrics( all_metrics ):
#print( all_metrics )
for p,v in all_metrics.items():
if '_' in p:
print( p, v )
else:
print( p+' = '+str(v))
def compute_metrics_dirspt( dataset_eval, pred_file, task='seg' ):
print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file,
pred_file )
if task == 'seg':
#clean_pred_file(pred_file, os.path.basename(pred_file)+"cleaned.preds")
my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg",
dataset_eval.annotations_file,
pred_file )
elif task == 'conn':
my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn",
dataset_eval.annotations_file,
pred_file )
else:
raise NotImplementedError
my_eval.compute_scores()
my_eval.print_results()
def clean_pred_file(pred_path: str, out_path: str):
c=0
with open(pred_path, "r", encoding="utf8") as fin, open(out_path, "w", encoding="utf8") as fout:
for line in fin:
if line.strip() == "" or line.startswith("#"):
fout.write(line)
continue
fields = line.strip().split("\t")
token = fields[1]
if token.startswith("[LANG=") or token.startswith("[FRAME="):
c+=1
continue # skip meta-tokens
fout.write(line)
print(f"we've cleaned {c} tokens")
# -------------------------------------------------------------------------------------------------
# ------ UTILS FUNCTIONS
# -------------------------------------------------------------------------------------------------
def read_config( config_file ):
'''Read the config file for training'''
f = open(config_file)
config = json.load(f)
if 'frozen' in config['trainer_config']:
config['trainer_config']["frozen"] = update_frozen_set( config['trainer_config']["frozen"] )
return config
def update_frozen_set( freeze ):
# MAke a set from the list of frozen layers
# [] --> nothing frozen
# [3] --> only layer 3 frozen
# [0,3] --> only layers 0 and 3
# [0-3, 12, 15] --> layers 0 to 3 included, + layers 12 and layers 15
frozen = set()
for spec in freeze:
if "-" in spec: # eg 1-9
b, e = spec.split("-")
frozen = frozen | set(range(int(b),int(e)+1))
else:
frozen.add(int(spec))
return frozen
def print_config(config):
'''Print info from config dictionary'''
print('\n'.join([ '| '+k+": "+str(v) for (k,v) in config.items() ]))
# -------------------------------------------------------------------------------------------------
def retrieve_files_dataset( input_path, list_dataset, mode='conllu', dset='train' ):
if mode == 'conllu':
pat = ".[cC][oO][nN][lL][lL][uU]"
elif mode == 'tok':
pat = ".[tT][oO][kK]"
else:
sys.exit('Unknown mode for file extension: '+mode)
if len(list_dataset) == 0:
return list(Path(input_path).rglob("*_"+dset+pat))
else:
# files eng.pdtb.pdtb_train.conllu
matched = []
for subdir in os.listdir( input_path ):
if subdir in list_dataset:
matched.extend( list(Path(os.path.join(input_path,subdir)).rglob("*_"+dset+pat)) )
return matched
# -------------------------------------------------------------------------------------------------
# https://wandb.ai/site
def init_wandb( config, model_checkpoint, annotations_file ):
'''
Initialize a new WANDB project to keep track of the experiments.
Parameters
----------
config : dict
Allow to retrieve the name of the entity and project (from config file)
model_checkpoint :
Name of the PLM used
annotations_file : str
Path to the training file
Returns
-------
None
'''
print("HERE WE INITIALIZE A WANDB PROJECT")
import wandb
proj_wandb = config["wandb"]
ent_wandbd = config["wandb_ent"]
# start a new wandb run to track this script
# The project name must be set before initializing the trainer
wandb.init(
# set the wandb project where this run will be logged
project=proj_wandb,
entity=ent_wandbd,
# track hyperparameters and run metadata
config={
"model_checkpoint": model_checkpoint,
"dataset": annotations_file,
}
)
wandb.define_metric("epoch")
wandb.define_metric("epoch")
wandb.define_metric("f1", step_metric="batch")
wandb.define_metric("f1", step_metric="epoch")
def set_name_output_dir( output_dir, config, corpus_name ):
'''
Set the path name for the target directory used to store models. The name should contain
info about the task, the PLM and the hyperparameter values.
Parameters
----------
output_dir : str
Path to the output directory provided by the user
config: dict
Information of configuration
corpus_name: str
Name of the corpus
Returns
-------
Str: Path to the output directory
'''
# Retrieve decimal number for learning rate, to avoir scientific notation
hyperparam = [
config['trainer_config']['batch_size'],
np.format_float_positional(config['trainer_config']['learning_rate'])
]
output_dir = os.path.join( output_dir,
'_'.join( [
corpus_name,
config["model_name"],
config["task"],
'_'.join([str(p) for p in hyperparam])
] ) )
return output_dir