Spaces:
Running
on
Zero
Running
on
Zero
| import inspect | |
| import warnings | |
| from typing import Callable, List, Optional, Union | |
| import numpy as np | |
| import torch | |
| from packaging import version | |
| from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection | |
| from ...configuration_utils import FrozenDict | |
| from ...image_processor import PipelineImageInput | |
| from ...loaders import IPAdapterMixin | |
| from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel | |
| from ...schedulers import KarrasDiffusionSchedulers | |
| from ...utils import deprecate, logging | |
| from ...utils.torch_utils import randn_tensor | |
| from ..pipeline_utils import DiffusionPipeline | |
| from . import StableDiffusionSafePipelineOutput | |
| from .safety_checker import SafeStableDiffusionSafetyChecker | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin): | |
| r""" | |
| Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion. | |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods | |
| implemented for all pipelines (downloading, saving, running on a particular device, etc.). | |
| The pipeline also inherits the following loading methods: | |
| - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters | |
| Args: | |
| vae ([`AutoencoderKL`]): | |
| Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. | |
| text_encoder ([`~transformers.CLIPTextModel`]): | |
| Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). | |
| tokenizer ([`~transformers.CLIPTokenizer`]): | |
| A `CLIPTokenizer` to tokenize text. | |
| unet ([`UNet2DConditionModel`]): | |
| A `UNet2DConditionModel` to denoise the encoded image latents. | |
| scheduler ([`SchedulerMixin`]): | |
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | |
| [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | |
| safety_checker ([`StableDiffusionSafetyChecker`]): | |
| Classification module that estimates whether generated images could be considered offensive or harmful. | |
| Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details | |
| about a model's potential harms. | |
| feature_extractor ([`~transformers.CLIPImageProcessor`]): | |
| A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. | |
| """ | |
| model_cpu_offload_seq = "text_encoder->unet->vae" | |
| _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler: KarrasDiffusionSchedulers, | |
| safety_checker: SafeStableDiffusionSafetyChecker, | |
| feature_extractor: CLIPImageProcessor, | |
| image_encoder: Optional[CLIPVisionModelWithProjection] = None, | |
| requires_safety_checker: bool = True, | |
| ): | |
| super().__init__() | |
| safety_concept: Optional[str] = ( | |
| "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," | |
| " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" | |
| " abuse, brutality, cruelty" | |
| ) | |
| if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | |
| deprecation_message = ( | |
| f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | |
| f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | |
| "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | |
| " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | |
| " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | |
| " file" | |
| ) | |
| deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) | |
| new_config = dict(scheduler.config) | |
| new_config["steps_offset"] = 1 | |
| scheduler._internal_dict = FrozenDict(new_config) | |
| if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: | |
| deprecation_message = ( | |
| f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." | |
| " `clip_sample` should be set to False in the configuration file. Please make sure to update the" | |
| " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" | |
| " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" | |
| " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" | |
| ) | |
| deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) | |
| new_config = dict(scheduler.config) | |
| new_config["clip_sample"] = False | |
| scheduler._internal_dict = FrozenDict(new_config) | |
| if safety_checker is None and requires_safety_checker: | |
| logger.warning( | |
| f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" | |
| " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" | |
| " results in services or applications open to the public. Both the diffusers team and Hugging Face" | |
| " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" | |
| " it only for use-cases that involve analyzing network behavior or auditing its results. For more" | |
| " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." | |
| ) | |
| if safety_checker is not None and feature_extractor is None: | |
| raise ValueError( | |
| "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" | |
| " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." | |
| ) | |
| is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( | |
| version.parse(unet.config._diffusers_version).base_version | |
| ) < version.parse("0.9.0.dev0") | |
| is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 | |
| if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: | |
| deprecation_message = ( | |
| "The configuration file of the unet has set the default `sample_size` to smaller than" | |
| " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" | |
| " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" | |
| " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" | |
| " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" | |
| " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" | |
| " in the config might lead to incorrect results in future versions. If you have downloaded this" | |
| " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" | |
| " the `unet/config.json` file" | |
| ) | |
| deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) | |
| new_config = dict(unet.config) | |
| new_config["sample_size"] = 64 | |
| unet._internal_dict = FrozenDict(new_config) | |
| self.register_modules( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| safety_checker=safety_checker, | |
| feature_extractor=feature_extractor, | |
| image_encoder=image_encoder, | |
| ) | |
| self._safety_text_concept = safety_concept | |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
| self.register_to_config(requires_safety_checker=requires_safety_checker) | |
| def safety_concept(self): | |
| r""" | |
| Getter method for the safety concept used with SLD | |
| Returns: | |
| `str`: The text describing the safety concept | |
| """ | |
| return self._safety_text_concept | |
| def safety_concept(self, concept): | |
| r""" | |
| Setter method for the safety concept used with SLD | |
| Args: | |
| concept (`str`): | |
| The text of the new safety concept | |
| """ | |
| self._safety_text_concept = concept | |
| def _encode_prompt( | |
| self, | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| do_classifier_free_guidance, | |
| negative_prompt, | |
| enable_safety_guidance, | |
| ): | |
| r""" | |
| Encodes the prompt into text encoder hidden states. | |
| Args: | |
| prompt (`str` or `List[str]`): | |
| prompt to be encoded | |
| device: (`torch.device`): | |
| torch device | |
| num_images_per_prompt (`int`): | |
| number of images that should be generated per prompt | |
| do_classifier_free_guidance (`bool`): | |
| whether to use classifier free guidance or not | |
| negative_prompt (`str` or `List[str]`): | |
| The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored | |
| if `guidance_scale` is less than `1`). | |
| """ | |
| batch_size = len(prompt) if isinstance(prompt, list) else 1 | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids | |
| if not torch.equal(text_input_ids, untruncated_ids): | |
| removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) | |
| logger.warning( | |
| "The following part of your input was truncated because CLIP can only handle sequences up to" | |
| f" {self.tokenizer.model_max_length} tokens: {removed_text}" | |
| ) | |
| if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
| attention_mask = text_inputs.attention_mask.to(device) | |
| else: | |
| attention_mask = None | |
| prompt_embeds = self.text_encoder( | |
| text_input_ids.to(device), | |
| attention_mask=attention_mask, | |
| ) | |
| prompt_embeds = prompt_embeds[0] | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| # get unconditional embeddings for classifier free guidance | |
| if do_classifier_free_guidance: | |
| uncond_tokens: List[str] | |
| if negative_prompt is None: | |
| uncond_tokens = [""] * batch_size | |
| elif type(prompt) is not type(negative_prompt): | |
| raise TypeError( | |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
| f" {type(prompt)}." | |
| ) | |
| elif isinstance(negative_prompt, str): | |
| uncond_tokens = [negative_prompt] | |
| elif batch_size != len(negative_prompt): | |
| raise ValueError( | |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
| " the batch size of `prompt`." | |
| ) | |
| else: | |
| uncond_tokens = negative_prompt | |
| max_length = text_input_ids.shape[-1] | |
| uncond_input = self.tokenizer( | |
| uncond_tokens, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
| attention_mask = uncond_input.attention_mask.to(device) | |
| else: | |
| attention_mask = None | |
| negative_prompt_embeds = self.text_encoder( | |
| uncond_input.input_ids.to(device), | |
| attention_mask=attention_mask, | |
| ) | |
| negative_prompt_embeds = negative_prompt_embeds[0] | |
| # duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
| seq_len = negative_prompt_embeds.shape[1] | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| # Encode the safety concept text | |
| if enable_safety_guidance: | |
| safety_concept_input = self.tokenizer( | |
| [self._safety_text_concept], | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] | |
| # duplicate safety embeddings for each generation per prompt, using mps friendly method | |
| seq_len = safety_embeddings.shape[1] | |
| safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) | |
| safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| # For classifier free guidance + sld, we need to do three forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing three forward passes | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) | |
| else: | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| return prompt_embeds | |
| def run_safety_checker(self, image, device, dtype, enable_safety_guidance): | |
| if self.safety_checker is not None: | |
| images = image.copy() | |
| safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) | |
| image, has_nsfw_concept = self.safety_checker( | |
| images=image, clip_input=safety_checker_input.pixel_values.to(dtype) | |
| ) | |
| flagged_images = np.zeros((2, *image.shape[1:])) | |
| if any(has_nsfw_concept): | |
| logger.warning( | |
| "Potential NSFW content was detected in one or more images. A black image will be returned" | |
| " instead." | |
| f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" | |
| ) | |
| for idx, has_nsfw_concept in enumerate(has_nsfw_concept): | |
| if has_nsfw_concept: | |
| flagged_images[idx] = images[idx] | |
| image[idx] = np.zeros(image[idx].shape) # black image | |
| else: | |
| has_nsfw_concept = None | |
| flagged_images = None | |
| return image, has_nsfw_concept, flagged_images | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents | |
| def decode_latents(self, latents): | |
| deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" | |
| deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| image = self.vae.decode(latents, return_dict=False)[0] | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| return image | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs | |
| def prepare_extra_step_kwargs(self, generator, eta): | |
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
| # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
| # and should be between [0, 1] | |
| accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
| extra_step_kwargs = {} | |
| if accepts_eta: | |
| extra_step_kwargs["eta"] = eta | |
| # check if the scheduler accepts generator | |
| accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
| if accepts_generator: | |
| extra_step_kwargs["generator"] = generator | |
| return extra_step_kwargs | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs | |
| def check_inputs( | |
| self, | |
| prompt, | |
| height, | |
| width, | |
| callback_steps, | |
| negative_prompt=None, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| callback_on_step_end_tensor_inputs=None, | |
| ): | |
| if height % 8 != 0 or width % 8 != 0: | |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
| if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): | |
| raise ValueError( | |
| f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | |
| f" {type(callback_steps)}." | |
| ) | |
| if callback_on_step_end_tensor_inputs is not None and not all( | |
| k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
| ): | |
| raise ValueError( | |
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
| ) | |
| if prompt is not None and prompt_embeds is not None: | |
| raise ValueError( | |
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
| " only forward one of the two." | |
| ) | |
| elif prompt is None and prompt_embeds is None: | |
| raise ValueError( | |
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
| ) | |
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
| if negative_prompt is not None and negative_prompt_embeds is not None: | |
| raise ValueError( | |
| f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" | |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | |
| ) | |
| if prompt_embeds is not None and negative_prompt_embeds is not None: | |
| if prompt_embeds.shape != negative_prompt_embeds.shape: | |
| raise ValueError( | |
| "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" | |
| f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" | |
| f" {negative_prompt_embeds.shape}." | |
| ) | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents | |
| def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | |
| shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if latents is None: | |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| else: | |
| latents = latents.to(device) | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| latents = latents * self.scheduler.init_noise_sigma | |
| return latents | |
| def perform_safety_guidance( | |
| self, | |
| enable_safety_guidance, | |
| safety_momentum, | |
| noise_guidance, | |
| noise_pred_out, | |
| i, | |
| sld_guidance_scale, | |
| sld_warmup_steps, | |
| sld_threshold, | |
| sld_momentum_scale, | |
| sld_mom_beta, | |
| ): | |
| # Perform SLD guidance | |
| if enable_safety_guidance: | |
| if safety_momentum is None: | |
| safety_momentum = torch.zeros_like(noise_guidance) | |
| noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] | |
| noise_pred_safety_concept = noise_pred_out[2] | |
| # Equation 6 | |
| scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) | |
| # Equation 6 | |
| safety_concept_scale = torch.where( | |
| (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale | |
| ) | |
| # Equation 4 | |
| noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) | |
| # Equation 7 | |
| noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum | |
| # Equation 8 | |
| safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety | |
| if i >= sld_warmup_steps: # Warmup | |
| # Equation 3 | |
| noise_guidance = noise_guidance - noise_guidance_safety | |
| return noise_guidance, safety_momentum | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image | |
| def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): | |
| dtype = next(self.image_encoder.parameters()).dtype | |
| if not isinstance(image, torch.Tensor): | |
| image = self.feature_extractor(image, return_tensors="pt").pixel_values | |
| image = image.to(device=device, dtype=dtype) | |
| if output_hidden_states: | |
| image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | |
| image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) | |
| uncond_image_enc_hidden_states = self.image_encoder( | |
| torch.zeros_like(image), output_hidden_states=True | |
| ).hidden_states[-2] | |
| uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( | |
| num_images_per_prompt, dim=0 | |
| ) | |
| return image_enc_hidden_states, uncond_image_enc_hidden_states | |
| else: | |
| image_embeds = self.image_encoder(image).image_embeds | |
| image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |
| uncond_image_embeds = torch.zeros_like(image_embeds) | |
| return image_embeds, uncond_image_embeds | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]], | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| ip_adapter_image: Optional[PipelineImageInput] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: int = 1, | |
| sld_guidance_scale: Optional[float] = 1000, | |
| sld_warmup_steps: Optional[int] = 10, | |
| sld_threshold: Optional[float] = 0.01, | |
| sld_momentum_scale: Optional[float] = 0.3, | |
| sld_mom_beta: Optional[float] = 0.4, | |
| ): | |
| r""" | |
| The call function to the pipeline for generation. | |
| Args: | |
| prompt (`str` or `List[str]`): | |
| The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
| height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
| The height in pixels of the generated image. | |
| width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
| The width in pixels of the generated image. | |
| num_inference_steps (`int`, *optional*, defaults to 50): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
| expense of slower inference. | |
| guidance_scale (`float`, *optional*, defaults to 7.5): | |
| A higher guidance scale value encourages the model to generate images closely linked to the text | |
| `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
| pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
| num_images_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of images to generate per prompt. | |
| eta (`float`, *optional*, defaults to 0.0): | |
| Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies | |
| to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
| generation deterministic. | |
| latents (`torch.FloatTensor`, *optional*): | |
| Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image | |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
| tensor is generated by sampling using the supplied random `generator`. | |
| ip_adapter_image: (`PipelineImageInput`, *optional*): | |
| Optional image input to work with IP Adapters. | |
| output_type (`str`, *optional*, defaults to `"pil"`): | |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
| plain tuple. | |
| callback (`Callable`, *optional*): | |
| A function that calls every `callback_steps` steps during inference. The function is called with the | |
| following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | |
| callback_steps (`int`, *optional*, defaults to 1): | |
| The frequency at which the `callback` function is called. If not specified, the callback is called at | |
| every step. | |
| sld_guidance_scale (`float`, *optional*, defaults to 1000): | |
| If `sld_guidance_scale < 1`, safety guidance is disabled. | |
| sld_warmup_steps (`int`, *optional*, defaults to 10): | |
| Number of warmup steps for safety guidance. SLD is only be applied for diffusion steps greater than | |
| `sld_warmup_steps`. | |
| sld_threshold (`float`, *optional*, defaults to 0.01): | |
| Threshold that separates the hyperplane between appropriate and inappropriate images. | |
| sld_momentum_scale (`float`, *optional*, defaults to 0.3): | |
| Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0, | |
| momentum is disabled. Momentum is built up during warmup for diffusion steps smaller than | |
| `sld_warmup_steps`. | |
| sld_mom_beta (`float`, *optional*, defaults to 0.4): | |
| Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous | |
| momentum is kept. Momentum is built up during warmup for diffusion steps smaller than | |
| `sld_warmup_steps`. | |
| Returns: | |
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | |
| If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, | |
| otherwise a `tuple` is returned where the first element is a list with the generated images and the | |
| second element is a list of `bool`s indicating whether the corresponding generated image contains | |
| "not-safe-for-work" (nsfw) content. | |
| Examples: | |
| ```py | |
| import torch | |
| from diffusers import StableDiffusionPipelineSafe | |
| from diffusers.pipelines.stable_diffusion_safe import SafetyConfig | |
| pipeline = StableDiffusionPipelineSafe.from_pretrained( | |
| "AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16 | |
| ).to("cuda") | |
| prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker" | |
| image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0] | |
| ``` | |
| """ | |
| # 0. Default height and width to unet | |
| height = height or self.unet.config.sample_size * self.vae_scale_factor | |
| width = width or self.unet.config.sample_size * self.vae_scale_factor | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs(prompt, height, width, callback_steps) | |
| # 2. Define call parameters | |
| batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
| device = self._execution_device | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance | |
| if not enable_safety_guidance: | |
| warnings.warn("Safety checker disabled!") | |
| if ip_adapter_image is not None: | |
| output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True | |
| image_embeds, negative_image_embeds = self.encode_image( | |
| ip_adapter_image, device, num_images_per_prompt, output_hidden_state | |
| ) | |
| if do_classifier_free_guidance: | |
| if enable_safety_guidance: | |
| image_embeds = torch.cat([negative_image_embeds, image_embeds, image_embeds]) | |
| else: | |
| image_embeds = torch.cat([negative_image_embeds, image_embeds]) | |
| # 3. Encode input prompt | |
| prompt_embeds = self._encode_prompt( | |
| prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance | |
| ) | |
| # 4. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Prepare extra step kwargs. | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 6.1 Add image embeds for IP-Adapter | |
| added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None | |
| safety_momentum = None | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = ( | |
| torch.cat([latents] * (3 if enable_safety_guidance else 2)) | |
| if do_classifier_free_guidance | |
| else latents | |
| ) | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs | |
| ).sample | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) | |
| noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] | |
| # default classifier free guidance | |
| noise_guidance = noise_pred_text - noise_pred_uncond | |
| # Perform SLD guidance | |
| if enable_safety_guidance: | |
| if safety_momentum is None: | |
| safety_momentum = torch.zeros_like(noise_guidance) | |
| noise_pred_safety_concept = noise_pred_out[2] | |
| # Equation 6 | |
| scale = torch.clamp( | |
| torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 | |
| ) | |
| # Equation 6 | |
| safety_concept_scale = torch.where( | |
| (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, | |
| torch.zeros_like(scale), | |
| scale, | |
| ) | |
| # Equation 4 | |
| noise_guidance_safety = torch.mul( | |
| (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale | |
| ) | |
| # Equation 7 | |
| noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum | |
| # Equation 8 | |
| safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety | |
| if i >= sld_warmup_steps: # Warmup | |
| # Equation 3 | |
| noise_guidance = noise_guidance - noise_guidance_safety | |
| noise_pred = noise_pred_uncond + guidance_scale * noise_guidance | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| step_idx = i // getattr(self.scheduler, "order", 1) | |
| callback(step_idx, t, latents) | |
| # 8. Post-processing | |
| image = self.decode_latents(latents) | |
| # 9. Run safety checker | |
| image, has_nsfw_concept, flagged_images = self.run_safety_checker( | |
| image, device, prompt_embeds.dtype, enable_safety_guidance | |
| ) | |
| # 10. Convert to PIL | |
| if output_type == "pil": | |
| image = self.numpy_to_pil(image) | |
| if flagged_images is not None: | |
| flagged_images = self.numpy_to_pil(flagged_images) | |
| if not return_dict: | |
| return ( | |
| image, | |
| has_nsfw_concept, | |
| self._safety_text_concept if enable_safety_guidance else None, | |
| flagged_images, | |
| ) | |
| return StableDiffusionSafePipelineOutput( | |
| images=image, | |
| nsfw_content_detected=has_nsfw_concept, | |
| applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, | |
| unsafe_images=flagged_images, | |
| ) | |