Spaces:
Runtime error
Runtime error
| import argparse, os, sys, glob | |
| import torch | |
| import time | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from tqdm import tqdm, trange | |
| from einops import repeat | |
| from main import instantiate_from_config | |
| from taming.modules.transformer.mingpt import sample_with_past | |
| rescale = lambda x: (x + 1.) / 2. | |
| def chw_to_pillow(x): | |
| return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8)) | |
| def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None, | |
| dim_z=256, h=16, w=16, verbose_time=False, top_p=None): | |
| log = dict() | |
| assert type(class_label) == int, f'expecting type int but type is {type(class_label)}' | |
| qzshape = [batch_size, dim_z, h, w] | |
| assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.' | |
| c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token | |
| t1 = time.time() | |
| index_sample = sample_with_past(c_indices, model.transformer, steps=steps, | |
| sample_logits=True, top_k=top_k, callback=callback, | |
| temperature=temperature, top_p=top_p) | |
| if verbose_time: | |
| sampling_time = time.time() - t1 | |
| print(f"Full sampling takes about {sampling_time:.2f} seconds.") | |
| x_sample = model.decode_to_img(index_sample, qzshape) | |
| log["samples"] = x_sample | |
| log["class_label"] = c_indices | |
| return log | |
| def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None, | |
| dim_z=256, h=16, w=16, verbose_time=False): | |
| log = dict() | |
| qzshape = [batch_size, dim_z, h, w] | |
| assert model.be_unconditional, 'Expecting an unconditional model.' | |
| c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token | |
| t1 = time.time() | |
| index_sample = sample_with_past(c_indices, model.transformer, steps=steps, | |
| sample_logits=True, top_k=top_k, callback=callback, | |
| temperature=temperature, top_p=top_p) | |
| if verbose_time: | |
| sampling_time = time.time() - t1 | |
| print(f"Full sampling takes about {sampling_time:.2f} seconds.") | |
| x_sample = model.decode_to_img(index_sample, qzshape) | |
| log["samples"] = x_sample | |
| return log | |
| def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000, | |
| given_classes=None, top_p=None): | |
| batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size] | |
| if not unconditional: | |
| assert given_classes is not None | |
| print("Running in pure class-conditional sampling mode. I will produce " | |
| f"{num_samples} samples for each of the {len(given_classes)} classes, " | |
| f"i.e. {num_samples*len(given_classes)} in total.") | |
| for class_label in tqdm(given_classes, desc="Classes"): | |
| for n, bs in tqdm(enumerate(batches), desc="Sampling Class"): | |
| if bs == 0: break | |
| logs = sample_classconditional(model, batch_size=bs, class_label=class_label, | |
| temperature=temperature, top_k=top_k, top_p=top_p) | |
| save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"]) | |
| else: | |
| print(f"Running in unconditional sampling mode, producing {num_samples} samples.") | |
| for n, bs in tqdm(enumerate(batches), desc="Sampling"): | |
| if bs == 0: break | |
| logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p) | |
| save_from_logs(logs, logdir, base_count=n * batch_size) | |
| def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None): | |
| xx = logs[key] | |
| for i, x in enumerate(xx): | |
| x = chw_to_pillow(x) | |
| count = base_count + i | |
| if cond_key is None: | |
| x.save(os.path.join(logdir, f"{count:06}.png")) | |
| else: | |
| condlabel = cond_key[i] | |
| if type(condlabel) == torch.Tensor: condlabel = condlabel.item() | |
| os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True) | |
| x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png")) | |
| def get_parser(): | |
| def str2bool(v): | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ("yes", "true", "t", "y", "1"): | |
| return True | |
| elif v.lower() in ("no", "false", "f", "n", "0"): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError("Boolean value expected.") | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-r", | |
| "--resume", | |
| type=str, | |
| nargs="?", | |
| help="load from logdir or checkpoint in logdir", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--outdir", | |
| type=str, | |
| nargs="?", | |
| help="path where the samples will be logged to.", | |
| default="" | |
| ) | |
| parser.add_argument( | |
| "-b", | |
| "--base", | |
| nargs="*", | |
| metavar="base_config.yaml", | |
| help="paths to base configs. Loaded from left-to-right. " | |
| "Parameters can be overwritten or added with command-line options of the form `--key value`.", | |
| default=list(), | |
| ) | |
| parser.add_argument( | |
| "-n", | |
| "--num_samples", | |
| type=int, | |
| nargs="?", | |
| help="num_samples to draw", | |
| default=50000 | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| nargs="?", | |
| help="the batch size", | |
| default=25 | |
| ) | |
| parser.add_argument( | |
| "-k", | |
| "--top_k", | |
| type=int, | |
| nargs="?", | |
| help="top-k value to sample with", | |
| default=250, | |
| ) | |
| parser.add_argument( | |
| "-t", | |
| "--temperature", | |
| type=float, | |
| nargs="?", | |
| help="temperature value to sample with", | |
| default=1.0 | |
| ) | |
| parser.add_argument( | |
| "-p", | |
| "--top_p", | |
| type=float, | |
| nargs="?", | |
| help="top-p value to sample with", | |
| default=1.0 | |
| ) | |
| parser.add_argument( | |
| "--classes", | |
| type=str, | |
| nargs="?", | |
| help="specify comma-separated classes to sample from. Uses 1000 classes per default.", | |
| default="imagenet" | |
| ) | |
| return parser | |
| def load_model_from_config(config, sd, gpu=True, eval_mode=True): | |
| model = instantiate_from_config(config) | |
| if sd is not None: | |
| model.load_state_dict(sd) | |
| if gpu: | |
| model.cuda() | |
| if eval_mode: | |
| model.eval() | |
| return {"model": model} | |
| def load_model(config, ckpt, gpu, eval_mode): | |
| # load the specified checkpoint | |
| if ckpt: | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| global_step = pl_sd["global_step"] | |
| print(f"loaded model from global step {global_step}.") | |
| else: | |
| pl_sd = {"state_dict": None} | |
| global_step = None | |
| model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] | |
| return model, global_step | |
| if __name__ == "__main__": | |
| sys.path.append(os.getcwd()) | |
| parser = get_parser() | |
| opt, unknown = parser.parse_known_args() | |
| assert opt.resume | |
| ckpt = None | |
| if not os.path.exists(opt.resume): | |
| raise ValueError("Cannot find {}".format(opt.resume)) | |
| if os.path.isfile(opt.resume): | |
| paths = opt.resume.split("/") | |
| try: | |
| idx = len(paths)-paths[::-1].index("logs")+1 | |
| except ValueError: | |
| idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt | |
| logdir = "/".join(paths[:idx]) | |
| ckpt = opt.resume | |
| else: | |
| assert os.path.isdir(opt.resume), opt.resume | |
| logdir = opt.resume.rstrip("/") | |
| ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") | |
| base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) | |
| opt.base = base_configs+opt.base | |
| configs = [OmegaConf.load(cfg) for cfg in opt.base] | |
| cli = OmegaConf.from_dotlist(unknown) | |
| config = OmegaConf.merge(*configs, cli) | |
| model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True) | |
| if opt.outdir: | |
| print(f"Switching logdir from '{logdir}' to '{opt.outdir}'") | |
| logdir = opt.outdir | |
| if opt.classes == "imagenet": | |
| given_classes = [i for i in range(1000)] | |
| else: | |
| cls_str = opt.classes | |
| assert not cls_str.endswith(","), 'class string should not end with a ","' | |
| given_classes = [int(c) for c in cls_str.split(",")] | |
| logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}", | |
| f"{global_step}") | |
| print(f"Logging to {logdir}") | |
| os.makedirs(logdir, exist_ok=True) | |
| run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional, | |
| given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p) | |
| print("done.") | |