Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| import torchvision.transforms as transforms | |
| from neural_network import MNISTNetwork | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), # Convert image to tensor | |
| transforms.Normalize((0.1307,), (0.3081,)) # Normalize the image | |
| ]) | |
| # Load the trained model | |
| net = MNISTNetwork() | |
| net.load_state_dict(torch.load('mnist_net.pth')) | |
| LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] | |
| def predict(drawing): | |
| if drawing is None: | |
| return "Draw something hoe" | |
| input_tensor = transform(drawing) | |
| x = input_tensor.view(input_tensor.shape[0], -1) | |
| with torch.no_grad(): | |
| output = net(x) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| values, indices = torch.topk(probabilities, 10) | |
| results = {LABELS[i]: v.item() for i, v in zip(indices, values)} | |
| return results | |
| sketchpad_input = gr.Sketchpad(shape=(28, 28)) | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=sketchpad_input, | |
| outputs="label", | |
| live=True | |
| ) | |
| interface.launch() | |