Spaces:
Runtime error
Runtime error
Commit
·
246dd98
1
Parent(s):
60407c7
added model loading function
Browse files- enhance_me/mirnet/mirnet.py +11 -1
enhance_me/mirnet/mirnet.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import List
|
|
| 5 |
from datetime import datetime
|
| 6 |
|
| 7 |
from tensorflow import keras
|
| 8 |
-
from tensorflow.keras import optimizers
|
| 9 |
|
| 10 |
from wandb.keras import WandbCallback
|
| 11 |
|
|
@@ -76,6 +76,16 @@ class MIRNet:
|
|
| 76 |
metrics=[peak_signal_noise_ratio],
|
| 77 |
)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
| 80 |
self.model.save_weights(
|
| 81 |
filepath, overwrite=overwrite, save_format=save_format, options=options
|
|
|
|
| 5 |
from datetime import datetime
|
| 6 |
|
| 7 |
from tensorflow import keras
|
| 8 |
+
from tensorflow.keras import optimizers, models
|
| 9 |
|
| 10 |
from wandb.keras import WandbCallback
|
| 11 |
|
|
|
|
| 76 |
metrics=[peak_signal_noise_ratio],
|
| 77 |
)
|
| 78 |
|
| 79 |
+
def load_model(
|
| 80 |
+
self, filepath, custom_objects=None, compile=True, options=None
|
| 81 |
+
) -> None:
|
| 82 |
+
self.model = models.load_model(
|
| 83 |
+
filepath=filepath,
|
| 84 |
+
custom_objects=custom_objects,
|
| 85 |
+
compile=compile,
|
| 86 |
+
options=options,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
| 90 |
self.model.save_weights(
|
| 91 |
filepath, overwrite=overwrite, save_format=save_format, options=options
|