Spaces:
Running
Running
Developer
Amp
commited on
Commit
·
02d5651
1
Parent(s):
86b05fe
Fix syntax error in avatar_app.py
Browse filesCo-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-a617baec-31b4-4481-b9ff-5899d702f52f
- avatar_app.py +23 -86
avatar_app.py
CHANGED
|
@@ -318,6 +318,29 @@ class OmniAvatarAPI:
|
|
| 318 |
logger.info(f"? SUCCESS: Downloaded models loaded - Video: {video_files}, Audio: {audio_files}")
|
| 319 |
return True
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
async def download_file(self, url: str, suffix: str = '') -> str:
|
| 322 |
"""Download file from URL to temporary location"""
|
| 323 |
import aiohttp
|
|
@@ -343,28 +366,6 @@ class OmniAvatarAPI:
|
|
| 343 |
logger.error(f"Failed to download file from {url}: {e}")
|
| 344 |
raise HTTPException(status_code=400, detail=f"Failed to download file: {e}")
|
| 345 |
|
| 346 |
-
|
| 347 |
-
# Fallback: Check traditional OmniAvatar paths
|
| 348 |
-
traditional_paths = [
|
| 349 |
-
"./pretrained_models/Wan2.1-T2V-14B",
|
| 350 |
-
"./pretrained_models/OmniAvatar-14B",
|
| 351 |
-
"./pretrained_models/wav2vec2-base-960h"
|
| 352 |
-
]
|
| 353 |
-
|
| 354 |
-
if all(os.path.exists(path) for path in traditional_paths):
|
| 355 |
-
self.model_loaded = True
|
| 356 |
-
logger.info("? SUCCESS: Traditional OmniAvatar models found")
|
| 357 |
-
return True
|
| 358 |
-
|
| 359 |
-
# No models found
|
| 360 |
-
logger.warning("?? WARNING: No models found (neither downloaded nor traditional)")
|
| 361 |
-
self.model_loaded = False
|
| 362 |
-
return True # App can still run in TTS-only mode
|
| 363 |
-
|
| 364 |
-
except Exception as e:
|
| 365 |
-
logger.error(f"? Error in model detection: {e}")
|
| 366 |
-
self.model_loaded = False
|
| 367 |
-
return True
|
| 368 |
def validate_audio_url(self, url: str) -> bool:
|
| 369 |
"""Validate if URL is likely an audio file"""
|
| 370 |
try:
|
|
@@ -766,8 +767,6 @@ async def download_video_models():
|
|
| 766 |
"audio_model_path": audio_model_path,
|
| 767 |
"status": "READY FOR VIDEO GENERATION"
|
| 768 |
}
|
| 769 |
-
|
| 770 |
-
|
| 771 |
|
| 772 |
except Exception as e:
|
| 773 |
logger.error(f"? Model download failed: {e}")
|
|
@@ -776,21 +775,6 @@ async def download_video_models():
|
|
| 776 |
"message": f"Model download failed: {str(e)}",
|
| 777 |
"error": str(e)
|
| 778 |
}
|
| 779 |
-
|
| 780 |
-
except Exception as e:
|
| 781 |
-
logger.error(f"? Model download failed: {e}")
|
| 782 |
-
return {
|
| 783 |
-
"success": False,
|
| 784 |
-
"message": f"Model download failed: {str(e)}",
|
| 785 |
-
"error": str(e)
|
| 786 |
-
}
|
| 787 |
-
except Exception as e:
|
| 788 |
-
logger.error(f"? Model download failed: {e}")
|
| 789 |
-
return {
|
| 790 |
-
"success": False,
|
| 791 |
-
"message": f"Model download failed: {str(e)}",
|
| 792 |
-
"error": str(e)
|
| 793 |
-
}
|
| 794 |
|
| 795 |
@app.post("/reload-models")
|
| 796 |
async def reload_models():
|
|
@@ -819,53 +803,6 @@ async def reload_models():
|
|
| 819 |
"message": f"Model reload failed: {str(e)}",
|
| 820 |
"error": str(e)
|
| 821 |
}
|
| 822 |
-
|
| 823 |
-
# Download small video generation model
|
| 824 |
-
logger.info("?? Downloading text-to-video model...")
|
| 825 |
-
|
| 826 |
-
model_path = snapshot_download(
|
| 827 |
-
repo_id="ali-vilab/text-to-video-ms-1.7b",
|
| 828 |
-
cache_dir="./downloaded_models/video",
|
| 829 |
-
local_files_only=False
|
| 830 |
-
)
|
| 831 |
-
|
| 832 |
-
logger.info(f"? Video model downloaded: {model_path}")
|
| 833 |
-
|
| 834 |
-
# Download audio model
|
| 835 |
-
audio_model_path = snapshot_download(
|
| 836 |
-
repo_id="facebook/wav2vec2-base-960h",
|
| 837 |
-
cache_dir="./downloaded_models/audio",
|
| 838 |
-
local_files_only=False
|
| 839 |
-
)
|
| 840 |
-
|
| 841 |
-
logger.info(f"? Audio model downloaded: {audio_model_path}")
|
| 842 |
-
|
| 843 |
-
# Check final storage usage
|
| 844 |
-
_, _, free_bytes_after = shutil.disk_usage(".")
|
| 845 |
-
free_gb_after = free_bytes_after / (1024**3)
|
| 846 |
-
used_gb = free_gb - free_gb_after
|
| 847 |
-
|
| 848 |
-
return {
|
| 849 |
-
"success": True,
|
| 850 |
-
"message": "? Video generation models downloaded successfully!",
|
| 851 |
-
"models_downloaded": [
|
| 852 |
-
"ali-vilab/text-to-video-ms-1.7b",
|
| 853 |
-
"facebook/wav2vec2-base-960h"
|
| 854 |
-
],
|
| 855 |
-
"storage_used_gb": round(used_gb, 2),
|
| 856 |
-
"storage_remaining_gb": round(free_gb_after, 2),
|
| 857 |
-
"video_model_path": model_path,
|
| 858 |
-
"audio_model_path": audio_model_path,
|
| 859 |
-
"status": "READY FOR VIDEO GENERATION"
|
| 860 |
-
}
|
| 861 |
-
|
| 862 |
-
except Exception as e:
|
| 863 |
-
logger.error(f"? Model download failed: {e}")
|
| 864 |
-
return {
|
| 865 |
-
"success": False,
|
| 866 |
-
"message": f"Model download failed: {str(e)}",
|
| 867 |
-
"error": str(e)
|
| 868 |
-
}
|
| 869 |
|
| 870 |
@app.get("/model-status")
|
| 871 |
async def get_model_status():
|
|
|
|
| 318 |
logger.info(f"? SUCCESS: Downloaded models loaded - Video: {video_files}, Audio: {audio_files}")
|
| 319 |
return True
|
| 320 |
|
| 321 |
+
# Fallback: Check traditional OmniAvatar paths
|
| 322 |
+
traditional_paths = [
|
| 323 |
+
"./pretrained_models/Wan2.1-T2V-14B",
|
| 324 |
+
"./pretrained_models/OmniAvatar-14B",
|
| 325 |
+
"./pretrained_models/wav2vec2-base-960h"
|
| 326 |
+
]
|
| 327 |
+
|
| 328 |
+
if all(os.path.exists(path) for path in traditional_paths):
|
| 329 |
+
self.model_loaded = True
|
| 330 |
+
logger.info("? SUCCESS: Traditional OmniAvatar models found")
|
| 331 |
+
return True
|
| 332 |
+
|
| 333 |
+
# No models found
|
| 334 |
+
logger.warning("?? WARNING: No models found (neither downloaded nor traditional)")
|
| 335 |
+
self.model_loaded = False
|
| 336 |
+
return True # App can still run in TTS-only mode
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"? Error in model detection: {e}")
|
| 340 |
+
self.model_loaded = False
|
| 341 |
+
return True
|
| 342 |
+
|
| 343 |
+
|
| 344 |
async def download_file(self, url: str, suffix: str = '') -> str:
|
| 345 |
"""Download file from URL to temporary location"""
|
| 346 |
import aiohttp
|
|
|
|
| 366 |
logger.error(f"Failed to download file from {url}: {e}")
|
| 367 |
raise HTTPException(status_code=400, detail=f"Failed to download file: {e}")
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
def validate_audio_url(self, url: str) -> bool:
|
| 370 |
"""Validate if URL is likely an audio file"""
|
| 371 |
try:
|
|
|
|
| 767 |
"audio_model_path": audio_model_path,
|
| 768 |
"status": "READY FOR VIDEO GENERATION"
|
| 769 |
}
|
|
|
|
|
|
|
| 770 |
|
| 771 |
except Exception as e:
|
| 772 |
logger.error(f"? Model download failed: {e}")
|
|
|
|
| 775 |
"message": f"Model download failed: {str(e)}",
|
| 776 |
"error": str(e)
|
| 777 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
|
| 779 |
@app.post("/reload-models")
|
| 780 |
async def reload_models():
|
|
|
|
| 803 |
"message": f"Model reload failed: {str(e)}",
|
| 804 |
"error": str(e)
|
| 805 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
|
| 807 |
@app.get("/model-status")
|
| 808 |
async def get_model_status():
|