Source code for rubix.telescope.psf.kernels

import jax.numpy as jnp
from jaxtyping import Float, Array


[docs] def gaussian_kernel_2d(m: int, n: int, sigma: float) -> Float[Array, "m n"]: """Create a 2D Gaussian kernel of size m x n with standard deviation sigma. The kernel is normalized so that the sum of all elements is 1. Parameters ---------- m : int The number of rows in the kernel. n : int The number of columns in the kernel. sigma : float The standard deviation of the Gaussian kernel. Returns ------- Float[Array, "m n"] The 2D Gaussian kernel of size m x n with standard deviation sigma. """ x = jnp.arange(-((m - 1) / 2), ((m - 1) / 2) + 1) y = jnp.arange(-((n - 1) / 2), ((n - 1) / 2) + 1) X, Y = jnp.meshgrid(x, y, indexing="ij") values = jnp.exp(-(X**2 + Y**2) / (2 * sigma**2)) normalized = values / jnp.sum(values) return normalized