from __future__ import annotations
from typing import Callable
import torch
import numpy as np
from functools import partial
from torch.nn.functional import conv1d, conv2d, grid_sample
from torchvision.transforms.functional import rotate
from torchquad import Simpson
[docs]device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs]class Kernel2D:
def __init__(self, a_min: float, a_max: float) -> None:
"""Parent class for 2D kernels. All children class should inherit from this class and implement the __call__ and normalization_constant methods.
Args:
a_min (float, optional): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value.
a_max (float, optional):Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value.
"""
self.a_min = a_min
self.a_max = a_max
[docs] def normalization_constant(self, a: torch.Tensor) -> torch.Tensor:
"""Computes the normalization constant for the kernel at each source-detector distance a. This method should be implemented in the child class.
Args:
a (torch.Tensor[Ld]): Source detector distances
Returns:
torch.Tensor[Ld]: Normalization at each source-detector distance
"""
...
[docs] def __call__(xv: torch.Tensor, yv: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
"""Computes the kernel value at each point in the meshgrid defined by xv and yv for each source-detector distance a. This method should be implemented in the child class.
Args:
xv (torch.Tensor[Lx,Ly]): Meshgrid x-coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y-coordinates
a (torch.Tensor[Ld]): Source-detector distances
Returns:
torch.Tensor[Ld,Lx,Ly]: Kernel values at each point in the meshgrid for each source-detector distance
"""
...
[docs]class FunctionalKernel2D(Kernel2D):
def __init__(
self,
kernel_fn: Callable,
amplitude_fn: Callable,
sigma_fn: Callable,
amplitude_params: torch.Tensor,
sigma_params: torch.Tensor,
a_min: float = -torch.inf,
a_max: float = torch.inf
) -> None:
"""2D kernel where the kernel is specified explicitly given a function of x and y. The kernel is evaluated as f(x,y) = amplitude(a,b) * k(x/sigma(a,b),y/sigma(a,b)) where amplitude and sigma are functions of the source-detector distance a and additional hyperparameters b.
Args:
kernel_fn (Callable): Kernel function k(x,y)
amplitude_fn (Callable): Amplitude function amplitude(a,b)
sigma_fn (Callable): Scaling function sigma(a,b)
amplitude_params (torch.Tensor): Amplitude hyperparameters b
sigma_params (torch.Tensor): Scaling hyperparameters b
a_min (float, optional): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to -torch.inf.
a_max (float, optional):Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to torch.inf.
"""
super(FunctionalKernel2D, self).__init__(a_min, a_max)
self.kernel_fn = kernel_fn
self.amplitude_fn = amplitude_fn
self.sigma_fn = sigma_fn
self.amplitude_params = amplitude_params
self.sigma_params = sigma_params
self.all_parameters = [amplitude_params, sigma_params]
self._compute_norm_via_integral()
[docs] def _compute_norm_via_integral(self) -> None:
"""Compute the normalization constant by integrating k(x,y) from -infinity to infinity. To do this, a variable transformation is used to convert the integral to a definite integral over the range [-pi/2, pi/2]. The definite integral is computed using the Simpson's rule.
"""
# Convert to definite integral
kernel_fn_definite = lambda t: self.kernel_fn(torch.tan(t[:,0]), torch.tan(t[:,1])) * torch.cos(t[:,0])**(-2) * torch.cos(t[:,1])**(-2)
# Should be good for most simple function cases
self.kernel_fn_norm = Simpson().integrate(kernel_fn_definite, dim=2, N=1001, integration_domain=[[-torch.pi/2, torch.pi/2], [-torch.pi/2, torch.pi/2]])
[docs] def normalization_constant(self, xv: torch.Tensor, yv: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
"""Obtains the normalization constant for the 2D kernel
Args:
xv (torch.Tensor[Lx,Ly]): Meshgrid x-coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y-coordinates
a (torch.Tensor[Ld]): Source-detector distances
Returns:
torch.Tensor[Ld]: Normalization constant at each source-detector distance
"""
return self.kernel_fn_norm*self.amplitude_fn(a, self.amplitude_params)*self.sigma_fn(a, self.sigma_params)**2 / (xv[0,1]-xv[0,0]) / (yv[1,0] - yv[0,0])
[docs] def __call__(self, xv: torch.Tensor, yv: torch.Tensor, a: torch.Tensor, normalize: bool = False) -> torch.Tensor:
"""Computes the kernel at each source detector distance
Args:
xv (torch.Tensor[Lx,Ly]): Meshgrid x-coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y-coordinates
a (torch.Tensor[Ld]): Source-detector distances
normalize (bool, optional): Whether or not to normalize the output of the kernel. Defaults to False.
Returns:
torch.Tensor[Ld,Lx,Ly]: Kernel at each source-detector distance
"""
a = torch.clamp(a, self.a_min, self.a_max)
a = a.reshape(-1,1,1)
N = 1 if normalize is False else self.normalization_constant(xv, yv, a)
return self.amplitude_fn(a, self.amplitude_params) * self.kernel_fn(xv/self.sigma_fn(a, self.sigma_params), yv/self.sigma_fn(a, self.sigma_params)) / N
[docs]class NGonKernel2D(Kernel2D):
def __init__(
self,
N_sides: int,
Nx: int,
collimator_width: float,
amplitude_fn: Callable,
sigma_fn: Callable,
amplitude_params: torch.Tensor,
sigma_params: torch.Tensor,
a_min = -torch.inf,
a_max = torch.inf,
rot: float = 0
) -> None:
"""Implementation of the arbitrary polygon kernel. This kernel is composed of a polygon shape convolved with itself, which is shown to be the true geometric component of the SPECT PSF when averaged over random collimator movement to get a linear shift invariant approximation. The kernel is computed as f(x,y) = amplitude(a,b) * k(x/sigma(a,b),y/sigma(a,b)) where k(x,y) is the convolved polygon shape.
Args:
N_sides (int): Number of sides of the polygon. Currently only supports even side lengths
Nx (int): Number of voxels to use for constructing the polygon (seperate from any meshgrid stuff done later on)
collimator_width (float): Width of the polygon (from flat edge to flat edge)
amplitude_fn (Callable): Amplitude function amplitude(a,b)
sigma_fn (Callable): Scaling function sigma(a,b)
amplitude_params (torch.Tensor): Amplitude hyperparameters b
sigma_params (torch.Tensor): Scaling hyperparameters b
a_min (float, optional): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to -torch.inf.
a_max (float, optional):Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to torch.inf.
rot (float, optional): Initial rotation of the polygon flat side. Defaults to 0 (first flat side aligned with +y axis).
"""
self.N_sides = N_sides
self.Nx = Nx
self.N_voxels_to_face = int(np.floor(Nx/6 * np.cos(np.pi/self.N_sides)))
self.collimator_width_voxels = 2 * self.N_voxels_to_face
self.pixel_size = collimator_width / self.collimator_width_voxels
self.collimator_width = collimator_width
self.amplitude_fn = amplitude_fn
self.sigma_fn = sigma_fn
self.amplitude_params = amplitude_params
self.sigma_params = sigma_params
self.rot = rot
self._compute_convolved_polygon()
self.a_min = a_min
self.a_max = a_max
self.params = [amplitude_params, sigma_params]
[docs] def _compute_convolved_polygon(self):
"""Computes the convolved polygon
"""
# Create
x = torch.zeros((self.Nx,self.Nx)).to(device)
x[:(self.Nx-1)//2 + self.N_voxels_to_face] = 1
polygon = []
for i in range(self.N_sides):
polygon.append(rotate(x.unsqueeze(0), 360*i/self.N_sides+self.rot).squeeze())
polygon = torch.stack(polygon).prod(dim=0)
convolved_polygon = conv2d(polygon.unsqueeze(0).unsqueeze(0), polygon.unsqueeze(0).unsqueeze(0), padding='same').squeeze()
self.convolved_polygon = convolved_polygon / convolved_polygon.max()
self.convolved_polygon_sum = self.convolved_polygon.sum()
[docs] def normalization_constant(self, xv: torch.Tensor, yv: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
"""Computes the normalization constant for the kernel at each source-detector distance a.
Args:
xv (torch.Tensor[Lx,Ly]): Meshgrid x-coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y-coordinates
a (torch.Tensor[Ld]): Source-detector distances
Returns:
torch.Tensor[Ld]: Normalization constant at each source detector distance.
"""
if self.grid_max<1:
# For cases where the polygon kernel exceeds the boundary
dx = xv[0,1] - xv[0,0]
dy = yv[1,0] - yv[0,0]
a = torch.clamp(a, self.a_min, self.a_max)
a = a.reshape(-1,1,1)
return self.amplitude_fn(a, self.amplitude_params) * self.pixel_size**2 * self.sigma_fn(a, self.sigma_params)**2 * self.convolved_polygon_sum / dx / dy
else:
# This is called nearly 100% of the time
return self.kernel.sum(dim=(1,2)).reshape(-1,1,1)
[docs] def __call__(self, xv: torch.Tensor, yv: torch.Tensor, a: torch.Tensor, normalize: bool = False):
"""Computes the kernel at each source detector distance
Args:
xv (torch.Tensor[Lx,Ly]): Meshgrid x-coordinates
yv (torch.Tensor[Lx,Ly]): Meshgrid y-coordinates
a (torch.Tensor[Ld]): Source-detector distances
normalize (bool, optional): Whether or not to normalize the output of the kernel. Defaults to False.
Returns:
torch.Tensor[Ld,Lx,Ly]: Kernel at each source-detector distance
"""
a = torch.clamp(a, self.a_min, self.a_max)
a = a.reshape(-1,1,1)
sigma = self.sigma_fn(a, self.sigma_params)
grid = torch.stack([
2*xv/(self.Nx*self.pixel_size*sigma),
2*yv/(self.Nx*self.pixel_size*sigma)],
dim=-1)
self.grid_max = grid.max()
amplitude = self.amplitude_fn(a, self.amplitude_params).reshape(a.shape[0],1,1)
self.kernel = amplitude * grid_sample(self.convolved_polygon.unsqueeze(0).unsqueeze(0).repeat(a.shape[0],1,1,1), grid, mode = 'bilinear', align_corners=False)[:,0]
if normalize:
self.kernel = self.kernel / self.normalization_constant(xv, yv, a)
return self.kernel