Source code for rubix.telescope.noise.noise
import jax.numpy as jnp
from jax import random as jrandom
from jaxtyping import Array, Float
SUPPORTED_NOISE_DISTRIBUTIONS = ["normal", "uniform"]
[docs]
def sample_noise(shape, type="normal", key=None):
"""Sample noise from a normal or uniform distribution.
Parameters
----------
shape : tuple
The shape of the noise array.
type : str, optional
The type of distribution to sample from. Can be either "normal" or "uniform", by default "normal".
key : jnp.array, optional
The random key to use for sampling, by default None.
Returns
-------
jnp.array
The sampled noise.
"""
if key is None:
key = jrandom.PRNGKey(0)
if type == "normal":
return jrandom.normal(key, shape)
elif type == "uniform":
return jrandom.uniform(key, shape)
else:
raise ValueError(
f"Invalid noise type: {type}. Supported types: {SUPPORTED_NOISE_DISTRIBUTIONS}"
)
[docs]
def calculate_S2N(
datacube: Float[Array, "n_x n_y n_wave_bins"], observation_signal_to_noise: float
) -> Float[Array, "n_y n_y"]:
"""
Calculate the signal-to-noise ratio array from a data cube.
Adapted from: https://github.com/kateharborne/SimSpin/blob/4e8f0af0ebc0e43cc31729978deb3a554e039f6b/R/build_datacube.R#L386
which implements equation 4 from Nanni et al. 2022
Parameters
----------
datacube : jnp.array (n_x, n_y, n_wave_bins)
The data cube with dimensions (n_x, n_y, n_wave_bins).
observation_signal_to_noise : float
The signal-to-noise ratio of the observation.
Returns
-------
jnp.array (n_x, n_y)
The signal-to-noise ratio array.
"""
# Sum up the spectra along the wavelength bins to get the flux image
flux_image = jnp.sum(datacube, axis=-1)
# Mask out regions where the flux is zero
nonzero_mask = flux_image > 0
# Calculate the median flux value where the flux is non-zero
median_flux = jnp.median(jnp.where(nonzero_mask, flux_image, jnp.nan))
median_flux = jnp.nan_to_num(median_flux, nan=0.0)
# Calculate the noise factor
noise_factor = jnp.sqrt(median_flux) / observation_signal_to_noise
# Calculate the signal-to-noise ratio for each pixel
S2N = noise_factor / jnp.sqrt(flux_image)
# Apply the mask to set S2N to zero where the flux is zero
S2N = jnp.where(nonzero_mask, S2N, 0)
return S2N
[docs]
def calculate_noise_cube(
cube: Float[Array, "n_x n_y n_wave_bins"],
signal_to_noise: float,
noise_distribution="normal",
) -> Float[Array, "n_x n_y n_wave_bins"]:
"""Calculate the noise cube given the cube and the signal-to-noise ratio.
Adapted from: https://github.com/kateharborne/SimSpin/blob/4e8f0af0ebc0e43cc31729978deb3a554e039f6b/R/utilities.R#L587
Parameters
----------
cube : jnp.array (n_x, n_y, n_wave_bins)
The data cube.
signal-to-noise : float
The signal-to-noise ratio of the observation.
noise_distribution: str, optional
The type of distribution to sample from. Can be either "normal" or "uniform", by default "normal".
Returns
-------
jnp.array (n_x, n_y, n_wave_bins)
The noise cube.
"""
key = jrandom.PRNGKey(0)
# S2N = jnp.where(
# jnp.isinf(S2N), 0, S2N
# ) # removing infinite noise where particles per pixel = 0
S2N = calculate_S2N(cube, signal_to_noise)
# Generate noise for each element in the cube based on the S/N
noise = sample_noise(cube.shape, type=noise_distribution, key=key) * S2N[:, :, None]
# Scale the noise by the cube to get S/N
noise = cube * noise
return noise