from __future__ import annotations
from typing import Callable, Sequence
import torch
from torch.nn.functional import conv1d, conv2d
from torchvision.transforms.functional import rotate
from torchvision.transforms import InterpolationMode
from torch.nn.functional import grid_sample
from fft_conv_pytorch import fft_conv
from spectpsftoolbox.utils import pad_object, unpad_object
from spectpsftoolbox.kernel1d import GaussianKernel1D, Kernel1D
from spectpsftoolbox.kernel2d import Kernel2D
import dill
[docs]class Operator:
"""Base class for operators; operators are used to apply linear shift invariant operations to a sequence of 2D images.
"""
[docs] def __call__(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
"""Evaluates the operator on the input. The meshgrid xv and yv is used to compute the kernel size; it is assumed that the spacing in xv and yv is the same as that in input. The output is multiplied by the area of a pixel in the meshgrid.
Args:
input (torch.Tensor[Ld,Li,Lj]): Input 3D map to be operated on
xv (torch.Tensor[Lx,Ly]): Meshgrid x coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y coordinates
a (torch.Tensor[Ld]): Source-detector distances
Returns:
torch.Tensor[Ld,Li,Lj]: Output of the operator
"""
return input * self._area(xv,yv)
[docs] def _area(self, xv: torch.Tensor, yv: torch.Tensor) -> float:
"""Compute pixel volume in meshgrid
Args:
xv (torch.Tensor): Meshgrid x coordinates
yv (torch.Tensor): Meshgrid y coordinates
Returns:
float: Are
"""
return (xv[0,1]-xv[0,0])*(yv[1,0]-yv[0,0])
[docs] def set_device(self, device: str):
"""Sets the device of all parameters in the operator
Args:
device (str): Device to set parameters to
"""
for p in self.params:
p.data = p.data.to(device)
[docs] def detach(self):
"""Detaches all parameters from autograd.
"""
for p in self.params:
p.detach_()
[docs] def set_requires_grad(self):
"""Sets all parameters to require grad
"""
for p in self.params:
p.requires_grad_(True)
[docs] def save(self, path: str):
"""Saves the operator
Args:
path (str): Path where to save the operator
"""
self.set_device('cpu')
self.detach()
dill.dump(self, open(path, 'wb'))
[docs] def normalization_constant(
self,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
"""Computes the normalization constant of the operator
Args:
xv (torch.Tensor[Lx,Ly]): Meshgrid x coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y coordinates
a (torch.Tensor[Ld]): Source-detector distances
Returns:
torch.Tensor[Ld]: Normalization constant at each source-detector distance
"""
...
[docs] def normalize(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
"""Normalizes the input by the normalization constant. This ensures that the operator maintains the total sum of the input at each source-detector distance.
Args:
input (torch.Tensor[Ld,Li,Lj]): Input to be normalized
xv (torch.Tensor[Lx,Ly]): Meshgrid x coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y coordinates
a (torch.Tensor[Ld]): Source-detector distances
Returns:
torch.Tensor[Ld,Li,Lj]: Normalized input
"""
return input / self.normalization_constant(xv, yv, a)
[docs] def __add__(self, other: Operator) -> Operator:
"""Implementation of addition to allow for adding operators. Addition of two operators yields a new operator that corresponds to the sum of two linear operators
Args:
other (Operator): Operator to add
Returns:
Operator: New operator corresponding to the sum of the two operators
"""
def combined_operator(x, *args, **kwargs):
return self(x, *args, **kwargs) + other(x, *args, **kwargs)
return CombinedOperator(combined_operator, [self, other], type='additive')
[docs] def __mul__(self, other: Operator) -> Operator:
"""Implementation of multiplication to allow for multiplying operators. Multiplication of two operators yields a new operator that corresponds to the composition of the two operators
Args:
other (Operator): Operator to use in composition
Returns:
Operator: Composed operators
"""
def combined_operator(x, *args, **kwargs):
return self(other(x, *args, **kwargs), *args, **kwargs)
return CombinedOperator(combined_operator, [self,other], type='sequential')
[docs]class CombinedOperator(Operator):
"""Operator that has been constructed using two other operators
Args:
func (Callable): Function that specifies how the two operators are combined
operators (Sequence[Operator]): Sequence of operators
type (str): Type of operator: either 'sequential' or 'additive'
"""
def __init__(
self,
func: Callable,
operators: Sequence[Operator],
type: str
) -> None:
self.params = [*operators[0].params, *operators[1].params]
self.func = func
self.type = type
self.operators = operators
[docs] def set_device(self, device: str) -> None:
"""Sets the device of all the parameters in the composed operator
Args:
device (str): Device to set parameters to
"""
for operator in self.operators:
operator.set_device(device)
[docs] def detach(self) -> None:
"""Detaches all parameters of the composed operator
"""
for operator in self.operators:
operator.detach()
[docs] def normalization_constant(
self,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
"""Computes the normalization constant of the combined operator using the normalization constants of its components
Args:
xv (torch.Tensor): Meshgrid x coordinates
yv (torch.Tensor): Meshgrid y coordinates
a (torch.Tensor): Source-detector distances
Returns:
torch.Tensor: Normalization constant
"""
if self.type=='additive':
return self.operators[0].normalization_constant(xv, yv, a) + self.operators[1].normalization_constant(xv, yv, a)
else:
return self.operators[0].normalization_constant(xv, yv, a) * self.operators[1].normalization_constant(xv, yv, a)
[docs] def __call__(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor,
normalize: bool = False
) -> torch.Tensor:
"""Computes the output of the combined operator
Args:
input (torch.Tensor[Ld,Li,Lj]): Input to the operator
xv (torch.Tensor[Lx,Ly]): Meshgrid x coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y coordinates
a (torch.Tensor[Ld]): Source-detector distances
normalize (bool, optional): Whether to normalize the output. Defaults to False.
Returns:
torch.Tensor[Ld,Li,Lj]: Output of the operator
"""
if normalize:
return self.func(input, xv, yv, a) / self.normalization_constant(xv, yv, a)
else:
return self.func(input,xv, yv, a)
[docs]class Rotate1DConvOperator(Operator):
"""Operator that functions by rotating the input by a number of angles and applying a 1D convolution at each angle
Args:
kernel1D (Kernel1D): 1D kernel to apply at each rotation angle
N_angles (int): Number of angles to convolve at. Evenly distributes these angles between 0 and 180 degrees (2 angles would be 0, 90 degrees)
additive (bool, optional): Use in additive mode; in this case, the initial input is used at each rotation angle. If False, then output from each previous angle is used in succeeding angles. Defaults to False.
use_fft_conv (bool, optional): Whether or not to use FFT based convolution. Defaults to False.
rot (float, optional): Initial angle offset. Defaults to 0.
"""
def __init__(
self,
kernel1D: Kernel1D,
N_angles: int,
additive: bool = False,
use_fft_conv: bool = False,
rot: float = 0
) -> None:
self.params = kernel1D.params
self.kernel1D = kernel1D
self.N_angles = N_angles
self.angles = [180*i/N_angles + rot for i in range(N_angles)]
self.additive = additive
self.angle_delta = 1e-4
self.use_fft_conv = use_fft_conv
[docs] def _conv(self, input: torch.Tensor) -> torch.Tensor:
"""Applies convolution to the input
Args:
input (torch.Tensor): Input tensor
Returns:
torch.Tensor: Convolved input tensor
"""
if self.use_fft_conv:
return fft_conv(input, self.kernel, padding='same', groups=self.kernel.shape[0])
else:
return conv1d(input, self.kernel, padding='same', groups=self.kernel.shape[0])
[docs] def normalization_constant(
self,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
# """Uses recursive docstring"""
if self.additive:
return (self.N_angles*self.kernel.sum(dim=-1)).unsqueeze(-1) * torch.sqrt(self._area(xv,yv))
else:
return ((self.kernel.sum(dim=-1))**self.N_angles).unsqueeze(-1) * self._area(xv,yv)
[docs] def _rotate(self, input: torch.Tensor, angle: float) -> torch.Tensor:
"""Rotates the input at the desired angle
Args:
input (torch.Tensor): Input tensor
angle (float): Angle to rotate by
Returns:
torch.Tensor: Rotated input
"""
if abs(angle)<self.angle_delta:
return input
elif abs(angle%90)<self.angle_delta:
return torch.rot90(input, int(angle//90), dims=[1,2])
else:
return rotate(input, angle, interpolation=InterpolationMode.BILINEAR)
[docs] def _apply_additive(self, input: torch.Tensor) -> torch.Tensor:
"""Applies the operator in additive mode
Args:
input (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor, which is rotated + convolved input tensor
"""
output = 0
for angle in self.angles:
output_i = self._rotate(input, angle)
output_i = output_i.swapaxes(0,1)
output_i = self._conv(output_i)
output_i = output_i.swapaxes(0,1)
output_i = self._rotate(output_i, -angle)
output += output_i
return output
[docs] def _apply_regular(self, input: torch.Tensor) -> torch.Tensor:
"""Applies operator in non-additive mode
Args:
input (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor, which is rotated + convolved input tensor
"""
for angle in self.angles:
input = self._rotate(input, angle)
input = input.swapaxes(0,1)
input = self._conv(input)
input = input.swapaxes(0,1)
input = self._rotate(input, angle)
return input
[docs] def __call__(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor,
normalize: bool = False
) -> torch.Tensor:
# """Uses recursive docstring"""
input = pad_object(input)
# Get padded kernel shape
dx = xv[0,1] - xv[0,0]
Nx_padded = input.shape[-1]
if Nx_padded%2==0: Nx_padded +=1 # kernel must be odd
x_padded = torch.arange(-(Nx_padded-1)/2, (Nx_padded+1)/2, 1).to(input.device) * dx
self.kernel = self.kernel1D(x_padded,a,normalize=False).unsqueeze(1)
# Apply operations
if self.additive:
input = self._apply_additive(input)
else:
input = self._apply_regular(input)
input = unpad_object(input)
if normalize: input = self.normalize(input, xv, yv, a)
if self.additive:
return input * torch.sqrt(self._area(xv,yv))
else:
return input * self._area(xv,yv)
[docs]class RotateSeperable2DConvOperator(Operator):
"""Operator that applies rotations followed by convolutions with two perpendicular 1D kernels (x/y) at each angle
Args:
kernel1D (Kernel1D): Kernel1D to use for convolution
N_angles (int): Number of angles to rotate at
additive (bool, optional): Use in additive mode; in this case, the initial input is used at each rotation angle. If False, then output from each previous angle is used in succeeding angles. Defaults to False.
use_fft_conv (bool, optional): Whether or not to use FFT based convoltution. Defaults to False.
rot (float, optional): Initial rotation angle. Defaults to 0.
"""
def __init__(
self,
kernel1D: Kernel1D,
N_angles: int,
additive: bool = False,
use_fft_conv: bool = False,
rot: float = 0
) -> None:
self.params = kernel1D.params
self.kernel1D = kernel1D
self.N_angles = N_angles
self.angles = [90*i/N_angles + rot for i in range(N_angles)]
self.additive = additive
self.angle_delta = 1e-4
self.use_fft_conv = use_fft_conv
[docs] def _conv(self, input: torch.Tensor) -> torch.Tensor:
"""Applies convolution
Args:
input (torch.Tensor): Input tensor to convole
Returns:
torch.Tensor: Convolved input tensor
"""
if self.use_fft_conv:
return fft_conv(input, self.kernel, padding='same', groups=self.kernel.shape[0])
else:
return conv1d(input, self.kernel, padding='same', groups=self.kernel.shape[0])
[docs] def normalization_constant(self, xv: torch.Tensor, yv: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
# """Uses recursive docstring"""
dx = xv[0,1] - xv[0,0]
if self.additive:
return (self.N_angles*self.kernel.sum(dim=-1)**2).unsqueeze(-1) * self._area(xv,yv)
else:
return ((self.kernel.sum(dim=-1)**2)**self.N_angles).unsqueeze(-1) * self._area(xv,yv)
[docs] def _rotate(self, input: torch.Tensor, angle: float) -> torch.Tensor:
"""Applies rotation to input tensor
Args:
input (torch.Tensor): Input tensor to be rotated
angle (float): Rotation angle
Returns:
torch.Tensor: Rotated input tensor
"""
if abs(angle)<self.angle_delta:
return input
elif abs(angle%90)<self.angle_delta:
return torch.rot90(input, int(angle//90), dims=[1,2])
else:
return rotate(input, angle, interpolation=InterpolationMode.BILINEAR)
[docs] def _apply_additive(self, input: torch.Tensor) -> torch.Tensor:
"""Applies operator in additive mode
Args:
input (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor
"""
output = 0
for angle in self.angles:
output_i = self._rotate(input, angle)
output_i = output_i.swapaxes(0,1)
# Perform 2D conv
output_i = self._conv(output_i)
output_i = output_i.swapaxes(0,-1)
output_i = self._conv(output_i)
output_i = output_i.swapaxes(0,-1)
# ----
output_i = output_i.swapaxes(0,1)
output_i = self._rotate(output_i, -angle)
output += output_i
return output
[docs] def _apply_regular(self, input: torch.Tensor) -> torch.Tensor:
"""Applies operator in non-additive mode
Args:
input (torch.Tensor): Input tensor
Returns:
torch.Tensor: Output tensor
"""
for angle in self.angles:
input = self._rotate(input, angle)
input = input.swapaxes(0,1)
# Perform 2D conv
input = self._conv(input)
input = input.swapaxes(0,-1)
input = self._conv(input)
input = input.swapaxes(0,-1)
# ----
input = input.swapaxes(0,1)
input = self._rotate(input, -angle)
return input
[docs] def __call__(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor,
normalize: bool = False
) -> torch.Tensor:
# """Uses recursive docstring"""
self.kernel = self.kernel1D(xv[0],a,normalize=False).unsqueeze(1) # always false
input = pad_object(input)
if self.additive:
input = self._apply_additive(input)
else:
input = self._apply_regular(input)
input = unpad_object(input)
if normalize: input = self.normalize(input, xv, yv, a)
return input * self._area(xv,yv)
[docs]class Kernel2DOperator(Operator):
"""Operator built using a general 2D kernel; the output of this operator is 2D convolution with the Kernel2D instance
Args:
kernel2D (Kernel2D): Kernel2D instance used for obtaining the generic 2D kernel
use_fft_conv (bool, optional): Whether or not to use FFT based convolution. Defaults to False.
"""
def __init__(
self,
kernel2D: Kernel2D,
use_fft_conv: bool = False,
) -> None:
self.params = kernel2D.params
self.kernel2D = kernel2D
self.use_fft_conv = use_fft_conv
[docs] def _conv(self, input: torch.Tensor) -> torch.Tensor:
"""Applies convolution to the input
Args:
input (torch.Tensor): Input
Returns:
torch.Tensor: Output
"""
if self.use_fft_conv:
return fft_conv(input, self.kernel.unsqueeze(1), padding='same', groups=self.kernel.shape[0])
else:
return conv2d(input, self.kernel.unsqueeze(1), padding='same', groups=self.kernel.shape[0])
[docs] def normalization_constant(
self,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
# """Uses recursive docstring"""
return self.kernel2D.normalization_constant(xv, yv, a) * self._area(xv,yv)
[docs] def __call__(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor,
normalize: bool = False
) -> torch.Tensor:
# """Uses recursive docstring"""
self.kernel = self.kernel2D(xv,yv,a)
if normalize:
self.kernel = self.kernel / self.normalization_constant(xv, yv, a)
return self._conv(input) * self._area(xv,yv)
[docs]class NearestKernelOperator(Operator):
"""Operator that uses a set of PSFs and distances to compute the output of the operator. The PSF is obtained by selecting the nearest PSF to each distance provided in __call__ so that each plane in input is convolved with the appropriate kernel.
Args:
psf_data (torch.Tensor[LD,LX,LY]): Provided PSF data
distances (torch.Tensor[LD]): Source-detector distance for each PSF
dr0 (float): Spacing in the PSF data
use_fft_conv (bool, optional): Whether or not to use FFT based convolutions. Defaults to True.
grid_sample_mode (str, optional): How to sample the PSF when the input spacing is not the same as the PSF. Defaults to 'bilinear'.
"""
def __init__(
self,
psf_data: torch.Tensor,
distances: torch.Tensor,
dr0: float,
use_fft_conv: bool = True,
grid_sample_mode: str = 'bilinear'
) -> None:
self.psf_data = psf_data
self.Nx0 = psf_data.shape[1]
self.Ny0 = psf_data.shape[2]
self.distances_original = distances
self.dr0 = dr0
self.use_fft_conv = use_fft_conv
self.params = []
self.grid_sample_mode = grid_sample_mode
[docs] def set_device(self, device: str) -> None:
# """Uses recursive docstring"""
self.psf_data = self.psf_data.to(device)
self.distances_original = self.distances_original.to(device)
[docs] def _conv(self, input: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
"""Performs convolution on the input data
Args:
input (torch.Tensor): Input data
kernel (torch.Tensor): Kernel to convolve with
Returns:
torch.Tensor: Convolved input data
"""
groups = input.shape[0]
if self.use_fft_conv:
return fft_conv(input.unsqueeze(0), kernel.unsqueeze(1), padding='same', groups=groups).squeeze()
else:
return conv2d(input.unsqueeze(0), kernel.unsqueeze(1), padding='same', groups=groups).squeeze()
[docs] def _get_nearest_distance_idxs(self, distances: torch.Tensor) -> torch.Tensor:
"""Obtains the indices of the nearest PSF to each distance
Args:
distances (torch.Tensor): Distances to find the nearest PSF for
Returns:
torch.Tensor: Array of indices of the nearest PSF
"""
differences = torch.abs(distances[:, None] - self.distances_original)
indices = torch.argmin(differences, dim=1)
return self.psf_data[indices]
[docs] def _get_kernel(
self,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
"""Obtains the kernel by sampling the nearest PSF at the appropriate location
Args:
xv (torch.Tensor[Lx,Ly]): Meshgrid x coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y coordinates
a (torch.Tensor[Ld]): Source-detector distances
Returns:
torch.Tensor[Ld,Lx,Ly]: Kernel obtained by sampling the nearest PSF
"""
dx = xv[0, 1] - xv[0, 0]
psf = self._get_nearest_distance_idxs(a)
grid = torch.stack([
2*xv/(self.Nx0 * self.dr0[0]),
2*yv/(self.Ny0 * self.dr0[1])],
dim=-1).unsqueeze(0).repeat(a.shape[0], 1, 1, 1)
return (dx/self.dr0[0])**2 * grid_sample(psf.unsqueeze(1), grid, align_corners=False, mode=self.grid_sample_mode).squeeze()
[docs] def __call__(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor,
normalize: bool = False
) -> torch.Tensor:
# """Uses recursive docstring"""
kernel = self._get_kernel(xv, yv, a)
return self._conv(input, kernel).squeeze()
# TODO: Make subclass of RotateSeperable2DConvOperator with one angle and Gaussian kernel
[docs]class GaussianOperator(Operator):
"""Gaussian operator; works by convolving the input with two perpendicular 1D kernels. This is implemented seperately from the Kernel2DOperator since it is more efficient to convolve with two 1D kernels than a 2D kernel.
Args:
amplitude_fn (Callable): Amplitude function for 1D Gaussian kernel
sigma_fn (Callable): Scale function for 1D Gaussian kernel
amplitude_params (torch.Tensor): Amplitude hyperparameters
sigma_params (torch.Tensor): Scaling hyperparameters
a_min (float, optional): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to -torch.inf.
a_max (float, optional):Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to torch.inf.
use_fft_conv (bool, optional): Whether or not to use FFT based convolution. Defaults to False.
"""
def __init__(
self,
amplitude_fn: Callable,
sigma_fn: Callable,
amplitude_params: torch.Tensor,
sigma_params: torch.Tensor,
a_min: float = -torch.inf,
a_max: float = torch.inf,
use_fft_conv: bool = False,
) -> None:
self.amplitude_fn = amplitude_fn
self.sigma_fn = sigma_fn
self.amplitude_params = amplitude_params
self.sigma_params = sigma_params
amplitude_fn1D = lambda a, bs: torch.sqrt(torch.abs(amplitude_fn(a, bs)))
self.kernel1D = GaussianKernel1D(amplitude_fn1D, sigma_fn, amplitude_params, sigma_params, a_min, a_max)
self.params = self.kernel1D.params
self.a_min = a_min
self.a_max = a_max
self.use_fft_conv = use_fft_conv
[docs] def normalization_constant(
self,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor
) -> torch.Tensor:
# """Uses recursive docstring"""
return self.kernel1D.normalization_constant(xv[0], a).unsqueeze(1)**2 * self._area(xv,yv)
[docs] def _conv(self, input: torch.Tensor, kernel1D: torch.Tensor) -> torch.Tensor:
"""Performs convolution on input
Args:
input (torch.Tensor): Input tensor
kernel1D (torch.Tensor): Gaussian 1D kernel
Returns:
torch.Tensor: Output convolved tensor
"""
if self.use_fft_conv:
return fft_conv(input, kernel1D, padding='same', groups=kernel1D.shape[0])
else:
return conv1d(input, kernel1D, padding='same', groups=kernel1D.shape[0])
[docs] def __call__(
self,
input: torch.Tensor,
xv: torch.Tensor,
yv: torch.Tensor,
a: torch.Tensor,
normalize: bool = False
) -> torch.Tensor:
# """Uses recursive docstring"""
x = xv[0]
kernel = self.kernel1D(x,a).unsqueeze(1)
input = input.swapaxes(0,1) # x needs to be channel index
for i in [0,2]:
input = input.swapaxes(i,2)
input = self._conv(input, kernel)
input= input.swapaxes(i,2)
input = input.swapaxes(0,1)
if normalize: input = self.normalize(input, xv, yv, a)
return input * self._area(xv,yv)