Spaces:
Running
Running
Update pages/19_ResNet.py
Browse files- pages/19_ResNet.py +17 -18
pages/19_ResNet.py
CHANGED
|
@@ -66,14 +66,15 @@ def imshow(inp, title=None):
|
|
| 66 |
std = np.array([0.229, 0.224, 0.225])
|
| 67 |
inp = std * inp + mean
|
| 68 |
inp = np.clip(inp, 0, 1)
|
| 69 |
-
plt.
|
|
|
|
| 70 |
if title is not None:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
|
| 74 |
inputs, classes = next(iter(dataloaders['train']))
|
| 75 |
out = torchvision.utils.make_grid(inputs)
|
| 76 |
-
|
| 77 |
|
| 78 |
# Model Preparation Section
|
| 79 |
st.markdown("""
|
|
@@ -169,20 +170,18 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
|
| 169 |
|
| 170 |
# Plot training history
|
| 171 |
epochs_range = range(num_epochs)
|
| 172 |
-
plt.
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
plt.show()
|
| 185 |
-
st.pyplot(plt)
|
| 186 |
|
| 187 |
return model
|
| 188 |
|
|
|
|
| 66 |
std = np.array([0.229, 0.224, 0.225])
|
| 67 |
inp = std * inp + mean
|
| 68 |
inp = np.clip(inp, 0, 1)
|
| 69 |
+
fig, ax = plt.subplots()
|
| 70 |
+
ax.imshow(inp)
|
| 71 |
if title is not None:
|
| 72 |
+
ax.set_title(title)
|
| 73 |
+
st.pyplot(fig)
|
| 74 |
|
| 75 |
inputs, classes = next(iter(dataloaders['train']))
|
| 76 |
out = torchvision.utils.make_grid(inputs)
|
| 77 |
+
imshow(out, title=[class_names[x] for x in classes])
|
| 78 |
|
| 79 |
# Model Preparation Section
|
| 80 |
st.markdown("""
|
|
|
|
| 170 |
|
| 171 |
# Plot training history
|
| 172 |
epochs_range = range(num_epochs)
|
| 173 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| 174 |
+
ax1.plot(epochs_range, train_loss_history, label='Training Loss')
|
| 175 |
+
ax1.plot(epochs_range, val_loss_history, label='Validation Loss')
|
| 176 |
+
ax1.legend(loc='upper right')
|
| 177 |
+
ax1.set_title('Training and Validation Loss')
|
| 178 |
+
|
| 179 |
+
ax2.plot(epochs_range, train_acc_history, label='Training Accuracy')
|
| 180 |
+
ax2.plot(epochs_range, val_acc_history, label='Validation Accuracy')
|
| 181 |
+
ax2.legend(loc='lower right')
|
| 182 |
+
ax2.set_title('Training and Validation Accuracy')
|
| 183 |
+
|
| 184 |
+
st.pyplot(fig)
|
|
|
|
|
|
|
| 185 |
|
| 186 |
return model
|
| 187 |
|