Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForImage2Image | |
| from src.eunms import Model_Type, Scheduler_Type | |
| from src.euler_scheduler import MyEulerAncestralDiscreteScheduler | |
| from src.lcm_scheduler import MyLCMScheduler | |
| from src.ddpm_scheduler import MyDDPMScheduler | |
| from src.sdxl_inversion_pipeline import SDXLDDIMPipeline | |
| from src.sd_inversion_pipeline import SDDDIMPipeline | |
| def scheduler_type_to_class(scheduler_type): | |
| if scheduler_type == Scheduler_Type.DDIM: | |
| return DDIMScheduler | |
| elif scheduler_type == Scheduler_Type.EULER: | |
| return MyEulerAncestralDiscreteScheduler | |
| elif scheduler_type == Scheduler_Type.LCM: | |
| return MyLCMScheduler | |
| elif scheduler_type == Scheduler_Type.DDPM: | |
| return MyDDPMScheduler | |
| else: | |
| raise ValueError("Unknown scheduler type") | |
| def model_type_to_class(model_type): | |
| if model_type == Model_Type.SDXL: | |
| return StableDiffusionXLImg2ImgPipeline, SDXLDDIMPipeline | |
| elif model_type == Model_Type.SDXL_Turbo: | |
| return AutoPipelineForImage2Image, SDXLDDIMPipeline | |
| elif model_type == Model_Type.LCM_SDXL: | |
| return AutoPipelineForImage2Image, SDXLDDIMPipeline | |
| elif model_type == Model_Type.SD15: | |
| return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
| elif model_type == Model_Type.SD14: | |
| return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
| elif model_type == Model_Type.SD21: | |
| return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
| elif model_type == Model_Type.SD21_Turbo: | |
| return StableDiffusionImg2ImgPipeline, SDDDIMPipeline | |
| else: | |
| raise ValueError("Unknown model type") | |
| def model_type_to_model_name(model_type): | |
| if model_type == Model_Type.SDXL: | |
| return "stabilityai/stable-diffusion-xl-base-1.0" | |
| elif model_type == Model_Type.SDXL_Turbo: | |
| return "stabilityai/sdxl-turbo" | |
| elif model_type == Model_Type.LCM_SDXL: | |
| return "stabilityai/stable-diffusion-xl-base-1.0" | |
| elif model_type == Model_Type.SD15: | |
| return "runwayml/stable-diffusion-v1-5" | |
| elif model_type == Model_Type.SD14: | |
| return "CompVis/stable-diffusion-v1-4" | |
| elif model_type == Model_Type.SD21: | |
| return "stabilityai/stable-diffusion-2-1" | |
| elif model_type == Model_Type.SD21_Turbo: | |
| return "stabilityai/sd-turbo" | |
| else: | |
| raise ValueError("Unknown model type") | |
| def model_type_to_size(model_type): | |
| if model_type == Model_Type.SDXL: | |
| return (1024, 1024) | |
| elif model_type == Model_Type.SDXL_Turbo: | |
| return (512, 512) | |
| elif model_type == Model_Type.LCM_SDXL: | |
| return (768, 768) #TODO: check | |
| elif model_type == Model_Type.SD15: | |
| return (512, 512) | |
| elif model_type == Model_Type.SD14: | |
| return (512, 512) | |
| elif model_type == Model_Type.SD21: | |
| return (512, 512) | |
| elif model_type == Model_Type.SD21_Turbo: | |
| return (512, 512) | |
| else: | |
| raise ValueError("Unknown model type") | |
| def is_float16(model_type): | |
| if model_type == Model_Type.SDXL: | |
| return True | |
| elif model_type == Model_Type.SDXL_Turbo: | |
| return True | |
| elif model_type == Model_Type.LCM_SDXL: | |
| return True | |
| elif model_type == Model_Type.SD15: | |
| return False | |
| elif model_type == Model_Type.SD14: | |
| return False | |
| elif model_type == Model_Type.SD21: | |
| return False | |
| elif model_type == Model_Type.SD21_Turbo: | |
| return False | |
| else: | |
| raise ValueError("Unknown model type") | |
| def is_sd(model_type): | |
| if model_type == Model_Type.SDXL: | |
| return False | |
| elif model_type == Model_Type.SDXL_Turbo: | |
| return False | |
| elif model_type == Model_Type.LCM_SDXL: | |
| return False | |
| elif model_type == Model_Type.SD15: | |
| return True | |
| elif model_type == Model_Type.SD14: | |
| return True | |
| elif model_type == Model_Type.SD21: | |
| return True | |
| elif model_type == Model_Type.SD21_Turbo: | |
| return True | |
| else: | |
| raise ValueError("Unknown model type") | |
| def _get_pipes(model_type, device): | |
| model_name = model_type_to_model_name(model_type) | |
| pipeline_inf, pipeline_inv = model_type_to_class(model_type) | |
| if is_float16(model_type): | |
| pipe_inversion = pipeline_inv.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| variant="fp16", | |
| safety_checker = None | |
| ).to(device) | |
| pipe_inference = pipeline_inf.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| variant="fp16", | |
| safety_checker = None | |
| ).to(device) | |
| else: | |
| pipe_inversion = pipeline_inv.from_pretrained( | |
| model_name, | |
| use_safetensors=True, | |
| safety_checker = None | |
| ).to(device) | |
| pipe_inference = pipeline_inf.from_pretrained( | |
| model_name, | |
| use_safetensors=True, | |
| safety_checker = None | |
| ).to(device) | |
| return pipe_inversion, pipe_inference | |
| def get_pipes(model_type, scheduler_type, device="cuda"): | |
| # model_name = model_type_to_model_name(model_type) | |
| # pipeline_inf, pipeline_inv = model_type_to_class(model_type) | |
| scheduler_class = scheduler_type_to_class(scheduler_type) | |
| pipe_inversion, pipe_inference = _get_pipes(model_type, device) | |
| # pipe_inversion = pipeline_inv.from_pretrained( | |
| # model_name, | |
| # # torch_dtype=torch.float16, | |
| # use_safetensors=True, | |
| # # variant="fp16", | |
| # safety_checker = None | |
| # ).to("cuda") | |
| # pipe_inference = pipeline_inf.from_pretrained( | |
| # model_name, | |
| # # torch_dtype=torch.float16, | |
| # use_safetensors=True, | |
| # # variant="fp16", | |
| # safety_checker = None | |
| # ).to("cuda") | |
| pipe_inference.scheduler = scheduler_class.from_config(pipe_inference.scheduler.config) | |
| pipe_inversion.scheduler = scheduler_class.from_config(pipe_inversion.scheduler.config) | |
| pipe_inversion.scheduler_inference = scheduler_class.from_config(pipe_inference.scheduler.config) | |
| if is_sd(model_type): | |
| pipe_inference.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents | |
| pipe_inversion.scheduler.add_noise = lambda init_latents, noise, timestep: init_latents | |
| pipe_inversion.scheduler_inference.add_noise = lambda init_latents, noise, timestep: init_latents | |
| if model_type == Model_Type.LCM_SDXL: | |
| adapter_id = "latent-consistency/lcm-lora-sdxl" | |
| # load and fuse lcm lora | |
| pipe_inversion.load_lora_weights(adapter_id) | |
| # pipe_inversion.fuse_lora() | |
| pipe_inference.load_lora_weights(adapter_id) | |
| # pipe_inference.fuse_lora() | |
| return pipe_inversion, pipe_inference |