senorix-AI / src /streamlit_app.py
Muyumba's picture
src/streamlit_app.py
ec96645 verified
import streamlit as st
from huggingface_hub import hf_hub_download, set_access_token
import torch
import os
import tempfile
import soundfile as sf
st.set_page_config(page_title="🎵 Générateur de Chansons Local", layout="centered")
st.title("🎵 Générateur de Chansons (Local CPU, Hugging Face)")
# -----------------------------
# Configuration Hugging Face Hub
# -----------------------------
# Token Hugging Face (optionnel si le repo est public)
HF_TOKEN = st.secrets.get("HF_TOKEN", None)
if HF_TOKEN:
set_access_token(HF_TOKEN)
# Forcer le cache local dans un dossier où on a les droits
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
# -----------------------------
# Télécharger le modèle SongGeneration
# -----------------------------
@st.cache_resource
def load_song_model():
model_file = hf_hub_download(
repo_id="tencent/SongGeneration",
filename="ckpt/songgeneration_base_zh/model.pt"
)
model = torch.load(model_file, map_location="cpu")
model.eval()
return model
song_model = load_song_model()
# -----------------------------
# Interface utilisateur
# -----------------------------
description = st.text_area(
"Décrivez l'ambiance ou le thème de la chanson",
value="Une chanson nostalgique sur l’amour perdu, style pop moderne."
)
if st.button("🎛️ Générer la chanson"):
if not description.strip():
st.warning("Veuillez fournir une description.")
else:
st.info("Génération en cours… (CPU, cela peut prendre du temps)")
try:
# -----------------------------
# Génération de la chanson (exemple simplifié)
# -----------------------------
with torch.no_grad():
# ⚠️ Adapter selon l'API exacte du modèle SongGeneration
# Ici on suppose qu'il y a une méthode .generate(text) qui renvoie un array audio
audio = song_model.generate(description) # numpy array ou torch tensor
# -----------------------------
# Sauvegarder l'audio temporaire pour Streamlit
# -----------------------------
tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
if isinstance(audio, torch.Tensor):
audio = audio.cpu().numpy()
sf.write(tmp_wav.name, audio, 44100)
tmp_wav.close()
# -----------------------------
# Affichage et téléchargement
# -----------------------------
st.audio(tmp_wav.name)
with open(tmp_wav.name, "rb") as f:
st.download_button("⬇️ Télécharger la chanson", f, "generated_song.wav")
st.success("✅ Chanson générée avec succès !")
except Exception as e:
st.error("❌ Erreur lors de la génération de la chanson :")
st.exception(e)