Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """Custom replacement for `torch.nn.functional.grid_sample` that | |
| supports arbitrarily high order gradients between the input and output. | |
| Only works on 2D images and assumes | |
| `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" | |
| import torch | |
| from pkg_resources import parse_version | |
| # pylint: disable=redefined-builtin | |
| # pylint: disable=arguments-differ | |
| # pylint: disable=protected-access | |
| #---------------------------------------------------------------------------- | |
| enabled = False # Enable the custom op by setting this to true. | |
| _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version( | |
| '1.11.0a') # Allow prerelease builds of 1.11 | |
| #---------------------------------------------------------------------------- | |
| def grid_sample(input, grid): | |
| if _should_use_custom_op(): | |
| return _GridSample2dForward.apply(input, grid) | |
| return torch.nn.functional.grid_sample(input=input, | |
| grid=grid, | |
| mode='bilinear', | |
| padding_mode='zeros', | |
| align_corners=False) | |
| #---------------------------------------------------------------------------- | |
| def _should_use_custom_op(): | |
| return enabled | |
| #---------------------------------------------------------------------------- | |
| class _GridSample2dForward(torch.autograd.Function): | |
| def forward(ctx, input, grid): | |
| assert input.ndim == 4 | |
| assert grid.ndim == 4 | |
| output = torch.nn.functional.grid_sample(input=input, | |
| grid=grid, | |
| mode='bilinear', | |
| padding_mode='zeros', | |
| align_corners=False) | |
| ctx.save_for_backward(input, grid) | |
| return output | |
| def backward(ctx, grad_output): | |
| input, grid = ctx.saved_tensors | |
| grad_input, grad_grid = _GridSample2dBackward.apply( | |
| grad_output, input, grid) | |
| return grad_input, grad_grid | |
| #---------------------------------------------------------------------------- | |
| class _GridSample2dBackward(torch.autograd.Function): | |
| def forward(ctx, grad_output, input, grid): | |
| op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') | |
| if _use_pytorch_1_11_api: | |
| output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) | |
| grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, | |
| output_mask) | |
| else: | |
| grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) | |
| ctx.save_for_backward(grid) | |
| return grad_input, grad_grid | |
| def backward(ctx, grad2_grad_input, grad2_grad_grid): | |
| _ = grad2_grad_grid # unused | |
| grid, = ctx.saved_tensors | |
| grad2_grad_output = None | |
| grad2_input = None | |
| grad2_grid = None | |
| if ctx.needs_input_grad[0]: | |
| grad2_grad_output = _GridSample2dForward.apply( | |
| grad2_grad_input, grid) | |
| # st() | |
| # ? why | |
| # assert not ctx.needs_input_grad[2] | |
| return grad2_grad_output, grad2_input, grad2_grid | |