Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import PIL | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from ... import ConfigMixin | |
| from ...configuration_utils import register_to_config | |
| from ...image_processor import PipelineImageInput | |
| from ...utils import CONFIG_NAME, logging | |
| from ...utils.import_utils import is_matplotlib_available | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class MarigoldImageProcessor(ConfigMixin): | |
| config_name = CONFIG_NAME | |
| def __init__( | |
| self, | |
| vae_scale_factor: int = 8, | |
| do_normalize: bool = True, | |
| do_range_check: bool = True, | |
| ): | |
| super().__init__() | |
| def expand_tensor_or_array(images: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: | |
| """ | |
| Expand a tensor or array to a specified number of images. | |
| """ | |
| if isinstance(images, np.ndarray): | |
| if images.ndim == 2: # [H,W] -> [1,H,W,1] | |
| images = images[None, ..., None] | |
| if images.ndim == 3: # [H,W,C] -> [1,H,W,C] | |
| images = images[None] | |
| elif isinstance(images, torch.Tensor): | |
| if images.ndim == 2: # [H,W] -> [1,1,H,W] | |
| images = images[None, None] | |
| elif images.ndim == 3: # [1,H,W] -> [1,1,H,W] | |
| images = images[None] | |
| else: | |
| raise ValueError(f"Unexpected input type: {type(images)}") | |
| return images | |
| def pt_to_numpy(images: torch.Tensor) -> np.ndarray: | |
| """ | |
| Convert a PyTorch tensor to a NumPy image. | |
| """ | |
| images = images.cpu().permute(0, 2, 3, 1).float().numpy() | |
| return images | |
| def numpy_to_pt(images: np.ndarray) -> torch.Tensor: | |
| """ | |
| Convert a NumPy image to a PyTorch tensor. | |
| """ | |
| if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): | |
| raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") | |
| if np.issubdtype(images.dtype, np.complexfloating): | |
| raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") | |
| if np.issubdtype(images.dtype, bool): | |
| raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") | |
| images = torch.from_numpy(images.transpose(0, 3, 1, 2)) | |
| return images | |
| def resize_antialias( | |
| image: torch.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None | |
| ) -> torch.Tensor: | |
| if not torch.is_tensor(image): | |
| raise ValueError(f"Invalid input type={type(image)}.") | |
| if not torch.is_floating_point(image): | |
| raise ValueError(f"Invalid input dtype={image.dtype}.") | |
| if image.dim() != 4: | |
| raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
| antialias = is_aa and mode in ("bilinear", "bicubic") | |
| image = F.interpolate(image, size, mode=mode, antialias=antialias) | |
| return image | |
| def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor: | |
| if not torch.is_tensor(image): | |
| raise ValueError(f"Invalid input type={type(image)}.") | |
| if not torch.is_floating_point(image): | |
| raise ValueError(f"Invalid input dtype={image.dtype}.") | |
| if image.dim() != 4: | |
| raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
| h, w = image.shape[-2:] | |
| max_orig = max(h, w) | |
| new_h = h * max_edge_sz // max_orig | |
| new_w = w * max_edge_sz // max_orig | |
| if new_h == 0 or new_w == 0: | |
| raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") | |
| image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) | |
| return image | |
| def pad_image(image: torch.Tensor, align: int) -> Tuple[torch.Tensor, Tuple[int, int]]: | |
| if not torch.is_tensor(image): | |
| raise ValueError(f"Invalid input type={type(image)}.") | |
| if not torch.is_floating_point(image): | |
| raise ValueError(f"Invalid input dtype={image.dtype}.") | |
| if image.dim() != 4: | |
| raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
| h, w = image.shape[-2:] | |
| ph, pw = -h % align, -w % align | |
| image = F.pad(image, (0, pw, 0, ph), mode="replicate") | |
| return image, (ph, pw) | |
| def unpad_image(image: torch.Tensor, padding: Tuple[int, int]) -> torch.Tensor: | |
| if not torch.is_tensor(image): | |
| raise ValueError(f"Invalid input type={type(image)}.") | |
| if not torch.is_floating_point(image): | |
| raise ValueError(f"Invalid input dtype={image.dtype}.") | |
| if image.dim() != 4: | |
| raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
| ph, pw = padding | |
| uh = None if ph == 0 else -ph | |
| uw = None if pw == 0 else -pw | |
| image = image[:, :, :uh, :uw] | |
| return image | |
| def load_image_canonical( | |
| image: Union[torch.Tensor, np.ndarray, Image.Image], | |
| device: torch.device = torch.device("cpu"), | |
| dtype: torch.dtype = torch.float32, | |
| ) -> Tuple[torch.Tensor, int]: | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| image_dtype_max = None | |
| if isinstance(image, (np.ndarray, torch.Tensor)): | |
| image = MarigoldImageProcessor.expand_tensor_or_array(image) | |
| if image.ndim != 4: | |
| raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") | |
| if isinstance(image, np.ndarray): | |
| if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): | |
| raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") | |
| if np.issubdtype(image.dtype, np.complexfloating): | |
| raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") | |
| if np.issubdtype(image.dtype, bool): | |
| raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") | |
| if np.issubdtype(image.dtype, np.unsignedinteger): | |
| image_dtype_max = np.iinfo(image.dtype).max | |
| image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond torch.uint8 | |
| image = MarigoldImageProcessor.numpy_to_pt(image) | |
| if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None: | |
| if image.dtype != torch.uint8: | |
| raise ValueError(f"Image dtype={image.dtype} is not supported.") | |
| image_dtype_max = 255 | |
| if not torch.is_tensor(image): | |
| raise ValueError(f"Input type unsupported: {type(image)}.") | |
| if image.shape[1] == 1: | |
| image = image.repeat(1, 3, 1, 1) # [N,1,H,W] -> [N,3,H,W] | |
| if image.shape[1] != 3: | |
| raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") | |
| image = image.to(device=device, dtype=dtype) | |
| if image_dtype_max is not None: | |
| image = image / image_dtype_max | |
| return image | |
| def check_image_values_range(image: torch.Tensor) -> None: | |
| if not torch.is_tensor(image): | |
| raise ValueError(f"Invalid input type={type(image)}.") | |
| if not torch.is_floating_point(image): | |
| raise ValueError(f"Invalid input dtype={image.dtype}.") | |
| if image.min().item() < 0.0 or image.max().item() > 1.0: | |
| raise ValueError("Input image data is partially outside of the [0,1] range.") | |
| def preprocess( | |
| self, | |
| image: PipelineImageInput, | |
| processing_resolution: Optional[int] = None, | |
| resample_method_input: str = "bilinear", | |
| device: torch.device = torch.device("cpu"), | |
| dtype: torch.dtype = torch.float32, | |
| ): | |
| if isinstance(image, list): | |
| images = None | |
| for i, img in enumerate(image): | |
| img = self.load_image_canonical(img, device, dtype) # [N,3,H,W] | |
| if images is None: | |
| images = img | |
| else: | |
| if images.shape[2:] != img.shape[2:]: | |
| raise ValueError( | |
| f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " | |
| f"{images.shape[2:]}" | |
| ) | |
| images = torch.cat((images, img), dim=0) | |
| image = images | |
| del images | |
| else: | |
| image = self.load_image_canonical(image, device, dtype) # [N,3,H,W] | |
| original_resolution = image.shape[2:] | |
| if self.config.do_range_check: | |
| self.check_image_values_range(image) | |
| if self.config.do_normalize: | |
| image = image * 2.0 - 1.0 | |
| if processing_resolution is not None and processing_resolution > 0: | |
| image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW] | |
| image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW] | |
| return image, padding, original_resolution | |
| def colormap( | |
| image: Union[np.ndarray, torch.Tensor], | |
| cmap: str = "Spectral", | |
| bytes: bool = False, | |
| _force_method: Optional[str] = None, | |
| ) -> Union[np.ndarray, torch.Tensor]: | |
| """ | |
| Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the | |
| behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral", | |
| "binary") without having to install or import matplotlib. For all other cases, the function will attempt to use | |
| the native implementation. | |
| Args: | |
| image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. | |
| cmap: Colormap name. | |
| bytes: Whether to return the output as uint8 or floating point image. | |
| _force_method: | |
| Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom | |
| implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default). | |
| Returns: | |
| An RGB-colorized tensor corresponding to the input image. | |
| """ | |
| if not (torch.is_tensor(image) or isinstance(image, np.ndarray)): | |
| raise ValueError("Argument must be a numpy array or torch tensor.") | |
| if _force_method not in (None, "matplotlib", "custom"): | |
| raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") | |
| supported_cmaps = { | |
| "binary": [ | |
| (1.0, 1.0, 1.0), | |
| (0.0, 0.0, 0.0), | |
| ], | |
| "Spectral": [ # Taken from matplotlib/_cm.py | |
| (0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] | |
| (0.83529411764705885, 0.24313725490196078, 0.30980392156862746), | |
| (0.95686274509803926, 0.42745098039215684, 0.2627450980392157), | |
| (0.99215686274509807, 0.68235294117647061, 0.38039215686274508), | |
| (0.99607843137254903, 0.8784313725490196, 0.54509803921568623), | |
| (1.0, 1.0, 0.74901960784313726), | |
| (0.90196078431372551, 0.96078431372549022, 0.59607843137254901), | |
| (0.6705882352941176, 0.8666666666666667, 0.64313725490196083), | |
| (0.4, 0.76078431372549016, 0.6470588235294118), | |
| (0.19607843137254902, 0.53333333333333333, 0.74117647058823533), | |
| (0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1] | |
| ], | |
| } | |
| def method_matplotlib(image, cmap, bytes=False): | |
| if is_matplotlib_available(): | |
| import matplotlib | |
| else: | |
| return None | |
| arg_is_pt, device = torch.is_tensor(image), None | |
| if arg_is_pt: | |
| image, device = image.cpu().numpy(), image.device | |
| if cmap not in matplotlib.colormaps: | |
| raise ValueError( | |
| f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" | |
| ) | |
| cmap = matplotlib.colormaps[cmap] | |
| out = cmap(image, bytes=bytes) # [?,4] | |
| out = out[..., :3] # [?,3] | |
| if arg_is_pt: | |
| out = torch.tensor(out, device=device) | |
| return out | |
| def method_custom(image, cmap, bytes=False): | |
| arg_is_np = isinstance(image, np.ndarray) | |
| if arg_is_np: | |
| image = torch.tensor(image) | |
| if image.dtype == torch.uint8: | |
| image = image.float() / 255 | |
| else: | |
| image = image.float() | |
| is_cmap_reversed = cmap.endswith("_r") | |
| if is_cmap_reversed: | |
| cmap = cmap[:-2] | |
| if cmap not in supported_cmaps: | |
| raise ValueError( | |
| f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib." | |
| ) | |
| cmap = supported_cmaps[cmap] | |
| if is_cmap_reversed: | |
| cmap = cmap[::-1] | |
| cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3] | |
| K = cmap.shape[0] | |
| pos = image.clamp(min=0, max=1) * (K - 1) | |
| left = pos.long() | |
| right = (left + 1).clamp(max=K - 1) | |
| d = (pos - left.float()).unsqueeze(-1) | |
| left_colors = cmap[left] | |
| right_colors = cmap[right] | |
| out = (1 - d) * left_colors + d * right_colors | |
| if bytes: | |
| out = (out * 255).to(torch.uint8) | |
| if arg_is_np: | |
| out = out.numpy() | |
| return out | |
| if _force_method is None and torch.is_tensor(image) and cmap == "Spectral": | |
| return method_custom(image, cmap, bytes) | |
| out = None | |
| if _force_method != "custom": | |
| out = method_matplotlib(image, cmap, bytes) | |
| if _force_method == "matplotlib" and out is None: | |
| raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") | |
| if out is None: | |
| out = method_custom(image, cmap, bytes) | |
| return out | |
| def visualize_depth( | |
| depth: Union[ | |
| PIL.Image.Image, | |
| np.ndarray, | |
| torch.Tensor, | |
| List[PIL.Image.Image], | |
| List[np.ndarray], | |
| List[torch.Tensor], | |
| ], | |
| val_min: float = 0.0, | |
| val_max: float = 1.0, | |
| color_map: str = "Spectral", | |
| ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
| """ | |
| Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. | |
| Args: | |
| depth (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], | |
| List[torch.Tensor]]`): Depth maps. | |
| val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. | |
| val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. | |
| color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel | |
| depth prediction into colored representation. | |
| Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization. | |
| """ | |
| if val_max <= val_min: | |
| raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") | |
| def visualize_depth_one(img, idx=None): | |
| prefix = "Depth" + (f"[{idx}]" if idx else "") | |
| if isinstance(img, PIL.Image.Image): | |
| if img.mode != "I;16": | |
| raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") | |
| img = np.array(img).astype(np.float32) / (2**16 - 1) | |
| if isinstance(img, np.ndarray) or torch.is_tensor(img): | |
| if img.ndim != 2: | |
| raise ValueError(f"{prefix}: unexpected shape={img.shape}.") | |
| if isinstance(img, np.ndarray): | |
| img = torch.from_numpy(img) | |
| if not torch.is_floating_point(img): | |
| raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") | |
| else: | |
| raise ValueError(f"{prefix}: unexpected type={type(img)}.") | |
| if val_min != 0.0 or val_max != 1.0: | |
| img = (img - val_min) / (val_max - val_min) | |
| img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3] | |
| img = PIL.Image.fromarray(img.cpu().numpy()) | |
| return img | |
| if depth is None or isinstance(depth, list) and any(o is None for o in depth): | |
| raise ValueError("Input depth is `None`") | |
| if isinstance(depth, (np.ndarray, torch.Tensor)): | |
| depth = MarigoldImageProcessor.expand_tensor_or_array(depth) | |
| if isinstance(depth, np.ndarray): | |
| depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] | |
| if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W] | |
| raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") | |
| return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] | |
| elif isinstance(depth, list): | |
| return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] | |
| else: | |
| raise ValueError(f"Unexpected input type: {type(depth)}") | |
| def export_depth_to_16bit_png( | |
| depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], | |
| val_min: float = 0.0, | |
| val_max: float = 1.0, | |
| ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
| def export_depth_to_16bit_png_one(img, idx=None): | |
| prefix = "Depth" + (f"[{idx}]" if idx else "") | |
| if not isinstance(img, np.ndarray) and not torch.is_tensor(img): | |
| raise ValueError(f"{prefix}: unexpected type={type(img)}.") | |
| if img.ndim != 2: | |
| raise ValueError(f"{prefix}: unexpected shape={img.shape}.") | |
| if torch.is_tensor(img): | |
| img = img.cpu().numpy() | |
| if not np.issubdtype(img.dtype, np.floating): | |
| raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") | |
| if val_min != 0.0 or val_max != 1.0: | |
| img = (img - val_min) / (val_max - val_min) | |
| img = (img * (2**16 - 1)).astype(np.uint16) | |
| img = PIL.Image.fromarray(img, mode="I;16") | |
| return img | |
| if depth is None or isinstance(depth, list) and any(o is None for o in depth): | |
| raise ValueError("Input depth is `None`") | |
| if isinstance(depth, (np.ndarray, torch.Tensor)): | |
| depth = MarigoldImageProcessor.expand_tensor_or_array(depth) | |
| if isinstance(depth, np.ndarray): | |
| depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] | |
| if not (depth.ndim == 4 and depth.shape[1] == 1): | |
| raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") | |
| return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] | |
| elif isinstance(depth, list): | |
| return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] | |
| else: | |
| raise ValueError(f"Unexpected input type: {type(depth)}") | |
| def visualize_normals( | |
| normals: Union[ | |
| np.ndarray, | |
| torch.Tensor, | |
| List[np.ndarray], | |
| List[torch.Tensor], | |
| ], | |
| flip_x: bool = False, | |
| flip_y: bool = False, | |
| flip_z: bool = False, | |
| ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
| """ | |
| Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. | |
| Args: | |
| normals (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): | |
| Surface normals. | |
| flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. | |
| Default direction is right. | |
| flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. | |
| Default direction is top. | |
| flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. | |
| Default direction is facing the observer. | |
| Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization. | |
| """ | |
| flip_vec = None | |
| if any((flip_x, flip_y, flip_z)): | |
| flip_vec = torch.tensor( | |
| [ | |
| (-1) ** flip_x, | |
| (-1) ** flip_y, | |
| (-1) ** flip_z, | |
| ], | |
| dtype=torch.float32, | |
| ) | |
| def visualize_normals_one(img, idx=None): | |
| img = img.permute(1, 2, 0) | |
| if flip_vec is not None: | |
| img *= flip_vec.to(img.device) | |
| img = (img + 1.0) * 0.5 | |
| img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() | |
| img = PIL.Image.fromarray(img) | |
| return img | |
| if normals is None or isinstance(normals, list) and any(o is None for o in normals): | |
| raise ValueError("Input normals is `None`") | |
| if isinstance(normals, (np.ndarray, torch.Tensor)): | |
| normals = MarigoldImageProcessor.expand_tensor_or_array(normals) | |
| if isinstance(normals, np.ndarray): | |
| normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W] | |
| if not (normals.ndim == 4 and normals.shape[1] == 3): | |
| raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") | |
| return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] | |
| elif isinstance(normals, list): | |
| return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] | |
| else: | |
| raise ValueError(f"Unexpected input type: {type(normals)}") | |
| def visualize_uncertainty( | |
| uncertainty: Union[ | |
| np.ndarray, | |
| torch.Tensor, | |
| List[np.ndarray], | |
| List[torch.Tensor], | |
| ], | |
| saturation_percentile=95, | |
| ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
| """ | |
| Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. | |
| Args: | |
| uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): | |
| Uncertainty maps. | |
| saturation_percentile (`int`, *optional*, defaults to `95`): | |
| Specifies the percentile uncertainty value visualized with maximum intensity. | |
| Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization. | |
| """ | |
| def visualize_uncertainty_one(img, idx=None): | |
| prefix = "Uncertainty" + (f"[{idx}]" if idx else "") | |
| if img.min() < 0: | |
| raise ValueError(f"{prefix}: unexected data range, min={img.min()}.") | |
| img = img.squeeze(0).cpu().numpy() | |
| saturation_value = np.percentile(img, saturation_percentile) | |
| img = np.clip(img * 255 / saturation_value, 0, 255) | |
| img = img.astype(np.uint8) | |
| img = PIL.Image.fromarray(img) | |
| return img | |
| if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): | |
| raise ValueError("Input uncertainty is `None`") | |
| if isinstance(uncertainty, (np.ndarray, torch.Tensor)): | |
| uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) | |
| if isinstance(uncertainty, np.ndarray): | |
| uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W] | |
| if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1): | |
| raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].") | |
| return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] | |
| elif isinstance(uncertainty, list): | |
| return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] | |
| else: | |
| raise ValueError(f"Unexpected input type: {type(uncertainty)}") | |