| from enum import Enum | |
| from typing import Union | |
| import torch | |
| class Format(str, Enum): | |
| NCHW = 'NCHW' | |
| NHWC = 'NHWC' | |
| NCL = 'NCL' | |
| NLC = 'NLC' | |
| FormatT = Union[str, Format] | |
| def get_spatial_dim(fmt: FormatT): | |
| fmt = Format(fmt) | |
| if fmt is Format.NLC: | |
| dim = (1,) | |
| elif fmt is Format.NCL: | |
| dim = (2,) | |
| elif fmt is Format.NHWC: | |
| dim = (1, 2) | |
| else: | |
| dim = (2, 3) | |
| return dim | |
| def get_channel_dim(fmt: FormatT): | |
| fmt = Format(fmt) | |
| if fmt is Format.NHWC: | |
| dim = 3 | |
| elif fmt is Format.NLC: | |
| dim = 2 | |
| else: | |
| dim = 1 | |
| return dim | |
| def nchw_to(x: torch.Tensor, fmt: Format): | |
| if fmt == Format.NHWC: | |
| x = x.permute(0, 2, 3, 1) | |
| elif fmt == Format.NLC: | |
| x = x.flatten(2).transpose(1, 2) | |
| elif fmt == Format.NCL: | |
| x = x.flatten(2) | |
| return x | |
| def nhwc_to(x: torch.Tensor, fmt: Format): | |
| if fmt == Format.NCHW: | |
| x = x.permute(0, 3, 1, 2) | |
| elif fmt == Format.NLC: | |
| x = x.flatten(1, 2) | |
| elif fmt == Format.NCL: | |
| x = x.flatten(1, 2).transpose(1, 2) | |
| return x | |