Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import List, Dict, Union | |
| from tqdm import tqdm | |
| import torch | |
| import safetensors | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, CLIPTextModelWithProjection | |
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| UNet2DConditionModel, | |
| EulerDiscreteScheduler, | |
| ) | |
| from diffusers.loaders import LoraLoaderMixin | |
| SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0" | |
| JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl" | |
| L_REPO = "ByteDance/SDXL-Lightning" | |
| def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"): | |
| file_extension = os.path.basename(checkpoint_file).split(".")[-1] | |
| if file_extension == "safetensors": | |
| return safetensors.torch.load_file(checkpoint_file, device=device) | |
| else: | |
| return torch.load(checkpoint_file, map_location=device) | |
| def load_from_pretrained( | |
| repo_id, | |
| filename="diffusion_pytorch_model.fp16.safetensors", | |
| subfolder="unet", | |
| device="cuda", | |
| ) -> Dict[str, torch.Tensor]: | |
| return load_state_dict( | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| ), | |
| device=device, | |
| ) | |
| def reshape_weight_task_tensors(task_tensors, weights): | |
| """ | |
| Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions. | |
| Args: | |
| task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`. | |
| weights (`torch.Tensor`): The tensor to be reshaped. | |
| Returns: | |
| `torch.Tensor`: The reshaped tensor. | |
| """ | |
| new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim()) | |
| weights = weights.view(new_shape) | |
| return weights | |
| def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Merge the task tensors using `linear`. | |
| Args: | |
| task_tensors(`List[torch.Tensor]`):The task tensors to merge. | |
| weights (`torch.Tensor`):The weights of the task tensors. | |
| Returns: | |
| `torch.Tensor`: The merged tensor. | |
| """ | |
| task_tensors = torch.stack(task_tensors, dim=0) | |
| # weighted task tensors | |
| weights = reshape_weight_task_tensors(task_tensors, weights) | |
| weighted_task_tensors = task_tensors * weights | |
| mixed_task_tensors = weighted_task_tensors.sum(dim=0) | |
| return mixed_task_tensors | |
| def merge_models( | |
| task_tensors, | |
| weights, | |
| ): | |
| keys = list(task_tensors[0].keys()) | |
| weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device) | |
| state_dict = {} | |
| for key in tqdm(keys, desc="Merging"): | |
| w_list = [] | |
| for i, sd in enumerate(task_tensors): | |
| w = sd.pop(key) | |
| w_list.append(w) | |
| new_w = linear(task_tensors=w_list, weights=weights) | |
| state_dict[key] = new_w | |
| return state_dict | |
| def split_conv_attn(weights): | |
| attn_tensors = {} | |
| conv_tensors = {} | |
| for key in list(weights.keys()): | |
| if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]): | |
| attn_tensors[key] = weights.pop(key) | |
| else: | |
| conv_tensors[key] = weights.pop(key) | |
| return {"conv": conv_tensors, "attn": attn_tensors} | |
| def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline: | |
| sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device)) | |
| dpo_weights = split_conv_attn( | |
| load_from_pretrained( | |
| "mhdang/dpo-sdxl-text2image-v1", | |
| "diffusion_pytorch_model.safetensors", | |
| device=device, | |
| ) | |
| ) | |
| jn_weights = split_conv_attn( | |
| load_from_pretrained("RunDiffusion/Juggernaut-XL-v9", device=device) | |
| ) | |
| jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device)) | |
| tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights] | |
| new_conv = merge_models( | |
| [sd["conv"] for sd in tensors], | |
| [ | |
| 0.15928833971605916, | |
| 0.1032449268871776, | |
| 0.6503217149752791, | |
| 0.08714501842148402, | |
| ], | |
| ) | |
| new_attn = merge_models( | |
| [sd["attn"] for sd in tensors], | |
| [ | |
| 0.1877279276437178, | |
| 0.20014114603909822, | |
| 0.3922685507065275, | |
| 0.2198623756106564, | |
| ], | |
| ) | |
| del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights | |
| torch.cuda.empty_cache() | |
| unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet") | |
| unet = UNet2DConditionModel.from_config(unet_config).to(device=device) | |
| unet.load_state_dict({**new_conv, **new_attn}) | |
| state_dict, network_alphas = LoraLoaderMixin.lora_state_dict( | |
| L_REPO, weight_name="sdxl_lightning_4step_lora.safetensors" | |
| ) | |
| LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet) | |
| unet.fuse_lora(lora_scale=3.224682864579401) | |
| new_weights = split_conv_attn(unet.state_dict()) | |
| l_weights = split_conv_attn( | |
| load_from_pretrained( | |
| L_REPO, | |
| "sdxl_lightning_4step_unet.safetensors", | |
| subfolder=None, | |
| device=device, | |
| ) | |
| ) | |
| jnl_weights = split_conv_attn( | |
| load_from_pretrained( | |
| "RunDiffusion/Juggernaut-XL-Lightning", | |
| "diffusion_pytorch_model.bin", | |
| device=device, | |
| ) | |
| ) | |
| tensors = [l_weights, jnl_weights, new_weights] | |
| new_conv = merge_models( | |
| [sd["conv"] for sd in tensors], | |
| [0.47222002022088533, 0.48419531030361584, 0.04358466947549889], | |
| ) | |
| new_attn = merge_models( | |
| [sd["attn"] for sd in tensors], | |
| [0.023119324530758375, 0.04924981616469831, 0.9276308593045434], | |
| ) | |
| new_weights = {**new_conv, **new_attn} | |
| unet = UNet2DConditionModel.from_config(unet_config).to(device=device) | |
| unet.load_state_dict({**new_conv, **new_attn}) | |
| text_encoder = CLIPTextModelWithProjection.from_pretrained( | |
| JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| JSDXL_REPO, subfolder="tokenizer", use_fast=False | |
| ) | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| SDXL_REPO, | |
| unet=unet, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ) | |
| # Ensure sampler uses "trailing" timesteps. | |
| pipe.scheduler = EulerDiscreteScheduler.from_config( | |
| pipe.scheduler.config, timestep_spacing="trailing" | |
| ) | |
| pipe = pipe.to(device, dtype=torch.float16) | |
| return pipe | |
| if __name__ == "__main__": | |
| pipe: StableDiffusionXLPipeline = load_evosdxl_jp() | |
| images = pipe("犬", num_inference_steps=4, guidance_scale=0).images | |
| images[0].save("out.png") | |