Source code for kernel1d

from __future__ import annotations
from typing import Callable
import torch
import numpy as np
from torch.nn.functional import grid_sample
from torchquad import Simpson

[docs]class Kernel1D: def __init__(self, a_min: float, a_max: float) -> None: """Super class for all implemented 1D kernels in the library. All child classes should implement the normalization_factor and __call__ methods. Args: a_min (float): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. a_max (float): Maximum source-detector distance for the kernel; any distance values passed to __call__ above this value will be clamped to this value. """ self.a_min = a_min self.a_max = a_max
[docs] def normalization_constant(self, x: torch.Tensor,a: torch.Tensor) -> torch.Tensor: """Computes the normalization constant for the kernel Args: x (torch.Tensor[Lx]): Positions where the kernel is being evaluated a (torch.Tensor[Ld]): Source-detector distances at which to compute the normalization constant Returns: torch.Tensor[Ld]: Normalization constant at each source-detector distance """ ...
[docs] def __call__(x: torch.Tensor, a: torch.Tensor) -> torch.Tensor: """Returns the kernel evaluated at each source detector distance Args: x (torch.Tensor[Lx]): x values at which to evaluate the kernel a (torch.Tensor[Ld]): source-detector distances at which to evaluate the kernel Returns: torch.Tensor[Ld, Lx]: Kernel evaluated at each source-detector distance and x value """ ...
[docs]class ArbitraryKernel1D(Kernel1D): def __init__( self, kernel: torch.Tensor, amplitude_fn: Callable, sigma_fn: Callable, amplitude_params: torch.Tensor, sigma_params: torch.Tensor, dx0: float, a_min: float = -torch.inf, a_max: float = torch.inf, grid_sample_mode: str = 'bilinear' ) -> None: """1D Kernel defined using an arbitrary 1D array as the kernel with spacing dx0. The kernel is evaluated as f(x) = amplitude(a,b) * kernel_interp(x/sigma(a,b)) where amplitude and sigma are functions of the source-detector distance a and additional hyperparameters b. kernel_interp is the 1D interpolation of the kernel at the provided x values. Args: kernel (torch.Tensor): 1D array that defines the kernel amplitude_fn (Callable): Amplitude function that depends on the source-detector distance and additional hyperparameters sigma_fn (Callable): Scaling function that depends on the source-detector distance and additional hyperparameters amplitude_params (torch.Tensor): Hyperparameters for the amplitude function sigma_params (torch.Tensor): Hyperparameters for the sigma function dx0 (float): Spacing of the 1D kernel provided in the same units as x used in __call__ a_min (float, optional): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to -torch.inf. a_max (float, optional):Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to torch.inf. grid_sample_mode (str, optional): How to sample the kernel for general grids. Defaults to 'bilinear'. """ self.kernel = kernel self.amplitude_fn = amplitude_fn self.sigma_fn = sigma_fn self.amplitude_params = amplitude_params self.sigma_params = sigma_params self.params = [self.amplitude_params, self.sigma_params, self.kernel] self.dx0 = dx0 self.Nx0 = kernel.shape[0] self.a_min = a_min self.a_max = a_max self.grid_sample_mode = grid_sample_mode
[docs] def normalization_constant(self, x: torch.Tensor,a: torch.Tensor) -> torch.Tensor: """Computes the normalization constant for the kernel Args: x (torch.Tensor[Lx]): Positions where the kernel is being evaluated a (torch.Tensor[Ld]): Source-detector distances at which to compute the normalization constant Returns: torch.Tensor[Ld]: Normalization constant at each source-detector distance """ return self.amplitude_fn(a, self.amplitude_params) * torch.abs(self.kernel.sum()) * self.sigma_fn(a, self.sigma_params) * self.dx0 / (x[1]-x[0])
[docs] def __call__(self, x: torch.Tensor, a: torch.Tensor, normalize: bool = False) -> torch.Tensor: """Calls the kernel function Args: x (torch.Tensor[Lx]): Positions where the kernel is being evaluated a (torch.Tensor[Ld]): Source-detector distances at which to compute the kernel normalize (bool, optional): Whether or not to normalize the output of the kernel. Defaults to False. Returns: torch.Tensor[Ld,Lx]: Kernel evaluated at each source-detector distance and x value """ a = torch.clamp(a, self.a_min, self.a_max) sigma = self.sigma_fn(a, self.sigma_params) amplitude = self.amplitude_fn(a, self.amplitude_params) grid = (2*x / (self.Nx0*self.dx0 * sigma.unsqueeze(1).unsqueeze(1))) grid = torch.stack([grid, 0*grid], dim=-1) kernel = torch.abs(self.kernel).reshape(1,1,-1).repeat(a.shape[0],1,1) kernel = grid_sample(kernel.unsqueeze(1), grid, mode=self.grid_sample_mode, align_corners=False)[:,0,0] kernel = amplitude.reshape(-1,1) * kernel if normalize: kernel = kernel / self.normalization_constant(x,a.reshape(-1,1)) return kernel
[docs]class FunctionKernel1D(Kernel1D): def __init__( self, kernel_fn: Callable, amplitude_fn: Callable, sigma_fn: Callable, amplitude_params: torch.Tensor, sigma_params: torch.Tensor, a_min=-torch.inf, a_max=torch.inf, ) -> None: """Implementation of kernel1D where an explicit functional form is provided for the kernel. The kernel is evaluated as f(x) = amplitude(a,b) * k(x/sigma(a,b)) where amplitude and sigma are functions of the source-detector distance a and additional hyperparameters b. Args: kernel_fn (Callable): Kernel function k(x) amplitude_fn (Callable): Amplitude function amplitude(a,b) sigma_fn (Callable): Scaling function sigma(a,b) amplitude_params (torch.Tensor): Hyperparameters for the amplitude function sigma_params (torch.Tensor): Hyperparameters for the sigma function a_min (float, optional): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to -torch.inf. a_max (float, optional):Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to torch.inf. """ self.kernel_fn = kernel_fn self.amplitude_fn = amplitude_fn self.sigma_fn = sigma_fn self.amplitude_params = amplitude_params self.sigma_params = sigma_params self.params = [self.amplitude_params, self.sigma_params] self.a_min = a_min self.a_max = a_max self._compute_norm_via_integral()
[docs] def _compute_norm_via_integral(self) -> torch.Tensor: """Compute the normalization constant by integrating k(x) from -infinity to infinity. To do this, a variable transformation is used to convert the integral to a definite integral over the range [-pi/2, pi/2]. The definite integral is computed using the Simpson's rule. Returns: torch.Tensor: Integral of k(x) from -infinity to infinity """ # Convert to definite integral kernel_fn_definite = lambda t: self.kernel_fn(torch.tan(t)) * torch.cos(t)**(-2) # Should be good for most simple function cases self.kernel_fn_norm = Simpson().integrate(kernel_fn_definite, dim=1, N=1001, integration_domain=[[-torch.pi/2, torch.pi/2]])
[docs] def normalization_constant(self,x: torch.Tensor,a: torch.Tensor) -> torch.Tensor: """Computes the normalization constant for the kernel Args: x (torch.Tensor[Lx]): Positions where the kernel is being evaluated a (torch.Tensor[Ld]): Source-detector distances at which to compute the normalization constant Returns: torch.Tensor[Ld]: Normalization constant at each source-detector distance """ return self.kernel_fn_norm*self.amplitude_fn(a, self.amplitude_params)*self.sigma_fn(a, self.sigma_params) / (x[1]-x[0])
[docs] def __call__(self, x: torch.Tensor, a: torch.Tensor, normalize: bool = False) -> torch.Tensor: """Computes the kernel at each source-detector distance Args: x (torch.Tensor[Lx]): Positions where the kernel is being evaluated a (torch.Tensor[Ld]): Source-detector distances at which to compute the kernel normalize (bool, optional): Whether or not to normalize the output. Defaults to False. Returns: torch.Tensor[Ld,Lx]: Kernel evaluated at each source-detector distance and x value """ a = a.reshape(-1,1) a = torch.clamp(a, self.a_min, self.a_max) sigma = self.sigma_fn(a, self.sigma_params) amplitude = self.amplitude_fn(a, self.amplitude_params) kernel = amplitude*self.kernel_fn(x/sigma) if normalize: kernel = kernel / self.normalization_constant(x,a) return kernel
[docs]class GaussianKernel1D(FunctionKernel1D): def __init__( self, amplitude_fn: Callable, sigma_fn: Callable, amplitude_params: torch.Tensor, sigma_params: torch.Tensor, a_min: float = -torch.inf, a_max: float = torch.inf ) -> None: """Subclass of FunctionKernel1D that implements a Gaussian kernel. The kernel is evaluated as f(x) = amplitude(a,b) * exp(-0.5*x^2/sigma(a,b)^2) where amplitude and sigma are functions of the source-detector distance a and additional hyperparameters b. Args: amplitude_fn (Callable): Amplitude function amplitude(a,b) sigma_fn (Callable): Scaling function sigma(a,b) amplitude_params (torch.Tensor): Amplitude hyperparameters sigma_params (torch.Tensor): Sigma hyperparameters a_min (float, optional): Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to -torch.inf. a_max (float, optional):Minimum source-detector distance for the kernel; any distance values passed to __call__ below this value will be clamped to this value. Defaults to torch.inf. """ kernel_fn = lambda x: torch.exp(-0.5*x**2) super(GaussianKernel1D, self).__init__(kernel_fn, amplitude_fn, sigma_fn, amplitude_params, sigma_params,a_min, a_max)
[docs] def normalization_constant(self, x: torch.Tensor, a: torch.Tensor): """Computes the normalization constant for the kernel. For small FOV to sigma ratios, the normalization constant is computed using the definite integral of the kernel. For large FOV to sigma ratios, the normalization constant is computed by summing the values of the kernel; this prevents normalization errors when the Gaussian function is large compared to the pixel size. Args: x (torch.Tensor[Lx]): Values at which to compute the kernel a (torch.Tensor[Ld]): Source-detector distances at which to compute the normalization constant Returns: torch.Tensor[Ld]: Normalization constant at each source-detector distance """ dx = x[1] - x[0] sigma = self.sigma_fn(a, self.sigma_params) FOV_to_sigma_ratio = x.max() / sigma.max() if FOV_to_sigma_ratio > 4: return self(x,a).sum(dim=1).reshape(-1,1) else: a = torch.clamp(a, self.a_min, self.a_max) a = a.reshape(-1,1) return np.sqrt(2 * torch.pi) * self.amplitude_fn(a, self.amplitude_params) * self.sigma_fn(a, self.sigma_params) / dx