File size: 6,645 Bytes
f709e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/python
# -*- coding: utf-8 -*-

import os, sys 
import json
import numpy as np
from pathlib import Path
import itertools 

import evaluate
import disrpt_eval_2025
#from .disrpt_eval_2025 import *

# TODO : should be conditioned on the task or the metric indicated in the config file ??
def prepare_compute_metrics(LABEL_NAMES):
	'''
	Return the method to be used in the trainer loop.
	For seg or conn, based on seqeval, and here ignore tokens with label 
		-100 (okay ?)

	Parameters :
	------------
	LABEL_NAMES: Dict
		Needed only for BIO labels, convert to the right labels for seqeval
	task: str
		Should be either 'seg', 'conn', but could be expanded to other 
			sequence / classif tasks
	
	Returns :
	---------
	compute_metrics: function 
	'''
	def compute_metrics(eval_preds):
		nonlocal LABEL_NAMES
		# nonlocal task
		# Retrieve gold and predictions 
		logits, labels = eval_preds
		
		predictions = np.argmax(logits, axis=-1)
		metric = evaluate.load("seqeval")
		# Remove ignored index (special tokens) and convert to labels
		true_labels = [[LABEL_NAMES[l] for l in label if l != -100] for label in labels]
		true_predictions = [
				[LABEL_NAMES[p] for (p, l) in zip(prediction, label) if l != -100]
					for prediction, label in zip(predictions, labels)
				]
		all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
		print_metrics( all_metrics )
		return {
				"precision": all_metrics["overall_precision"],
				"recall": all_metrics["overall_recall"],
				"f1": all_metrics["overall_f1"],
				"accuracy": all_metrics["overall_accuracy"],
				}
	return compute_metrics


def print_metrics( all_metrics ):
	#print( all_metrics )
	for p,v in all_metrics.items():
		if '_' in p:
			print( p, v )
		else:
			print( p+' = '+str(v))

def compute_metrics_dirspt( dataset_eval, pred_file, task='seg' ):
	print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file, 
			pred_file )
	if task == 'seg':
		#clean_pred_file(pred_file, os.path.basename(pred_file)+"cleaned.preds")
		my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg", 
			dataset_eval.annotations_file, 
			pred_file )
	elif task == 'conn':
		my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn", 
			dataset_eval.annotations_file, 
			pred_file )
	else:
		raise NotImplementedError
	my_eval.compute_scores()
	my_eval.print_results()

def clean_pred_file(pred_path: str, out_path: str):
	c=0
	with open(pred_path, "r", encoding="utf8") as fin, open(out_path, "w", encoding="utf8") as fout:
		for line in fin:
			if line.strip() == "" or line.startswith("#"):
				fout.write(line)
				continue
			fields = line.strip().split("\t")
			token = fields[1]
			if token.startswith("[LANG=") or token.startswith("[FRAME="):
				c+=1
				continue  # skip meta-tokens
			fout.write(line)
	print(f"we've cleaned {c} tokens")	
# -------------------------------------------------------------------------------------------------
# ------ UTILS FUNCTIONS 
# -------------------------------------------------------------------------------------------------
def read_config( config_file ):
	'''Read the config file for training'''
	f = open(config_file)
	config = json.load(f)
	if 'frozen' in config['trainer_config']:
		config['trainer_config']["frozen"] = update_frozen_set( config['trainer_config']["frozen"] )
	return config

def update_frozen_set( freeze ):
	# MAke a set from the list of frozen layers
	# [] --> nothing frozen
	# [3] --> only layer 3 frozen
	# [0,3] --> only layers 0 and 3
	# [0-3, 12, 15] --> layers 0 to 3 included, + layers 12 and layers 15
	frozen = set()
	for spec in freeze:
		if "-" in spec: # eg 1-9
			b, e = spec.split("-")
			frozen = frozen | set(range(int(b),int(e)+1))
		else:
			frozen.add(int(spec))
	return frozen

def print_config(config):
	'''Print info from config dictionary'''
	print('\n'.join([ '| '+k+": "+str(v) for (k,v) in config.items() ]))

# -------------------------------------------------------------------------------------------------
def retrieve_files_dataset( input_path, list_dataset, mode='conllu', dset='train' ):
	if mode == 'conllu':
		pat = ".[cC][oO][nN][lL][lL][uU]"
	elif mode == 'tok':
		pat = ".[tT][oO][kK]"
	else:
		sys.exit('Unknown mode for file extension: '+mode)
	if len(list_dataset) == 0: 
		return list(Path(input_path).rglob("*_"+dset+pat))
	else:
		# files eng.pdtb.pdtb_train.conllu
		matched = []
		for subdir in os.listdir( input_path ):
			if subdir in list_dataset:
				matched.extend( list(Path(os.path.join(input_path,subdir)).rglob("*_"+dset+pat)) )
		return matched 


# -------------------------------------------------------------------------------------------------
# https://wandb.ai/site
def init_wandb( config, model_checkpoint, annotations_file ):
	'''
	Initialize a new WANDB project to keep track of the experiments.
	Parameters
	----------
	config : dict
		Allow to retrieve the name of the entity and project (from config file)
	model_checkpoint : 
		Name of the PLM used
	annotations_file : str
		Path to the training file

	Returns
	-------
	None
	'''
	print("HERE WE INITIALIZE A WANDB PROJECT")
	
	import wandb
	proj_wandb = config["wandb"]
	ent_wandbd = config["wandb_ent"]
	# start a new wandb run to track this script
	# The project name must be set before initializing the trainer 
	wandb.init(
		# set the wandb project where this run will be logged
		project=proj_wandb,
		entity=ent_wandbd,
		# track hyperparameters and run metadata
		config={
			"model_checkpoint": model_checkpoint,
			"dataset": annotations_file,
		}
	)
	wandb.define_metric("epoch")
	wandb.define_metric("epoch")
	wandb.define_metric("f1", step_metric="batch")
	wandb.define_metric("f1", step_metric="epoch")
	
def set_name_output_dir( output_dir, config, corpus_name ):
	'''
	Set the path name for the target directory used to store models. The name should contain
	info about the task, the PLM and the hyperparameter values.

	Parameters
	----------
	output_dir : str
		Path to the output directory provided by the user
	config: dict
		Information of configuration
	corpus_name: str
		Name of the corpus

	Returns
	-------
	Str: Path to the output directory
	'''
	# Retrieve decimal number for learning rate, to avoir scientific notation
	hyperparam = [
				config['trainer_config']['batch_size'], 
				np.format_float_positional(config['trainer_config']['learning_rate'])
				  ]
	output_dir = os.path.join( output_dir, 
				'_'.join( [
					corpus_name, 
					config["model_name"], 
					config["task"], 
					'_'.join([str(p) for p in hyperparam]) 
					] ) )
	return output_dir