Spaces:
Paused
Paused
| import random | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from modeling.dmm_pipeline import StableDiffusionDMMPipeline | |
| from huggingface_hub import snapshot_download | |
| ckpt_path = "ckpt" | |
| snapshot_download(repo_id="MCG-NJU/DMM", local_dir=ckpt_path) | |
| pipe = StableDiffusionDMMPipeline.from_pretrained( | |
| ckpt_path, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ) | |
| pipe.to("cuda") | |
| def generate(prompt: str, | |
| negative_prompt: str, | |
| model_id: int, | |
| seed: int = 1234, | |
| height: int = 512, | |
| width: int = 512, | |
| all: bool = True): | |
| if all: | |
| outputs = [] | |
| for i in range(pipe.unet.get_num_models()): | |
| output = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=25, | |
| guidance_scale=7, | |
| model_id=i, | |
| generator=torch.Generator().manual_seed(seed), | |
| ).images[0] | |
| outputs.append(output) | |
| return outputs | |
| else: | |
| output = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=25, | |
| guidance_scale=7, | |
| model_id=int(model_id), | |
| generator=torch.Generator().manual_seed(seed), | |
| ).images[0] | |
| return [output,] | |
| candidates = [ | |
| "0. [JuggernautReborn] realistic", | |
| "1. [MajicmixRealisticV7] realistic, Asia portrait", | |
| "2. [EpicRealismV5] realistic", | |
| "3. [RealisticVisionV5] realistic", | |
| "4. [MajicmixFantasyV3] animation", | |
| "5. [MinimalismV2] illustration", | |
| "6. [RealCartoon3dV17] cartoon 3d", | |
| "7. [AWPaintingV1.4] animation", | |
| ] | |
| def main(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # DMM Demo | |
| The checkpoint is https://huggingface.co/MCG-NJU/DMM. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Column(): | |
| model_id = gr.Dropdown(candidates, label="Model Index", type="index") | |
| all_check = gr.Checkbox(label="All (ignore the selection above)") | |
| prompt = gr.Textbox("portrait photo of a girl, long golden hair, flowers, best quality", label="Prompt") | |
| negative_prompt = gr.Textbox("worst quality,low quality,normal quality,lowres,watermark,nsfw", label="Negative Prompt") | |
| with gr.Row(): | |
| seed = gr.Number(0, label="Seed", precision=0, scale=3) | |
| update_seed_btn = gr.Button("🎲", scale=1) | |
| with gr.Row(): | |
| height = gr.Number(768, step=8, label="Height (suggest 512~768)") | |
| width = gr.Number(512, step=8, label="Width") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| output = gr.Gallery(label="images") | |
| submit_btn.click(generate, | |
| inputs=[prompt, negative_prompt, model_id, seed, height, width, all_check], | |
| outputs=[output]) | |
| update_seed_btn.click(lambda: random.randint(0, 1000000), | |
| outputs=[seed]) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |