update
Browse files- bigvgan.py +34 -20
bigvgan.py
CHANGED
|
@@ -257,14 +257,18 @@ class BigVGAN(
|
|
| 257 |
return x
|
| 258 |
|
| 259 |
def remove_weight_norm(self):
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
for
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
l.
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
##################################################################
|
| 270 |
# additional methods for huggingface_hub support
|
|
@@ -304,17 +308,21 @@ class BigVGAN(
|
|
| 304 |
##################################################################
|
| 305 |
# download and load hyperparameters (h) used by BigVGAN
|
| 306 |
##################################################################
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
h = load_hparams_from_json(config_file)
|
| 319 |
|
| 320 |
##################################################################
|
|
@@ -347,6 +355,12 @@ class BigVGAN(
|
|
| 347 |
)
|
| 348 |
|
| 349 |
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
return model
|
|
|
|
| 257 |
return x
|
| 258 |
|
| 259 |
def remove_weight_norm(self):
|
| 260 |
+
try:
|
| 261 |
+
print('Removing weight norm...')
|
| 262 |
+
for l in self.ups:
|
| 263 |
+
for l_i in l:
|
| 264 |
+
remove_weight_norm(l_i)
|
| 265 |
+
for l in self.resblocks:
|
| 266 |
+
l.remove_weight_norm()
|
| 267 |
+
remove_weight_norm(self.conv_pre)
|
| 268 |
+
remove_weight_norm(self.conv_post)
|
| 269 |
+
except ValueError:
|
| 270 |
+
print('[INFO] Model already removed weight norm. Skipping!')
|
| 271 |
+
pass
|
| 272 |
|
| 273 |
##################################################################
|
| 274 |
# additional methods for huggingface_hub support
|
|
|
|
| 308 |
##################################################################
|
| 309 |
# download and load hyperparameters (h) used by BigVGAN
|
| 310 |
##################################################################
|
| 311 |
+
if os.path.isdir(model_id):
|
| 312 |
+
print("Loading config.json from local directory")
|
| 313 |
+
config_file = os.path.join(model_id, 'config.json')
|
| 314 |
+
else:
|
| 315 |
+
config_file = hf_hub_download(
|
| 316 |
+
repo_id=model_id,
|
| 317 |
+
filename='config.json',
|
| 318 |
+
revision=revision,
|
| 319 |
+
cache_dir=cache_dir,
|
| 320 |
+
force_download=force_download,
|
| 321 |
+
proxies=proxies,
|
| 322 |
+
resume_download=resume_download,
|
| 323 |
+
token=token,
|
| 324 |
+
local_files_only=local_files_only,
|
| 325 |
+
)
|
| 326 |
h = load_hparams_from_json(config_file)
|
| 327 |
|
| 328 |
##################################################################
|
|
|
|
| 355 |
)
|
| 356 |
|
| 357 |
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
| 358 |
+
|
| 359 |
+
try:
|
| 360 |
+
model.load_state_dict(checkpoint_dict['generator'])
|
| 361 |
+
except RuntimeError:
|
| 362 |
+
print(f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!")
|
| 363 |
+
model.remove_weight_norm()
|
| 364 |
+
model.load_state_dict(checkpoint_dict['generator'])
|
| 365 |
|
| 366 |
return model
|