Source code for rubix.telescope.lsf.lsf
"""
Mainly reimplmented from SimSpin:
https://github.com/kateharborne/SimSpin/blob/4e8f0af0ebc0e43cc31729978deb3a554e039f6b/R/utilities.R#L570
"""
import jax.numpy as jnp
from jax.scipy.signal import convolve
from jax import vmap
from jaxtyping import Float, Array
[docs]
def gaussian1d(x: Float[Array, " n_x"], sigma: float) -> Float[Array, " n_x"]:
res = jnp.exp(-0.5 * (x**2) / sigma**2)
# return jnp.exp(-0.5 * ((x - mu) / sigma) ** 2) / (sigma * jnp.sqrt(2 * jnp.pi))
return res / jnp.sum(res)
def _convolve_kernel(spec, kernel, mode="full"):
return convolve(spec, kernel, mode=mode)
def _get_kernel(sigma: float, wave_res: float, factor: int = 12):
x = jnp.arange(-factor * wave_res, factor * wave_res + wave_res, step=wave_res)
kernel = gaussian1d(x, sigma)
return kernel
[docs]
def apply_lsf_spectra(
spectra: Float[Array, "n_spectra wave_bins"],
lsf_sigma: float,
wave_resolution: float,
extend_factor: int = 12,
) -> Float[Array, "n_spectra wave_bins"]:
"""Apply the Line Spread Function (LSF) to multiple spectra.
This function applies the LSF to multiple spectra in parallel using JAX's vmap.
Currently only supports a Gaussian kernel and fixed wave resolution across all spectra and wavelenghts.
Parameters
----------
spectra : ndarray
The input spectra to apply the LSF to.
lsf_sigma : float
The sigma of the LSF. Currently a Gaussian kernel.
wave_resolution : float
The wave resolution of the spectra.
extend_factor : int
The factor to extend the kernel by.
Returns
-------
convolved : ndarray
The convolved spectra.
"""
kernel = _get_kernel(lsf_sigma, wave_resolution, factor=extend_factor)
# Vmap the convolution across all stars
convolved = vmap(_convolve_kernel, in_axes=(0, None))(spectra, kernel)
end = spectra.shape[1] + kernel.shape[0] - 1 - extend_factor
return convolved[:, extend_factor:end]
[docs]
def apply_lsf(
datacube: Float[Array, "n1 n2 wave_bins"],
lsf_sigma: float,
wave_resolution: float,
extend_factor: int = 12,
) -> Float[Array, "n1 n2 wave_bins"]:
"""Apply the Line Spread Function (LSF) to a datacube.
This function first flattens the datacube, applies the LSF to the spectra, and then reshapes the datacube back to the original shape.
Parameters
----------
datacube : ndarray
The input datacube to apply the LSF to.
lsf_sigma : float
The sigma of the LSF. Currently a Gaussian kernel.
wave_resolution : float
The wave resolution of the spectra inside the datacube.
extend_factor : int
The factor to extend the kernel by.
Returns
-------
convolved : ndarray
The convolved datacube.
"""
dimensions = datacube.shape
# flatten the datacube
datacube = datacube.reshape(-1, dimensions[-1])
# Apply LSF to the spectra
convolved = apply_lsf_spectra(datacube, lsf_sigma, wave_resolution, extend_factor)
# Reshape back to the original shape
# This assumes that input and output shape after convolution are the same
return convolved.reshape(dimensions)