Spaces:
Runtime error
Runtime error
Updating app file
Browse files
app.py
CHANGED
|
@@ -61,13 +61,13 @@ def infer(model,data, notes):
|
|
| 61 |
data= torch.tensor(data)
|
| 62 |
if model == "CNN":
|
| 63 |
model = MMCNN_CAT()
|
| 64 |
-
checkpoint = torch.load(MMCNN_CAT_ckpt_path)
|
| 65 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 66 |
data = data.transpose(1,2).float()
|
| 67 |
|
| 68 |
elif model == "RNN":
|
| 69 |
model = MMRNN(device='cpu')
|
| 70 |
-
model.load_state_dict(torch.load(MMRNN_ckpt_path)['model_state_dict'])
|
| 71 |
data = data.float()
|
| 72 |
model.eval()
|
| 73 |
outputs, predicted = predict(model, data, embed_notes, device='cpu')
|
|
|
|
| 61 |
data= torch.tensor(data)
|
| 62 |
if model == "CNN":
|
| 63 |
model = MMCNN_CAT()
|
| 64 |
+
checkpoint = torch.load(MMCNN_CAT_ckpt_path, map_location="cpu")
|
| 65 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 66 |
data = data.transpose(1,2).float()
|
| 67 |
|
| 68 |
elif model == "RNN":
|
| 69 |
model = MMRNN(device='cpu')
|
| 70 |
+
model.load_state_dict(torch.load(MMRNN_ckpt_path, map_location="cpu")['model_state_dict'])
|
| 71 |
data = data.float()
|
| 72 |
model.eval()
|
| 73 |
outputs, predicted = predict(model, data, embed_notes, device='cpu')
|