Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import cv2 | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from collections import defaultdict | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") | |
| model = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") | |
| def mlm(image, text): | |
| questions_dict = { | |
| #'location': f'[CLS] Only [MASK] cells have a {cls_name}. [SEP]', #num of mask? | |
| # 'location': f'[CLS] The {cls_name} normally appears at or near the [MASK] of a cell. [SEP]', | |
| # 'color': f'[CLS] When a cell is histologically stained, the {cls_name} are in [MASK] color. [SEP]', | |
| # 'shape': f'[CLS] Mostly the shape of {cls_name} is [MASK]. [SEP]', | |
| 'location': f'[CLS] The location of {text} is at [MASK]. [SEP]', | |
| 'color': f'[CLS] The typical color of {text} is [MASK]. [SEP]', | |
| 'shape': f'[CLS] The typical shape of {text} is [MASK]. [SEP]', | |
| #'def': f'{cls_name} is a . [SEP]', | |
| } | |
| ans = list() | |
| res = defaultdict(list) | |
| for k, v in questions_dict.items(): | |
| predicted_tokens = [] | |
| tokenized_text = tokenizer.tokenize(v) | |
| indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) | |
| # Create the segments tensors. | |
| segments_ids = [0] * len(tokenized_text) | |
| # Convert inputs to PyTorch tensors | |
| tokens_tensor = torch.tensor([indexed_tokens]).to('cuda') | |
| segments_tensors = torch.tensor([segments_ids]).to('cuda') | |
| masked_index = tokenized_text.index('[MASK]') | |
| with torch.no_grad(): | |
| predictions = model(tokens_tensor, segments_tensors) | |
| _, predicted_index = torch.topk(predictions[0][0][masked_index], topk)#.item() | |
| predicted_index = predicted_index.detach().cpu().numpy() | |
| #print(predicted_index) | |
| for idx in predicted_index: | |
| predicted_tokens.append(tokenizer.convert_ids_to_tokens([idx])[0]) | |
| return image, res | |
| def to_black(image, text): | |
| output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| outputs = [output, text] | |
| return outputs | |
| interface = gr.Interface(fn=mlm, inputs=["image", "text"], outputs=["image", "text"]) | |
| interface.launch() |