| import torch | |
| def tensor_to_size(source, dest_size): | |
| if isinstance(dest_size, torch.Tensor): | |
| dest_size = dest_size.shape[0] | |
| source_size = source.shape[0] | |
| if source_size < dest_size: | |
| shape = [dest_size - source_size] + [1]*(source.dim()-1) | |
| source = torch.cat((source, source[-1:].repeat(shape)), dim=0) | |
| elif source_size > dest_size: | |
| source = source[:dest_size] | |
| return source | |
| def tensor_to_image(tensor): | |
| image = tensor.mul(255).clamp(0, 255).byte().cpu() | |
| image = image[..., [2, 1, 0]].numpy() | |
| return image | |
| def image_to_tensor(image): | |
| tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) | |
| tensor = tensor[..., [2, 1, 0]] | |
| return tensor | |