Spaces:
Sleeping
Sleeping
| #!/usr/bin/python | |
| # -*- coding: utf-8 -*- | |
| import os, sys | |
| import numpy as np | |
| import transformers | |
| import utils | |
| import reading | |
| SUBTOKEN_START = '##' | |
| ''' | |
| TODOs: | |
| - for now, if the dataset is cached, can t use word ids and the predictions | |
| written are not based on original eval file, thus not exactly same number | |
| of tokens (ignore contractions) --> doesn t work in disrpt eval script | |
| Change in newest version of transformers: | |
| from seqeval.metrics import accuracy_score | |
| from seqeval.metrics import classification_report | |
| from seqeval.metrics import f1_score | |
| ''' | |
| def simple_eval( dataset_eval, model_checkpoint, tokenizer, output_path, | |
| config, trace=False ): | |
| ''' | |
| Run the pre-trained model on the (dev) dataset to get predictions, | |
| then write the predictions in an output file. | |
| Parameters: | |
| ----------- | |
| datasets: dict of DatasetDisc | |
| The datasets read | |
| model_checkpoint: str | |
| path to the saved model | |
| tokenizer: Tokenizer | |
| tokenizer of the saved model (TODO: retrieve from model? or should be removed?) | |
| output_path: str | |
| path to the output directory where prediction files will be written | |
| data_collator: DataCollator | |
| (TODO: retrieve from model?) | |
| ''' | |
| # Retrieve predictions (list of list of 0 and 1) | |
| print("\n-- PREDICT on:", dataset_eval.annotations_file ) | |
| model_checkpoint = os.path.normpath(model_checkpoint) | |
| print("model_checkpoint", model_checkpoint) | |
| preds_from_model, label_ids, metrics = retrieve_predictions( model_checkpoint, | |
| dataset_eval, output_path, tokenizer, config ) | |
| print("preds_from_model.shape", preds_from_model.shape) | |
| print("label_ids.shape", label_ids.shape) | |
| # - Compute metrics | |
| print("\n-- COMPUTE METRICS" ) | |
| compute_metrics = utils.prepare_compute_metrics( dataset_eval.LABEL_NAMES_BIO ) | |
| metrics=compute_metrics([preds_from_model, label_ids]) | |
| max_preds_from_model = np.argmax(preds_from_model, axis=-1) | |
| # - Write predictions: | |
| pred_file = os.path.join( output_path, dataset_eval.basename+'.preds' ) | |
| print("\n-- WRITE PREDS in:", pred_file ) | |
| pred_file_success = True | |
| try: | |
| try: | |
| # * retrieving the original words: will fail if cache not emptied | |
| print( "Write predictions based on words") | |
| predictions = align_tokens_labels_from_wordids( max_preds_from_model, dataset_eval, | |
| tokenizer) | |
| write_pred_file( dataset_eval.annotations_file, pred_file, predictions, trace=trace ) | |
| except IndexError: | |
| # if error, we print the predictions with tokens, trying to merge subtokens | |
| # based on SUBTOKEN_START and labels at -100 | |
| print( "Write predictions based on model tokenisation" ) | |
| aligned_tokens, aligned_golds, aligned_preds = align_tokens_labels_from_subtokens( | |
| max_preds_from_model, dataset_eval, tokenizer, pred_file, trace=trace ) | |
| write_pred_file_from_scratch( aligned_tokens, aligned_golds, aligned_preds, | |
| pred_file, trace=trace ) | |
| except Exception as e: | |
| print( "Problem when trying to write predictions in file", pred_file ) | |
| print( "Exception:", e ) | |
| print("we skip the prediction writing step") | |
| pred_file_success=False | |
| if pred_file_success: | |
| print( "\n-- EVAL DISRPT script" ) | |
| clean_pred_path = pred_file.replace('.preds', '.cleaned.preds') | |
| utils.clean_pred_file(pred_file, clean_pred_path) | |
| utils.compute_metrics_dirspt( dataset_eval, clean_pred_path, task=config['task'] ) | |
| # except: | |
| # print("Problem when trying to compute scores with DISRPT eval script") | |
| return metrics | |
| # - Test DISRPT eval script | |
| # try: | |
| def write_pred_file(annotations_file, pred_file, predictions, trace=False): | |
| ''' | |
| Write a file containing the predictions based on the original annotation file. | |
| It takes each line in the original evaluation file and append the prediction at | |
| the end. Predictions and original tokens need to be perfectly aligned. | |
| Parameters: | |
| ----------- | |
| annotations_file: str | file path OR raw text | |
| Path to the original evaluation file, or the text content itself | |
| pred_file: str | |
| Path to the output prediction file | |
| predictions: list of str | |
| Flat list of all predictions (DISRPT format) for all tokens in eval | |
| ''' | |
| count_pred_B, count_gold_B = 0, 0 | |
| count_line_dash = 0 | |
| count_line_dot = 0 | |
| # --- Déterminer si annotations_file est un chemin ou du texte brut | |
| if os.path.isfile(annotations_file): | |
| with open(annotations_file, 'r', encoding='utf-8') as fin: | |
| mylines = fin.readlines() | |
| else: | |
| # Considérer que c’est une string brute | |
| mylines = annotations_file.strip().splitlines() | |
| os.makedirs(os.path.dirname(pred_file), exist_ok=True) | |
| with open(pred_file, 'w', encoding='utf-8') as fout: | |
| count = 0 | |
| if trace: | |
| print("len(predictions)", len(predictions)) | |
| for l in mylines: | |
| l = l.strip() | |
| if l.startswith("#"): # Keep metadata | |
| fout.write(l + '\n') | |
| elif l == '' or l == '\n': # keep line break | |
| fout.write('\n') | |
| elif '-' in l.split('\t')[0]: # Keep lines for contractions but no label | |
| if trace: | |
| print("WARNING: line with - in token, no label will be added") | |
| count_line_dash += 1 | |
| fout.write(l + '\t' + '_' + '\n') | |
| # strange case in GUM | |
| elif '.' in l.split('\t')[0]: # Keep lines no label | |
| count_line_dot += 1 | |
| if trace: | |
| print("WARNING: line with . in token, no label will be added") | |
| fout.write(l + '\t' + '_' + '\n') | |
| else: | |
| if 'B' in predictions[count]: | |
| count_pred_B += 1 | |
| if 'Seg=B-seg' in l or 'Conn=B-conn' in l: | |
| count_gold_B += 1 | |
| fout.write(l + '\t' + predictions[count] + '\n') | |
| count += 1 | |
| print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B) | |
| print("Count the number of lines with - in token", count_line_dash) | |
| print("Count the number of lines with . in token", count_line_dot) | |
| def write_pred_file_from_scratch( aligned_tokens, aligned_golds, aligned_preds, pred_file, trace=False ): | |
| ''' | |
| Write a prediction file based on a alignment between tokenisation and predictions. | |
| Since we are not sur that we retrieved the exact alignment, the writing here is not based | |
| on the original annotation file, but we use a similar format: | |
| # Sent ID | |
| tok_ID token gold_label pred_label | |
| The use of the DISRPT script will show whther the alignment worked or not ... | |
| Parameters: | |
| ---------- | |
| aligned_XX: list of list of str | |
| The tokens / preds / golds for each sentence | |
| ''' | |
| count_pred_B, count_gold_B = 0, 0 | |
| with open( pred_file, 'w' ) as fout: | |
| if trace: | |
| print( 'len tokens', len(aligned_tokens)) | |
| print("len(predictions)", len(aligned_preds)) | |
| print( 'len(golds)', len(aligned_preds)) | |
| for s, tok_sent in enumerate( aligned_tokens ): | |
| fout.write( "# sent_id = "+str(s)+"\n" ) | |
| for i, tok in enumerate( tok_sent ): | |
| g = aligned_golds[s][i] | |
| p = aligned_preds[s][i] | |
| fout.write( '\t'.join([str(i), tok, g, p])+'\n' ) | |
| if 'B' in p: | |
| count_pred_B += 1 | |
| if 'Seg=B-seg' in g or 'Conn=B-conn' in g: | |
| count_gold_B += 1 | |
| fout.write( "\n" ) | |
| print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B) | |
| def align_tokens_labels_from_wordids( preds_from_model, dataset_eval, tokenizer, trace=False ): | |
| ''' | |
| Write predictions for segmentation or connective tasks in an output files. | |
| The output is the same as the input gold file, with an additional column | |
| corresponding to the predicted label. | |
| Easiest way (?): use word_ids information to merge the words that been split et | |
| retrieve the original tokens from the input .tok / .conllu files and run | |
| evaluation --> but not kept in the cached datasets | |
| Parameters: | |
| ----------- | |
| preds_from_model: list of int | |
| The predicted labels (numeric ids) | |
| dev: DatasetDisc | |
| Dataset for evalusation | |
| pred_file: str | |
| Path to the file where predictions will be written | |
| Return: | |
| ------- | |
| predictions: list of String | |
| The predicted labels (DISRPT format) for each original input word | |
| ''' | |
| word_ids = dataset_eval.all_word_ids | |
| id2label = dataset_eval.id2label | |
| predictions = [] | |
| for i in range( preds_from_model.shape[0] ): | |
| sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i] | |
| tokens = dataset_eval.dataset['tokens'][i] | |
| sent_tokens = tokenizer.decode(sent_input_ids[1:-1]) | |
| aligned_preds = _merge_tokens_preds_sent( word_ids[i], preds_from_model[i], tokens ) | |
| if trace: | |
| print( '\n', i, sent_tokens ) | |
| print( sent_input_ids ) | |
| print( preds_from_model[i]) | |
| print( ' '.join( tokens ) ) | |
| print( "aligned_preds", aligned_preds ) | |
| for k, tok in enumerate( tokens ): | |
| # Ignorer les tokens spéciaux | |
| if tok.startswith('[LANG=') or tok.startswith('[FRAME='): | |
| if trace: | |
| print(f"Skip special token: {tok}") | |
| continue | |
| label = aligned_preds[k] | |
| predictions.append( id2label[label] ) | |
| return predictions | |
| def _merge_tokens_preds_sent( word_ids, preds, tokens ): | |
| ''' | |
| The tokenizer split the tokens into subtokens, with labels added on subwords. | |
| For evaluation, we need to merge the subtokens, and keep only the labels on | |
| the plain tokens. | |
| The function takes the whole input_ids and predictions for one sentence and | |
| return the merged version. | |
| We also get rid of tokens and associated labels for [CLS] and [SEP] and don't | |
| keep predictions for padding tokens. | |
| TODO: here inspireed from the mthod to split the labels, but we can cut the | |
| 2 continue (kept for debug) | |
| input_ids: list | |
| list of ids of (sub)tokens as produced by the (BERT like) tokenizer | |
| preds: list | |
| the predictions of the model | |
| ''' | |
| aligned_toks = [] | |
| count = 0 | |
| new_labels = [] | |
| current_word = None | |
| for i, word_id in enumerate( word_ids ): | |
| count += 1 | |
| if word_id != current_word: | |
| # New word | |
| current_word = word_id | |
| if word_id is not None: | |
| new_labels.append( preds[i] ) | |
| aligned_toks.append( tokens[word_id] ) | |
| elif word_id is None: | |
| # Special token | |
| continue | |
| else: | |
| # Same word as previous token | |
| continue | |
| if len(new_labels) != len(aligned_toks) or len(new_labels) != len(tokens): | |
| print( "WARNING, something wrong, not the same nb of tokens and predictions") | |
| print( len(new_labels), len(aligned_toks), len(tokens) ) | |
| return new_labels | |
| def map_labels_list( list_labels, id2label ): | |
| return [id2label[l] for l in list_labels] | |
| def align_tokens_labels_from_subtokens( preds_from_model, dataset_eval, tokenizer, pred_file, trace=False ): | |
| ''' | |
| Align tokens and labels (merging subtokens, assigning the right label) | |
| based on the specific characters for starting a subtoken (e.g. ## for BERT) | |
| and label -100 assigned to contractions of MWE (e.g. it's). | |
| But not completely sure that we get the exact alignment with original words here. | |
| ''' | |
| aligned_tokens, aligned_golds, aligned_preds = [], [], [] | |
| id2label = dataset_eval.id2label | |
| tokenized_dataset = dataset_eval.tokenized_datasets | |
| # print("\ndataset_eval.tokenized_datasets", dataset_eval.tokenized_datasets) | |
| # print("preds_from_model.shape", preds_from_model.shape) | |
| # For each sentence | |
| with open(pred_file, 'w') as fout: | |
| # Iterate on sentences | |
| for i in range( preds_from_model.shape[0] ): | |
| # fout.write( "new_sent_"+str(i)+'\n' ) | |
| sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i] | |
| sent_gold_labels = tokenized_dataset['labels'][i] | |
| sent_pred_labels = preds_from_model[i] | |
| aligned_t, aligned_g, aligned_p = _retrieve_tokens_from_sent( sent_input_ids, sent_pred_labels, | |
| sent_gold_labels, tokenizer, trace=trace ) | |
| aligned_tokens.append(aligned_t) | |
| aligned_golds.append( map_labels_list(aligned_g, id2label) ) | |
| aligned_preds.append( map_labels_list(aligned_p, id2label) ) | |
| return aligned_tokens, aligned_golds, aligned_preds | |
| def _retrieve_tokens_from_sent( sent_input_ids, preds_from_model, sent_gold_labels, tokenizer, trace=False ): | |
| # tokenized_dataset = dataset.tokenized_datasets | |
| cur_token, cur_pred, cur_gold = None, None, None | |
| tokens, golds, preds = [], [], [] | |
| if trace: | |
| print( '\n\nlen(sent_input_ids', len(sent_input_ids)) | |
| print( 'len(preds_from_model)', len(preds_from_model) ) #with padding | |
| print( 'len(sent_gold_labels)', sent_gold_labels) | |
| # Ignore first and last token / labels | |
| for j, input_id in enumerate( sent_input_ids[1:-1] ): | |
| gold_label = sent_gold_labels[j+1] | |
| pred_label = preds_from_model[j+1] | |
| subtoken = tokenizer.decode( input_id ) | |
| if trace: | |
| print( subtoken, gold_label, pred_label ) | |
| # Deal with tokens split into subtokens, keep label of the first subtoken | |
| if subtoken.startswith( SUBTOKEN_START ) or gold_label == -100: | |
| if cur_token == None: | |
| print( "WARNING: first subtoken without a token, probably a contraction or MWE") | |
| cur_token="" | |
| cur_token += subtoken | |
| else: | |
| if cur_token != None: | |
| tokens.append( cur_token ) | |
| golds.append(cur_gold) | |
| preds.append(cur_pred) | |
| cur_token = subtoken | |
| cur_pred = pred_label | |
| cur_gold = gold_label | |
| # add last one | |
| tokens.append( cur_token ) | |
| golds.append(cur_gold) | |
| preds.append(cur_pred) | |
| if trace: | |
| print( "\ntokens:", len(tokens), tokens ) | |
| print( "golds", len(golds), golds ) | |
| print( "preds", len(preds), preds ) | |
| for i, tok in enumerate(tokens): | |
| print( tok, golds[i], preds[i]) | |
| return tokens, golds, preds | |
| def retrieve_predictions(model_checkpoint, dataset_eval, output_path, tokenizer, config): | |
| """ | |
| Load the trainer in eval mode and compute predictions | |
| on dataset_eval (peut être un dataset HuggingFace OU une liste de phrases) | |
| """ | |
| import os, transformers | |
| model_path = model_checkpoint | |
| if os.path.isfile(model_checkpoint): | |
| print(f"[INFO] Le chemin du modèle pointe vers un fichier, utilisation du dossier parent: {os.path.dirname(model_checkpoint)}") | |
| model_path = os.path.dirname(model_checkpoint) | |
| config_file = os.path.join(model_path, "config.json") | |
| # if not os.path.exists(config_file): | |
| # raise FileNotFoundError(f"Aucun fichier config.json trouvé dans {model_path}.") | |
| # Load model | |
| model = transformers.AutoModelForTokenClassification.from_pretrained(model_path) | |
| # Collator | |
| data_collator = transformers.DataCollatorForTokenClassification( | |
| tokenizer=tokenizer, | |
| padding=config["tok_config"]["padding"] | |
| ) | |
| compute_metrics = utils.prepare_compute_metrics( | |
| getattr(dataset_eval, "LABEL_NAMES_BIO", None) or [] | |
| ) | |
| # Mode eval | |
| model.eval() | |
| test_args = transformers.TrainingArguments( | |
| output_dir=output_path, | |
| do_train=False, | |
| do_predict=True, | |
| dataloader_drop_last=False, | |
| report_to=config.get("report_to", "none"), | |
| ) | |
| trainer = transformers.Trainer( | |
| model=model, | |
| args=test_args, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| # Si dataset_eval est juste une liste de phrases → on fabrique un Dataset | |
| from datasets import Dataset | |
| if isinstance(dataset_eval, list): | |
| dataset_eval = Dataset.from_dict({"text": dataset_eval}) | |
| def tokenize(batch): | |
| return tokenizer(batch["text"], truncation=True, padding=True) | |
| dataset_eval = dataset_eval.map(tokenize, batched=True) | |
| predictions, label_ids, metrics = trainer.predict(dataset_eval) | |
| else: | |
| # - Make predictions on eval dataset | |
| predictions, label_ids, metrics = trainer.predict(dataset_eval.tokenized_datasets) | |
| return predictions, label_ids, metrics | |
| # -------------------------------------------------------------------------- | |
| # -------------------------------------------------------------------------- | |
| if __name__=="__main__": | |
| import argparse, os | |
| import shutil | |
| path = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") | |
| if os.path.exists(path): | |
| shutil.rmtree(path) | |
| print(f"Le dossier '{path}' a été supprimé.") | |
| else: | |
| print(f"Le dossier '{path}' n'existe pas.") | |
| parser = argparse.ArgumentParser( | |
| description='DISCUT: Discourse segmentation and connective detection' | |
| ) | |
| # EVAL file | |
| parser.add_argument("-t", "--test", | |
| help="Eval file. Default: data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu", | |
| default="data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu") | |
| # PRE FINE-TUNED MODEL | |
| parser.add_argument("-m", "--model", | |
| help="path to the directory where is the Model file.", | |
| default=None) | |
| # OUTPUT DIRECTORY | |
| parser.add_argument("-o", "--output", | |
| help="Directory where models and pred will be saved. Default: /home/cbraud/experiments/expe_discut_2025/", | |
| default="./data/temp_expe/") | |
| # CONFIG FILE FROM THE FINE TUNED MODEL | |
| parser.add_argument("-c", "--config", | |
| help="Config file. Default: ./config_seg.json", | |
| default="./config_seg.json") | |
| # TRACE / VERBOSITY | |
| parser.add_argument( '-v', '--trace', | |
| action='store_true', | |
| default=False, | |
| help="Whether to print full messages. If used, it will override the value in config file.") | |
| # TODO Add an option for choosing the tool to split the sentences | |
| args = parser.parse_args() | |
| eval_path = args.test | |
| output_path = args.output | |
| if not os.path.isdir( output_path ): | |
| os.makedirs(output_path, exist_ok=True ) | |
| config_file = args.config | |
| model = args.model | |
| trace = args.trace | |
| print( '\n-[DISCUT]--PROGRAM (eval) ARGUMENTS') | |
| print( '| Mode', 'eval' ) | |
| if not model: | |
| sys.exit( "Please provide a path to a model for eval mode.") | |
| print( '| Test_path:', eval_path ) | |
| print( "| Output_path:", output_path ) | |
| if model: | |
| print( "| Model:", model ) | |
| print( '| Config:', config_file ) | |
| print( '\n-[DISCUT]--CONFIG INFO') | |
| config = utils.read_config( config_file ) | |
| utils.print_config( config ) | |
| print( "\n-[DISCUT]--READING DATASET") | |
| ### | |
| datasets = {} | |
| datasets['dev'], tokenizer = reading.read_dataset( eval_path, output_path, config ) | |
| # model also in config[best_model_path] | |
| metrics=simple_eval( datasets['dev'], model, tokenizer, output_path, config, trace=trace ) | |
| # # TODO clean, probably unused arguments here | |
| # def simple_eval_deprecated( dataset_eval, model_checkpoint, tokenizer, output_path, | |
| # config ): | |
| # ''' | |
| # Run the pre-trained model on the (dev) dataset to get predictions, | |
| # then write the predictions in an output file. | |
| # Parameters: | |
| # ----------- | |
| # datasets: dict of DatasetDisc | |
| # The datasets read | |
| # model_checkpoint: str | |
| # path to the saved model | |
| # tokenizer: Tokenizer | |
| # tokenizer of the saved model (TODO: retrieve from model? or should be removed?) | |
| # output_path: str | |
| # path to the output directory where prediction files will be written | |
| # data_collator: DataCollator | |
| # (TODO: retrieve from model?) | |
| # ''' | |
| # # tokenized_dataset = dataset_eval.tokenized_datasets | |
| # dev_dataset = dataset_eval.dataset | |
| # LABEL_NAMES = dataset_eval.LABEL_NAMES_BIO | |
| # # TODO check if needed | |
| # word_ids = dataset_eval.all_word_ids | |
| # model = transformers.AutoModelForTokenClassification.from_pretrained( | |
| # model_checkpoint | |
| # ) | |
| # data_collator = transformers.DataCollatorForTokenClassification( | |
| # tokenizer=tokenizer, | |
| # padding=config["tok_config"]["padding"] ) | |
| # compute_metrics = utils.prepare_compute_metrics(LABEL_NAMES) | |
| # # TODO is it useful to have both .eval() and test_args ? | |
| # model.eval() | |
| # test_args = transformers.TrainingArguments( | |
| # output_dir = output_path, | |
| # do_train = False, | |
| # do_predict = True, | |
| # #per_device_eval_batch_size = BATCH_SIZE, | |
| # dataloader_drop_last = False | |
| # ) | |
| # trainer = transformers.Trainer( | |
| # model=model, | |
| # args=test_args, | |
| # data_collator=data_collator, | |
| # compute_metrics=compute_metrics, | |
| # ) | |
| # predictions, label_ids, metrics = trainer.predict(dataset_eval.tokenized_datasets) | |
| # preds = np.argmax(predictions, axis=-1) | |
| # compute_metrics([predictions, label_ids]) | |
| # # Try to write predictions: will fail if cache not emptied | |
| # # because we need word_ids not saved in cache TODO check... | |
| # pred_file = os.path.join( output_path, dataset_eval.basename+'.preds' ) | |
| # try: | |
| # write_predictions_words( preds, dataset_eval.tokenized_datasets, | |
| # tokenizer, pred_file, dataset_eval.id2label, | |
| # word_ids, dev_dataset, dataset_eval ) | |
| # except IndexError: | |
| # # if error, we print the predictions with tokens as is | |
| # write_predictions_subtokens( preds, dataset_eval.tokenized_datasets, | |
| # tokenizer, pred_file, dataset_eval.id2label ) | |
| # # Test DISRPT eval script | |
| # print( "\nPerformance computed using disrpt eval script on", dataset_eval.annotations_file, | |
| # pred_file ) | |
| # if config['task'] == 'seg': | |
| # my_eval = disrpt_eval_2025.SegmentationEvaluation("temp_test_disrpt_eval_seg", | |
| # dataset_eval.annotations_file, | |
| # pred_file ) | |
| # elif config['task'] == 'conn': | |
| # my_eval = disrpt_eval_2025.ConnectivesEvaluation("temp_test_disrpt_eval_conn", | |
| # dataset_eval.annotations_file, | |
| # pred_file ) | |
| # else: | |
| # raise NotImplementedError | |
| # my_eval.compute_scores() | |
| # my_eval.print_results() | |
| # # TODO: dd???? | |
| # # TODO : only for SEG/CONN --> to rename (and make a generic function) | |
| # def write_predictions_words_deprecated( preds, dev, tokenizer, pred_file, id2label, word_ids, | |
| # dev_dataset, dd, trace=False ): | |
| # ''' | |
| # Write predictions for segmentation or connective tasks in an output files. | |
| # The output is the same as the input gold file, with an additional column | |
| # corresponding to the predicted label. | |
| # ?? We need the word_ids information to merge the words that been split et | |
| # retrieve the original tokens from the input .tok / .conllu files and run | |
| # evaluation. | |
| # Parameters: | |
| # ----------- | |
| # preds: list of int | |
| # The predicted labels (numeric ids) | |
| # dev: Dataset | |
| # tokenized_dev | |
| # pred_file: str | |
| # Path to the file where predictions will be written | |
| # id2label: dict | |
| # Convert from ids to labels | |
| # word_ids: list? | |
| # Word ids, None for task rel | |
| # dev_dataset : Dataset | |
| # Dataset for the dev set | |
| # dd : str? | |
| # dset | |
| # ''' | |
| # predictions = [] | |
| # for i in range( preds.shape[0] ): | |
| # sent_input_ids = dev['input_ids'][i] | |
| # tokens = dev_dataset['tokens'][i] | |
| # # sentence text | |
| # sent_tokens = tokenizer.decode(sent_input_ids[1:-1]) | |
| # # list of decoded subtokens | |
| # #sub_tokens = [tokenizer.decode(tok_id) for tok_id in sent_input_ids] | |
| # # Merge subtokens and retrieve corresp. pred labels | |
| # # i.e. we ignore: CLS, SEP, PAD and labels on ##subtoks | |
| # aligned_preds = merge_tokens_preds_sent( word_ids[i], preds[i], tokens ) | |
| # if trace: | |
| # print( '\n', i, sent_tokens ) | |
| # print( sent_input_ids ) | |
| # print( preds[i]) | |
| # print( ' '.join( tokens ) ) | |
| # print( "aligned_preds", aligned_preds ) | |
| # # sentence id, but TODO: retrieve doc ids | |
| # #f.write( "# sent_id = "+str(i)+"\n" ) | |
| # # Write the original sentence text | |
| # #f.write( "# text = "+sent_tokens+"\n" ) | |
| # # indices should start at 1 | |
| # for k, tok in enumerate( tokens ): | |
| # label = aligned_preds[k] | |
| # predictions.append( id2label[label] ) | |
| # #f.write( "\t".join( [str(k+1), tok, "_","_","_","_","_","_","_", id2label[label] ] )+"\n" ) | |
| # #f.write("\n") | |
| # print("PREDICTIONS", predictions) | |
| # count_pred_B, count_gold_B = 0, 0 | |
| # with open( dd.annotations_file, 'r' ) as fin: | |
| # with open( pred_file, 'w' ) as fout: | |
| # mylines = fin.readlines() | |
| # count = 0 | |
| # if trace: | |
| # print("len(predictions)", len(predictions)) | |
| # for l in mylines: | |
| # l = l.strip() | |
| # if l.startswith("#"): | |
| # fout.write( l+'\n') | |
| # elif l == '' or l == '\n': | |
| # fout.write('\n') | |
| # elif '-' in l.split('\t')[0]: | |
| # fout.write( l+'\t'+'_'+'\n') | |
| # else: | |
| # if 'B' in predictions[count]: | |
| # count_pred_B += 1 | |
| # if 'Seg=B-seg' in l or 'Conn=B-conn' in l: | |
| # count_gold_B += 1 | |
| # fout.write( l+'\t'+predictions[count]+'\n') | |
| # count += 1 | |
| # print("Count the number of predictions corresponding to a B", count_pred_B, "vs Gold B", count_gold_B) | |
| # # TODO: dd???? | |
| # # TODO : only for SEG/CONN --> to rename (and make a generic function) | |
| # def write_predictions_words( preds_from_model, dataset_eval, tokenizer, pred_file, trace=True ): | |
| # ''' | |
| # Write predictions for segmentation or connective tasks in an output files. | |
| # The output is the same as the input gold file, with an additional column | |
| # corresponding to the predicted label. | |
| # ?? We need the word_ids information to merge the words that been split et | |
| # retrieve the original tokens from the input .tok / .conllu files and run | |
| # evaluation. | |
| # Parameters: | |
| # ----------- | |
| # preds_from_model: list of int | |
| # The predicted labels (numeric ids) | |
| # dev: Dataset | |
| # tokenized_dev | |
| # pred_file: str | |
| # Path to the file where predictions will be written | |
| # id2label: dict | |
| # Convert from ids to labels | |
| # word_ids: list? | |
| # Word ids, None for task rel | |
| # dev_dataset : Dataset | |
| # Dataset for the dev set | |
| # dd : str? | |
| # dset | |
| # ''' | |
| # word_ids = dataset_eval.all_word_ids | |
| # id2label = dataset_eval.id2label | |
| # predictions = [] | |
| # for i in range( preds_from_model.shape[0] ): | |
| # sent_input_ids = dataset_eval.tokenized_datasets['input_ids'][i] | |
| # tokens = dataset_eval.dataset['tokens'][i] | |
| # # sentence text | |
| # sent_tokens = tokenizer.decode(sent_input_ids[1:-1]) | |
| # # list of decoded subtokens | |
| # #sub_tokens = [tokenizer.decode(tok_id) for tok_id in sent_input_ids] | |
| # # Merge subtokens and retrieve corresp. pred labels | |
| # # i.e. we ignore: CLS, SEP, PAD and labels on ##subtoks | |
| # aligned_preds = merge_tokens_preds_sent( word_ids[i], preds_from_model[i], tokens ) | |
| # if trace: | |
| # print( '\n', i, sent_tokens ) | |
| # print( sent_input_ids ) | |
| # print( preds_from_model[i]) | |
| # print( ' '.join( tokens ) ) | |
| # print( "aligned_preds", aligned_preds ) | |
| # # sentence id, but TODO: retrieve doc ids | |
| # #f.write( "# sent_id = "+str(i)+"\n" ) | |
| # # Write the original sentence text | |
| # #f.write( "# text = "+sent_tokens+"\n" ) | |
| # # indices should start at 1 | |
| # for k, tok in enumerate( tokens ): | |
| # label = aligned_preds[k] | |
| # predictions.append( id2label[label] ) | |
| # #f.write( "\t".join( [str(k+1), tok, "_","_","_","_","_","_","_", id2label[label] ] )+"\n" ) | |
| # #f.write("\n") | |
| # # print("PREDICTIONS", predictions) | |
| # write_pred_file( dataset_eval.annotations_file, pred_file, predictions ) |