Spaces:
Runtime error
Runtime error
Commit
·
2dd1081
1
Parent(s):
4b640eb
updated mirnet
Browse files- enhance_me/mirnet/mirnet.py +19 -28
enhance_me/mirnet/mirnet.py
CHANGED
|
@@ -24,48 +24,39 @@ class MIRNet:
|
|
| 24 |
def __init__(
|
| 25 |
self,
|
| 26 |
experiment_name: str,
|
| 27 |
-
image_size: int = 256,
|
| 28 |
-
dataset_label: str = "lol",
|
| 29 |
-
build_datasets: bool = True,
|
| 30 |
-
val_split: float = 0.2,
|
| 31 |
-
batch_size: int = 16,
|
| 32 |
-
apply_random_horizontal_flip: bool = True,
|
| 33 |
-
apply_random_vertical_flip: bool = True,
|
| 34 |
-
apply_random_rotation: bool = True,
|
| 35 |
wandb_api_key=None,
|
| 36 |
) -> None:
|
| 37 |
self.experiment_name = experiment_name
|
| 38 |
-
if dataset_label == "lol":
|
| 39 |
-
(low_images, enhanced_images), (
|
| 40 |
-
self.test_low_images,
|
| 41 |
-
self.test_enhanced_images,
|
| 42 |
-
) = download_lol_dataset()
|
| 43 |
-
if build_datasets:
|
| 44 |
-
self.data_loader = LowLightDataset(
|
| 45 |
-
image_size=image_size,
|
| 46 |
-
apply_random_horizontal_flip=apply_random_horizontal_flip,
|
| 47 |
-
apply_random_vertical_flip=apply_random_vertical_flip,
|
| 48 |
-
apply_random_rotation=apply_random_rotation,
|
| 49 |
-
)
|
| 50 |
-
self._build_datasets(
|
| 51 |
-
low_images, enhanced_images, val_split=val_split, batch_size=batch_size
|
| 52 |
-
)
|
| 53 |
if wandb_api_key is not None:
|
| 54 |
init_wandb("mirnet", experiment_name, wandb_api_key)
|
| 55 |
self.using_wandb = True
|
| 56 |
else:
|
| 57 |
self.using_wandb = False
|
| 58 |
|
| 59 |
-
def
|
| 60 |
self,
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
| 63 |
val_split: float = 0.2,
|
| 64 |
batch_size: int = 16,
|
| 65 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
|
| 67 |
-
low_light_images=
|
| 68 |
-
enhanced_images=enhanced_images,
|
| 69 |
val_split=val_split,
|
| 70 |
batch_size=batch_size,
|
| 71 |
)
|
|
|
|
| 24 |
def __init__(
|
| 25 |
self,
|
| 26 |
experiment_name: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
wandb_api_key=None,
|
| 28 |
) -> None:
|
| 29 |
self.experiment_name = experiment_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if wandb_api_key is not None:
|
| 31 |
init_wandb("mirnet", experiment_name, wandb_api_key)
|
| 32 |
self.using_wandb = True
|
| 33 |
else:
|
| 34 |
self.using_wandb = False
|
| 35 |
|
| 36 |
+
def build_datasets(
|
| 37 |
self,
|
| 38 |
+
image_size: int = 256,
|
| 39 |
+
dataset_label: str = "lol",
|
| 40 |
+
apply_random_horizontal_flip: bool = True,
|
| 41 |
+
apply_random_vertical_flip: bool = True,
|
| 42 |
+
apply_random_rotation: bool = True,
|
| 43 |
val_split: float = 0.2,
|
| 44 |
batch_size: int = 16,
|
| 45 |
):
|
| 46 |
+
if dataset_label == "lol":
|
| 47 |
+
(self.low_images, self.enhanced_images), (
|
| 48 |
+
self.test_low_images,
|
| 49 |
+
self.test_enhanced_images,
|
| 50 |
+
) = download_lol_dataset()
|
| 51 |
+
self.data_loader = LowLightDataset(
|
| 52 |
+
image_size=image_size,
|
| 53 |
+
apply_random_horizontal_flip=apply_random_horizontal_flip,
|
| 54 |
+
apply_random_vertical_flip=apply_random_vertical_flip,
|
| 55 |
+
apply_random_rotation=apply_random_rotation,
|
| 56 |
+
)
|
| 57 |
(self.train_dataset, self.val_dataset) = self.data_loader.get_datasets(
|
| 58 |
+
low_light_images=self.low_images,
|
| 59 |
+
enhanced_images=self.enhanced_images,
|
| 60 |
val_split=val_split,
|
| 61 |
batch_size=batch_size,
|
| 62 |
)
|