Source code for rubix.telescope.psf.psf

import jax.numpy as jnp
from jax.scipy.signal import convolve2d
from jaxtyping import Array, Float
from jax import vmap
from .kernels import gaussian_kernel_2d


def _convolve_plane(plane, kernel):
    """Convolve a single plane of a datacube with a kernel."""
    return convolve2d(plane, kernel, mode="same")


[docs] def get_psf_kernel(name: str, m: int, n: int, **kwargs) -> Float[Array, "m n"]: """Get a point spread function (PSF) kernel. Parameters ---------- name : str The name of the PSF kernel to get. **kwargs Additional keyword arguments to pass to the PSF kernel function. Returns ------- Float[Array, "m n"] The PSF kernel. """ if name == "gaussian": return gaussian_kernel_2d(m=m, n=n, **kwargs) else: raise ValueError(f"Unknown PSF kernel name: {name}")
[docs] def apply_psf( datacube: Float[Array, "n_pixel n_pixel wave_bins"], psf_kernel: Float[Array, "m n"] ) -> Float[Array, "n_pixel n_pixel wave_bins"]: """Apply a point spread function (PSF) to the spectral datacube. The PSF kernel is convolved with each spectral plane of the datacube to simulate the blurring effect of the telescope. Parameters ---------- datacube : Float[Array, "n_pixel n_pixel wave_bins"] The spectral datacube to convolve with the PSF kernel. psf_kernel : Float[Array, "m n"] The 2D PSF kernel to apply to the datacube. Returns ------- Float[Array, "n_pixel n_pixel wave_bins"] The datacube convolved with the PSF kernel. """ # Convolve each plane of the datacube with the PSF kernel # Vmap the convolution operation over the spectral dimension convolved = vmap(_convolve_plane, in_axes=(2, None))(datacube, psf_kernel) transposed = jnp.transpose(convolved, (1, 2, 0)) # Reorder to original shape return transposed