Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| import copy | |
| def update_sample(bin_edges, target_bin_left, target_bin_right, depth_r, pred_label, depth_num, min_depth, max_depth, uncertainty_range): | |
| with torch.no_grad(): | |
| b, _, h, w = bin_edges.shape | |
| mode = 'direct' | |
| if mode == 'direct': | |
| depth_range = uncertainty_range | |
| depth_start_update = torch.clamp_min(depth_r - 0.5 * depth_range, min_depth) | |
| else: | |
| depth_range = uncertainty_range + (target_bin_right - target_bin_left).abs() | |
| depth_start_update = torch.clamp_min(target_bin_left - 0.5 * uncertainty_range, min_depth) | |
| interval = depth_range / depth_num | |
| interval = interval.repeat(1, depth_num, 1, 1) | |
| interval = torch.cat([torch.ones([b, 1, h, w], device=bin_edges.device) * depth_start_update, interval], 1) | |
| bin_edges = torch.cumsum(interval, 1).clamp(min_depth, max_depth) | |
| curr_depth = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:]) | |
| return bin_edges.detach(), curr_depth.detach() | |
| def get_label(gt_depth_img, bin_edges, depth_num): | |
| with torch.no_grad(): | |
| gt_label = torch.zeros(gt_depth_img.size(), dtype=torch.int64, device=gt_depth_img.device) | |
| for i in range(depth_num): | |
| bin_mask = torch.ge(gt_depth_img, bin_edges[:, i]) | |
| bin_mask = torch.logical_and(bin_mask, | |
| torch.lt(gt_depth_img, bin_edges[:, i + 1])) | |
| gt_label[bin_mask] = i | |
| return gt_label | |