Spaces:
Paused
Paused
| import os | |
| import argparse | |
| from typing import List | |
| import torch | |
| import numpy as np | |
| from .model import build_model | |
| from .dataset import NERDataset, get_collate_fn | |
| from huggingface_hub import hf_hub_download | |
| from .utils import get_class_to_index | |
| class ChemNER: | |
| def __init__(self, model_path, device = None, cache_dir = None): | |
| self.args = self._get_args(cache_dir) | |
| states = torch.load(model_path, map_location = torch.device('cpu')) | |
| if device is None: | |
| device = torch.device('cpu') | |
| self.device = device | |
| self.model = self.get_model(self.args, device, states['state_dict']) | |
| self.collate = get_collate_fn() | |
| self.dataset = NERDataset(self.args, data_file = None) | |
| self.class_to_index = get_class_to_index(self.args.corpus) | |
| self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index} | |
| def _get_args(self, cache_dir): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use') | |
| parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from") | |
| args = parser.parse_args([]) | |
| args.cache_dir = cache_dir | |
| return args | |
| def get_model(self, args, device, model_states): | |
| model = build_model(args) | |
| def remove_prefix(state_dict): | |
| return {k.replace('model.', ''): v for k, v in state_dict.items()} | |
| model.load_state_dict(remove_prefix(model_states), strict = False) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def predict_strings(self, strings: List, batch_size = 8): | |
| device = self.device | |
| predictions = [] | |
| def prepare_output(char_span, prediction): | |
| toreturn = [] | |
| i = 0 | |
| while i < len(char_span): | |
| if prediction[i][0] == 'B': | |
| toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end])) | |
| elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]: | |
| toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end]) | |
| i += 1 | |
| return toreturn | |
| output = [] | |
| for idx in range(0, len(strings), batch_size): | |
| batch_strings = strings[idx:idx+batch_size] | |
| batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512), torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings] | |
| sentences, masks, refs = self.collate(batch_strings_tokenized) | |
| predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu') | |
| sentences_list = list(sentences) | |
| predictions_list = list(predictions) | |
| char_spans = [] | |
| for j, sentence in enumerate(sentences_list): | |
| to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ] | |
| char_spans.append(to_add) | |
| class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)] | |
| output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)] | |
| return output | |