Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| from seqeval.metrics.v1 import _prf_divide | |
| def extract_tp_actual_correct(y_true, y_pred): | |
| entities_true = defaultdict(set) | |
| entities_pred = defaultdict(set) | |
| for type_name, (start, end), idx in y_true: | |
| entities_true[type_name].add((start, end, idx)) | |
| for type_name, (start, end), idx in y_pred: | |
| entities_pred[type_name].add((start, end, idx)) | |
| target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys())) | |
| tp_sum = np.array([], dtype=np.int32) | |
| pred_sum = np.array([], dtype=np.int32) | |
| true_sum = np.array([], dtype=np.int32) | |
| for type_name in target_names: | |
| entities_true_type = entities_true.get(type_name, set()) | |
| entities_pred_type = entities_pred.get(type_name, set()) | |
| tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type)) | |
| pred_sum = np.append(pred_sum, len(entities_pred_type)) | |
| true_sum = np.append(true_sum, len(entities_true_type)) | |
| return pred_sum, tp_sum, true_sum, target_names | |
| def flatten_for_eval(y_true, y_pred): | |
| all_true = [] | |
| all_pred = [] | |
| for i, (true, pred) in enumerate(zip(y_true, y_pred)): | |
| all_true.extend([t + [i] for t in true]) | |
| all_pred.extend([p + [i] for p in pred]) | |
| return all_true, all_pred | |
| def compute_prf(y_true, y_pred, average='micro'): | |
| y_true, y_pred = flatten_for_eval(y_true, y_pred) | |
| pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred) | |
| if average == 'micro': | |
| tp_sum = np.array([tp_sum.sum()]) | |
| pred_sum = np.array([pred_sum.sum()]) | |
| true_sum = np.array([true_sum.sum()]) | |
| precision = _prf_divide( | |
| numerator=tp_sum, | |
| denominator=pred_sum, | |
| metric='precision', | |
| modifier='predicted', | |
| average=average, | |
| warn_for=('precision', 'recall', 'f-score'), | |
| zero_division='warn' | |
| ) | |
| recall = _prf_divide( | |
| numerator=tp_sum, | |
| denominator=true_sum, | |
| metric='recall', | |
| modifier='true', | |
| average=average, | |
| warn_for=('precision', 'recall', 'f-score'), | |
| zero_division='warn' | |
| ) | |
| denominator = precision + recall | |
| denominator[denominator == 0.] = 1 | |
| f_score = 2 * (precision * recall) / denominator | |
| return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]} | |
| class Evaluator: | |
| def __init__(self, all_true, all_outs): | |
| self.all_true = all_true | |
| self.all_outs = all_outs | |
| def get_entities_fr(self, ents): | |
| all_ents = [] | |
| for s, e, lab in ents: | |
| all_ents.append([lab, (s, e)]) | |
| return all_ents | |
| def transform_data(self): | |
| all_true_ent = [] | |
| all_outs_ent = [] | |
| for i, j in zip(self.all_true, self.all_outs): | |
| e = self.get_entities_fr(i) | |
| all_true_ent.append(e) | |
| e = self.get_entities_fr(j) | |
| all_outs_ent.append(e) | |
| return all_true_ent, all_outs_ent | |
| def evaluate(self): | |
| all_true_typed, all_outs_typed = self.transform_data() | |
| precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values() | |
| output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n" | |
| return output_str, f1 | |
| def is_nested(idx1, idx2): | |
| # Return True if idx2 is nested inside idx1 or vice versa | |
| return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1]) | |
| def has_overlapping(idx1, idx2): | |
| overlapping = True | |
| if idx1[:2] == idx2[:2]: | |
| return overlapping | |
| if (idx1[0] > idx2[1] or idx2[0] > idx1[1]): | |
| overlapping = False | |
| return overlapping | |
| def has_overlapping_nested(idx1, idx2): | |
| # Return True if idx1 and idx2 overlap, but neither is nested inside the other | |
| if idx1[:2] == idx2[:2]: | |
| return True | |
| if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2: | |
| return False | |
| else: | |
| return True | |
| def greedy_search(spans, flat_ner=True): # start, end, class, score | |
| if flat_ner: | |
| has_ov = has_overlapping | |
| else: | |
| has_ov = has_overlapping_nested | |
| new_list = [] | |
| span_prob = sorted(spans, key=lambda x: -x[-1]) | |
| for i in range(len(spans)): | |
| b = span_prob[i] | |
| flag = False | |
| for new in new_list: | |
| if has_ov(b[:-1], new): | |
| flag = True | |
| break | |
| if not flag: | |
| new_list.append(b[:-1]) | |
| new_list = sorted(new_list, key=lambda x: x[0]) | |
| return new_list | |