Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| if TYPE_CHECKING: | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: | |
| """If the type of value is torch.Tensor, convert the value to np.ndarray. | |
| Args: | |
| value (np.ndarray, torch.Tensor): value. | |
| Returns: | |
| Any: value. | |
| """ | |
| if isinstance(value, torch.Tensor): | |
| value = value.detach().cpu().numpy() | |
| return value | |
| def value2list(value: Any, valid_type: Union[Type, Tuple[Type, ...]], | |
| expand_dim: int) -> List[Any]: | |
| """If the type of ``value`` is ``valid_type``, convert the value to list | |
| and expand to ``expand_dim``. | |
| Args: | |
| value (Any): value. | |
| valid_type (Union[Type, Tuple[Type, ...]): valid type. | |
| expand_dim (int): expand dim. | |
| Returns: | |
| List[Any]: value. | |
| """ | |
| if isinstance(value, valid_type): | |
| value = [value] * expand_dim | |
| return value | |
| def check_type(name: str, value: Any, | |
| valid_type: Union[Type, Tuple[Type, ...]]) -> None: | |
| """Check whether the type of value is in ``valid_type``. | |
| Args: | |
| name (str): value name. | |
| value (Any): value. | |
| valid_type (Type, Tuple[Type, ...]): expected type. | |
| """ | |
| if not isinstance(value, valid_type): | |
| raise TypeError(f'`{name}` should be {valid_type} ' | |
| f' but got {type(value)}') | |
| def check_length(name: str, value: Any, valid_length: int) -> None: | |
| """If type of the ``value`` is list, check whether its length is equal with | |
| or greater than ``valid_length``. | |
| Args: | |
| name (str): value name. | |
| value (Any): value. | |
| valid_length (int): expected length. | |
| """ | |
| if isinstance(value, list): | |
| if len(value) < valid_length: | |
| raise AssertionError( | |
| f'The length of {name} must equal with or ' | |
| f'greater than {valid_length}, but got {len(value)}') | |
| def check_type_and_length(name: str, value: Any, | |
| valid_type: Union[Type, Tuple[Type, ...]], | |
| valid_length: int) -> None: | |
| """Check whether the type of value is in ``valid_type``. If type of the | |
| ``value`` is list, check whether its length is equal with or greater than | |
| ``valid_length``. | |
| Args: | |
| value (Any): value. | |
| legal_type (Type, Tuple[Type, ...]): legal type. | |
| valid_length (int): expected length. | |
| Returns: | |
| List[Any]: value. | |
| """ | |
| check_type(name, value, valid_type) | |
| check_length(name, value, valid_length) | |
| def color_val_matplotlib( | |
| colors: Union[str, tuple, List[Union[str, tuple]]] | |
| ) -> Union[str, tuple, List[Union[str, tuple]]]: | |
| """Convert various input in RGB order to normalized RGB matplotlib color | |
| tuples, | |
| Args: | |
| colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs | |
| Returns: | |
| Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized | |
| floats indicating RGB channels. | |
| """ | |
| if isinstance(colors, str): | |
| return colors | |
| elif isinstance(colors, tuple): | |
| assert len(colors) == 3 | |
| for channel in colors: | |
| assert 0 <= channel <= 255 | |
| colors = [channel / 255 for channel in colors] | |
| return tuple(colors) | |
| elif isinstance(colors, list): | |
| colors = [ | |
| color_val_matplotlib(color) # type:ignore | |
| for color in colors | |
| ] | |
| return colors | |
| else: | |
| raise TypeError(f'Invalid type for color: {type(colors)}') | |
| def color_str2rgb(color: str) -> tuple: | |
| """Convert Matplotlib str color to an RGB color which range is 0 to 255, | |
| silently dropping the alpha channel. | |
| Args: | |
| color (str): Matplotlib color. | |
| Returns: | |
| tuple: RGB color. | |
| """ | |
| import matplotlib | |
| rgb_color: tuple = matplotlib.colors.to_rgb(color) | |
| rgb_color = tuple(int(c * 255) for c in rgb_color) | |
| return rgb_color | |
| def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], | |
| img: Optional[np.ndarray] = None, | |
| alpha: float = 0.5) -> np.ndarray: | |
| """Convert feat_map to heatmap and overlay on image, if image is not None. | |
| Args: | |
| feat_map (np.ndarray, torch.Tensor): The feat_map to convert | |
| with of shape (H, W), where H is the image height and W is | |
| the image width. | |
| img (np.ndarray, optional): The origin image. The format | |
| should be RGB. Defaults to None. | |
| alpha (float): The transparency of featmap. Defaults to 0.5. | |
| Returns: | |
| np.ndarray: heatmap | |
| """ | |
| assert feat_map.ndim == 2 or (feat_map.ndim == 3 | |
| and feat_map.shape[0] in [1, 3]) | |
| if isinstance(feat_map, torch.Tensor): | |
| feat_map = feat_map.detach().cpu().numpy() | |
| if feat_map.ndim == 3: | |
| feat_map = feat_map.transpose(1, 2, 0) | |
| norm_img = np.zeros(feat_map.shape) | |
| norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) | |
| norm_img = np.asarray(norm_img, dtype=np.uint8) | |
| heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) | |
| heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) | |
| if img is not None: | |
| heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) | |
| return heat_img | |
| def wait_continue(figure, timeout: float = 0, continue_key: str = ' ') -> int: | |
| """Show the image and wait for the user's input. | |
| This implementation refers to | |
| https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py | |
| Args: | |
| timeout (float): If positive, continue after ``timeout`` seconds. | |
| Defaults to 0. | |
| continue_key (str): The key for users to continue. Defaults to | |
| the space key. | |
| Returns: | |
| int: If zero, means time out or the user pressed ``continue_key``, | |
| and if one, means the user closed the show figure. | |
| """ # noqa: E501 | |
| import matplotlib.pyplot as plt | |
| from matplotlib.backend_bases import CloseEvent | |
| is_inline = 'inline' in plt.get_backend() | |
| if is_inline: | |
| # If use inline backend, interactive input and timeout is no use. | |
| return 0 | |
| if figure.canvas.manager: # type: ignore | |
| # Ensure that the figure is shown | |
| figure.show() # type: ignore | |
| while True: | |
| # Connect the events to the handler function call. | |
| event = None | |
| def handler(ev): | |
| # Set external event variable | |
| nonlocal event | |
| # Qt backend may fire two events at the same time, | |
| # use a condition to avoid missing close event. | |
| event = ev if not isinstance(event, CloseEvent) else event | |
| figure.canvas.stop_event_loop() | |
| cids = [ | |
| figure.canvas.mpl_connect(name, handler) # type: ignore | |
| for name in ('key_press_event', 'close_event') | |
| ] | |
| try: | |
| figure.canvas.start_event_loop(timeout) # type: ignore | |
| finally: # Run even on exception like ctrl-c. | |
| # Disconnect the callbacks. | |
| for cid in cids: | |
| figure.canvas.mpl_disconnect(cid) # type: ignore | |
| if isinstance(event, CloseEvent): | |
| return 1 # Quit for close. | |
| elif event is None or event.key == continue_key: | |
| return 0 # Quit for continue. | |
| def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray: | |
| """Get RGB image from ``FigureCanvasAgg``. | |
| Args: | |
| canvas (FigureCanvasAgg): The canvas to get image. | |
| Returns: | |
| np.ndarray: the output of image in RGB. | |
| """ # noqa: E501 | |
| s, (width, height) = canvas.print_to_buffer() | |
| buffer = np.frombuffer(s, dtype='uint8') | |
| img_rgba = buffer.reshape(height, width, 4) | |
| rgb, alpha = np.split(img_rgba, [3], axis=2) | |
| return rgb.astype('uint8') | |