Spaces:
Runtime error
Runtime error
Commit
·
0668e89
1
Parent(s):
295bcab
updated zero-dce model
Browse files
enhance_me/zero_dce/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .zero_dce import ZeroDCE
|
enhance_me/zero_dce/zero_dce.py
CHANGED
|
@@ -24,7 +24,7 @@ class ZeroDCE(Model):
|
|
| 24 |
super(ZeroDCE, self).__init__(**kwargs)
|
| 25 |
self.experiment_name = experiment_name
|
| 26 |
if wandb_api_key is not None:
|
| 27 |
-
init_wandb("
|
| 28 |
self.using_wandb = True
|
| 29 |
else:
|
| 30 |
self.using_wandb = False
|
|
|
|
| 24 |
super(ZeroDCE, self).__init__(**kwargs)
|
| 25 |
self.experiment_name = experiment_name
|
| 26 |
if wandb_api_key is not None:
|
| 27 |
+
init_wandb("zero-dce", experiment_name, wandb_api_key)
|
| 28 |
self.using_wandb = True
|
| 29 |
else:
|
| 30 |
self.using_wandb = False
|
notebooks/enhance_me_train.ipynb
CHANGED
|
@@ -41,7 +41,8 @@
|
|
| 41 |
"\n",
|
| 42 |
"from PIL import Image\n",
|
| 43 |
"from enhance_me import commons\n",
|
| 44 |
-
"from enhance_me.mirnet import MIRNet"
|
|
|
|
| 45 |
]
|
| 46 |
},
|
| 47 |
{
|
|
@@ -183,7 +184,62 @@
|
|
| 183 |
"id": "dO-IbNQHkB3R"
|
| 184 |
},
|
| 185 |
"outputs": [],
|
| 186 |
-
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
}
|
| 188 |
],
|
| 189 |
"metadata": {
|
|
|
|
| 41 |
"\n",
|
| 42 |
"from PIL import Image\n",
|
| 43 |
"from enhance_me import commons\n",
|
| 44 |
+
"from enhance_me.mirnet import MIRNet\n",
|
| 45 |
+
"from enhance_me.zero_dce import ZeroDCE"
|
| 46 |
]
|
| 47 |
},
|
| 48 |
{
|
|
|
|
| 184 |
"id": "dO-IbNQHkB3R"
|
| 185 |
},
|
| 186 |
"outputs": [],
|
| 187 |
+
"source": [
|
| 188 |
+
"# @title Zero-DCE Train Configs\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"experiment_name = \"lol_dataset_128\" # @param {type:\"string\"}\n",
|
| 191 |
+
"image_size = 128 # @param {type:\"integer\"}\n",
|
| 192 |
+
"dataset_label = \"lol\" # @param [\"lol\"]\n",
|
| 193 |
+
"apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
|
| 194 |
+
"apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
|
| 195 |
+
"apply_random_rotation = True # @param {type:\"boolean\"}\n",
|
| 196 |
+
"use_mixed_precision = False # @param {type:\"boolean\"}\n",
|
| 197 |
+
"wandb_api_key = \"\" # @param {type:\"string\"}\n",
|
| 198 |
+
"val_split = 0.1 # @param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n",
|
| 199 |
+
"batch_size = 32 # @param {type:\"integer\"}\n",
|
| 200 |
+
"learning_rate = 1e-4 # @param {type:\"number\"}\n",
|
| 201 |
+
"epsilon = 1e-3 # @param {type:\"number\"}\n",
|
| 202 |
+
"epochs = 100 # @param {type:\"slider\", min:10, max:100, step:5}"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "code",
|
| 207 |
+
"execution_count": null,
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"outputs": [],
|
| 210 |
+
"source": [
|
| 211 |
+
"zero_dce = ZeroDCE(\n",
|
| 212 |
+
" experiment_name=experiment_name,\n",
|
| 213 |
+
" wandb_api_key=None if wandb_api_key == \"\" else wandb_api_key\n",
|
| 214 |
+
")"
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"cell_type": "code",
|
| 219 |
+
"execution_count": null,
|
| 220 |
+
"metadata": {},
|
| 221 |
+
"outputs": [],
|
| 222 |
+
"source": [
|
| 223 |
+
"zero_dce.build_datasets(\n",
|
| 224 |
+
" image_size=image_size,\n",
|
| 225 |
+
" dataset_label=dataset_label,\n",
|
| 226 |
+
" apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
|
| 227 |
+
" apply_random_vertical_flip=apply_random_vertical_flip,\n",
|
| 228 |
+
" apply_random_rotation=apply_random_rotation,\n",
|
| 229 |
+
" val_split=val_split,\n",
|
| 230 |
+
" batch_size=batch_size\n",
|
| 231 |
+
")"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"cell_type": "code",
|
| 236 |
+
"execution_count": null,
|
| 237 |
+
"metadata": {},
|
| 238 |
+
"outputs": [],
|
| 239 |
+
"source": [
|
| 240 |
+
"zero_dce.compile(learning_rate=learning_rate)\n",
|
| 241 |
+
"zero_dce.train(epochs=epochs)"
|
| 242 |
+
]
|
| 243 |
}
|
| 244 |
],
|
| 245 |
"metadata": {
|