Source code for utils

import torch
import numpy as np
from torch.nn.functional import pad

[docs]def compute_pad_size(width: int) -> int: """Computes the pad size required for rotation / inverse rotation so that pad + rotate + inverse rotate + unpad = original object Args: width (int): Width of input tensor (assumed to be square) Returns: int: How much to pad each side by """ return int(np.ceil((np.sqrt(2)*width - width)/2))
[docs]def compute_pad_size_padded(width: int) -> int: """Given a padded tensor, computes how much padding it has had Args: width (int): Width of input padded tensor (assumed square) Returns: int: How much padding was applied to the input tensor """ a = (np.sqrt(2) - 1)/2 if width%2==0: width_old = int(2*np.floor((width/2)/(1+2*a))) else: width_old = int(2*np.floor(((width-1)/2)/(1+2*a))) return int((width-width_old)/2)
[docs]def pad_object(object: torch.Tensor, mode='constant') -> torch.Tensor: """Pads an input tensor so that pad + rotate + inverse rotate + unpad = original object. This is useful for rotating objects without losing information at the edges. Args: object (torch.Tensor): Object to be padded mode (str, optional): Mode for extrapolation beyonf out of bounds. Defaults to 'constant'. Returns: torch.Tensor: Padded object """ pad_size = compute_pad_size(object.shape[-2]) return pad(object, [pad_size,pad_size,pad_size,pad_size], mode=mode)
[docs]def unpad_object(object: torch.Tensor) -> torch.Tensor: """Given a padded object, removes the padding to return the original object Args: object (torch.Tensor): Padded object Returns: torch.Tensor: Unpadded, original object """ pad_size = compute_pad_size_padded(object.shape[-2]) return object[:,pad_size:-pad_size,pad_size:-pad_size]
[docs]def get_kernel_meshgrid( xv_input: torch.Tensor, yv_input: torch.Tensor, k_width: float ) -> tuple[torch.Tensor, torch.Tensor]: """Obtains a kernel meshgrid of given spatial width k_width (in same units as meshgrid). Enforces the kernel size is odd Args: xv_input (torch.Tensor): Meshgrid x-coordinates corresponding to the input of some operator yv_input (torch.Tensor): Meshgrid y-coordinates corresponding to the input of some operator k_width (float): Width of kernel in same units as meshgrid Returns: tuple[torch.Tensor, torch.Tensor]: Meshgrid of kernel """ dx = xv_input[0,1] - xv_input[0,0] dy = yv_input[1,0] - yv_input[0,0] x_kernel = torch.arange(0,k_width/2,dx).to(xv_input.device) x_kernel = torch.cat([-x_kernel.flip(dims=(0,))[:-1], x_kernel]) y_kernel = torch.arange(0,k_width/2,dy).to(xv_input.device) y_kernel = torch.cat([-y_kernel.flip(dims=(0,))[:-1], y_kernel]) xv_kernel, yv_kernel = torch.meshgrid(x_kernel, y_kernel, indexing='xy') return xv_kernel, yv_kernel