Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update src/streamlit_app.py
Browse files- src/streamlit_app.py +148 -38
    	
        src/streamlit_app.py
    CHANGED
    
    | @@ -1,40 +1,150 @@ | |
| 1 | 
            -
            import altair as alt
         | 
| 2 | 
            -
            import numpy as np
         | 
| 3 | 
            -
            import pandas as pd
         | 
| 4 | 
             
            import streamlit as st
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            """
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
         | 
| 17 | 
            -
            num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            indices = np.linspace(0, 1, num_points)
         | 
| 20 | 
            -
            theta = 2 * np.pi * num_turns * indices
         | 
| 21 | 
            -
            radius = indices
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            x = radius * np.cos(theta)
         | 
| 24 | 
            -
            y = radius * np.sin(theta)
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            df = pd.DataFrame({
         | 
| 27 | 
            -
                "x": x,
         | 
| 28 | 
            -
                "y": y,
         | 
| 29 | 
            -
                "idx": indices,
         | 
| 30 | 
            -
                "rand": np.random.randn(num_points),
         | 
| 31 | 
            -
            })
         | 
| 32 | 
            -
             | 
| 33 | 
            -
            st.altair_chart(alt.Chart(df, height=700, width=700)
         | 
| 34 | 
            -
                .mark_point(filled=True)
         | 
| 35 | 
            -
                .encode(
         | 
| 36 | 
            -
                    x=alt.X("x", axis=None),
         | 
| 37 | 
            -
                    y=alt.Y("y", axis=None),
         | 
| 38 | 
            -
                    color=alt.Color("idx", legend=None, scale=alt.Scale()),
         | 
| 39 | 
            -
                    size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
         | 
| 40 | 
            -
                ))
         | 
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
            +
            import io
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            import soundfile as sf
         | 
| 5 | 
            +
            import librosa
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch # Importa torch
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            sys.setrecursionlimit(2000) # Aumentiamo il limite di ricorsione
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # --- Configurazione del Dispositivo ---
         | 
| 12 | 
            +
            # Questo rileva automaticamente se MPS (GPU Apple Silicon) è disponibile
         | 
| 13 | 
            +
            # Per ora, useremo la CPU come fallback se MPS è problematico per Stable Audio
         | 
| 14 | 
            +
            device = "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 15 | 
            +
            # ******************** MODIFICA QUI: Forza device = "cpu" ********************
         | 
| 16 | 
            +
            # Per superare i problemi di Stable Audio su MPS con float16/float32
         | 
| 17 | 
            +
            # FORZA LA CPU PER TUTTI I MODELLI, per semplicità.
         | 
| 18 | 
            +
            # Se la caption genera velocemente, potremmo tornare indietro e mettere il modello vit_gpt2 su MPS
         | 
| 19 | 
            +
            device = "cpu"
         | 
| 20 | 
            +
            # **************************************************************************
         | 
| 21 | 
            +
            st.write(f"Utilizzo del dispositivo: {device}")
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            # --- 1. Caricamento dei Modelli AI (spostati qui, fuori dalle funzioni Streamlit) ---
         | 
| 25 | 
            +
            @st.cache_resource
         | 
| 26 | 
            +
            def load_models():
         | 
| 27 | 
            +
                # Caricamento del modello per la captioning (ViT-GPT2)
         | 
| 28 | 
            +
                from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModelForVision2Seq
         | 
| 29 | 
            +
                st.write("Caricamento del modello ViT-GPT2 per la captioning dell'immagine...")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                vit_gpt2_feature_extractor = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
         | 
| 32 | 
            +
                vit_gpt2_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
         | 
| 33 | 
            +
                
         | 
| 34 | 
            +
                # Questo modello andrà sulla CPU
         | 
| 35 | 
            +
                vit_gpt2_model = AutoModelForVision2Seq.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                st.write("Modello ViT-GPT2 caricato.")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # Caricamento del modello Text-to-Audio (Stable Audio Open - 1.0)
         | 
| 40 | 
            +
                from diffusers import DiffusionPipeline
         | 
| 41 | 
            +
                st.write("Caricamento del modello Stable Audio Open - 1.0 per la generazione del soundscape...")
         | 
| 42 | 
            +
                # ******************** MODIFICA QUI ********************
         | 
| 43 | 
            +
                # Assicurati che non ci sia torch_dtype=torch.float16 e che vada sulla CPU
         | 
| 44 | 
            +
                stable_audio_pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", force_download=True).to(device) 
         | 
| 45 | 
            +
                # ******************************************************
         | 
| 46 | 
            +
                st.write("Modello Stable Audio Open 1.0 caricato.")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                return vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_pipeline
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            # Carica i modelli all'avvio dell'app
         | 
| 51 | 
            +
            vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_pipeline = load_models()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            # --- 2. Funzioni della Pipeline ---
         | 
| 55 | 
            +
            def generate_image_caption(image_pil):
         | 
| 56 | 
            +
                pixel_values = vit_gpt2_feature_extractor(images=image_pil.convert("RGB"), return_tensors="pt").pixel_values
         | 
| 57 | 
            +
                pixel_values = pixel_values.to(device) # Sposta input su CPU
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                # Token di inizio per GPT-2, assicurandosi che sia su CPU
         | 
| 60 | 
            +
                # Ottieni il decoder_start_token_id dal modello o dal tokenizer
         | 
| 61 | 
            +
                if hasattr(vit_gpt2_model.config, "decoder_start_token_id"):
         | 
| 62 | 
            +
                    decoder_start_token_id = vit_gpt2_model.config.decoder_start_token_id
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    if vit_gpt2_tokenizer.pad_token_id is not None:
         | 
| 65 | 
            +
                        decoder_start_token_id = vit_gpt2_tokenizer.pad_token_id
         | 
| 66 | 
            +
                    else:
         | 
| 67 | 
            +
                        decoder_start_token_id = 50256 # Default GPT-2 EOS token
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # Crea un input_ids iniziale con il decoder_start_token_id e spostalo su CPU
         | 
| 70 | 
            +
                input_ids = torch.ones((1, 1), device=device, dtype=torch.long) * decoder_start_token_id
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
                output_ids = vit_gpt2_model.generate(
         | 
| 74 | 
            +
                    pixel_values=pixel_values,
         | 
| 75 | 
            +
                    input_ids=input_ids,
         | 
| 76 | 
            +
                    max_length=50,
         | 
| 77 | 
            +
                    do_sample=True,
         | 
| 78 | 
            +
                    top_k=50,
         | 
| 79 | 
            +
                    temperature=0.7,
         | 
| 80 | 
            +
                    no_repeat_ngram_size=2,
         | 
| 81 | 
            +
                    early_stopping=True
         | 
| 82 | 
            +
                )
         | 
| 83 | 
            +
                caption = vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
         | 
| 84 | 
            +
                return caption
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def generate_soundscape_from_caption(caption: str, duration_seconds: int = 10):
         | 
| 88 | 
            +
                st.write(f"Generazione soundscape per: '{caption}' (durata: {duration_seconds}s)")
         | 
| 89 | 
            +
                with st.spinner("Generazione audio in corso..."):
         | 
| 90 | 
            +
                    try:
         | 
| 91 | 
            +
                        # Assicurati che il modello sia già su CPU dal caricamento
         | 
| 92 | 
            +
                        audio_output = stable_audio_pipeline(
         | 
| 93 | 
            +
                            prompt=caption,
         | 
| 94 | 
            +
                            audio_end_in_s=duration_seconds 
         | 
| 95 | 
            +
                        ).audios
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        audio_data = audio_output[0].cpu().numpy() 
         | 
| 98 | 
            +
                        sample_rate = stable_audio_pipeline.sample_rate
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                        audio_data = audio_data.astype(np.float32)
         | 
| 101 | 
            +
                        audio_data = librosa.util.normalize(audio_data)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                        buffer = io.BytesIO()
         | 
| 104 | 
            +
                        sf.write(buffer, audio_data, sample_rate, format='WAV')
         | 
| 105 | 
            +
                        buffer.seek(0)
         | 
| 106 | 
            +
                        return buffer.getvalue(), sample_rate
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    except Exception as e:
         | 
| 109 | 
            +
                        st.error(f"Errore durante la generazione dell'audio: {e}")
         | 
| 110 | 
            +
                        return None, None
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            # --- 3. Interfaccia Streamlit ---
         | 
| 114 | 
            +
            st.title("Generatore di Paesaggi Sonori da Immagini")
         | 
| 115 | 
            +
            st.write("Carica un'immagine e otterrai una descrizione testuale e un paesaggio sonoro generato!")
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            uploaded_file = st.file_uploader("Scegli un'immagine...", type=["jpg", "jpeg", "png"])
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            if uploaded_file is not None:
         | 
| 120 | 
            +
                input_image = Image.open(uploaded_file)
         | 
| 121 | 
            +
                st.image(input_image, caption='Immagine Caricata.', use_column_width=True)
         | 
| 122 | 
            +
                st.write("")
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                audio_duration = st.slider("Durata audio (secondi):", 5, 30, 10, key="audio_duration_slider")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
                if st.button("Genera Paesaggio Sonoro"):
         | 
| 128 | 
            +
                    st.subheader("Processo in Corso...")
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # PASSO 1: Genera la caption
         | 
| 131 | 
            +
                    st.write("Generazione della caption...")
         | 
| 132 | 
            +
                    caption = generate_image_caption(input_image)
         | 
| 133 | 
            +
                    st.write(f"Caption generata: **{caption}**")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # PASSO 2: Genera il soundscape
         | 
| 136 | 
            +
                    st.write("Generazione del paesaggio sonoro...")
         | 
| 137 | 
            +
                    audio_data_bytes, sample_rate = generate_soundscape_from_caption(caption, duration_seconds=audio_duration)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    if audio_data_bytes is not None:
         | 
| 140 | 
            +
                        st.subheader("Paesaggio Sonoro Generato")
         | 
| 141 | 
            +
                        st.audio(audio_data_bytes, format='audio/wav', sample_rate=sample_rate)
         | 
| 142 |  | 
| 143 | 
            +
                        st.download_button(
         | 
| 144 | 
            +
                            label="Scarica Audio WAV",
         | 
| 145 | 
            +
                            data=audio_data_bytes,
         | 
| 146 | 
            +
                            file_name="paesaggio_sonoro_generato.wav",
         | 
| 147 | 
            +
                            mime="audio/wav"
         | 
| 148 | 
            +
                        )
         | 
| 149 | 
            +
                    else:
         | 
| 150 | 
            +
                        st.error("La generazione del paesaggio sonoro è fallita.")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
