Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from rembg import remove | |
| from diffusers import StableDiffusionPipeline | |
| # ----------------------------------------------------------------------------- | |
| # Helper function to adjust image size to multiples of 8. | |
| # ----------------------------------------------------------------------------- | |
| def adjust_size(w, h): | |
| """ | |
| Adjust width and height to be multiples of 8, as required by the Stable Diffusion model. | |
| """ | |
| new_w = (w // 8) * 8 | |
| new_h = (h // 8) * 8 | |
| return new_w, new_h | |
| # ----------------------------------------------------------------------------- | |
| # Core processing function: | |
| # 1. Remove background from the uploaded image. | |
| # 2. Generate a new background image based on the text prompt. | |
| # 3. Composite the foreground onto the generated background. | |
| # ----------------------------------------------------------------------------- | |
| def process_image(input_image: Image.Image, bg_prompt: str) -> Image.Image: | |
| """ | |
| Processes the uploaded image by removing its background and replacing it with a generated one. | |
| Parameters: | |
| input_image (PIL.Image.Image): The uploaded image. | |
| bg_prompt (str): Text prompt describing the new background. | |
| Returns: | |
| PIL.Image.Image: The final composited image. | |
| """ | |
| if input_image is None: | |
| raise ValueError("No image provided.") | |
| # Step 1: Remove the background from the input image. | |
| print("Removing background from the uploaded image...") | |
| foreground = remove(input_image) | |
| foreground = foreground.convert("RGBA") | |
| # Step 2: Adjust dimensions for background generation. | |
| orig_w, orig_h = foreground.size | |
| gen_w, gen_h = adjust_size(orig_w, orig_h) | |
| print(f"Original size: {orig_w}x{orig_h} | Adjusted size: {gen_w}x{gen_h}") | |
| # Step 3: Generate a new background using the provided text prompt. | |
| print("Generating new background using Stable Diffusion...") | |
| bg_output = pipe( | |
| bg_prompt, | |
| height=gen_h, | |
| width=gen_w, | |
| num_inference_steps=50, # Adjust if needed. | |
| guidance_scale=7.5 # Adjust for more/less prompt adherence. | |
| ) | |
| # Convert the generated background to RGBA. | |
| background = bg_output.images[0].convert("RGBA") | |
| # Step 4: Ensure the foreground matches the background dimensions. | |
| if foreground.size != background.size: | |
| print("Resizing foreground to match background dimensions...") | |
| foreground = foreground.resize(background.size, Image.ANTIALIAS) | |
| # Step 5: Composite the images. | |
| print("Compositing images...") | |
| final_image = Image.alpha_composite(background, foreground) | |
| return final_image | |
| # ----------------------------------------------------------------------------- | |
| # Load the Stable Diffusion pipeline from Hugging Face. | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "stabilityai/stable-diffusion-2" # Change the model if desired. | |
| # Use half precision if GPU is available. | |
| if torch.cuda.is_available(): | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| print("Loading Stable Diffusion pipeline...") | |
| pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype) | |
| if torch.cuda.is_available(): | |
| pipe = pipe.to("cuda") | |
| print("Stable Diffusion pipeline loaded.") | |
| # ----------------------------------------------------------------------------- | |
| # Create the Gradio Interface using the updated API. | |
| # ----------------------------------------------------------------------------- | |
| title = "Background Removal & Replacement" | |
| description = ( | |
| "Upload an image (e.g., a person or an animal) and provide a text prompt " | |
| "describing the new background. The app will remove the original background and " | |
| "composite the subject onto a generated background." | |
| ) | |
| iface = gr.Interface( | |
| fn=process_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Your Image"), | |
| gr.Textbox(lines=2, placeholder="Describe the new background...", label="Background Prompt") | |
| ], | |
| outputs=gr.Image(label="Output Image"), | |
| title=title, | |
| description=description, | |
| allow_flagging="never" | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Launch the app. | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| iface.launch() | |