Spaces:
Running
on
Zero
Running
on
Zero
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline | |
| from transformers import CLIPVisionModelWithProjection | |
| import torch | |
| from copy import deepcopy | |
| ENABLE_CPU_CACHE = False | |
| DEFAULT_BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5" | |
| cached_models = {} # cache for models to avoid repeated loading, key is model name | |
| def cache_model(func): | |
| def wrapper(*args, **kwargs): | |
| if ENABLE_CPU_CACHE: | |
| model_name = func.__name__ + str(args) + str(kwargs) | |
| if model_name not in cached_models: | |
| cached_models[model_name] = func(*args, **kwargs) | |
| return cached_models[model_name] | |
| else: | |
| return func(*args, **kwargs) | |
| return wrapper | |
| def copied_cache_model(func): | |
| def wrapper(*args, **kwargs): | |
| if ENABLE_CPU_CACHE: | |
| model_name = func.__name__ + str(args) + str(kwargs) | |
| if model_name not in cached_models: | |
| cached_models[model_name] = func(*args, **kwargs) | |
| return deepcopy(cached_models[model_name]) | |
| else: | |
| return func(*args, **kwargs) | |
| return wrapper | |
| def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): | |
| if ckpt_or_pretrained.endswith(".safetensors"): | |
| pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) | |
| else: | |
| pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) | |
| return pipe | |
| def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): | |
| model_kwargs = dict( | |
| torch_dtype=torch_dtype, | |
| requires_safety_checker=False, | |
| safety_checker=None, | |
| ) | |
| pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( | |
| base_model, | |
| StableDiffusionPipeline, | |
| **model_kwargs | |
| ) | |
| pipe.to("cpu") | |
| return pipe.components | |
| def load_controlnet(controlnet_path, torch_dtype=torch.float16): | |
| controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) | |
| return controlnet | |
| def load_image_encoder(): | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| "h94/IP-Adapter", | |
| subfolder="models/image_encoder", | |
| torch_dtype=torch.float16, | |
| ) | |
| return image_encoder | |
| def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): | |
| model_kwargs = dict( | |
| torch_dtype=torch_dtype, | |
| # device_map=device, | |
| requires_safety_checker=False, | |
| safety_checker=None, | |
| ) | |
| components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) | |
| model_kwargs.update(components) | |
| model_kwargs.update(kwargs) | |
| if controlnet is not None: | |
| if isinstance(controlnet, list): | |
| controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] | |
| else: | |
| controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) | |
| model_kwargs.update(controlnet=controlnet) | |
| if pipeline_class is None: | |
| if controlnet is not None: | |
| pipeline_class = StableDiffusionControlNetPipeline | |
| else: | |
| pipeline_class = StableDiffusionPipeline | |
| pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( | |
| base_model, | |
| pipeline_class, | |
| **model_kwargs | |
| ) | |
| if ip_adapter: | |
| image_encoder = load_image_encoder() | |
| pipe.image_encoder = image_encoder | |
| if plus_model: | |
| pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") | |
| else: | |
| pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") | |
| pipe.set_ip_adapter_scale(1.0) | |
| else: | |
| pipe.unload_ip_adapter() | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| if model_cpu_offload_seq is None: | |
| if isinstance(pipe, StableDiffusionControlNetPipeline): | |
| pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" | |
| elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): | |
| pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" | |
| else: | |
| pipe.model_cpu_offload_seq = model_cpu_offload_seq | |
| if enable_sequential_cpu_offload: | |
| pipe.enable_sequential_cpu_offload() | |
| else: | |
| pass | |
| pipe.enable_model_cpu_offload() | |
| if vae_slicing: | |
| pipe.enable_vae_slicing() | |
| import gc | |
| gc.collect() | |
| return pipe | |