| import torch | |
| from omegaconf import OmegaConf | |
| from sgm.util import instantiate_from_config | |
| from sgm.modules.diffusionmodules.sampling import * | |
| SD_XL_BASE_RATIOS = { | |
| "0.5": (704, 1408), | |
| "0.52": (704, 1344), | |
| "0.57": (768, 1344), | |
| "0.6": (768, 1280), | |
| "0.68": (832, 1216), | |
| "0.72": (832, 1152), | |
| "0.78": (896, 1152), | |
| "0.82": (896, 1088), | |
| "0.88": (960, 1088), | |
| "0.94": (960, 1024), | |
| "1.0": (1024, 1024), | |
| "1.07": (1024, 960), | |
| "1.13": (1088, 960), | |
| "1.21": (1088, 896), | |
| "1.29": (1152, 896), | |
| "1.38": (1152, 832), | |
| "1.46": (1216, 832), | |
| "1.67": (1280, 768), | |
| "1.75": (1344, 768), | |
| "1.91": (1344, 704), | |
| "2.0": (1408, 704), | |
| "2.09": (1472, 704), | |
| "2.4": (1536, 640), | |
| "2.5": (1600, 640), | |
| "2.89": (1664, 576), | |
| "3.0": (1728, 576), | |
| } | |
| def init_model(cfgs): | |
| model_cfg = OmegaConf.load(cfgs.model_cfg_path) | |
| ckpt = cfgs.load_ckpt_path | |
| model = instantiate_from_config(model_cfg.model) | |
| model.init_from_ckpt(ckpt) | |
| if cfgs.type == "train": | |
| model.train() | |
| else: | |
| if cfgs.use_gpu: | |
| model.to(torch.device("cuda", index=cfgs.gpu)) | |
| model.eval() | |
| model.freeze() | |
| return model | |
| def init_sampling(cfgs): | |
| discretization_config = { | |
| "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", | |
| } | |
| if cfgs.dual_conditioner: | |
| guider_config = { | |
| "target": "sgm.modules.diffusionmodules.guiders.DualCFG", | |
| "params": {"scale": cfgs.scale}, | |
| } | |
| sampler = EulerEDMDualSampler( | |
| num_steps=cfgs.steps, | |
| discretization_config=discretization_config, | |
| guider_config=guider_config, | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=999.0, | |
| s_noise=1.0, | |
| verbose=True, | |
| device=torch.device("cuda", index=cfgs.gpu) | |
| ) | |
| else: | |
| guider_config = { | |
| "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", | |
| "params": {"scale": cfgs.scale[0]}, | |
| } | |
| sampler = EulerEDMSampler( | |
| num_steps=cfgs.steps, | |
| discretization_config=discretization_config, | |
| guider_config=guider_config, | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=999.0, | |
| s_noise=1.0, | |
| verbose=True, | |
| device=torch.device("cuda", index=cfgs.gpu) | |
| ) | |
| return sampler | |
| def deep_copy(batch): | |
| c_batch = {} | |
| for key in batch: | |
| if isinstance(batch[key], torch.Tensor): | |
| c_batch[key] = torch.clone(batch[key]) | |
| elif isinstance(batch[key], (tuple, list)): | |
| c_batch[key] = batch[key].copy() | |
| else: | |
| c_batch[key] = batch[key] | |
| return c_batch | |
| def prepare_batch(cfgs, batch): | |
| for key in batch: | |
| if isinstance(batch[key], torch.Tensor) and cfgs.use_gpu: | |
| batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) | |
| if not cfgs.dual_conditioner: | |
| batch_uc = deep_copy(batch) | |
| if "ntxt" in batch: | |
| batch_uc["txt"] = batch["ntxt"] | |
| else: | |
| batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))] | |
| if "label" in batch: | |
| batch_uc["label"] = ["" for _ in range(len(batch["label"]))] | |
| return batch, batch_uc, None | |
| else: | |
| batch_uc_1 = deep_copy(batch) | |
| batch_uc_2 = deep_copy(batch) | |
| batch_uc_1["ref"] = torch.zeros_like(batch["ref"]) | |
| batch_uc_2["ref"] = torch.zeros_like(batch["ref"]) | |
| batch_uc_1["label"] = ["" for _ in range(len(batch["label"]))] | |
| return batch, batch_uc_1, batch_uc_2 |