Spaces:
Paused
Paused
| import os | |
| import cv2 | |
| import copy | |
| import random | |
| import json | |
| import contextlib | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence | |
| from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast | |
| from .utils import get_class_to_index | |
| class NERDataset(Dataset): | |
| def __init__(self, args, data_file, split='train'): | |
| super().__init__() | |
| self.args = args | |
| if data_file: | |
| data_path = os.path.join(args.data_path, data_file) | |
| with open(data_path) as f: | |
| self.data = json.load(f) | |
| self.name = os.path.basename(data_file).split('.')[0] | |
| self.split = split | |
| self.is_train = (split == 'train') | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased') | |
| self.class_to_index = get_class_to_index(self.args.corpus) | |
| self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index} | |
| #commment | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length) | |
| if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids'])) | |
| text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text']) | |
| return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])) | |
| def align_labels(self, text_tokenized, entities, length): | |
| char_to_class = {} | |
| for entity in entities: | |
| for span in entities[entity]["span"]: | |
| for i in range(span[0], span[1]): | |
| char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])] | |
| for i in range(length): | |
| if i not in char_to_class: | |
| char_to_class[i] = 0 | |
| classes = [] | |
| for i in range(len(text_tokenized[0])): | |
| span = text_tokenized.token_to_chars(i) | |
| if span is not None: | |
| classes.append(char_to_class[span.start]) | |
| else: | |
| classes.append(-100) | |
| return torch.LongTensor(classes) | |
| def make_html(word_tokens, predictions): | |
| toreturn = '''<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Named Entity Recognition Visualization</title> | |
| <style> | |
| .EXAMPLE_LABEL { | |
| color: red; | |
| text-decoration: underline red; | |
| } | |
| .REACTION_PRODUCT { | |
| color: orange; | |
| text-decoration: underline orange; | |
| } | |
| .STARTING_MATERIAL { | |
| color: gold; | |
| text-decoration: underline gold; | |
| } | |
| .REAGENT_CATALYST { | |
| color: green; | |
| text-decoration: underline green; | |
| } | |
| .SOLVENT { | |
| color: cyan; | |
| text-decoration: underline cyan; | |
| } | |
| .OTHER_COMPOUND { | |
| color: blue; | |
| text-decoration: underline blue; | |
| } | |
| .TIME { | |
| color: purple; | |
| text-decoration: underline purple; | |
| } | |
| .TEMPERATURE { | |
| color: magenta; | |
| text-decoration: underline magenta; | |
| } | |
| .YIELD_OTHER { | |
| color: palegreen; | |
| text-decoration: underline palegreen; | |
| } | |
| .YIELD_PERCENT { | |
| color: pink; | |
| text-decoration: underline pink; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <p>''' | |
| last_label = None | |
| for idx, item in enumerate(word_tokens): | |
| decoded = self.tokenizer.decode(item, skip_special_tokens = True) | |
| if len(decoded)>0: | |
| if idx!=0 and decoded[0]!='#': | |
| toreturn+=" " | |
| label = predictions[idx] | |
| if label == last_label: | |
| toreturn+=decoded if decoded[0]!="#" else decoded[2:] | |
| else: | |
| if last_label is not None and last_label>0: | |
| toreturn+="</u>" | |
| if label >0: | |
| toreturn+="<u class=\"" | |
| toreturn+=self.index_to_class[label] | |
| toreturn+="\">" | |
| toreturn+=decoded if decoded[0]!="#" else decoded[2:] | |
| if label == 0: | |
| toreturn+=decoded if decoded[0]!="#" else decoded[2:] | |
| if idx==len(word_tokens) and label>0: | |
| toreturn+="</u>" | |
| last_label = label | |
| toreturn += ''' </p> | |
| </body> | |
| </html>''' | |
| return toreturn | |
| def get_collate_fn(): | |
| def collate(batch): | |
| sentences = [] | |
| masks = [] | |
| refs = [] | |
| for ex in batch: | |
| sentences.append(torch.LongTensor(ex[0]['input_ids'])) | |
| masks.append(torch.Tensor(ex[0]['attention_mask'])) | |
| refs.append(ex[1]) | |
| sentences = pad_sequence(sentences, batch_first = True, padding_value = 0) | |
| masks = pad_sequence(masks, batch_first = True, padding_value = 0) | |
| refs = pad_sequence(refs, batch_first = True, padding_value = -100) | |
| return sentences, masks, refs | |
| return collate | |