Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import gradio as gr | |
| import numpy as np | |
| import wfdb | |
| import torch | |
| from wfdb.plot.plot import plot_wfdb | |
| from wfdb.io.record import Record, rdrecord | |
| from models.CNN import CNN, MMCNN_CAT | |
| from models.RNN import MMRNN | |
| from utils.helper_functions import predict | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoTokenizer, AutoModel | |
| from langdetect import detect | |
| # edit this before Running | |
| CWD = os.getcwd() | |
| #CKPT paths | |
| MMCNN_CAT_ckpt_path = f"{CWD}/demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt" | |
| MMRNN_ckpt_path = f"{CWD}/demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt" | |
| # Define clinical models and tokenizers | |
| en_clin_bert = 'emilyalsentzer/Bio_ClinicalBERT' | |
| ger_clin_bert = 'smanjil/German-MedBERT' | |
| en_tokenizer = AutoTokenizer.from_pretrained(en_clin_bert) | |
| en_model = AutoModel.from_pretrained(en_clin_bert) | |
| g_tokenizer = AutoTokenizer.from_pretrained(ger_clin_bert) | |
| g_model = AutoModel.from_pretrained(ger_clin_bert) | |
| def preprocess(data_file_path): | |
| data = [wfdb.rdsamp(data_file_path)] | |
| data = np.array([signal for signal, meta in data]) | |
| return data | |
| def embed(notes): | |
| if detect(notes) == 'en': | |
| tokens = en_tokenizer(notes, return_tensors='pt') | |
| outputs = en_model(**tokens) | |
| else: | |
| tokens = g_tokenizer(notes, return_tensors='pt') | |
| outputs = g_model(**tokens) | |
| embeddings = outputs.last_hidden_state | |
| embedding = torch.mean(embeddings, dim=1).squeeze(0) | |
| return embedding | |
| # return torch.load(f'{"./data/embeddings/"}1.pt') | |
| def plot_ecg(path): | |
| record100 = rdrecord(path) | |
| return plot_wfdb(record=record100, title='ECG Signal Graph', figsize=(12,10), return_fig=True) | |
| def infer(model,data, notes): | |
| embed_notes = embed(notes).unsqueeze(0) | |
| data= torch.tensor(data) | |
| if model == "CNN": | |
| model = MMCNN_CAT() | |
| checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu") | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| data = data.transpose(1,2).float() | |
| elif model == "RNN": | |
| model = MMRNN(device='cpu') | |
| model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict']) | |
| data = data.float() | |
| model.eval() | |
| outputs, predicted = predict(model, data, embed_notes, device='cpu') | |
| outputs = torch.sigmoid(outputs)[0] | |
| return {'Conduction Disturbance':round(outputs[0].item(),2), 'Hypertrophy':round(outputs[1].item(),2), 'Myocardial Infarction':round(outputs[2].item(),2), 'Normal ECG':round(outputs[3].item(),2), 'ST/T Change':round(outputs[4].item(),2)} | |
| def run(model_name, header_file, data_file, notes): | |
| demo_dir = f"{CWD}/demo_data" | |
| hdr_dirname, hdr_basename = os.path.split(header_file.name) | |
| data_dirname, data_basename = os.path.split(data_file.name) | |
| shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}") | |
| shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}") | |
| data = preprocess(f"{demo_dir}/{hdr_basename.split('.')[0]}") | |
| ECG_graph = plot_ecg(f"{demo_dir}/{hdr_basename.split('.')[0]}") | |
| os.remove(f"{demo_dir}/{data_basename}") | |
| os.remove(f"{demo_dir}/{hdr_basename}") | |
| output = infer(model_name, data, notes) | |
| return output, ECG_graph | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| model = gr.Radio(['CNN', 'RNN'], label= "Select Model") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| header_file = gr.File(label = "header_file", file_types=[".hea"]) | |
| data_file = gr.File(label = "data_file", file_types=[".dat"]) | |
| notes = gr.Textbox(label = "Clinical Notes") | |
| with gr.Column(scale=1): | |
| output_prob = gr.Label({'Normal ECG':0, 'Myocardial Infarction':0, 'ST/T Change':0, 'Conduction Disturbance':0, 'Hypertrophy':0}, show_label=False) | |
| with gr.Row(): | |
| ecg_graph = gr.Plot(label = "ECG Signal Visualisation") | |
| with gr.Row(): | |
| predict_btn = gr.Button("Predict Class") | |
| predict_btn.click(fn= run, inputs = [model, header_file, data_file, notes], outputs=[output_prob, ecg_graph]) | |
| with gr.Row(): | |
| gr.Examples(examples=[[f"{CWD}/demo_data/test/00001_lr.hea", f"{CWD}/demo_data/test/00001_lr.dat", "sinusrhythmus periphere niederspannung"],\ | |
| [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \ | |
| [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\ | |
| [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\ | |
| ], | |
| inputs = [header_file, data_file, notes]) | |
| if __name__ == "__main__": | |
| demo.launch() |