| import gradio as gr | |
| import laion_clap | |
| from qdrant_client import QdrantClient | |
| import os | |
| # Utilisez les variables d'environnement pour la configuration | |
| QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost') | |
| QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333)) | |
| # Connexion à Qdrant | |
| client = QdrantClient(QDRANT_HOST, port=QDRANT_PORT) | |
| print("[INFO] Client created...") | |
| # Charger le modèle | |
| print("[INFO] Loading the model...") | |
| model_name = "laion/larger_clap_music" | |
| model = laion_clap.CLAP_Module(enable_fusion=False) | |
| model.load_ckpt() # télécharger le checkpoint préentraîné par défaut | |
| # Interface Gradio | |
| max_results = 10 | |
| def sound_search(query): | |
| text_embed = model.get_text_embedding([query, ''])[0] # trick because can't accept singleton | |
| hits = client.search( | |
| collection_name="demo_db7", | |
| query_vector=text_embed, | |
| limit=max_results, | |
| ) | |
| return [ | |
| gr.Audio( | |
| hit.payload['audio_path'], | |
| label=f"style: {hit.payload['style']} -- score: {hit.score}") | |
| for hit in hits | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """# Sound search database """ | |
| ) | |
| inp = gr.Textbox(placeholder="What sound are you looking for ?") | |
| out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Nécessaire pour avoir différents objets | |
| inp.change(sound_search, inp, out) | |
| demo.launch() | |