Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| StableDiffusionInstructPix2PixPipeline, | |
| StableVideoDiffusionPipeline, | |
| WanPipeline, | |
| ) | |
| from diffusers.utils import export_to_video, load_image | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| def make_pipe(cls, model_id, **kwargs): | |
| pipe = cls.from_pretrained(model_id, torch_dtype=dtype, **kwargs) | |
| pipe.enable_model_cpu_offload() | |
| return pipe | |
| TXT2IMG_PIPE = None | |
| IMG2IMG_PIPE = None | |
| TXT2VID_PIPE = None | |
| IMG2VID_PIPE = None | |
| def generate_image_from_text(prompt): | |
| global TXT2IMG_PIPE | |
| if TXT2IMG_PIPE is None: | |
| TXT2IMG_PIPE = make_pipe( | |
| StableDiffusionPipeline, | |
| "stabilityai/stable-diffusion-2-1-base" | |
| ).to(device) | |
| return TXT2IMG_PIPE(prompt, num_inference_steps=20).images[0] | |
| def generate_image_from_image_and_prompt(image, prompt): | |
| global IMG2IMG_PIPE | |
| if IMG2IMG_PIPE is None: | |
| IMG2IMG_PIPE = make_pipe( | |
| StableDiffusionInstructPix2PixPipeline, | |
| "timbrooks/instruct-pix2pix" | |
| ).to(device) | |
| out = IMG2IMG_PIPE(prompt=prompt, image=image, num_inference_steps=8) | |
| return out.images[0] | |
| def generate_video_from_text(prompt): | |
| global TXT2VID_PIPE | |
| if TXT2VID_PIPE is None: | |
| TXT2VID_PIPE = make_pipe( | |
| WanPipeline, | |
| "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | |
| ).to(device) | |
| frames = TXT2VID_PIPE(prompt=prompt, num_frames=12).frames[0] | |
| return export_to_video(frames, "/tmp/wan_video.mp4", fps=8) | |
| def generate_video_from_image(image): | |
| global IMG2VID_PIPE | |
| if IMG2VID_PIPE is None: | |
| IMG2VID_PIPE = make_pipe( | |
| StableVideoDiffusionPipeline, | |
| "stabilityai/stable-video-diffusion-img2vid-xt", | |
| variant="fp16" if dtype == torch.float16 else None | |
| ).to(device) | |
| image = load_image(image).resize((512, 288)) | |
| frames = IMG2VID_PIPE(image, num_inference_steps=16).frames[0] | |
| return export_to_video(frames, "/tmp/svd_video.mp4", fps=8) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π§ Lightweight Any-to-Any AI Playground") | |
| with gr.Tab("Text β Image"): | |
| text_input = gr.Textbox(label="Prompt") | |
| image_output = gr.Image(label="Generated Image") | |
| generate_button = gr.Button("Generate") | |
| generate_button.click(generate_image_from_text, inputs=text_input, outputs=image_output) | |
| with gr.Tab("Image β Image"): | |
| input_image = gr.Image(label="Input Image") | |
| prompt_input = gr.Textbox(label="Edit Prompt") | |
| output_image = gr.Image(label="Edited Image") | |
| edit_button = gr.Button("Generate") | |
| edit_button.click(generate_image_from_image_and_prompt, inputs=[input_image, prompt_input], outputs=output_image) | |
| with gr.Tab("Text β Video"): | |
| video_prompt = gr.Textbox(label="Prompt") | |
| video_output = gr.Video(label="Generated Video") | |
| video_button = gr.Button("Generate") | |
| video_button.click(generate_video_from_text, inputs=video_prompt, outputs=video_output) | |
| with gr.Tab("Image β Video"): | |
| anim_image = gr.Image(label="Input Image") | |
| anim_video_output = gr.Video(label="Animated Video") | |
| anim_button = gr.Button("Animate") | |
| anim_button.click(generate_video_from_image, inputs=anim_image, outputs=anim_video_output) | |
| demo.launch() | |