Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| from tqdm import trange | |
| from PIL import Image | |
| def get_state(gpu): | |
| import torch | |
| midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") | |
| if gpu: | |
| midas.cuda() | |
| midas.eval() | |
| midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") | |
| transform = midas_transforms.default_transform | |
| state = {"model": midas, | |
| "transform": transform} | |
| return state | |
| def depth_to_rgba(x): | |
| assert x.dtype == np.float32 | |
| assert len(x.shape) == 2 | |
| y = x.copy() | |
| y.dtype = np.uint8 | |
| y = y.reshape(x.shape+(4,)) | |
| return np.ascontiguousarray(y) | |
| def rgba_to_depth(x): | |
| assert x.dtype == np.uint8 | |
| assert len(x.shape) == 3 and x.shape[2] == 4 | |
| y = x.copy() | |
| y.dtype = np.float32 | |
| y = y.reshape(x.shape[:2]) | |
| return np.ascontiguousarray(y) | |
| def run(x, state): | |
| model = state["model"] | |
| transform = state["transform"] | |
| hw = x.shape[:2] | |
| with torch.no_grad(): | |
| prediction = model(transform((x + 1.0) * 127.5).cuda()) | |
| prediction = torch.nn.functional.interpolate( | |
| prediction.unsqueeze(1), | |
| size=hw, | |
| mode="bicubic", | |
| align_corners=False, | |
| ).squeeze() | |
| output = prediction.cpu().numpy() | |
| return output | |
| def get_filename(relpath, level=-2): | |
| # save class folder structure and filename: | |
| fn = relpath.split(os.sep)[level:] | |
| folder = fn[-2] | |
| file = fn[-1].split('.')[0] | |
| return folder, file | |
| def save_depth(dataset, path, debug=False): | |
| os.makedirs(path) | |
| N = len(dset) | |
| if debug: | |
| N = 10 | |
| state = get_state(gpu=True) | |
| for idx in trange(N, desc="Data"): | |
| ex = dataset[idx] | |
| image, relpath = ex["image"], ex["relpath"] | |
| folder, filename = get_filename(relpath) | |
| # prepare | |
| folderabspath = os.path.join(path, folder) | |
| os.makedirs(folderabspath, exist_ok=True) | |
| savepath = os.path.join(folderabspath, filename) | |
| # run model | |
| xout = run(image, state) | |
| I = depth_to_rgba(xout) | |
| Image.fromarray(I).save("{}.png".format(savepath)) | |
| if __name__ == "__main__": | |
| from taming.data.imagenet import ImageNetTrain, ImageNetValidation | |
| out = "data/imagenet_depth" | |
| if not os.path.exists(out): | |
| print("Please create a folder or symlink '{}' to extract depth data ".format(out) + | |
| "(be prepared that the output size will be larger than ImageNet itself).") | |
| exit(1) | |
| # go | |
| dset = ImageNetValidation() | |
| abspath = os.path.join(out, "val") | |
| if os.path.exists(abspath): | |
| print("{} exists - not doing anything.".format(abspath)) | |
| else: | |
| print("preparing {}".format(abspath)) | |
| save_depth(dset, abspath) | |
| print("done with validation split") | |
| dset = ImageNetTrain() | |
| abspath = os.path.join(out, "train") | |
| if os.path.exists(abspath): | |
| print("{} exists - not doing anything.".format(abspath)) | |
| else: | |
| print("preparing {}".format(abspath)) | |
| save_depth(dset, abspath) | |
| print("done with train split") | |
| print("done done.") | |