Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Sequence, Union | |
| import mmengine | |
| import numpy as np | |
| import torch | |
| from .base import BaseTransform | |
| from .builder import TRANSFORMS | |
| def to_tensor( | |
| data: Union[torch.Tensor, np.ndarray, Sequence, int, | |
| float]) -> torch.Tensor: | |
| """Convert objects of various python types to :obj:`torch.Tensor`. | |
| Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |
| :class:`Sequence`, :class:`int` and :class:`float`. | |
| Args: | |
| data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to | |
| be converted. | |
| Returns: | |
| torch.Tensor: the converted data. | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| return data | |
| elif isinstance(data, np.ndarray): | |
| return torch.from_numpy(data) | |
| elif isinstance(data, Sequence) and not mmengine.is_str(data): | |
| return torch.tensor(data) | |
| elif isinstance(data, int): | |
| return torch.LongTensor([data]) | |
| elif isinstance(data, float): | |
| return torch.FloatTensor([data]) | |
| else: | |
| raise TypeError(f'type {type(data)} cannot be converted to tensor.') | |
| class ToTensor(BaseTransform): | |
| """Convert some results to :obj:`torch.Tensor` by given keys. | |
| Required keys: | |
| - all these keys in `keys` | |
| Modified Keys: | |
| - all these keys in `keys` | |
| Args: | |
| keys (Sequence[str]): Keys that need to be converted to Tensor. | |
| """ | |
| def __init__(self, keys: Sequence[str]) -> None: | |
| self.keys = keys | |
| def transform(self, results: dict) -> dict: | |
| """Transform function to convert data to `torch.Tensor`. | |
| Args: | |
| results (dict): Result dict from loading pipeline. | |
| Returns: | |
| dict: `keys` in results will be updated. | |
| """ | |
| for key in self.keys: | |
| key_list = key.split('.') | |
| cur_item = results | |
| for i in range(len(key_list)): | |
| if key_list[i] not in cur_item: | |
| raise KeyError(f'Can not find key {key}') | |
| if i == len(key_list) - 1: | |
| cur_item[key_list[i]] = to_tensor(cur_item[key_list[i]]) | |
| break | |
| cur_item = cur_item[key_list[i]] | |
| return results | |
| def __repr__(self) -> str: | |
| return self.__class__.__name__ + f'(keys={self.keys})' | |
| class ImageToTensor(BaseTransform): | |
| """Convert image to :obj:`torch.Tensor` by given keys. | |
| The dimension order of input image is (H, W, C). The pipeline will convert | |
| it to (C, H, W). If only 2 dimension (H, W) is given, the output would be | |
| (1, H, W). | |
| Required keys: | |
| - all these keys in `keys` | |
| Modified Keys: | |
| - all these keys in `keys` | |
| Args: | |
| keys (Sequence[str]): Key of images to be converted to Tensor. | |
| """ | |
| def __init__(self, keys: dict) -> None: | |
| self.keys = keys | |
| def transform(self, results: dict) -> dict: | |
| """Transform function to convert image in results to | |
| :obj:`torch.Tensor` and transpose the channel order. | |
| Args: | |
| results (dict): Result dict contains the image data to convert. | |
| Returns: | |
| dict: The result dict contains the image converted | |
| to :obj:``torch.Tensor`` and transposed to (C, H, W) order. | |
| """ | |
| for key in self.keys: | |
| img = results[key] | |
| if len(img.shape) < 3: | |
| img = np.expand_dims(img, -1) | |
| results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous() | |
| return results | |
| def __repr__(self) -> str: | |
| return self.__class__.__name__ + f'(keys={self.keys})' | |