#!/usr/bin/env python3 """ Test script for the new Facebook VITS & SpeechT5 TTS system """ import asyncio import logging import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) async def test_advanced_tts(): """Test the new advanced TTS system""" print("=" * 60) print("Testing Facebook VITS & SpeechT5 TTS System") print("=" * 60) try: from advanced_tts_client import AdvancedTTSClient client = AdvancedTTSClient() print(f"Device: {client.device}") print("Loading TTS models...") # Load models success = await client.load_models() if success: print("✅ Models loaded successfully!") # Get model info info = client.get_model_info() print(f"SpeechT5 available: {info['speecht5_available']}") print(f"VITS available: {info['vits_available']}") print(f"Primary method: {info['primary_method']}") # Test TTS generation test_text = "Hello! This is a test of the Facebook VITS and SpeechT5 text-to-speech system." voice_id = "21m00Tcm4TlvDq8ikWAM" print(f"\nTesting with text: {test_text}") print(f"Voice ID: {voice_id}") audio_path = await client.text_to_speech(test_text, voice_id) print(f"✅ TTS SUCCESS: Generated audio at {audio_path}") # Check file if os.path.exists(audio_path): size = os.path.getsize(audio_path) print(f"📁 Audio file size: {size} bytes") if size > 1000: print("✅ Audio file appears valid!") return True else: print("⚠️ Audio file seems too small") return False else: print("❌ Audio file not found") return False else: print("❌ Model loading failed") return False except Exception as e: print(f"❌ Test failed: {e}") import traceback traceback.print_exc() return False async def test_tts_manager(): """Test the TTS manager with fallback""" print("\n" + "=" * 60) print("Testing TTS Manager with Fallback System") print("=" * 60) try: # Import from the main app import sys sys.path.append('.') from app import TTSManager manager = TTSManager() # Load models print("Loading TTS manager...") success = await manager.load_models() if success: print("✅ TTS Manager loaded successfully!") # Get info info = manager.get_tts_info() print(f"Advanced TTS available: {info.get('advanced_tts_available', False)}") print(f"Primary method: {info.get('primary_method', 'Unknown')}") # Test generation test_text = "Testing the TTS manager with automatic fallback capabilities." voice_id = "pNInz6obpgDQGcFmaJgB" print(f"\nTesting with text: {test_text}") print(f"Voice ID: {voice_id}") audio_path, method = await manager.text_to_speech(test_text, voice_id) print(f"✅ TTS Manager SUCCESS: Generated audio at {audio_path}") print(f"🎙️ Method used: {method}") # Check file if os.path.exists(audio_path): size = os.path.getsize(audio_path) print(f"📁 Audio file size: {size} bytes") return True else: print("❌ Audio file not found") return False else: print("❌ TTS Manager loading failed") return False except Exception as e: print(f"❌ TTS Manager test failed: {e}") import traceback traceback.print_exc() return False async def main(): """Run all tests""" print("🧪 FACEBOOK VITS & SPEECHT5 TTS TEST SUITE") print("Testing the new open-source TTS system...") print() results = [] # Test 1: Advanced TTS direct results.append(await test_advanced_tts()) # Test 2: TTS Manager with fallback results.append(await test_tts_manager()) # Summary print("\n" + "=" * 60) print("TEST SUMMARY") print("=" * 60) test_names = ["Advanced TTS Direct", "TTS Manager with Fallback"] for i, (name, result) in enumerate(zip(test_names, results)): status = "✅ PASS" if result else "❌ FAIL" print(f"{i+1}. {name}: {status}") passed = sum(results) total = len(results) print(f"\nOverall: {passed}/{total} tests passed") if passed >= 1: print("🎉 New TTS system is functional!") if passed == total: print("🌟 All components working perfectly!") else: print("⚠️ Some components failed, but system should still work") else: print("💥 All tests failed - check dependencies and installation") print("\n📝 Next steps:") print("1. Install missing dependencies: pip install transformers datasets") print("2. Run the main app: python app.py") print("3. Test via /health endpoint") print("4. Test generation via /generate endpoint or Gradio interface") return passed >= 1 if __name__ == "__main__": success = asyncio.run(main()) exit(0 if success else 1)