Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import inspect | |
| import os.path as osp | |
| import warnings | |
| from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union | |
| if TYPE_CHECKING: | |
| from matplotlib.font_manager import FontProperties | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from mmengine.config import Config | |
| from mmengine.dist import master_only | |
| from mmengine.registry import VISBACKENDS, VISUALIZERS | |
| from mmengine.structures import BaseDataElement | |
| from mmengine.utils import ManagerMixin, is_seq_of | |
| from mmengine.visualization.utils import (check_type, check_type_and_length, | |
| color_str2rgb, color_val_matplotlib, | |
| convert_overlay_heatmap, | |
| img_from_canvas, tensor2ndarray, | |
| value2list, wait_continue) | |
| from mmengine.visualization.vis_backend import BaseVisBackend | |
| VisBackendsType = Union[List[Union[List, BaseDataElement]], BaseDataElement, | |
| dict, None] | |
| class Visualizer(ManagerMixin): | |
| """MMEngine provides a Visualizer class that uses the ``Matplotlib`` | |
| library as the backend. It has the following functions: | |
| - Basic drawing methods | |
| - draw_bboxes: draw single or multiple bounding boxes | |
| - draw_texts: draw single or multiple text boxes | |
| - draw_points: draw single or multiple points | |
| - draw_lines: draw single or multiple line segments | |
| - draw_circles: draw single or multiple circles | |
| - draw_polygons: draw single or multiple polygons | |
| - draw_binary_masks: draw single or multiple binary masks | |
| - draw_featmap: draw feature map | |
| - Basic visualizer backend methods | |
| - add_configs: write config to all vis storage backends | |
| - add_graph: write model graph to all vis storage backends | |
| - add_image: write image to all vis storage backends | |
| - add_scalar: write scalar to all vis storage backends | |
| - add_scalars: write scalars to all vis storage backends | |
| - add_datasample: write datasample to all vis storage \ | |
| backends. The abstract drawing interface used by the user | |
| - Basic info methods | |
| - set_image: sets the original image data | |
| - get_image: get the image data in Numpy format after drawing | |
| - show: visualization | |
| - close: close all resources that have been opened | |
| - get_backend: get the specified vis backend | |
| All the basic drawing methods support chain calls, which is convenient for | |
| overlaydrawing and display. Each downstream algorithm library can inherit | |
| ``Visualizer`` and implement the add_datasample logic. For example, | |
| ``DetLocalVisualizer`` in MMDetection inherits from ``Visualizer`` | |
| and implements functions, such as visual detection boxes, instance masks, | |
| and semantic segmentation maps in the add_datasample interface. | |
| Args: | |
| name (str): Name of the instance. Defaults to 'visualizer'. | |
| image (np.ndarray, optional): the origin image to draw. The format | |
| should be RGB. Defaults to None. | |
| vis_backends (list, optional): Visual backend config list. | |
| Defaults to None. | |
| save_dir (str, optional): Save file dir for all storage backends. | |
| If it is None, the backend storage will not save any data. | |
| fig_save_cfg (dict): Keyword parameters of figure for saving. | |
| Defaults to empty dict. | |
| fig_show_cfg (dict): Keyword parameters of figure for showing. | |
| Defaults to empty dict. | |
| Examples: | |
| >>> # Basic info methods | |
| >>> vis = Visualizer() | |
| >>> vis.set_image(image) | |
| >>> vis.get_image() | |
| >>> vis.show() | |
| >>> # Basic drawing methods | |
| >>> vis = Visualizer(image=image) | |
| >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edge_colors='g') | |
| >>> vis.draw_bboxes(bbox=np.array([[1, 1, 2, 2], [2, 2, 3, 3]]), | |
| >>> edge_colors=['g', 'r']) | |
| >>> vis.draw_lines(x_datas=np.array([1, 3]), | |
| >>> y_datas=np.array([1, 3]), | |
| >>> colors='r', line_widths=1) | |
| >>> vis.draw_lines(x_datas=np.array([[1, 3], [2, 4]]), | |
| >>> y_datas=np.array([[1, 3], [2, 4]]), | |
| >>> colors=['r', 'r'], line_widths=[1, 2]) | |
| >>> vis.draw_texts(text='MMEngine', | |
| >>> position=np.array([2, 2]), | |
| >>> colors='b') | |
| >>> vis.draw_texts(text=['MMEngine','OpenMMLab'], | |
| >>> position=np.array([[2, 2], [5, 5]]), | |
| >>> colors=['b', 'b']) | |
| >>> vis.draw_circles(circle_coord=np.array([2, 2]), radius=np.array[1]) | |
| >>> vis.draw_circles(circle_coord=np.array([[2, 2], [3, 5]), | |
| >>> radius=np.array[1, 2], colors=['g', 'r']) | |
| >>> square = np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) | |
| >>> vis.draw_polygons(polygons=square, edge_colors='g') | |
| >>> squares = [np.array([[0, 0], [100, 0], [100, 100], [0, 100]]), | |
| >>> np.array([[0, 0], [50, 0], [50, 50], [0, 50]])] | |
| >>> vis.draw_polygons(polygons=squares, edge_colors=['g', 'r']) | |
| >>> vis.draw_binary_masks(binary_mask, alpha=0.6) | |
| >>> heatmap = vis.draw_featmap(featmap, img, | |
| >>> channel_reduction='select_max') | |
| >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, | |
| >>> topk=8, arrangement=(4, 2)) | |
| >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, | |
| >>> topk=-1) | |
| >>> # chain calls | |
| >>> vis.draw_bboxes().draw_texts().draw_circle().draw_binary_masks() | |
| >>> # Backend related methods | |
| >>> vis = Visualizer(vis_backends=[dict(type='LocalVisBackend')], | |
| >>> save_dir='temp_dir') | |
| >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) | |
| >>> vis.add_config(cfg) | |
| >>> image=np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) | |
| >>> vis.add_image('image',image) | |
| >>> vis.add_scaler('mAP', 0.6) | |
| >>> vis.add_scalars({'loss': 0.1,'acc':0.8}) | |
| >>> # inherit | |
| >>> class DetLocalVisualizer(Visualizer): | |
| >>> def add_datasample(self, | |
| >>> name, | |
| >>> image: np.ndarray, | |
| >>> gt_sample: | |
| >>> Optional['BaseDataElement'] = None, | |
| >>> pred_sample: | |
| >>> Optional['BaseDataElement'] = None, | |
| >>> draw_gt: bool = True, | |
| >>> draw_pred: bool = True, | |
| >>> show: bool = False, | |
| >>> wait_time: int = 0, | |
| >>> step: int = 0) -> None: | |
| >>> pass | |
| """ | |
| def __init__( | |
| self, | |
| name='visualizer', | |
| image: Optional[np.ndarray] = None, | |
| vis_backends: VisBackendsType = None, | |
| save_dir: Optional[str] = None, | |
| fig_save_cfg=dict(frameon=False), | |
| fig_show_cfg=dict(frameon=False) | |
| ) -> None: | |
| super().__init__(name) | |
| self._dataset_meta: Optional[dict] = None | |
| self._vis_backends: Dict[str, BaseVisBackend] = {} | |
| if vis_backends is None: | |
| vis_backends = [] | |
| if isinstance(vis_backends, (dict, BaseVisBackend)): | |
| vis_backends = [vis_backends] # type: ignore | |
| if not is_seq_of(vis_backends, (dict, BaseVisBackend)): | |
| raise TypeError('vis_backends must be a list of dicts or a list ' | |
| 'of BaseBackend instances') | |
| if save_dir is not None: | |
| save_dir = osp.join(save_dir, 'vis_data') | |
| for vis_backend in vis_backends: # type: ignore | |
| name = None | |
| if isinstance(vis_backend, dict): | |
| name = vis_backend.pop('name', None) | |
| vis_backend.setdefault('save_dir', save_dir) | |
| vis_backend = VISBACKENDS.build(vis_backend) | |
| # If vis_backend requires `save_dir` (with no default value) | |
| # but is initialized with None, then don't add this | |
| # vis_backend to the visualizer. | |
| save_dir_arg = inspect.signature( | |
| vis_backend.__class__.__init__).parameters.get('save_dir') | |
| if (save_dir_arg is not None | |
| and save_dir_arg.default is save_dir_arg.empty | |
| and getattr(vis_backend, '_save_dir') is None): | |
| # warnings.warn(f'Failed to add {vis_backend.__class__}, please provide the `save_dir` argument.') | |
| continue | |
| type_name = vis_backend.__class__.__name__ | |
| name = name or type_name | |
| if name in self._vis_backends: | |
| raise RuntimeError(f'vis_backend name {name} already exists') | |
| self._vis_backends[name] = vis_backend # type: ignore | |
| self.fig_save = None | |
| self.fig_save_cfg = fig_save_cfg | |
| self.fig_show_cfg = fig_show_cfg | |
| (self.fig_save_canvas, self.fig_save, | |
| self.ax_save) = self._initialize_fig(fig_save_cfg) | |
| self.dpi = self.fig_save.get_dpi() | |
| if image is not None: | |
| self.set_image(image) | |
| # type: ignore | |
| def dataset_meta(self) -> Optional[dict]: | |
| """Optional[dict]: Meta info of the dataset.""" | |
| return self._dataset_meta | |
| # type: ignore | |
| def dataset_meta(self, dataset_meta: dict) -> None: | |
| """Set the dataset meta info to the Visualizer.""" | |
| self._dataset_meta = dataset_meta | |
| def show(self, | |
| drawn_img: Optional[np.ndarray] = None, | |
| win_name: str = 'image', | |
| wait_time: float = 0., | |
| continue_key: str = ' ', | |
| backend: str = 'matplotlib') -> None: | |
| """Show the drawn image. | |
| Args: | |
| drawn_img (np.ndarray, optional): The image to show. If drawn_img | |
| is None, it will show the image got by Visualizer. Defaults | |
| to None. | |
| win_name (str): The image title. Defaults to 'image'. | |
| wait_time (float): Delay in seconds. 0 is the special | |
| value that means "forever". Defaults to 0. | |
| continue_key (str): The key for users to continue. Defaults to | |
| the space key. | |
| backend (str): The backend to show the image. Defaults to | |
| 'matplotlib'. `New in version 0.7.3.` | |
| """ | |
| if backend == 'matplotlib': | |
| import matplotlib.pyplot as plt | |
| is_inline = 'inline' in plt.get_backend() | |
| img = self.get_image() if drawn_img is None else drawn_img | |
| self._init_manager(win_name) | |
| fig = self.manager.canvas.figure | |
| # remove white edges by set subplot margin | |
| fig.subplots_adjust(left=0, right=1, bottom=0, top=1) | |
| fig.clear() | |
| ax = fig.add_subplot() | |
| ax.axis(False) | |
| ax.imshow(img) | |
| self.manager.canvas.draw() | |
| # Find a better way for inline to show the image | |
| if is_inline: | |
| return fig | |
| wait_continue(fig, timeout=wait_time, continue_key=continue_key) | |
| elif backend == 'cv2': | |
| # Keep images are shown in the same window, and the title of window | |
| # will be updated with `win_name`. | |
| cv2.namedWindow(winname=f'{id(self)}') | |
| cv2.setWindowTitle(f'{id(self)}', win_name) | |
| cv2.imshow( | |
| str(id(self)), | |
| self.get_image() if drawn_img is None else drawn_img) | |
| cv2.waitKey(int(np.ceil(wait_time * 1000))) | |
| else: | |
| raise ValueError('backend should be "matplotlib" or "cv2", ' | |
| f'but got {backend} instead') | |
| def set_image(self, image: np.ndarray) -> None: | |
| """Set the image to draw. | |
| Args: | |
| image (np.ndarray): The image to draw. | |
| """ | |
| assert image is not None | |
| image = image.astype('uint8') | |
| self._image = image | |
| self.width, self.height = image.shape[1], image.shape[0] | |
| self._default_font_size = max( | |
| np.sqrt(self.height * self.width) // 90, 10) | |
| # add a small 1e-2 to avoid precision lost due to matplotlib's | |
| # truncation (https://github.com/matplotlib/matplotlib/issues/15363) | |
| self.fig_save.set_size_inches( # type: ignore | |
| (self.width + 1e-2) / self.dpi, (self.height + 1e-2) / self.dpi) | |
| # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) | |
| self.ax_save.cla() | |
| self.ax_save.axis(False) | |
| self.ax_save.imshow( | |
| image, | |
| extent=(0, self.width, self.height, 0), | |
| interpolation='none') | |
| def get_image(self) -> np.ndarray: | |
| """Get the drawn image. The format is RGB. | |
| Returns: | |
| np.ndarray: the drawn image which channel is RGB. | |
| """ | |
| assert self._image is not None, 'Please set image using `set_image`' | |
| return img_from_canvas(self.fig_save_canvas) # type: ignore | |
| def _initialize_fig(self, fig_cfg) -> tuple: | |
| """Build figure according to fig_cfg. | |
| Args: | |
| fig_cfg (dict): The config to build figure. | |
| Returns: | |
| tuple: build canvas figure and axes. | |
| """ | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| from matplotlib.figure import Figure | |
| fig = Figure(**fig_cfg) | |
| ax = fig.add_subplot() | |
| ax.axis(False) | |
| # remove white edges by set subplot margin | |
| fig.subplots_adjust(left=0, right=1, bottom=0, top=1) | |
| canvas = FigureCanvasAgg(fig) | |
| return canvas, fig, ax | |
| def _init_manager(self, win_name: str) -> None: | |
| """Initialize the matplot manager. | |
| Args: | |
| win_name (str): The window name. | |
| """ | |
| from matplotlib.figure import Figure | |
| from matplotlib.pyplot import new_figure_manager | |
| if getattr(self, 'manager', None) is None: | |
| self.manager = new_figure_manager( | |
| num=1, FigureClass=Figure, **self.fig_show_cfg) | |
| try: | |
| self.manager.set_window_title(win_name) | |
| except Exception: | |
| self.manager = new_figure_manager( | |
| num=1, FigureClass=Figure, **self.fig_show_cfg) | |
| self.manager.set_window_title(win_name) | |
| def get_backend(self, name) -> 'BaseVisBackend': | |
| """get vis backend by name. | |
| Args: | |
| name (str): The name of vis backend | |
| Returns: | |
| BaseVisBackend: The vis backend. | |
| """ | |
| return self._vis_backends.get(name) # type: ignore | |
| def _is_posion_valid(self, position: np.ndarray) -> bool: | |
| """Judge whether the position is in image. | |
| Args: | |
| position (np.ndarray): The position to judge which last dim must | |
| be two and the format is [x, y]. | |
| Returns: | |
| bool: Whether the position is in image. | |
| """ | |
| flag = (position[..., 0] < self.width).all() and \ | |
| (position[..., 0] >= 0).all() and \ | |
| (position[..., 1] < self.height).all() and \ | |
| (position[..., 1] >= 0).all() | |
| return flag | |
| def draw_points(self, | |
| positions: Union[np.ndarray, torch.Tensor], | |
| colors: Union[str, tuple, List[str], List[tuple]] = 'g', | |
| marker: Optional[str] = None, | |
| sizes: Optional[Union[np.ndarray, torch.Tensor]] = None): | |
| """Draw single or multiple points. | |
| Args: | |
| positions (Union[np.ndarray, torch.Tensor]): Positions to draw. | |
| colors (Union[str, tuple, List[str], List[tuple]]): The colors | |
| of points. ``colors`` can have the same length with points or | |
| just single value. If ``colors`` is single value, all the | |
| points will have the same colors. Reference to | |
| https://matplotlib.org/stable/gallery/color/named_colors.html | |
| for more details. Defaults to 'g. | |
| marker (str, optional): The marker style. | |
| See :mod:`matplotlib.markers` for more information about | |
| marker styles. Defaults to None. | |
| sizes (Optional[Union[np.ndarray, torch.Tensor]]): The marker size. | |
| Defaults to None. | |
| """ | |
| check_type('positions', positions, (np.ndarray, torch.Tensor)) | |
| positions = tensor2ndarray(positions) | |
| if len(positions.shape) == 1: | |
| positions = positions[None] | |
| assert positions.shape[-1] == 2, ( | |
| 'The shape of `positions` should be (N, 2), ' | |
| f'but got {positions.shape}') | |
| colors = color_val_matplotlib(colors) # type: ignore | |
| self.ax_save.scatter( | |
| positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) | |
| return self | |
| def draw_texts( | |
| self, | |
| texts: Union[str, List[str]], | |
| positions: Union[np.ndarray, torch.Tensor], | |
| font_sizes: Optional[Union[int, List[int]]] = None, | |
| colors: Union[str, tuple, List[str], List[tuple]] = 'g', | |
| vertical_alignments: Union[str, List[str]] = 'top', | |
| horizontal_alignments: Union[str, List[str]] = 'left', | |
| font_families: Union[str, List[str]] = 'sans-serif', | |
| bboxes: Optional[Union[dict, List[dict]]] = None, | |
| font_properties: Optional[Union['FontProperties', | |
| List['FontProperties']]] = None | |
| ) -> 'Visualizer': | |
| """Draw single or multiple text boxes. | |
| Args: | |
| texts (Union[str, List[str]]): Texts to draw. | |
| positions (Union[np.ndarray, torch.Tensor]): The position to draw | |
| the texts, which should have the same length with texts and | |
| each dim contain x and y. | |
| font_sizes (Union[int, List[int]], optional): The font size of | |
| texts. ``font_sizes`` can have the same length with texts or | |
| just single value. If ``font_sizes`` is single value, all the | |
| texts will have the same font size. Defaults to None. | |
| colors (Union[str, tuple, List[str], List[tuple]]): The colors | |
| of texts. ``colors`` can have the same length with texts or | |
| just single value. If ``colors`` is single value, all the | |
| texts will have the same colors. Reference to | |
| https://matplotlib.org/stable/gallery/color/named_colors.html | |
| for more details. Defaults to 'g. | |
| vertical_alignments (Union[str, List[str]]): The verticalalignment | |
| of texts. verticalalignment controls whether the y positional | |
| argument for the text indicates the bottom, center or top side | |
| of the text bounding box. | |
| ``vertical_alignments`` can have the same length with | |
| texts or just single value. If ``vertical_alignments`` is | |
| single value, all the texts will have the same | |
| verticalalignment. verticalalignment can be 'center' or | |
| 'top', 'bottom' or 'baseline'. Defaults to 'top'. | |
| horizontal_alignments (Union[str, List[str]]): The | |
| horizontalalignment of texts. Horizontalalignment controls | |
| whether the x positional argument for the text indicates the | |
| left, center or right side of the text bounding box. | |
| ``horizontal_alignments`` can have | |
| the same length with texts or just single value. | |
| If ``horizontal_alignments`` is single value, all the texts | |
| will have the same horizontalalignment. Horizontalalignment | |
| can be 'center','right' or 'left'. Defaults to 'left'. | |
| font_families (Union[str, List[str]]): The font family of | |
| texts. ``font_families`` can have the same length with texts or | |
| just single value. If ``font_families`` is single value, all | |
| the texts will have the same font family. | |
| font_familiy can be 'serif', 'sans-serif', 'cursive', 'fantasy' | |
| or 'monospace'. Defaults to 'sans-serif'. | |
| bboxes (Union[dict, List[dict]], optional): The bounding box of the | |
| texts. If bboxes is None, there are no bounding box around | |
| texts. ``bboxes`` can have the same length with texts or | |
| just single value. If ``bboxes`` is single value, all | |
| the texts will have the same bbox. Reference to | |
| https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyBboxPatch.html#matplotlib.patches.FancyBboxPatch | |
| for more details. Defaults to None. | |
| font_properties (Union[FontProperties, List[FontProperties]], optional): | |
| The font properties of texts. FontProperties is | |
| a ``font_manager.FontProperties()`` object. | |
| If you want to draw Chinese texts, you need to prepare | |
| a font file that can show Chinese characters properly. | |
| For example: `simhei.ttf`, `simsun.ttc`, `simkai.ttf` and so on. | |
| Then set ``font_properties=matplotlib.font_manager.FontProperties(fname='path/to/font_file')`` | |
| ``font_properties`` can have the same length with texts or | |
| just single value. If ``font_properties`` is single value, | |
| all the texts will have the same font properties. | |
| Defaults to None. | |
| `New in version 0.6.0.` | |
| """ # noqa: E501 | |
| from matplotlib.font_manager import FontProperties | |
| check_type('texts', texts, (str, list)) | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| num_text = len(texts) | |
| check_type('positions', positions, (np.ndarray, torch.Tensor)) | |
| positions = tensor2ndarray(positions) | |
| if len(positions.shape) == 1: | |
| positions = positions[None] | |
| assert positions.shape == (num_text, 2), ( | |
| '`positions` should have the shape of ' | |
| f'({num_text}, 2), but got {positions.shape}') | |
| if not self._is_posion_valid(positions): | |
| warnings.warn( | |
| 'Warning: The text is out of bounds,' | |
| ' the drawn text may not be in the image', UserWarning) | |
| positions = positions.tolist() | |
| if font_sizes is None: | |
| font_sizes = self._default_font_size | |
| check_type_and_length('font_sizes', font_sizes, (int, float, list), | |
| num_text) | |
| font_sizes = value2list(font_sizes, (int, float), num_text) | |
| check_type_and_length('colors', colors, (str, tuple, list), num_text) | |
| colors = value2list(colors, (str, tuple), num_text) | |
| colors = color_val_matplotlib(colors) # type: ignore | |
| check_type_and_length('vertical_alignments', vertical_alignments, | |
| (str, list), num_text) | |
| vertical_alignments = value2list(vertical_alignments, str, num_text) | |
| check_type_and_length('horizontal_alignments', horizontal_alignments, | |
| (str, list), num_text) | |
| horizontal_alignments = value2list(horizontal_alignments, str, | |
| num_text) | |
| check_type_and_length('font_families', font_families, (str, list), | |
| num_text) | |
| font_families = value2list(font_families, str, num_text) | |
| if font_properties is None: | |
| font_properties = [None for _ in range(num_text)] # type: ignore | |
| else: | |
| check_type_and_length('font_properties', font_properties, | |
| (FontProperties, list), num_text) | |
| font_properties = value2list(font_properties, FontProperties, | |
| num_text) | |
| if bboxes is None: | |
| bboxes = [None for _ in range(num_text)] # type: ignore | |
| else: | |
| check_type_and_length('bboxes', bboxes, (dict, list), num_text) | |
| bboxes = value2list(bboxes, dict, num_text) | |
| for i in range(num_text): | |
| self.ax_save.text( | |
| positions[i][0], | |
| positions[i][1], | |
| texts[i], | |
| size=font_sizes[i], # type: ignore | |
| bbox=bboxes[i], # type: ignore | |
| verticalalignment=vertical_alignments[i], | |
| horizontalalignment=horizontal_alignments[i], | |
| family=font_families[i], | |
| fontproperties=font_properties[i], | |
| color=colors[i]) | |
| return self | |
| def draw_lines( | |
| self, | |
| x_datas: Union[np.ndarray, torch.Tensor], | |
| y_datas: Union[np.ndarray, torch.Tensor], | |
| colors: Union[str, tuple, List[str], List[tuple]] = 'g', | |
| line_styles: Union[str, List[str]] = '-', | |
| line_widths: Union[Union[int, float], List[Union[int, float]]] = 2 | |
| ) -> 'Visualizer': | |
| """Draw single or multiple line segments. | |
| Args: | |
| x_datas (Union[np.ndarray, torch.Tensor]): The x coordinate of | |
| each line' start and end points. | |
| y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of | |
| each line' start and end points. | |
| colors (Union[str, tuple, List[str], List[tuple]]): The colors of | |
| lines. ``colors`` can have the same length with lines or just | |
| single value. If ``colors`` is single value, all the lines | |
| will have the same colors. Reference to | |
| https://matplotlib.org/stable/gallery/color/named_colors.html | |
| for more details. Defaults to 'g'. | |
| line_styles (Union[str, List[str]]): The linestyle | |
| of lines. ``line_styles`` can have the same length with | |
| texts or just single value. If ``line_styles`` is single | |
| value, all the lines will have the same linestyle. | |
| Reference to | |
| https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle | |
| for more details. Defaults to '-'. | |
| line_widths (Union[Union[int, float], List[Union[int, float]]]): | |
| The linewidth of lines. ``line_widths`` can have | |
| the same length with lines or just single value. | |
| If ``line_widths`` is single value, all the lines will | |
| have the same linewidth. Defaults to 2. | |
| """ | |
| from matplotlib.collections import LineCollection | |
| check_type('x_datas', x_datas, (np.ndarray, torch.Tensor)) | |
| x_datas = tensor2ndarray(x_datas) | |
| check_type('y_datas', y_datas, (np.ndarray, torch.Tensor)) | |
| y_datas = tensor2ndarray(y_datas) | |
| assert x_datas.shape == y_datas.shape, ( | |
| '`x_datas` and `y_datas` should have the same shape') | |
| assert x_datas.shape[-1] == 2, ( | |
| f'The shape of `x_datas` should be (N, 2), but got {x_datas.shape}' | |
| ) | |
| if len(x_datas.shape) == 1: | |
| x_datas = x_datas[None] | |
| y_datas = y_datas[None] | |
| colors = color_val_matplotlib(colors) # type: ignore | |
| lines = np.concatenate( | |
| (x_datas.reshape(-1, 2, 1), y_datas.reshape(-1, 2, 1)), axis=-1) | |
| if not self._is_posion_valid(lines): | |
| warnings.warn( | |
| 'Warning: The line is out of bounds,' | |
| ' the drawn line may not be in the image', UserWarning) | |
| line_collect = LineCollection( | |
| lines.tolist(), | |
| colors=colors, | |
| linestyles=line_styles, | |
| linewidths=line_widths) | |
| self.ax_save.add_collection(line_collect) | |
| return self | |
| def draw_circles( | |
| self, | |
| center: Union[np.ndarray, torch.Tensor], | |
| radius: Union[np.ndarray, torch.Tensor], | |
| edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', | |
| line_styles: Union[str, List[str]] = '-', | |
| line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, | |
| face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', | |
| alpha: Union[float, int] = 0.8, | |
| ) -> 'Visualizer': | |
| """Draw single or multiple circles. | |
| Args: | |
| center (Union[np.ndarray, torch.Tensor]): The x coordinate of | |
| each line' start and end points. | |
| radius (Union[np.ndarray, torch.Tensor]): The y coordinate of | |
| each line' start and end points. | |
| edge_colors (Union[str, tuple, List[str], List[tuple]]): The | |
| colors of circles. ``colors`` can have the same length with | |
| lines or just single value. If ``colors`` is single value, | |
| all the lines will have the same colors. Reference to | |
| https://matplotlib.org/stable/gallery/color/named_colors.html | |
| for more details. Defaults to 'g. | |
| line_styles (Union[str, List[str]]): The linestyle | |
| of lines. ``line_styles`` can have the same length with | |
| texts or just single value. If ``line_styles`` is single | |
| value, all the lines will have the same linestyle. | |
| Reference to | |
| https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle | |
| for more details. Defaults to '-'. | |
| line_widths (Union[Union[int, float], List[Union[int, float]]]): | |
| The linewidth of lines. ``line_widths`` can have | |
| the same length with lines or just single value. | |
| If ``line_widths`` is single value, all the lines will | |
| have the same linewidth. Defaults to 2. | |
| face_colors (Union[str, tuple, List[str], List[tuple]]): | |
| The face colors. Defaults to None. | |
| alpha (Union[int, float]): The transparency of circles. | |
| Defaults to 0.8. | |
| """ | |
| from matplotlib.collections import PatchCollection | |
| from matplotlib.patches import Circle | |
| check_type('center', center, (np.ndarray, torch.Tensor)) | |
| center = tensor2ndarray(center) | |
| check_type('radius', radius, (np.ndarray, torch.Tensor)) | |
| radius = tensor2ndarray(radius) | |
| if len(center.shape) == 1: | |
| center = center[None] | |
| assert center.shape == (radius.shape[0], 2), ( | |
| 'The shape of `center` should be (radius.shape, 2), ' | |
| f'but got {center.shape}') | |
| if not (self._is_posion_valid(center - | |
| np.tile(radius.reshape((-1, 1)), (1, 2))) | |
| and self._is_posion_valid( | |
| center + np.tile(radius.reshape((-1, 1)), (1, 2)))): | |
| warnings.warn( | |
| 'Warning: The circle is out of bounds,' | |
| ' the drawn circle may not be in the image', UserWarning) | |
| center = center.tolist() | |
| radius = radius.tolist() | |
| edge_colors = color_val_matplotlib(edge_colors) # type: ignore | |
| face_colors = color_val_matplotlib(face_colors) # type: ignore | |
| circles = [] | |
| for i in range(len(center)): | |
| circles.append(Circle(tuple(center[i]), radius[i])) | |
| if isinstance(line_widths, (int, float)): | |
| line_widths = [line_widths] * len(circles) | |
| line_widths = [ | |
| min(max(linewidth, 1), self._default_font_size / 4) | |
| for linewidth in line_widths | |
| ] | |
| p = PatchCollection( | |
| circles, | |
| alpha=alpha, | |
| facecolors=face_colors, | |
| edgecolors=edge_colors, | |
| linewidths=line_widths, | |
| linestyles=line_styles) | |
| self.ax_save.add_collection(p) | |
| return self | |
| def draw_bboxes( | |
| self, | |
| bboxes: Union[np.ndarray, torch.Tensor], | |
| edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', | |
| line_styles: Union[str, List[str]] = '-', | |
| line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, | |
| face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', | |
| alpha: Union[int, float] = 0.8, | |
| ) -> 'Visualizer': | |
| """Draw single or multiple bboxes. | |
| Args: | |
| bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with | |
| the format of(x1,y1,x2,y2). | |
| edge_colors (Union[str, tuple, List[str], List[tuple]]): The | |
| colors of bboxes. ``colors`` can have the same length with | |
| lines or just single value. If ``colors`` is single value, all | |
| the lines will have the same colors. Refer to `matplotlib. | |
| colors` for full list of formats that are accepted. | |
| Defaults to 'g'. | |
| line_styles (Union[str, List[str]]): The linestyle | |
| of lines. ``line_styles`` can have the same length with | |
| texts or just single value. If ``line_styles`` is single | |
| value, all the lines will have the same linestyle. | |
| Reference to | |
| https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle | |
| for more details. Defaults to '-'. | |
| line_widths (Union[Union[int, float], List[Union[int, float]]]): | |
| The linewidth of lines. ``line_widths`` can have | |
| the same length with lines or just single value. | |
| If ``line_widths`` is single value, all the lines will | |
| have the same linewidth. Defaults to 2. | |
| face_colors (Union[str, tuple, List[str], List[tuple]]): | |
| The face colors. Defaults to None. | |
| alpha (Union[int, float]): The transparency of bboxes. | |
| Defaults to 0.8. | |
| """ | |
| check_type('bboxes', bboxes, (np.ndarray, torch.Tensor)) | |
| bboxes = tensor2ndarray(bboxes) | |
| if len(bboxes.shape) == 1: | |
| bboxes = bboxes[None] | |
| assert bboxes.shape[-1] == 4, ( | |
| f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}') | |
| assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= | |
| bboxes[:, 3]).all() | |
| if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): | |
| warnings.warn( | |
| 'Warning: The bbox is out of bounds,' | |
| ' the drawn bbox may not be in the image', UserWarning) | |
| poly = np.stack( | |
| (bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 1], | |
| bboxes[:, 2], bboxes[:, 3], bboxes[:, 0], bboxes[:, 3]), | |
| axis=-1).reshape(-1, 4, 2) | |
| poly = [p for p in poly] | |
| return self.draw_polygons( | |
| poly, | |
| alpha=alpha, | |
| edge_colors=edge_colors, | |
| line_styles=line_styles, | |
| line_widths=line_widths, | |
| face_colors=face_colors) | |
| def draw_polygons( | |
| self, | |
| polygons: Union[Union[np.ndarray, torch.Tensor], | |
| List[Union[np.ndarray, torch.Tensor]]], | |
| edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', | |
| line_styles: Union[str, List[str]] = '-', | |
| line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, | |
| face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', | |
| alpha: Union[int, float] = 0.8, | |
| ) -> 'Visualizer': | |
| """Draw single or multiple bboxes. | |
| Args: | |
| polygons (Union[Union[np.ndarray, torch.Tensor],\ | |
| List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw | |
| with the format of (x1,y1,x2,y2,...,xn,yn). | |
| edge_colors (Union[str, tuple, List[str], List[tuple]]): The | |
| colors of polygons. ``colors`` can have the same length with | |
| lines or just single value. If ``colors`` is single value, | |
| all the lines will have the same colors. Refer to | |
| `matplotlib.colors` for full list of formats that are accepted. | |
| Defaults to 'g. | |
| line_styles (Union[str, List[str]]): The linestyle | |
| of lines. ``line_styles`` can have the same length with | |
| texts or just single value. If ``line_styles`` is single | |
| value, all the lines will have the same linestyle. | |
| Reference to | |
| https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle | |
| for more details. Defaults to '-'. | |
| line_widths (Union[Union[int, float], List[Union[int, float]]]): | |
| The linewidth of lines. ``line_widths`` can have | |
| the same length with lines or just single value. | |
| If ``line_widths`` is single value, all the lines will | |
| have the same linewidth. Defaults to 2. | |
| face_colors (Union[str, tuple, List[str], List[tuple]]): | |
| The face colors. Defaults to None. | |
| alpha (Union[int, float]): The transparency of polygons. | |
| Defaults to 0.8. | |
| """ | |
| from matplotlib.collections import PolyCollection | |
| check_type('polygons', polygons, (list, np.ndarray, torch.Tensor)) | |
| edge_colors = color_val_matplotlib(edge_colors) # type: ignore | |
| face_colors = color_val_matplotlib(face_colors) # type: ignore | |
| if isinstance(polygons, (np.ndarray, torch.Tensor)): | |
| polygons = [polygons] | |
| if isinstance(polygons, list): | |
| for polygon in polygons: | |
| assert polygon.shape[1] == 2, ( | |
| 'The shape of each polygon in `polygons` should be (M, 2),' | |
| f' but got {polygon.shape}') | |
| polygons = [tensor2ndarray(polygon) for polygon in polygons] | |
| for polygon in polygons: | |
| if not self._is_posion_valid(polygon): | |
| warnings.warn( | |
| 'Warning: The polygon is out of bounds,' | |
| ' the drawn polygon may not be in the image', UserWarning) | |
| if isinstance(line_widths, (int, float)): | |
| line_widths = [line_widths] * len(polygons) | |
| line_widths = [ | |
| min(max(linewidth, 1), self._default_font_size / 4) | |
| for linewidth in line_widths | |
| ] | |
| polygon_collection = PolyCollection( | |
| polygons, | |
| alpha=alpha, | |
| facecolor=face_colors, | |
| linestyles=line_styles, | |
| edgecolors=edge_colors, | |
| linewidths=line_widths) | |
| self.ax_save.add_collection(polygon_collection) | |
| return self | |
| def draw_binary_masks( | |
| self, | |
| binary_masks: Union[np.ndarray, torch.Tensor], | |
| colors: Union[str, tuple, List[str], List[tuple]] = 'g', | |
| alphas: Union[float, List[float]] = 0.8) -> 'Visualizer': | |
| """Draw single or multiple binary masks. | |
| Args: | |
| binary_masks (np.ndarray, torch.Tensor): The binary_masks to draw | |
| with of shape (N, H, W), where H is the image height and W is | |
| the image width. Each value in the array is either a 0 or 1 | |
| value of uint8 type. | |
| colors (np.ndarray): The colors which binary_masks will convert to. | |
| ``colors`` can have the same length with binary_masks or just | |
| single value. If ``colors`` is single value, all the | |
| binary_masks will convert to the same colors. The colors format | |
| is RGB. Defaults to np.array([0, 255, 0]). | |
| alphas (Union[int, List[int]]): The transparency of masks. | |
| Defaults to 0.8. | |
| """ | |
| check_type('binary_masks', binary_masks, (np.ndarray, torch.Tensor)) | |
| binary_masks = tensor2ndarray(binary_masks) | |
| assert binary_masks.dtype == np.bool_, ( | |
| 'The dtype of binary_masks should be np.bool_, ' | |
| f'but got {binary_masks.dtype}') | |
| binary_masks = binary_masks.astype('uint8') * 255 | |
| img = self.get_image() | |
| if binary_masks.ndim == 2: | |
| binary_masks = binary_masks[None] | |
| assert img.shape[:2] == binary_masks.shape[ | |
| 1:], '`binary_marks` must have ' \ | |
| 'the same shape with image' | |
| binary_mask_len = binary_masks.shape[0] | |
| check_type_and_length('colors', colors, (str, tuple, list), | |
| binary_mask_len) | |
| colors = value2list(colors, (str, tuple), binary_mask_len) | |
| colors = [ | |
| color_str2rgb(color) if isinstance(color, str) else color | |
| for color in colors | |
| ] | |
| for color in colors: | |
| assert len(color) == 3 | |
| for channel in color: | |
| assert 0 <= channel <= 255 # type: ignore | |
| if isinstance(alphas, float): | |
| alphas = [alphas] * binary_mask_len | |
| for binary_mask, color, alpha in zip(binary_masks, colors, alphas): | |
| binary_mask_complement = cv2.bitwise_not(binary_mask) | |
| rgb = np.zeros_like(img) | |
| rgb[...] = color | |
| rgb = cv2.bitwise_and(rgb, rgb, mask=binary_mask) | |
| img_complement = cv2.bitwise_and( | |
| img, img, mask=binary_mask_complement) | |
| rgb = rgb + img_complement | |
| img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0) | |
| self.ax_save.imshow( | |
| img, | |
| extent=(0, self.width, self.height, 0), | |
| interpolation='nearest') | |
| return self | |
| def draw_featmap(featmap: torch.Tensor, | |
| overlaid_image: Optional[np.ndarray] = None, | |
| channel_reduction: Optional[str] = 'squeeze_mean', | |
| topk: int = 20, | |
| arrangement: Tuple[int, int] = (4, 5), | |
| resize_shape: Optional[tuple] = None, | |
| alpha: float = 0.5) -> np.ndarray: | |
| """Draw featmap. | |
| - If `overlaid_image` is not None, the final output image will be the | |
| weighted sum of img and featmap. | |
| - If `resize_shape` is specified, `featmap` and `overlaid_image` | |
| are interpolated. | |
| - If `resize_shape` is None and `overlaid_image` is not None, | |
| the feature map will be interpolated to the spatial size of the image | |
| in the case where the spatial dimensions of `overlaid_image` and | |
| `featmap` are different. | |
| - If `channel_reduction` is "squeeze_mean" and "select_max", | |
| it will compress featmap to single channel image and weighted | |
| sum to `overlaid_image`. | |
| - If `channel_reduction` is None | |
| - If topk <= 0, featmap is assert to be one or three | |
| channel and treated as image and will be weighted sum | |
| to ``overlaid_image``. | |
| - If topk > 0, it will select topk channel to show by the sum of | |
| each channel. At the same time, you can specify the `arrangement` | |
| to set the window layout. | |
| Args: | |
| featmap (torch.Tensor): The featmap to draw which format is | |
| (C, H, W). | |
| overlaid_image (np.ndarray, optional): The overlaid image. | |
| Defaults to None. | |
| channel_reduction (str, optional): Reduce multiple channels to a | |
| single channel. The optional value is 'squeeze_mean' | |
| or 'select_max'. Defaults to 'squeeze_mean'. | |
| topk (int): If channel_reduction is not None and topk > 0, | |
| it will select topk channel to show by the sum of each channel. | |
| if topk <= 0, tensor_chw is assert to be one or three. | |
| Defaults to 20. | |
| arrangement (Tuple[int, int]): The arrangement of featmap when | |
| channel_reduction is not None and topk > 0. Defaults to (4, 5). | |
| resize_shape (tuple, optional): The shape to scale the feature map. | |
| Defaults to None. | |
| alpha (Union[int, List[int]]): The transparency of featmap. | |
| Defaults to 0.5. | |
| Returns: | |
| np.ndarray: RGB image. | |
| """ | |
| import matplotlib.pyplot as plt | |
| assert isinstance(featmap, | |
| torch.Tensor), (f'`featmap` should be torch.Tensor,' | |
| f' but got {type(featmap)}') | |
| assert featmap.ndim == 3, f'Input dimension must be 3, ' \ | |
| f'but got {featmap.ndim}' | |
| featmap = featmap.detach().cpu() | |
| if overlaid_image is not None: | |
| if overlaid_image.ndim == 2: | |
| overlaid_image = cv2.cvtColor(overlaid_image, | |
| cv2.COLOR_GRAY2RGB) | |
| if overlaid_image.shape[:2] != featmap.shape[1:]: | |
| warnings.warn( | |
| f'Since the spatial dimensions of ' | |
| f'overlaid_image: {overlaid_image.shape[:2]} and ' | |
| f'featmap: {featmap.shape[1:]} are not same, ' | |
| f'the feature map will be interpolated. ' | |
| f'This may cause mismatch problems !') | |
| if resize_shape is None: | |
| featmap = F.interpolate( | |
| featmap[None], | |
| overlaid_image.shape[:2], | |
| mode='bilinear', | |
| align_corners=False)[0] | |
| if resize_shape is not None: | |
| featmap = F.interpolate( | |
| featmap[None], | |
| resize_shape, | |
| mode='bilinear', | |
| align_corners=False)[0] | |
| if overlaid_image is not None: | |
| overlaid_image = cv2.resize(overlaid_image, resize_shape[::-1]) | |
| if channel_reduction is not None: | |
| assert channel_reduction in [ | |
| 'squeeze_mean', 'select_max'], \ | |
| f'Mode only support "squeeze_mean", "select_max", ' \ | |
| f'but got {channel_reduction}' | |
| if channel_reduction == 'select_max': | |
| sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) | |
| _, indices = torch.topk(sum_channel_featmap, 1) | |
| feat_map = featmap[indices] | |
| else: | |
| feat_map = torch.mean(featmap, dim=0) | |
| return convert_overlay_heatmap(feat_map, overlaid_image, alpha) | |
| elif topk <= 0: | |
| featmap_channel = featmap.shape[0] | |
| assert featmap_channel in [ | |
| 1, 3 | |
| ], ('The input tensor channel dimension must be 1 or 3 ' | |
| 'when topk is less than 1, but the channel ' | |
| f'dimension you input is {featmap_channel}, you can use the' | |
| ' channel_reduction parameter or set topk greater than ' | |
| '0 to solve the error') | |
| return convert_overlay_heatmap(featmap, overlaid_image, alpha) | |
| else: | |
| row, col = arrangement | |
| channel, height, width = featmap.shape | |
| assert row * col >= topk, 'The product of row and col in ' \ | |
| 'the `arrangement` is less than ' \ | |
| 'topk, please set the ' \ | |
| '`arrangement` correctly' | |
| # Extract the feature map of topk | |
| topk = min(channel, topk) | |
| sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) | |
| _, indices = torch.topk(sum_channel_featmap, topk) | |
| topk_featmap = featmap[indices] | |
| fig = plt.figure(frameon=False) | |
| # Set the window layout | |
| fig.subplots_adjust( | |
| left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) | |
| dpi = fig.get_dpi() | |
| fig.set_size_inches((width * col + 1e-2) / dpi, | |
| (height * row + 1e-2) / dpi) | |
| for i in range(topk): | |
| axes = fig.add_subplot(row, col, i + 1) | |
| axes.axis('off') | |
| axes.text(2, 15, f'channel: {indices[i]}', fontsize=10) | |
| axes.imshow( | |
| convert_overlay_heatmap(topk_featmap[i], overlaid_image, | |
| alpha)) | |
| image = img_from_canvas(fig.canvas) | |
| plt.close(fig) | |
| return image | |
| def add_config(self, config: Config, **kwargs): | |
| """Record the config. | |
| Args: | |
| config (Config): The Config object. | |
| """ | |
| for vis_backend in self._vis_backends.values(): | |
| vis_backend.add_config(config, **kwargs) | |
| def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], | |
| **kwargs) -> None: | |
| """Record the model graph. | |
| Args: | |
| model (torch.nn.Module): Model to draw. | |
| data_batch (Sequence[dict]): Batch of data from dataloader. | |
| """ | |
| for vis_backend in self._vis_backends.values(): | |
| vis_backend.add_graph(model, data_batch, **kwargs) | |
| def add_image(self, name: str, image: np.ndarray, step: int = 0) -> None: | |
| """Record the image. | |
| Args: | |
| name (str): The image identifier. | |
| image (np.ndarray, optional): The image to be saved. The format | |
| should be RGB. Defaults to None. | |
| step (int): Global step value to record. Defaults to 0. | |
| """ | |
| for vis_backend in self._vis_backends.values(): | |
| vis_backend.add_image(name, image, step) # type: ignore | |
| def add_scalar(self, | |
| name: str, | |
| value: Union[int, float], | |
| step: int = 0, | |
| **kwargs) -> None: | |
| """Record the scalar data. | |
| Args: | |
| name (str): The scalar identifier. | |
| value (float, int): Value to save. | |
| step (int): Global step value to record. Defaults to 0. | |
| """ | |
| for vis_backend in self._vis_backends.values(): | |
| vis_backend.add_scalar(name, value, step, **kwargs) # type: ignore | |
| def add_scalars(self, | |
| scalar_dict: dict, | |
| step: int = 0, | |
| file_path: Optional[str] = None, | |
| **kwargs) -> None: | |
| """Record the scalars' data. | |
| Args: | |
| scalar_dict (dict): Key-value pair storing the tag and | |
| corresponding values. | |
| step (int): Global step value to record. Defaults to 0. | |
| file_path (str, optional): The scalar's data will be | |
| saved to the `file_path` file at the same time | |
| if the `file_path` parameter is specified. | |
| Defaults to None. | |
| """ | |
| for vis_backend in self._vis_backends.values(): | |
| vis_backend.add_scalars(scalar_dict, step, file_path, **kwargs) | |
| def add_datasample(self, | |
| name, | |
| image: np.ndarray, | |
| data_sample: Optional['BaseDataElement'] = None, | |
| draw_gt: bool = True, | |
| draw_pred: bool = True, | |
| show: bool = False, | |
| wait_time: int = 0, | |
| step: int = 0) -> None: | |
| """Draw datasample.""" | |
| pass | |
| def close(self) -> None: | |
| """close an opened object.""" | |
| for vis_backend in self._vis_backends.values(): | |
| vis_backend.close() | |
| def get_instance(cls, name: str, **kwargs) -> 'Visualizer': | |
| """Make subclass can get latest created instance by | |
| ``Visualizer.get_current_instance()``. | |
| Downstream codebase may need to get the latest created instance | |
| without knowing the specific Visualizer type. For example, mmdetection | |
| builds visualizer in runner and some component which cannot access | |
| runner wants to get latest created visualizer. In this case, | |
| the component does not know which type of visualizer has been built | |
| and cannot get target instance. Therefore, :class:`Visualizer` | |
| overrides the :meth:`get_instance` and its subclass will register | |
| the created instance to :attr:`_instance_dict` additionally. | |
| :meth:`get_current_instance` will return the latest created subclass | |
| instance. | |
| Examples: | |
| >>> class DetLocalVisualizer(Visualizer): | |
| >>> def __init__(self, name): | |
| >>> super().__init__(name) | |
| >>> | |
| >>> visualizer1 = DetLocalVisualizer.get_instance('name1') | |
| >>> visualizer2 = Visualizer.get_current_instance() | |
| >>> visualizer3 = DetLocalVisualizer.get_current_instance() | |
| >>> assert id(visualizer1) == id(visualizer2) == id(visualizer3) | |
| Args: | |
| name (str): Name of instance. | |
| Returns: | |
| object: Corresponding name instance. | |
| """ | |
| instance = super().get_instance(name, **kwargs) | |
| Visualizer._instance_dict[name] = instance | |
| return instance | |