Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| sys.path.insert(0, os.path.abspath('./')) | |
| import torch | |
| from tqdm.auto import tqdm | |
| from torch.utils.data import DataLoader, random_split | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from event_detection_dataset import * | |
| from event_detection_model import * | |
| import gradio as gr | |
| #print(f"Gradio version: {gr.__version__}") | |
| def predict(data): | |
| data=[data] | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| #print(f"Device {device}") | |
| """Load Tokenizer""" | |
| tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased', use_fast=True) | |
| """Tokenized Inputs""" | |
| tokenized_inputs = tokenizer( | |
| data, | |
| add_special_tokens=True, | |
| max_length=512, | |
| padding='max_length', | |
| return_token_type_ids=True, | |
| truncation=True, | |
| is_split_into_words=True | |
| ) | |
| """Load Model""" | |
| model_path = "./" | |
| #print("model_path:", model_path) | |
| #print("================ load model ===========================") | |
| model = DistillBERTClass('distilbert-base-cased') | |
| #print("================ model init ===========================") | |
| pretrained_model=torch.load(model_path + "event_domain_final.pt",map_location=torch.device('cpu')) | |
| model.load_state_dict(pretrained_model['model_state_dict']) | |
| model.to(device) | |
| """Make Prediction""" | |
| model.eval() | |
| ids = torch.tensor(tokenized_inputs['input_ids']).to(device) | |
| mask = torch.tensor(tokenized_inputs['attention_mask']).to(device) | |
| with torch.no_grad(): | |
| outputs = model(ids, mask) | |
| max_val, max_idx = torch.max(outputs.data, dim=1) | |
| #print("=============== inference result =================") | |
| #print(f"predicted class {max_idx}") | |
| id2tags={0: "Acquisition",1: "I-Positive Clinical Trial & FDA Approval",2: "Dividend Cut",3: "Dividend Increase",4: "Guidance Increase",5: "New Contract",6: "Dividend",7: "Reverse Stock Split",8: "Special Dividend ",9: "Stock Repurchase",10: "Stock Split",11: "Others"} | |
| return id2tags[max_idx.item()] | |
| title="Financial Event Detection" | |
| description="Predict Finacial Events." | |
| article="modified the model in the following paper: Zhou, Z., Ma, L., & Liu, H. (2021)." | |
| example_list=[["Investors who receive dividends can choose to take them as cash or as additional shares."]] | |
| # Create the Gradio demo | |
| demo = gr.Interface(fn=predict, # mapping function from input to output | |
| inputs="text", # what are the inputs? | |
| outputs="text", # our fn has two outputs, therefore we have two outputs | |
| examples=example_list, | |
| title=title, | |
| description=description, | |
| article=article) | |
| # Launch the demo! | |
| demo.launch(debug=False, share=True) | |