Spaces:
Running
on
Zero
Running
on
Zero
| import io | |
| from pathlib import Path | |
| from typing import Tuple, Optional | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| from diffusers import DiffusionPipeline | |
| from transformers import pipeline | |
| from huggingface_hub import login | |
| import os | |
| # Load environment variables | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TKN") | |
| if hf_token: | |
| login(token=hf_token) | |
| # Device configuration | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Load models | |
| def load_models(): | |
| """Load both models with proper device placement""" | |
| caption_pipe = pipeline( | |
| "image-to-text", | |
| model="nlpconnect/vit-gpt2-image-captioning", | |
| device=device | |
| ) | |
| audio_pipe = DiffusionPipeline.from_pretrained( | |
| "cvssp/audioldm2", | |
| token=hf_token, | |
| torch_dtype=torch_dtype | |
| ) | |
| return caption_pipe, audio_pipe | |
| caption_pipe, audio_pipe = load_models() | |
| def analyze_image(image_bytes: bytes) -> Tuple[str, bool]: | |
| """Generate caption from image bytes with enhanced error handling""" | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| results = caption_pipe(image) | |
| if not results or not isinstance(results, list): | |
| return "Error: Invalid response from caption model", True | |
| caption = results[0].get("generated_text", "").strip() | |
| return caption or "No caption generated", not bool(caption) | |
| except Exception as e: | |
| return f"Image processing error: {str(e)}", True | |
| def generate_audio(caption: str) -> Optional[Tuple[int, np.ndarray]]: | |
| """Generate audio from caption with resource management""" | |
| try: | |
| # Device management with context | |
| original_device = next(audio_pipe.parameters()).device | |
| audio_pipe.to(device) | |
| # Generation with progress awareness | |
| audio = audio_pipe( | |
| prompt=caption, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| audio_length_in_s=5.0 # Keep audio generation short | |
| ).audios[0] | |
| # Post-processing | |
| audio = audio.squeeze() # Handle mono channel | |
| audio = np.clip(audio, -1, 1) # Ensure valid range | |
| return (16000, audio) | |
| except Exception as e: | |
| print(f"Audio generation error: {str(e)}") | |
| return None | |
| finally: | |
| audio_pipe.to(original_device) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # UI Components | |
| css = """ | |
| #col-container { | |
| max-width: 800px; | |
| margin: 0 auto; | |
| } | |
| .disclaimer { | |
| font-size: 0.9em; | |
| color: #666; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(""" | |
| <h1 style="text-align: center;">🎶 Image to Sound Effect Generator</h1> | |
| <p style="text-align: center;"> | |
| ⚡ Powered by <a href="https://bilsimaging.com" target="_blank">Bilsimaging</a> | |
| </p> | |
| """) | |
| with gr.Row(): | |
| image_input = gr.Image(type="filepath", label="Upload Image") | |
| caption_output = gr.Textbox(label="Generated Description", interactive=False) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Description", variant="primary") | |
| audio_output = gr.Audio(label="Generated Sound", interactive=False) | |
| sound_btn = gr.Button("Generate Sound", variant="secondary") | |
| gr.Examples( | |
| examples=[str(Path(__file__).parent / "examples" / f) for f in ["storm.jpg", "city.jpg"]], | |
| inputs=image_input, | |
| outputs=[caption_output, audio_output], | |
| fn=lambda x: (analyze_image(Path(x).read_bytes())[0], None), | |
| cache_examples=True | |
| ) | |
| gr.Markdown("### 🛠️ Usage Tips") | |
| gr.Markdown(""" | |
| - Use clear, high-contrast images for best results | |
| - Complex scenes may require multiple generations | |
| - Keep sound generation under 10 seconds for quick results | |
| """) | |
| gr.Markdown("### ⚠️ Disclaimer", elem_classes="disclaimer") | |
| gr.Markdown(""" | |
| Generated content may not always be accurate. Use at your own discretion. | |
| [Privacy Policy](https://bilsimaging.com/privacy) | | |
| [Terms of Service](https://bilsimaging.com/terms) | |
| """) | |
| # Event handling | |
| generate_btn.click( | |
| fn=lambda x: analyze_image(Path(x).read_bytes())[0], | |
| inputs=image_input, | |
| outputs=caption_output, | |
| api_name="describe" | |
| ) | |
| sound_btn.click( | |
| fn=generate_audio, | |
| inputs=caption_output, | |
| outputs=audio_output, | |
| api_name="generate_sound" | |
| ) | |
| # Input validation | |
| image_input.change( | |
| fn=lambda: [gr.update(value=""), gr.update(value=None)], | |
| outputs=[caption_output, audio_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1") |