Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from torch.utils.data import Sampler | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| import os, sys | |
| import numpy as np | |
| import math | |
| import torch | |
| def convert_arg_line_to_args(arg_line): | |
| for arg in arg_line.split(): | |
| if not arg.strip(): | |
| continue | |
| yield arg | |
| def block_print(): | |
| sys.stdout = open(os.devnull, 'w') | |
| def enable_print(): | |
| sys.stdout = sys.__stdout__ | |
| def get_num_lines(file_path): | |
| f = open(file_path, 'r') | |
| lines = f.readlines() | |
| f.close() | |
| return len(lines) | |
| def colorize(value, vmin=None, vmax=None, cmap='Greys'): | |
| value = value.cpu().numpy()[:, :, :] | |
| value = np.log10(value) | |
| vmin = value.min() if vmin is None else vmin | |
| vmax = value.max() if vmax is None else vmax | |
| if vmin != vmax: | |
| value = (value - vmin) / (vmax - vmin) | |
| else: | |
| value = value*0. | |
| cmapper = matplotlib.cm.get_cmap(cmap) | |
| value = cmapper(value, bytes=True) | |
| img = value[:, :, :3] | |
| return img.transpose((2, 0, 1)) | |
| def normalize_result(value, vmin=None, vmax=None): | |
| value = value.cpu().numpy()[0, :, :] | |
| vmin = value.min() if vmin is None else vmin | |
| vmax = value.max() if vmax is None else vmax | |
| if vmin != vmax: | |
| value = (value - vmin) / (vmax - vmin) | |
| else: | |
| value = value * 0. | |
| return np.expand_dims(value, 0) | |
| inv_normalize = transforms.Normalize( | |
| mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], | |
| std=[1/0.229, 1/0.224, 1/0.225] | |
| ) | |
| eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3'] | |
| def compute_errors(gt, pred): | |
| thresh = np.maximum((gt / pred), (pred / gt)) | |
| d1 = (thresh < 1.25).mean() | |
| d2 = (thresh < 1.25 ** 2).mean() | |
| d3 = (thresh < 1.25 ** 3).mean() | |
| rms = (gt - pred) ** 2 | |
| rms = np.sqrt(rms.mean()) | |
| log_rms = (np.log(gt) - np.log(pred)) ** 2 | |
| log_rms = np.sqrt(log_rms.mean()) | |
| abs_rel = np.mean(np.abs(gt - pred) / gt) | |
| sq_rel = np.mean(((gt - pred) ** 2) / gt) | |
| err = np.log(pred) - np.log(gt) | |
| silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 | |
| err = np.abs(np.log10(pred) - np.log10(gt)) | |
| log10 = np.mean(err) | |
| return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3] | |
| class silog_loss(nn.Module): | |
| def __init__(self, variance_focus): | |
| super(silog_loss, self).__init__() | |
| self.variance_focus = variance_focus | |
| def forward(self, depth_est, depth_gt, mask): | |
| d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask]) | |
| return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0 | |
| def entropy_loss(preds, gt_label, mask): | |
| # preds: B, C, H, W | |
| # gt_label: B, H, W | |
| # mask: B, H, W | |
| mask = mask > 0.0 # B, H, W | |
| preds = preds.permute(0, 2, 3, 1) # B, H, W, C | |
| preds_mask = preds[mask] # N, C | |
| gt_label_mask = gt_label[mask] # N | |
| loss = F.cross_entropy(preds_mask, gt_label_mask, reduction='mean') | |
| return loss | |
| def colormap(inputs, normalize=True, torch_transpose=True): | |
| if isinstance(inputs, torch.Tensor): | |
| inputs = inputs.detach().cpu().numpy() | |
| _DEPTH_COLORMAP = plt.get_cmap('jet', 256) # for plotting | |
| vis = inputs | |
| if normalize: | |
| ma = float(vis.max()) | |
| mi = float(vis.min()) | |
| d = ma - mi if ma != mi else 1e5 | |
| vis = (vis - mi) / d | |
| if vis.ndim == 4: | |
| vis = vis.transpose([0, 2, 3, 1]) | |
| vis = _DEPTH_COLORMAP(vis) | |
| vis = vis[:, :, :, 0, :3] | |
| if torch_transpose: | |
| vis = vis.transpose(0, 3, 1, 2) | |
| elif vis.ndim == 3: | |
| vis = _DEPTH_COLORMAP(vis) | |
| vis = vis[:, :, :, :3] | |
| if torch_transpose: | |
| vis = vis.transpose(0, 3, 1, 2) | |
| elif vis.ndim == 2: | |
| vis = _DEPTH_COLORMAP(vis) | |
| vis = vis[..., :3] | |
| if torch_transpose: | |
| vis = vis.transpose(2, 0, 1) | |
| return vis[0,:,:,:] | |
| def colormap_magma(inputs, normalize=True, torch_transpose=True): | |
| if isinstance(inputs, torch.Tensor): | |
| inputs = inputs.detach().cpu().numpy() | |
| _DEPTH_COLORMAP = plt.get_cmap('magma', 256) # for plotting | |
| vis = inputs | |
| if normalize: | |
| ma = float(vis.max()) | |
| mi = float(vis.min()) | |
| d = ma - mi if ma != mi else 1e5 | |
| vis = (vis - mi) / d | |
| if vis.ndim == 4: | |
| vis = vis.transpose([0, 2, 3, 1]) | |
| vis = _DEPTH_COLORMAP(vis) | |
| vis = vis[:, :, :, 0, :3] | |
| if torch_transpose: | |
| vis = vis.transpose(0, 3, 1, 2) | |
| elif vis.ndim == 3: | |
| vis = _DEPTH_COLORMAP(vis) | |
| vis = vis[:, :, :, :3] | |
| if torch_transpose: | |
| vis = vis.transpose(0, 3, 1, 2) | |
| elif vis.ndim == 2: | |
| vis = _DEPTH_COLORMAP(vis) | |
| vis = vis[..., :3] | |
| if torch_transpose: | |
| vis = vis.transpose(2, 0, 1) | |
| return vis[0,:,:,:] | |
| def flip_lr(image): | |
| """ | |
| Flip image horizontally | |
| Parameters | |
| ---------- | |
| image : torch.Tensor [B,3,H,W] | |
| Image to be flipped | |
| Returns | |
| ------- | |
| image_flipped : torch.Tensor [B,3,H,W] | |
| Flipped image | |
| """ | |
| assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip' | |
| return torch.flip(image, [3]) | |
| def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'): | |
| """ | |
| Fuse inverse depth and flipped inverse depth maps | |
| Parameters | |
| ---------- | |
| inv_depth : torch.Tensor [B,1,H,W] | |
| Inverse depth map | |
| inv_depth_hat : torch.Tensor [B,1,H,W] | |
| Flipped inverse depth map produced from a flipped image | |
| method : str | |
| Method that will be used to fuse the inverse depth maps | |
| Returns | |
| ------- | |
| fused_inv_depth : torch.Tensor [B,1,H,W] | |
| Fused inverse depth map | |
| """ | |
| if method == 'mean': | |
| return 0.5 * (inv_depth + inv_depth_hat) | |
| elif method == 'max': | |
| return torch.max(inv_depth, inv_depth_hat) | |
| elif method == 'min': | |
| return torch.min(inv_depth, inv_depth_hat) | |
| else: | |
| raise ValueError('Unknown post-process method {}'.format(method)) | |
| def post_process_depth(depth, depth_flipped, method='mean'): | |
| """ | |
| Post-process an inverse and flipped inverse depth map | |
| Parameters | |
| ---------- | |
| inv_depth : torch.Tensor [B,1,H,W] | |
| Inverse depth map | |
| inv_depth_flipped : torch.Tensor [B,1,H,W] | |
| Inverse depth map produced from a flipped image | |
| method : str | |
| Method that will be used to fuse the inverse depth maps | |
| Returns | |
| ------- | |
| inv_depth_pp : torch.Tensor [B,1,H,W] | |
| Post-processed inverse depth map | |
| """ | |
| B, C, H, W = depth.shape | |
| inv_depth_hat = flip_lr(depth_flipped) | |
| inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method) | |
| xs = torch.linspace(0., 1., W, device=depth.device, | |
| dtype=depth.dtype).repeat(B, C, H, 1) | |
| mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.) | |
| mask_hat = flip_lr(mask) | |
| return mask_hat * depth + mask * inv_depth_hat + \ | |
| (1.0 - mask - mask_hat) * inv_depth_fused | |
| class DistributedSamplerNoEvenlyDivisible(Sampler): | |
| """Sampler that restricts data loading to a subset of the dataset. | |
| It is especially useful in conjunction with | |
| :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each | |
| process can pass a DistributedSampler instance as a DataLoader sampler, | |
| and load a subset of the original dataset that is exclusive to it. | |
| .. note:: | |
| Dataset is assumed to be of constant size. | |
| Arguments: | |
| dataset: Dataset used for sampling. | |
| num_replicas (optional): Number of processes participating in | |
| distributed training. | |
| rank (optional): Rank of the current process within num_replicas. | |
| shuffle (optional): If true (default), sampler will shuffle the indices | |
| """ | |
| def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): | |
| if num_replicas is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| num_replicas = dist.get_world_size() | |
| if rank is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| rank = dist.get_rank() | |
| self.dataset = dataset | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.epoch = 0 | |
| num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) | |
| rest = len(self.dataset) - num_samples * self.num_replicas | |
| if self.rank < rest: | |
| num_samples += 1 | |
| self.num_samples = num_samples | |
| self.total_size = len(dataset) | |
| # self.total_size = self.num_samples * self.num_replicas | |
| self.shuffle = shuffle | |
| def __iter__(self): | |
| # deterministically shuffle based on epoch | |
| g = torch.Generator() | |
| g.manual_seed(self.epoch) | |
| if self.shuffle: | |
| indices = torch.randperm(len(self.dataset), generator=g).tolist() | |
| else: | |
| indices = list(range(len(self.dataset))) | |
| # add extra samples to make it evenly divisible | |
| # indices += indices[:(self.total_size - len(indices))] | |
| # assert len(indices) == self.total_size | |
| # subsample | |
| indices = indices[self.rank:self.total_size:self.num_replicas] | |
| self.num_samples = len(indices) | |
| # assert len(indices) == self.num_samples | |
| return iter(indices) | |
| def __len__(self): | |
| return self.num_samples | |
| def set_epoch(self, epoch): | |
| self.epoch = epoch | |
| class D_to_cloud(nn.Module): | |
| """Layer to transform depth into point cloud | |
| """ | |
| def __init__(self, batch_size, height, width): | |
| super(D_to_cloud, self).__init__() | |
| self.batch_size = batch_size | |
| self.height = height | |
| self.width = width | |
| meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') | |
| self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) # 2, H, W | |
| self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) # 2, H, W | |
| self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), | |
| requires_grad=False) # B, 1, H, W | |
| self.pix_coords = torch.unsqueeze(torch.stack( | |
| [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) # 1, 2, L | |
| self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) # B, 2, L | |
| self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) # B, 3, L | |
| def forward(self, depth, inv_K): | |
| cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) | |
| cam_points = depth.view(self.batch_size, 1, -1) * cam_points | |
| return cam_points.permute(0, 2, 1) |