Source code for rubix.telescope.psf.psf

import jax.numpy as jnp
from jax import vmap
from jax.scipy.signal import convolve2d
from jaxtyping import Array, Float

from .kernels import gaussian_kernel_2d


def _convolve_plane(plane, kernel):
    """Convolve a single plane of a datacube with a kernel."""
    return convolve2d(plane, kernel, mode="same")


[docs] def get_psf_kernel(name: str, m: int, n: int, **kwargs) -> Float[Array, "m n"]: """ Get a point spread function (PSF) kernel. Args: name (str): The name of the PSF kernel to get. m (int): Kernel height in pixels. n (int): Kernel width in pixels. **kwargs: Additional keyword arguments to pass to the PSF kernel function. Returns: Float[Array, "m n"]: The PSF kernel. Raises: ValueError: If ``name`` is not a supported kernel type. """ if name == "gaussian": return gaussian_kernel_2d(m=m, n=n, **kwargs) else: raise ValueError(f"Unknown PSF kernel name: {name}")
[docs] def apply_psf( datacube: Float[Array, "n_pixel n_pixel wave_bins"], psf_kernel: Float[Array, "m n"] ) -> Float[Array, "n_pixel n_pixel wave_bins"]: """ Apply a point spread function (PSF) to the spectral datacube. The PSF kernel is convolved with each spectral plane of the datacube to simulate the blurring effect of the telescope. Args: datacube (Float[Array, "n_pixel n_pixel wave_bins"]): The spectral datacube to convolve with the PSF kernel. psf_kernel (Float[Array, "m n"]): The 2D PSF kernel to apply to the datacube. Returns: (Float[Array, "n_pixel n_pixel wave_bins"]): The datacube convolved with the PSF kernel. """ # Convolve each plane of the datacube with the PSF kernel # Vmap the convolution operation over the spectral dimension convolved = vmap(_convolve_plane, in_axes=(2, None))(datacube, psf_kernel) transposed = jnp.transpose(convolved, (1, 2, 0)) # Reorder to original shape return transposed