Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import inspect | |
| import os | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from torch.nn import functional as F | |
| from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer | |
| from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel | |
| from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dump_path", required=False, default=None, type=str) | |
| parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str) | |
| parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str) | |
| parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file") | |
| parser.add_argument( | |
| "--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file" | |
| ) | |
| parser.add_argument( | |
| "--unet_checkpoint_path_stage_2", | |
| required=False, | |
| default=None, | |
| type=str, | |
| help="Path to stage 2 unet checkpoint file", | |
| ) | |
| parser.add_argument( | |
| "--unet_checkpoint_path_stage_3", | |
| required=False, | |
| default=None, | |
| type=str, | |
| help="Path to stage 3 unet checkpoint file", | |
| ) | |
| parser.add_argument("--p_head_path", type=str, required=True) | |
| parser.add_argument("--w_head_path", type=str, required=True) | |
| args = parser.parse_args() | |
| return args | |
| def main(args): | |
| tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl") | |
| text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl") | |
| feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path) | |
| if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None: | |
| convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args) | |
| if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None: | |
| convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2) | |
| if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None: | |
| convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3) | |
| def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args): | |
| unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path) | |
| scheduler = DDPMScheduler( | |
| variance_type="learned_range", | |
| beta_schedule="squaredcos_cap_v2", | |
| prediction_type="epsilon", | |
| thresholding=True, | |
| dynamic_thresholding_ratio=0.95, | |
| sample_max_value=1.5, | |
| ) | |
| pipe = IFPipeline( | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| unet=unet, | |
| scheduler=scheduler, | |
| safety_checker=safety_checker, | |
| feature_extractor=feature_extractor, | |
| requires_safety_checker=True, | |
| ) | |
| pipe.save_pretrained(args.dump_path) | |
| def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage): | |
| if stage == 2: | |
| unet_checkpoint_path = args.unet_checkpoint_path_stage_2 | |
| sample_size = None | |
| dump_path = args.dump_path_stage_2 | |
| elif stage == 3: | |
| unet_checkpoint_path = args.unet_checkpoint_path_stage_3 | |
| sample_size = 1024 | |
| dump_path = args.dump_path_stage_3 | |
| else: | |
| assert False | |
| unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size) | |
| image_noising_scheduler = DDPMScheduler( | |
| beta_schedule="squaredcos_cap_v2", | |
| ) | |
| scheduler = DDPMScheduler( | |
| variance_type="learned_range", | |
| beta_schedule="squaredcos_cap_v2", | |
| prediction_type="epsilon", | |
| thresholding=True, | |
| dynamic_thresholding_ratio=0.95, | |
| sample_max_value=1.0, | |
| ) | |
| pipe = IFSuperResolutionPipeline( | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| unet=unet, | |
| scheduler=scheduler, | |
| image_noising_scheduler=image_noising_scheduler, | |
| safety_checker=safety_checker, | |
| feature_extractor=feature_extractor, | |
| requires_safety_checker=True, | |
| ) | |
| pipe.save_pretrained(dump_path) | |
| def get_stage_1_unet(unet_config, unet_checkpoint_path): | |
| original_unet_config = yaml.safe_load(unet_config) | |
| original_unet_config = original_unet_config["params"] | |
| unet_diffusers_config = create_unet_diffusers_config(original_unet_config) | |
| unet = UNet2DConditionModel(**unet_diffusers_config) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device) | |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint( | |
| unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path | |
| ) | |
| unet.load_state_dict(converted_unet_checkpoint) | |
| return unet | |
| def convert_safety_checker(p_head_path, w_head_path): | |
| state_dict = {} | |
| # p head | |
| p_head = np.load(p_head_path) | |
| p_head_weights = p_head["weights"] | |
| p_head_weights = torch.from_numpy(p_head_weights) | |
| p_head_weights = p_head_weights.unsqueeze(0) | |
| p_head_biases = p_head["biases"] | |
| p_head_biases = torch.from_numpy(p_head_biases) | |
| p_head_biases = p_head_biases.unsqueeze(0) | |
| state_dict["p_head.weight"] = p_head_weights | |
| state_dict["p_head.bias"] = p_head_biases | |
| # w head | |
| w_head = np.load(w_head_path) | |
| w_head_weights = w_head["weights"] | |
| w_head_weights = torch.from_numpy(w_head_weights) | |
| w_head_weights = w_head_weights.unsqueeze(0) | |
| w_head_biases = w_head["biases"] | |
| w_head_biases = torch.from_numpy(w_head_biases) | |
| w_head_biases = w_head_biases.unsqueeze(0) | |
| state_dict["w_head.weight"] = w_head_weights | |
| state_dict["w_head.bias"] = w_head_biases | |
| # vision model | |
| vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") | |
| vision_model_state_dict = vision_model.state_dict() | |
| for key, value in vision_model_state_dict.items(): | |
| key = f"vision_model.{key}" | |
| state_dict[key] = value | |
| # full model | |
| config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14") | |
| safety_checker = IFSafetyChecker(config) | |
| safety_checker.load_state_dict(state_dict) | |
| return safety_checker | |
| def create_unet_diffusers_config(original_unet_config, class_embed_type=None): | |
| attention_resolutions = parse_list(original_unet_config["attention_resolutions"]) | |
| attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions] | |
| channel_mult = parse_list(original_unet_config["channel_mult"]) | |
| block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult] | |
| down_block_types = [] | |
| resolution = 1 | |
| for i in range(len(block_out_channels)): | |
| if resolution in attention_resolutions: | |
| block_type = "SimpleCrossAttnDownBlock2D" | |
| elif original_unet_config["resblock_updown"]: | |
| block_type = "ResnetDownsampleBlock2D" | |
| else: | |
| block_type = "DownBlock2D" | |
| down_block_types.append(block_type) | |
| if i != len(block_out_channels) - 1: | |
| resolution *= 2 | |
| up_block_types = [] | |
| for i in range(len(block_out_channels)): | |
| if resolution in attention_resolutions: | |
| block_type = "SimpleCrossAttnUpBlock2D" | |
| elif original_unet_config["resblock_updown"]: | |
| block_type = "ResnetUpsampleBlock2D" | |
| else: | |
| block_type = "UpBlock2D" | |
| up_block_types.append(block_type) | |
| resolution //= 2 | |
| head_dim = original_unet_config["num_head_channels"] | |
| use_linear_projection = ( | |
| original_unet_config["use_linear_in_transformer"] | |
| if "use_linear_in_transformer" in original_unet_config | |
| else False | |
| ) | |
| if use_linear_projection: | |
| # stable diffusion 2-base-512 and 2-768 | |
| if head_dim is None: | |
| head_dim = [5, 10, 20, 20] | |
| projection_class_embeddings_input_dim = None | |
| if class_embed_type is None: | |
| if "num_classes" in original_unet_config: | |
| if original_unet_config["num_classes"] == "sequential": | |
| class_embed_type = "projection" | |
| assert "adm_in_channels" in original_unet_config | |
| projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"] | |
| else: | |
| raise NotImplementedError( | |
| f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}" | |
| ) | |
| config = { | |
| "sample_size": original_unet_config["image_size"], | |
| "in_channels": original_unet_config["in_channels"], | |
| "down_block_types": tuple(down_block_types), | |
| "block_out_channels": tuple(block_out_channels), | |
| "layers_per_block": original_unet_config["num_res_blocks"], | |
| "cross_attention_dim": original_unet_config["encoder_channels"], | |
| "attention_head_dim": head_dim, | |
| "use_linear_projection": use_linear_projection, | |
| "class_embed_type": class_embed_type, | |
| "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, | |
| "out_channels": original_unet_config["out_channels"], | |
| "up_block_types": tuple(up_block_types), | |
| "upcast_attention": False, # TODO: guessing | |
| "cross_attention_norm": "group_norm", | |
| "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", | |
| "addition_embed_type": "text", | |
| "act_fn": "gelu", | |
| } | |
| if original_unet_config["use_scale_shift_norm"]: | |
| config["resnet_time_scale_shift"] = "scale_shift" | |
| if "encoder_dim" in original_unet_config: | |
| config["encoder_hid_dim"] = original_unet_config["encoder_dim"] | |
| return config | |
| def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None): | |
| """ | |
| Takes a state dict and a config, and returns a converted checkpoint. | |
| """ | |
| new_checkpoint = {} | |
| new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | |
| new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | |
| new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | |
| new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | |
| if config["class_embed_type"] in [None, "identity"]: | |
| # No parameters to port | |
| ... | |
| elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": | |
| new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] | |
| new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] | |
| new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] | |
| new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] | |
| else: | |
| raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") | |
| new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | |
| new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | |
| new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | |
| new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | |
| new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | |
| new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | |
| # Retrieves the keys for the input blocks only | |
| num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | |
| input_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] | |
| for layer_id in range(num_input_blocks) | |
| } | |
| # Retrieves the keys for the middle blocks only | |
| num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | |
| middle_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | |
| for layer_id in range(num_middle_blocks) | |
| } | |
| # Retrieves the keys for the output blocks only | |
| num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | |
| output_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] | |
| for layer_id in range(num_output_blocks) | |
| } | |
| for i in range(1, num_input_blocks): | |
| block_id = (i - 1) // (config["layers_per_block"] + 1) | |
| layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | |
| resnets = [ | |
| key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | |
| ] | |
| attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | |
| if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | |
| new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | |
| f"input_blocks.{i}.0.op.weight" | |
| ) | |
| new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | |
| f"input_blocks.{i}.0.op.bias" | |
| ) | |
| paths = renew_resnet_paths(resnets) | |
| # TODO need better check than i in [4, 8, 12, 16] | |
| block_type = config["down_block_types"][block_id] | |
| if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [ | |
| 4, | |
| 8, | |
| 12, | |
| 16, | |
| ]: | |
| meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} | |
| else: | |
| meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| if len(attentions): | |
| old_path = f"input_blocks.{i}.1" | |
| new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" | |
| assign_attention_to_checkpoint( | |
| new_checkpoint=new_checkpoint, | |
| unet_state_dict=unet_state_dict, | |
| old_path=old_path, | |
| new_path=new_path, | |
| config=config, | |
| ) | |
| paths = renew_attention_paths(attentions) | |
| meta_path = {"old": old_path, "new": new_path} | |
| assign_to_checkpoint( | |
| paths, | |
| new_checkpoint, | |
| unet_state_dict, | |
| additional_replacements=[meta_path], | |
| config=config, | |
| ) | |
| resnet_0 = middle_blocks[0] | |
| attentions = middle_blocks[1] | |
| resnet_1 = middle_blocks[2] | |
| resnet_0_paths = renew_resnet_paths(resnet_0) | |
| assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | |
| resnet_1_paths = renew_resnet_paths(resnet_1) | |
| assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | |
| old_path = "middle_block.1" | |
| new_path = "mid_block.attentions.0" | |
| assign_attention_to_checkpoint( | |
| new_checkpoint=new_checkpoint, | |
| unet_state_dict=unet_state_dict, | |
| old_path=old_path, | |
| new_path=new_path, | |
| config=config, | |
| ) | |
| attentions_paths = renew_attention_paths(attentions) | |
| meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | |
| assign_to_checkpoint( | |
| attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| for i in range(num_output_blocks): | |
| block_id = i // (config["layers_per_block"] + 1) | |
| layer_in_block_id = i % (config["layers_per_block"] + 1) | |
| output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | |
| output_block_list = {} | |
| for layer in output_block_layers: | |
| layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | |
| if layer_id in output_block_list: | |
| output_block_list[layer_id].append(layer_name) | |
| else: | |
| output_block_list[layer_id] = [layer_name] | |
| # len(output_block_list) == 1 -> resnet | |
| # len(output_block_list) == 2 -> resnet, attention | |
| # len(output_block_list) == 3 -> resnet, attention, upscale resnet | |
| if len(output_block_list) > 1: | |
| resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | |
| attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | |
| paths = renew_resnet_paths(resnets) | |
| meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| output_block_list = {k: sorted(v) for k, v in output_block_list.items()} | |
| if ["conv.bias", "conv.weight"] in output_block_list.values(): | |
| index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) | |
| new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | |
| f"output_blocks.{i}.{index}.conv.weight" | |
| ] | |
| new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | |
| f"output_blocks.{i}.{index}.conv.bias" | |
| ] | |
| # Clear attentions as they have been attributed above. | |
| if len(attentions) == 2: | |
| attentions = [] | |
| if len(attentions): | |
| old_path = f"output_blocks.{i}.1" | |
| new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" | |
| assign_attention_to_checkpoint( | |
| new_checkpoint=new_checkpoint, | |
| unet_state_dict=unet_state_dict, | |
| old_path=old_path, | |
| new_path=new_path, | |
| config=config, | |
| ) | |
| paths = renew_attention_paths(attentions) | |
| meta_path = { | |
| "old": old_path, | |
| "new": new_path, | |
| } | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| if len(output_block_list) == 3: | |
| resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] | |
| paths = renew_resnet_paths(resnets) | |
| meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| else: | |
| resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | |
| for path in resnet_0_paths: | |
| old_path = ".".join(["output_blocks", str(i), path["old"]]) | |
| new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | |
| new_checkpoint[new_path] = unet_state_dict[old_path] | |
| if "encoder_proj.weight" in unet_state_dict: | |
| new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight") | |
| new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias") | |
| if "encoder_pooling.0.weight" in unet_state_dict: | |
| new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight") | |
| new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias") | |
| new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop( | |
| "encoder_pooling.1.positional_embedding" | |
| ) | |
| new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight") | |
| new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias") | |
| new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight") | |
| new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias") | |
| new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight") | |
| new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias") | |
| new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight") | |
| new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias") | |
| new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight") | |
| new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias") | |
| return new_checkpoint | |
| def shave_segments(path, n_shave_prefix_segments=1): | |
| """ | |
| Removes segments. Positive values shave the first segments, negative shave the last segments. | |
| """ | |
| if n_shave_prefix_segments >= 0: | |
| return ".".join(path.split(".")[n_shave_prefix_segments:]) | |
| else: | |
| return ".".join(path.split(".")[:n_shave_prefix_segments]) | |
| def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | |
| """ | |
| Updates paths inside resnets to the new naming scheme (local renaming) | |
| """ | |
| mapping = [] | |
| for old_item in old_list: | |
| new_item = old_item.replace("in_layers.0", "norm1") | |
| new_item = new_item.replace("in_layers.2", "conv1") | |
| new_item = new_item.replace("out_layers.0", "norm2") | |
| new_item = new_item.replace("out_layers.3", "conv2") | |
| new_item = new_item.replace("emb_layers.1", "time_emb_proj") | |
| new_item = new_item.replace("skip_connection", "conv_shortcut") | |
| new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
| mapping.append({"old": old_item, "new": new_item}) | |
| return mapping | |
| def renew_attention_paths(old_list, n_shave_prefix_segments=0): | |
| """ | |
| Updates paths inside attentions to the new naming scheme (local renaming) | |
| """ | |
| mapping = [] | |
| for old_item in old_list: | |
| new_item = old_item | |
| if "qkv" in new_item: | |
| continue | |
| if "encoder_kv" in new_item: | |
| continue | |
| new_item = new_item.replace("norm.weight", "group_norm.weight") | |
| new_item = new_item.replace("norm.bias", "group_norm.bias") | |
| new_item = new_item.replace("proj_out.weight", "to_out.0.weight") | |
| new_item = new_item.replace("proj_out.bias", "to_out.0.bias") | |
| new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight") | |
| new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias") | |
| new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
| mapping.append({"old": old_item, "new": new_item}) | |
| return mapping | |
| def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config): | |
| qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight") | |
| qkv_weight = qkv_weight[:, :, 0] | |
| qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias") | |
| is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"] | |
| split = 1 if is_cross_attn_only else 3 | |
| weights, bias = split_attentions( | |
| weight=qkv_weight, | |
| bias=qkv_bias, | |
| split=split, | |
| chunk_size=config["attention_head_dim"], | |
| ) | |
| if is_cross_attn_only: | |
| query_weight, q_bias = weights, bias | |
| new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0] | |
| new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0] | |
| else: | |
| [query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias | |
| new_checkpoint[f"{new_path}.to_q.weight"] = query_weight | |
| new_checkpoint[f"{new_path}.to_q.bias"] = q_bias | |
| new_checkpoint[f"{new_path}.to_k.weight"] = key_weight | |
| new_checkpoint[f"{new_path}.to_k.bias"] = k_bias | |
| new_checkpoint[f"{new_path}.to_v.weight"] = value_weight | |
| new_checkpoint[f"{new_path}.to_v.bias"] = v_bias | |
| encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight") | |
| encoder_kv_weight = encoder_kv_weight[:, :, 0] | |
| encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias") | |
| [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( | |
| weight=encoder_kv_weight, | |
| bias=encoder_kv_bias, | |
| split=2, | |
| chunk_size=config["attention_head_dim"], | |
| ) | |
| new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight | |
| new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias | |
| new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight | |
| new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias | |
| def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None): | |
| """ | |
| This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits | |
| attention layers, and takes into account additional replacements that may arise. | |
| Assigns the weights to the new checkpoint. | |
| """ | |
| assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | |
| for path in paths: | |
| new_path = path["new"] | |
| # Global renaming happens here | |
| new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | |
| new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | |
| new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | |
| if additional_replacements is not None: | |
| for replacement in additional_replacements: | |
| new_path = new_path.replace(replacement["old"], replacement["new"]) | |
| # proj_attn.weight has to be converted from conv 1D to linear | |
| if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path: | |
| checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | |
| else: | |
| checkpoint[new_path] = old_checkpoint[path["old"]] | |
| # TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) | |
| def split_attentions(*, weight, bias, split, chunk_size): | |
| weights = [None] * split | |
| biases = [None] * split | |
| weights_biases_idx = 0 | |
| for starting_row_index in range(0, weight.shape[0], chunk_size): | |
| row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) | |
| weight_rows = weight[row_indices, :] | |
| bias_rows = bias[row_indices] | |
| if weights[weights_biases_idx] is None: | |
| weights[weights_biases_idx] = weight_rows | |
| biases[weights_biases_idx] = bias_rows | |
| else: | |
| assert weights[weights_biases_idx] is not None | |
| weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) | |
| biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) | |
| weights_biases_idx = (weights_biases_idx + 1) % split | |
| return weights, biases | |
| def parse_list(value): | |
| if isinstance(value, str): | |
| value = value.split(",") | |
| value = [int(v) for v in value] | |
| elif isinstance(value, list): | |
| pass | |
| else: | |
| raise ValueError(f"Can't parse list for type: {type(value)}") | |
| return value | |
| # below is copy and pasted from original convert_if_stage_2.py script | |
| def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None): | |
| orig_path = unet_checkpoint_path | |
| original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml")) | |
| original_unet_config = original_unet_config["params"] | |
| unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config) | |
| unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int( | |
| original_unet_config["channel_mult"].split(",")[-1] | |
| ) | |
| if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]: | |
| unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"] | |
| unet_diffusers_config["class_embed_type"] = "timestep" | |
| unet_diffusers_config["addition_embed_type"] = "text" | |
| unet_diffusers_config["time_embedding_act_fn"] = "gelu" | |
| unet_diffusers_config["resnet_skip_time_act"] = True | |
| unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071 | |
| unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071 | |
| unet_diffusers_config["only_cross_attention"] = ( | |
| bool(original_unet_config["disable_self_attentions"]) | |
| if ( | |
| "disable_self_attentions" in original_unet_config | |
| and isinstance(original_unet_config["disable_self_attentions"], int) | |
| ) | |
| else True | |
| ) | |
| if sample_size is None: | |
| unet_diffusers_config["sample_size"] = original_unet_config["image_size"] | |
| else: | |
| # The second upscaler unet's sample size is incorrectly specified | |
| # in the config and is instead hardcoded in source | |
| unet_diffusers_config["sample_size"] = sample_size | |
| unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu") | |
| if verify_param_count: | |
| # check that architecture matches - is a bit slow | |
| verify_param_count(orig_path, unet_diffusers_config) | |
| converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint( | |
| unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path | |
| ) | |
| converted_keys = converted_unet_checkpoint.keys() | |
| model = UNet2DConditionModel(**unet_diffusers_config) | |
| expected_weights = model.state_dict().keys() | |
| diff_c_e = set(converted_keys) - set(expected_weights) | |
| diff_e_c = set(expected_weights) - set(converted_keys) | |
| assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}" | |
| assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}" | |
| model.load_state_dict(converted_unet_checkpoint) | |
| return model | |
| def superres_create_unet_diffusers_config(original_unet_config): | |
| attention_resolutions = parse_list(original_unet_config["attention_resolutions"]) | |
| attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions] | |
| channel_mult = parse_list(original_unet_config["channel_mult"]) | |
| block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult] | |
| down_block_types = [] | |
| resolution = 1 | |
| for i in range(len(block_out_channels)): | |
| if resolution in attention_resolutions: | |
| block_type = "SimpleCrossAttnDownBlock2D" | |
| elif original_unet_config["resblock_updown"]: | |
| block_type = "ResnetDownsampleBlock2D" | |
| else: | |
| block_type = "DownBlock2D" | |
| down_block_types.append(block_type) | |
| if i != len(block_out_channels) - 1: | |
| resolution *= 2 | |
| up_block_types = [] | |
| for i in range(len(block_out_channels)): | |
| if resolution in attention_resolutions: | |
| block_type = "SimpleCrossAttnUpBlock2D" | |
| elif original_unet_config["resblock_updown"]: | |
| block_type = "ResnetUpsampleBlock2D" | |
| else: | |
| block_type = "UpBlock2D" | |
| up_block_types.append(block_type) | |
| resolution //= 2 | |
| head_dim = original_unet_config["num_head_channels"] | |
| use_linear_projection = ( | |
| original_unet_config["use_linear_in_transformer"] | |
| if "use_linear_in_transformer" in original_unet_config | |
| else False | |
| ) | |
| if use_linear_projection: | |
| # stable diffusion 2-base-512 and 2-768 | |
| if head_dim is None: | |
| head_dim = [5, 10, 20, 20] | |
| class_embed_type = None | |
| projection_class_embeddings_input_dim = None | |
| if "num_classes" in original_unet_config: | |
| if original_unet_config["num_classes"] == "sequential": | |
| class_embed_type = "projection" | |
| assert "adm_in_channels" in original_unet_config | |
| projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"] | |
| else: | |
| raise NotImplementedError( | |
| f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}" | |
| ) | |
| config = { | |
| "in_channels": original_unet_config["in_channels"], | |
| "down_block_types": tuple(down_block_types), | |
| "block_out_channels": tuple(block_out_channels), | |
| "layers_per_block": tuple(original_unet_config["num_res_blocks"]), | |
| "cross_attention_dim": original_unet_config["encoder_channels"], | |
| "attention_head_dim": head_dim, | |
| "use_linear_projection": use_linear_projection, | |
| "class_embed_type": class_embed_type, | |
| "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, | |
| "out_channels": original_unet_config["out_channels"], | |
| "up_block_types": tuple(up_block_types), | |
| "upcast_attention": False, # TODO: guessing | |
| "cross_attention_norm": "group_norm", | |
| "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", | |
| "act_fn": "gelu", | |
| } | |
| if original_unet_config["use_scale_shift_norm"]: | |
| config["resnet_time_scale_shift"] = "scale_shift" | |
| return config | |
| def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False): | |
| """ | |
| Takes a state dict and a config, and returns a converted checkpoint. | |
| """ | |
| new_checkpoint = {} | |
| new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | |
| new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | |
| new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | |
| new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | |
| if config["class_embed_type"] is None: | |
| # No parameters to port | |
| ... | |
| elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": | |
| new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"] | |
| new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"] | |
| new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"] | |
| new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"] | |
| else: | |
| raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") | |
| if "encoder_proj.weight" in unet_state_dict: | |
| new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"] | |
| new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"] | |
| if "encoder_pooling.0.weight" in unet_state_dict: | |
| mapping = { | |
| "encoder_pooling.0": "add_embedding.norm1", | |
| "encoder_pooling.1": "add_embedding.pool", | |
| "encoder_pooling.2": "add_embedding.proj", | |
| "encoder_pooling.3": "add_embedding.norm2", | |
| } | |
| for key in unet_state_dict.keys(): | |
| if key.startswith("encoder_pooling"): | |
| prefix = key[: len("encoder_pooling.0")] | |
| new_key = key.replace(prefix, mapping[prefix]) | |
| new_checkpoint[new_key] = unet_state_dict[key] | |
| new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | |
| new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | |
| new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | |
| new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | |
| new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | |
| new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | |
| # Retrieves the keys for the input blocks only | |
| num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | |
| input_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] | |
| for layer_id in range(num_input_blocks) | |
| } | |
| # Retrieves the keys for the middle blocks only | |
| num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | |
| middle_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | |
| for layer_id in range(num_middle_blocks) | |
| } | |
| # Retrieves the keys for the output blocks only | |
| num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | |
| output_blocks = { | |
| layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] | |
| for layer_id in range(num_output_blocks) | |
| } | |
| if not isinstance(config["layers_per_block"], int): | |
| layers_per_block_list = [e + 1 for e in config["layers_per_block"]] | |
| layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) | |
| downsampler_ids = layers_per_block_cumsum | |
| else: | |
| # TODO need better check than i in [4, 8, 12, 16] | |
| downsampler_ids = [4, 8, 12, 16] | |
| for i in range(1, num_input_blocks): | |
| if isinstance(config["layers_per_block"], int): | |
| layers_per_block = config["layers_per_block"] | |
| block_id = (i - 1) // (layers_per_block + 1) | |
| layer_in_block_id = (i - 1) % (layers_per_block + 1) | |
| else: | |
| block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n) | |
| passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 | |
| layer_in_block_id = (i - 1) - passed_blocks | |
| resnets = [ | |
| key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | |
| ] | |
| attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | |
| if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | |
| new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | |
| f"input_blocks.{i}.0.op.weight" | |
| ) | |
| new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | |
| f"input_blocks.{i}.0.op.bias" | |
| ) | |
| paths = renew_resnet_paths(resnets) | |
| block_type = config["down_block_types"][block_id] | |
| if ( | |
| block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D" | |
| ) and i in downsampler_ids: | |
| meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} | |
| else: | |
| meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| if len(attentions): | |
| old_path = f"input_blocks.{i}.1" | |
| new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" | |
| assign_attention_to_checkpoint( | |
| new_checkpoint=new_checkpoint, | |
| unet_state_dict=unet_state_dict, | |
| old_path=old_path, | |
| new_path=new_path, | |
| config=config, | |
| ) | |
| paths = renew_attention_paths(attentions) | |
| meta_path = {"old": old_path, "new": new_path} | |
| assign_to_checkpoint( | |
| paths, | |
| new_checkpoint, | |
| unet_state_dict, | |
| additional_replacements=[meta_path], | |
| config=config, | |
| ) | |
| resnet_0 = middle_blocks[0] | |
| attentions = middle_blocks[1] | |
| resnet_1 = middle_blocks[2] | |
| resnet_0_paths = renew_resnet_paths(resnet_0) | |
| assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | |
| resnet_1_paths = renew_resnet_paths(resnet_1) | |
| assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | |
| old_path = "middle_block.1" | |
| new_path = "mid_block.attentions.0" | |
| assign_attention_to_checkpoint( | |
| new_checkpoint=new_checkpoint, | |
| unet_state_dict=unet_state_dict, | |
| old_path=old_path, | |
| new_path=new_path, | |
| config=config, | |
| ) | |
| attentions_paths = renew_attention_paths(attentions) | |
| meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | |
| assign_to_checkpoint( | |
| attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| if not isinstance(config["layers_per_block"], int): | |
| layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]])) | |
| layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) | |
| for i in range(num_output_blocks): | |
| if isinstance(config["layers_per_block"], int): | |
| layers_per_block = config["layers_per_block"] | |
| block_id = i // (layers_per_block + 1) | |
| layer_in_block_id = i % (layers_per_block + 1) | |
| else: | |
| block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n) | |
| passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 | |
| layer_in_block_id = i - passed_blocks | |
| output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | |
| output_block_list = {} | |
| for layer in output_block_layers: | |
| layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | |
| if layer_id in output_block_list: | |
| output_block_list[layer_id].append(layer_name) | |
| else: | |
| output_block_list[layer_id] = [layer_name] | |
| # len(output_block_list) == 1 -> resnet | |
| # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet | |
| # len(output_block_list) == 3 -> resnet, attention, upscale resnet | |
| if len(output_block_list) > 1: | |
| resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | |
| has_attention = True | |
| if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]): | |
| has_attention = False | |
| maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | |
| paths = renew_resnet_paths(resnets) | |
| meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| output_block_list = {k: sorted(v) for k, v in output_block_list.items()} | |
| if ["conv.bias", "conv.weight"] in output_block_list.values(): | |
| index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) | |
| new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | |
| f"output_blocks.{i}.{index}.conv.weight" | |
| ] | |
| new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | |
| f"output_blocks.{i}.{index}.conv.bias" | |
| ] | |
| # this layer was no attention | |
| has_attention = False | |
| maybe_attentions = [] | |
| if has_attention: | |
| old_path = f"output_blocks.{i}.1" | |
| new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" | |
| assign_attention_to_checkpoint( | |
| new_checkpoint=new_checkpoint, | |
| unet_state_dict=unet_state_dict, | |
| old_path=old_path, | |
| new_path=new_path, | |
| config=config, | |
| ) | |
| paths = renew_attention_paths(maybe_attentions) | |
| meta_path = { | |
| "old": old_path, | |
| "new": new_path, | |
| } | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0): | |
| layer_id = len(output_block_list) - 1 | |
| resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key] | |
| paths = renew_resnet_paths(resnets) | |
| meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"} | |
| assign_to_checkpoint( | |
| paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
| ) | |
| else: | |
| resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | |
| for path in resnet_0_paths: | |
| old_path = ".".join(["output_blocks", str(i), path["old"]]) | |
| new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | |
| new_checkpoint[new_path] = unet_state_dict[old_path] | |
| return new_checkpoint | |
| def verify_param_count(orig_path, unet_diffusers_config): | |
| if "-II-" in orig_path: | |
| from deepfloyd_if.modules import IFStageII | |
| if_II = IFStageII(device="cpu", dir_or_name=orig_path) | |
| elif "-III-" in orig_path: | |
| from deepfloyd_if.modules import IFStageIII | |
| if_II = IFStageIII(device="cpu", dir_or_name=orig_path) | |
| else: | |
| assert f"Weird name. Should have -II- or -III- in path: {orig_path}" | |
| unet = UNet2DConditionModel(**unet_diffusers_config) | |
| # in params | |
| assert_param_count(unet.time_embedding, if_II.model.time_embed) | |
| assert_param_count(unet.conv_in, if_II.model.input_blocks[:1]) | |
| # downblocks | |
| assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4]) | |
| assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7]) | |
| assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11]) | |
| if "-II-" in orig_path: | |
| assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17]) | |
| assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:]) | |
| if "-III-" in orig_path: | |
| assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15]) | |
| assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20]) | |
| assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:]) | |
| # mid block | |
| assert_param_count(unet.mid_block, if_II.model.middle_block) | |
| # up block | |
| if "-II-" in orig_path: | |
| assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6]) | |
| assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12]) | |
| assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16]) | |
| assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19]) | |
| assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:]) | |
| if "-III-" in orig_path: | |
| assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5]) | |
| assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10]) | |
| assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14]) | |
| assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18]) | |
| assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21]) | |
| assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24]) | |
| # out params | |
| assert_param_count(unet.conv_norm_out, if_II.model.out[0]) | |
| assert_param_count(unet.conv_out, if_II.model.out[2]) | |
| # make sure all model architecture has same param count | |
| assert_param_count(unet, if_II.model) | |
| def assert_param_count(model_1, model_2): | |
| count_1 = sum(p.numel() for p in model_1.parameters()) | |
| count_2 = sum(p.numel() for p in model_2.parameters()) | |
| assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" | |
| def superres_check_against_original(dump_path, unet_checkpoint_path): | |
| model_path = dump_path | |
| model = UNet2DConditionModel.from_pretrained(model_path) | |
| model.to("cuda") | |
| orig_path = unet_checkpoint_path | |
| if "-II-" in orig_path: | |
| from deepfloyd_if.modules import IFStageII | |
| if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model | |
| elif "-III-" in orig_path: | |
| from deepfloyd_if.modules import IFStageIII | |
| if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model | |
| batch_size = 1 | |
| channels = model.config.in_channels // 2 | |
| height = model.config.sample_size | |
| width = model.config.sample_size | |
| height = 1024 | |
| width = 1024 | |
| torch.manual_seed(0) | |
| latents = torch.randn((batch_size, channels, height, width), device=model.device) | |
| image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device) | |
| interpolate_antialias = {} | |
| if "antialias" in inspect.signature(F.interpolate).parameters: | |
| interpolate_antialias["antialias"] = True | |
| image_upscaled = F.interpolate( | |
| image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias | |
| ) | |
| latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype) | |
| t = torch.tensor([5], device=model.device).to(model.dtype) | |
| seq_len = 64 | |
| encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to( | |
| model.dtype | |
| ) | |
| fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype) | |
| with torch.no_grad(): | |
| out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states) | |
| if_II_model.to("cpu") | |
| del if_II_model | |
| import gc | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print(50 * "=") | |
| with torch.no_grad(): | |
| noise_pred = model( | |
| sample=latent_model_input, | |
| encoder_hidden_states=encoder_hidden_states, | |
| class_labels=fake_class_labels, | |
| timestep=t, | |
| ).sample | |
| print("Out shape", noise_pred.shape) | |
| print("Diff", (out - noise_pred).abs().sum()) | |
| if __name__ == "__main__": | |
| main(parse_args()) | |