Spaces:
Running
Running
Update pages/13_FFNN.py
Browse files- pages/13_FFNN.py +20 -5
pages/13_FFNN.py
CHANGED
|
@@ -59,6 +59,8 @@ def train_network(net, trainloader, criterion, optimizer, epochs):
|
|
| 59 |
def test_network(net, testloader):
|
| 60 |
correct = 0
|
| 61 |
total = 0
|
|
|
|
|
|
|
| 62 |
with torch.no_grad():
|
| 63 |
for data in testloader:
|
| 64 |
images, labels = data
|
|
@@ -66,9 +68,11 @@ def test_network(net, testloader):
|
|
| 66 |
_, predicted = torch.max(outputs.data, 1)
|
| 67 |
total += labels.size(0)
|
| 68 |
correct += (predicted == labels).sum().item()
|
|
|
|
|
|
|
| 69 |
accuracy = 100 * correct / total
|
| 70 |
st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')
|
| 71 |
-
return accuracy
|
| 72 |
|
| 73 |
# Load the data
|
| 74 |
trainloader, testloader = load_data()
|
|
@@ -102,11 +106,22 @@ if st.sidebar.button('Train Network'):
|
|
| 102 |
plt.ylabel('Loss')
|
| 103 |
plt.grid(True)
|
| 104 |
st.pyplot(plt)
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# Test the network
|
| 107 |
-
if st.sidebar.button('Test Network'):
|
| 108 |
-
accuracy = test_network(
|
| 109 |
st.write(f'Test Accuracy: {accuracy:.2f}%')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# Visualize some test results
|
| 112 |
def imshow(img):
|
|
@@ -115,12 +130,12 @@ def imshow(img):
|
|
| 115 |
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
| 116 |
plt.show()
|
| 117 |
|
| 118 |
-
if st.sidebar.button('Show Test Results'):
|
| 119 |
dataiter = iter(testloader)
|
| 120 |
images, labels = next(dataiter) # Use next function
|
| 121 |
imshow(torchvision.utils.make_grid(images))
|
| 122 |
|
| 123 |
-
outputs =
|
| 124 |
_, predicted = torch.max(outputs, 1)
|
| 125 |
|
| 126 |
st.write('GroundTruth vs Predicted')
|
|
|
|
| 59 |
def test_network(net, testloader):
|
| 60 |
correct = 0
|
| 61 |
total = 0
|
| 62 |
+
all_labels = []
|
| 63 |
+
all_predicted = []
|
| 64 |
with torch.no_grad():
|
| 65 |
for data in testloader:
|
| 66 |
images, labels = data
|
|
|
|
| 68 |
_, predicted = torch.max(outputs.data, 1)
|
| 69 |
total += labels.size(0)
|
| 70 |
correct += (predicted == labels).sum().item()
|
| 71 |
+
all_labels.extend(labels.numpy())
|
| 72 |
+
all_predicted.extend(predicted.numpy())
|
| 73 |
accuracy = 100 * correct / total
|
| 74 |
st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')
|
| 75 |
+
return accuracy, all_labels, all_predicted
|
| 76 |
|
| 77 |
# Load the data
|
| 78 |
trainloader, testloader = load_data()
|
|
|
|
| 106 |
plt.ylabel('Loss')
|
| 107 |
plt.grid(True)
|
| 108 |
st.pyplot(plt)
|
| 109 |
+
|
| 110 |
+
# Store the trained model in the session state
|
| 111 |
+
st.session_state['trained_model'] = net
|
| 112 |
|
| 113 |
# Test the network
|
| 114 |
+
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
|
| 115 |
+
accuracy, all_labels, all_predicted = test_network(st.session_state['trained_model'], testloader)
|
| 116 |
st.write(f'Test Accuracy: {accuracy:.2f}%')
|
| 117 |
+
|
| 118 |
+
# Display results in a table
|
| 119 |
+
st.write('GroundTruth vs Predicted')
|
| 120 |
+
results = pd.DataFrame({
|
| 121 |
+
'Ground Truth': all_labels,
|
| 122 |
+
'Predicted': all_predicted
|
| 123 |
+
})
|
| 124 |
+
st.table(results.head(50)) # Display first 50 results for brevity
|
| 125 |
|
| 126 |
# Visualize some test results
|
| 127 |
def imshow(img):
|
|
|
|
| 130 |
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
| 131 |
plt.show()
|
| 132 |
|
| 133 |
+
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
|
| 134 |
dataiter = iter(testloader)
|
| 135 |
images, labels = next(dataiter) # Use next function
|
| 136 |
imshow(torchvision.utils.make_grid(images))
|
| 137 |
|
| 138 |
+
outputs = st.session_state['trained_model'](images)
|
| 139 |
_, predicted = torch.max(outputs, 1)
|
| 140 |
|
| 141 |
st.write('GroundTruth vs Predicted')
|