Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import torch | |
| import PIL | |
| import gradio as gr | |
| from typing import Optional | |
| from accelerate import Accelerator | |
| from diffusers import ( | |
| AutoencoderKL, | |
| StableDiffusionXLControlNetPipeline, | |
| ControlNetModel, | |
| UNet2DConditionModel, | |
| ) | |
| from transformers import ( | |
| BlipProcessor, BlipForConditionalGeneration, | |
| ) | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| # ========== Initialization ========== | |
| # Ensure required directories exist | |
| os.makedirs("sdxl_light_caption_output", exist_ok=True) | |
| # Download controlnet model snapshot | |
| snapshot_download( | |
| repo_id='nickpai/sdxl_light_caption_output', | |
| local_dir='sdxl_light_caption_output' | |
| ) | |
| # Device and precision setup | |
| accelerator = Accelerator(mixed_precision="fp16") | |
| weight_dtype = torch.float16 if accelerator.mixed_precision == "fp16" else torch.float32 | |
| device = accelerator.device | |
| print(f"[INFO] Accelerator device: {device}") | |
| # ========== Models ========== | |
| # Pretrained paths | |
| base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
| safetensors_ckpt = "sdxl_lightning_8step_unet.safetensors" | |
| controlnet_path = "sdxl_light_caption_output/checkpoint-30000/controlnet" | |
| # Load diffusion components | |
| vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae") | |
| unet = UNet2DConditionModel.from_config(base_model_path, subfolder="unet") | |
| unet.load_state_dict(load_file(hf_hub_download("ByteDance/SDXL-Lightning", safetensors_ckpt))) | |
| controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=weight_dtype) | |
| pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
| base_model_path, vae=vae, unet=unet, controlnet=controlnet | |
| ) | |
| pipe.to(device, dtype=weight_dtype) | |
| pipe.safety_checker = None | |
| # Load BLIP captioning model | |
| caption_model_name = "blip-image-captioning-large" | |
| processor = BlipProcessor.from_pretrained(f"Salesforce/{caption_model_name}") | |
| caption_model = BlipForConditionalGeneration.from_pretrained( | |
| f"Salesforce/{caption_model_name}", torch_dtype=weight_dtype | |
| ).to(device) | |
| # ========== Utility Functions ========== | |
| def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image: | |
| # Convert to LAB color space | |
| image_lab = image.convert('LAB') | |
| color_map_lab = color_map.convert('LAB') | |
| # Extract and merge LAB channels | |
| l, _, _ = image_lab.split() | |
| _, a_map, b_map = color_map_lab.split() | |
| merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map)) | |
| return merged_lab.convert('RGB') | |
| def remove_unlikely_words(prompt: str) -> str: | |
| """Removes predefined unlikely phrases from prompt text.""" | |
| unlikely_words = [] | |
| a1 = [f'{i}s' for i in range(1900, 2000)] | |
| a2 = [f'{i}' for i in range(1900, 2000)] | |
| a3 = [f'year {i}' for i in range(1900, 2000)] | |
| a4 = [f'circa {i}' for i in range(1900, 2000)] | |
| b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1] | |
| b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] | |
| b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] | |
| b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] | |
| manual = [ # same list as your original words_list | |
| "black and white,", "black and white", "black & white,", "black & white", "circa", | |
| "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", | |
| "black - and - white photography,", "monochrome bw,", "black white,", "black an white,", | |
| "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", | |
| "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", | |
| "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", | |
| "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", | |
| "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", | |
| "black-and-white photo,", "black-and-white photo", "black - and - white photography", | |
| "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", | |
| "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", | |
| "black - and - white photograph,", "black - and - white photograph", "black on white,", | |
| "black on white", "black-and-white", "historical image,", "historical picture,", | |
| "historical photo,", "historical photograph,", "archival photo,", "taken in the early", | |
| "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", | |
| "historical photo", "historical setting,", | |
| "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", | |
| "taken in", "shot on leica", "shot on leica sl2", "sl2", | |
| "taken with a leica camera", "leica sl2", "leica", "setting", | |
| "overcast day", "overcast weather", "slight overcast", "overcast", | |
| "picture taken in", "photo taken in", | |
| ", photo", ", photo", ", photo", ", photo", ", photograph", | |
| ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", | |
| ] | |
| unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual) | |
| for word in unlikely_words: | |
| prompt = prompt.replace(word, "") | |
| return prompt | |
| def get_image_paths(folder_path: str) -> list: | |
| return [[os.path.join(folder_path, f)] for f in os.listdir(folder_path) | |
| if f.lower().endswith((".jpg", ".png"))] | |
| def process_image(image_path: str, | |
| positive_prompt: Optional[str], | |
| negative_prompt: Optional[str], | |
| seed: int) -> tuple[PIL.Image.Image, str]: | |
| """Colorize a grayscale or low-color image using automatic captioning and text-guided diffusion. | |
| This function performs image-to-image generation using a ControlNet model and Stable Diffusion XL, | |
| guided by a text caption extracted from the image itself using a BLIP captioning model. Optional | |
| prompts (positive and negative) can further influence the output style or content. | |
| Process Overview: | |
| 1. The input image is loaded and resized to 512x512 for inference. | |
| 2. A BLIP model generates a caption describing the image content. | |
| 3. The caption is cleaned using a filtering function to remove misleading or unwanted terms. | |
| 4. A prompt is constructed by combining the user-provided positive prompt with the caption. | |
| 5. A ControlNet-guided image is generated using the SDXL pipeline. | |
| 6. The output image's color channels (A and B in LAB space) are applied to the original luminance (L) | |
| of the control image to preserve structure while transferring color. | |
| 7. The image is resized back to the original resolution and returned. | |
| Args: | |
| image_path: Path to the grayscale or lightly colored input image (JPEG/PNG). | |
| positive_prompt: Additional descriptive text to enhance or guide the generation. | |
| negative_prompt: Words or phrases to avoid during generation (e.g., "blurry", "monochrome"). | |
| seed: Random seed for reproducible generation. | |
| Returns: | |
| A tuple containing: | |
| - A colorized PIL image based on the input and generated caption. | |
| - The cleaned caption string used to guide the generation. | |
| """ | |
| torch.manual_seed(seed) | |
| image = PIL.Image.open(image_path) | |
| original_size = image.size | |
| control_image = image.convert("L").convert("RGB").resize((512, 512)) | |
| # Image captioning | |
| input_text = "a photography of" | |
| inputs = processor(image, input_text, return_tensors="pt").to(device, dtype=weight_dtype) | |
| caption_ids = caption_model.generate(**inputs) | |
| caption = processor.decode(caption_ids[0], skip_special_tokens=True) | |
| caption = remove_unlikely_words(caption) | |
| # Inference | |
| final_prompt = [f"{positive_prompt}, {caption}"] | |
| result = pipe(prompt=final_prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=8, | |
| generator=torch.manual_seed(seed), | |
| image=control_image) | |
| colorized = apply_color(control_image, result.images[0]).resize(original_size) | |
| return colorized, caption | |
| # ========== Gradio UI ========== | |
| def create_interface(): | |
| examples = get_image_paths("example/legacy_images") | |
| return gr.Interface( | |
| fn=process_image, | |
| inputs=[ | |
| gr.Image(label="Upload Image", type='filepath', | |
| value="example/legacy_images/Hollywood-Sign.jpg"), | |
| gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"), | |
| gr.Textbox(label="Negative Prompt", value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate"), | |
| ], | |
| outputs=[ | |
| gr.Image(label="Colorized Image", format="jpeg", | |
| value="example/UUColor_results/Hollywood-Sign.jpeg"), | |
| gr.Textbox(label="Caption", show_copy_button=True) | |
| ], | |
| examples=examples, | |
| additional_inputs=[gr.Slider(0, 1000, 123, label="Seed")], | |
| title="Text-Guided Image Colorization", | |
| description="Upload a grayscale image and generate a color version guided by automatic captioning.", | |
| cache_examples=False | |
| ) | |
| def main(): | |
| interface = create_interface() | |
| interface.launch(ssr_mode=False, mcp_server=True) | |
| if __name__ == "__main__": | |
| main() | |