Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| import multiprocessing, prettytable | |
| import torchvision.transforms as transforms | |
| from neural_network import MNISTNetwork | |
| # hyperparameters | |
| BATCH_SIZE = 64 | |
| NUM_WORKERS = 2 | |
| EPOCH = 15 | |
| LEARNING_RATE = 0.01 | |
| MOMENTUM = 0.5 | |
| LOSS = torch.nn.CrossEntropyLoss() | |
| ## Step 1: define our transforms | |
| transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5), (0.5)) | |
| ] | |
| ) | |
| ## Step 2: get our datasets | |
| full_ds = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform) | |
| train_size = int(0.8 * len(full_ds)) # Use 80% of the data for training | |
| val_size = len(full_ds) - train_size # Use the remaining 20% for validation | |
| train_ds, valid_ds = torch.utils.data.random_split(full_ds, [train_size, val_size]) | |
| test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform) | |
| ## Step 3: create our dataloaders | |
| train_dl = torch.utils.data.DataLoader(train_ds, num_workers=NUM_WORKERS, shuffle=True, batch_size=BATCH_SIZE) | |
| valid_dl = torch.utils.data.DataLoader(valid_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE) | |
| test_dl = torch.utils.data.DataLoader(test_ds, num_workers=NUM_WORKERS, shuffle=False, batch_size=BATCH_SIZE) | |
| ## Step 4: define our model and optimizer | |
| model = MNISTNetwork() | |
| criteron = LOSS # define our loss function | |
| optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM) | |
| ## define our table | |
| table = prettytable.PrettyTable() | |
| table.field_names = ['Epoch', 'Training Loss', 'Validation Accuracy'] | |
| if __name__ == "__main__": | |
| multiprocessing.freeze_support() | |
| # begin training process | |
| for e in range(EPOCH): | |
| model.train() | |
| running_loss = 0.0 | |
| for inputs, labels in train_dl: | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criteron(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| train_loss = round(running_loss/len(train_dl), 4) | |
| # evaluate on the test set | |
| model.eval() | |
| with torch.no_grad(): | |
| total, correct = 0, 0 | |
| for inputs, labels in valid_dl: | |
| outputs = model(inputs) | |
| _, predicted = torch.max(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| val_acc = round((correct/total)*100, 3) | |
| table.add_row([e, train_loss, val_acc]) | |
| print(f'Training Loss: {train_loss}, Validation Accuracy: {val_acc}') | |
| print(table) | |
| # evaluate on test set | |
| model.eval() | |
| with torch.no_grad(): | |
| total, correct = 0, 0 | |
| for inputs, labels in test_dl: | |
| outputs = model(inputs) | |
| _, predicted = torch.max(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| test_acc = round((correct/total)*100, 3) | |
| print(f'Test Accuracy: {test_acc}') | |
| torch.save(model.state_dict(), 'MNISTModel.pth') | |