Spaces:
Runtime error
Runtime error
| if __name__ == '__main__': | |
| inputs = ['gbjjhbdjhbdgjhdbfjhsdkjrkjf', 'fdjhbjhsbd'] | |
| from transformers import AutoTokenizer | |
| from model import CustomModel | |
| import torch | |
| from configuration import CFG | |
| from dataset import SingleInputDataset | |
| from torch.utils.data import DataLoader | |
| from utils import inference_fn, get_char_probs, get_results, get_text | |
| import numpy as np | |
| import gradio as gr | |
| import os | |
| device = torch.device('cpu') | |
| config_path = os.path.join('models_file', 'config.pth') | |
| model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth') | |
| tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer') | |
| model = CustomModel(CFG, config_path=config_path, pretrained=False) | |
| state = torch.load(model_path, | |
| map_location=torch.device('cpu')) | |
| model.load_state_dict(state['model']) | |
| def get_answer(context, feature): | |
| ## Input to the model using patient-history and feature-text | |
| inputs_single = tokenizer(context, feature, | |
| add_special_tokens=True, | |
| max_length=CFG.max_len, | |
| padding="max_length", | |
| return_offsets_mapping=False) | |
| for k, v in inputs_single.items(): | |
| inputs_single[k] = torch.tensor(v, dtype=torch.long) | |
| # Create a new dataset containing only the input sample | |
| single_input_dataset = SingleInputDataset(inputs_single) | |
| # Create a DataLoader for the new dataset | |
| single_input_loader = DataLoader(single_input_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=2) | |
| # Perform inference on the single input | |
| output = inference_fn(single_input_loader, model, device) | |
| prediction = output.reshape((1, CFG.max_len)) | |
| char_probs = get_char_probs([context], prediction, tokenizer) | |
| predictions = np.mean([char_probs], axis=0) | |
| results = get_results(predictions, th=0.5) | |
| print(results) | |
| return get_text(context, results[0]) | |
| inputs = [gr.inputs.Textbox(label="Context Para", lines=10), gr.inputs.Textbox(label="Question", lines=1)] | |
| output = gr.outputs.Textbox(label="Answer") | |
| app = gr.Interface(fn=get_answer, inputs=inputs, outputs=output, allow_flagging='never') | |
| app.launch() | |
| print(get_answer(inputs[0], inputs[1])) | |