Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -275,6 +275,18 @@ def safe_remove_dir(directory):
|
|
| 275 |
print(f"Unexpected error while deleting {directory}: {e}")
|
| 276 |
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
class Music2emo:
|
| 279 |
def __init__(
|
| 280 |
self,
|
|
@@ -377,12 +389,13 @@ class Music2emo:
|
|
| 377 |
mert_dir.mkdir(parents=True, exist_ok=True)
|
| 378 |
|
| 379 |
waveform, sample_rate = torchaudio.load(audio)
|
|
|
|
|
|
|
| 380 |
if waveform.shape[0] > 1:
|
| 381 |
waveform = waveform.mean(dim=0).unsqueeze(0)
|
| 382 |
waveform = waveform.squeeze()
|
| 383 |
waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
|
| 384 |
|
| 385 |
-
|
| 386 |
# π Check duration
|
| 387 |
duration_sec = waveform.shape[-1] / sample_rate
|
| 388 |
is_split = duration_sec > 30.0
|
|
|
|
| 275 |
print(f"Unexpected error while deleting {directory}: {e}")
|
| 276 |
|
| 277 |
|
| 278 |
+
def pad_to_30_seconds(waveform, sample_rate):
|
| 279 |
+
target_len = 30 * sample_rate
|
| 280 |
+
current_len = waveform.shape[-1]
|
| 281 |
+
|
| 282 |
+
if current_len >= target_len:
|
| 283 |
+
return waveform[:, :target_len] # Truncate if longer
|
| 284 |
+
else:
|
| 285 |
+
pad_len = target_len - current_len
|
| 286 |
+
padding = torch.zeros((waveform.shape[0], pad_len))
|
| 287 |
+
padded_waveform = torch.cat((waveform, padding), dim=1)
|
| 288 |
+
return padded_waveform
|
| 289 |
+
|
| 290 |
class Music2emo:
|
| 291 |
def __init__(
|
| 292 |
self,
|
|
|
|
| 389 |
mert_dir.mkdir(parents=True, exist_ok=True)
|
| 390 |
|
| 391 |
waveform, sample_rate = torchaudio.load(audio)
|
| 392 |
+
waveform = pad_to_30_seconds(waveform, sample_rate)
|
| 393 |
+
|
| 394 |
if waveform.shape[0] > 1:
|
| 395 |
waveform = waveform.mean(dim=0).unsqueeze(0)
|
| 396 |
waveform = waveform.squeeze()
|
| 397 |
waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
|
| 398 |
|
|
|
|
| 399 |
# π Check duration
|
| 400 |
duration_sec = waveform.shape[-1] / sample_rate
|
| 401 |
is_split = duration_sec > 30.0
|