from typing import Union
import jax
import jax.numpy as jnp
import numpy as np
from beartype import beartype as typechecker
from jaxtyping import Array, Float, Int, jaxtyped
from rubix import config
N_BINS_AXIS = "n_bins"
N_BINS_INITIAL_AXIS = "n_bins_initial"
N_BINS_TARGET_AXIS = "n_bins_target"
N_PARTICLE_AXIS = "n_particles"
VELOCITY_AXIS = "3"
PARTICLE_MATRIX_AXES = "n_particles 3"
ELLIPSIS_THREE_AXES = "... 3"
STAR_WAVE_AXES = "n_stars n_wave_bins"
SPAXEL_INDEX_AXIS = "n_stars"
SPAXEL_CUBE_AXES = "num_spaxels_x num_spaxels_y n_wave_bins"
[docs]
@jaxtyped(typechecker=typechecker)
def convert_luminoisty_to_flux(
luminosity: Float[Array, "..."],
observation_lum_dist: Union[Float[Array, "..."], float],
observation_z: float,
pixel_size: float,
CONSTANTS: dict = config["constants"],
) -> Float[Array, "..."]:
"""
Convert luminosity to flux in units erg/s/cm^2/Angstrom as observed by
the telescope.
The luminosity is object specific, the flux depends on the distance to the
object, the redshift, and the pixel size of the telescope.
Args:
luminosity (Float[Array, "..."]): Intrinsic luminosity per bin.
observation_lum_dist (Union[Float[Array, "..."], float]): Luminosity
distance in Mpc.
observation_z (float): Object redshift.
pixel_size (float): Telescope pixel size in cm.
CONSTANTS (dict, optional): Conversion constants. Defaults to
``config["constants"]``.
Returns:
Float[Array, "..."]: Flux in erg/s/cm^2/Å.
"""
CONST = float(CONSTANTS.get("LSOL_TO_ERG")) / (
float(CONSTANTS.get("MPC_TO_CM")) ** 2
)
FACTOR = (
CONST
/ (4 * jnp.pi * observation_lum_dist**2)
/ (1 + observation_z)
/ pixel_size
)
spectral_dist = luminosity * FACTOR
return spectral_dist
[docs]
@jaxtyped(typechecker=typechecker)
def convert_luminoisty_to_flux_factor(
observation_lum_dist,
observation_z,
pixel_size,
CONSTANTS=config["constants"],
):
"""Convert luminosity to flux in units erg/s/cm^2/Å."""
CONST = np.float64(
float(CONSTANTS.get("LSOL_TO_ERG")) / (float(CONSTANTS.get("MPC_TO_CM")) ** 2)
)
FACTOR = (
CONST
/ (4 * np.pi * np.float64(observation_lum_dist) ** 2)
/ (1 + np.float64(observation_z))
/ np.float64(pixel_size)
)
FACTOR = jnp.float64(FACTOR)
return FACTOR
[docs]
def cosmological_doppler_shift(
z: float,
wavelength: Float[Array, N_BINS_AXIS],
) -> Float[Array, N_BINS_AXIS]:
"""Apply the cosmological Doppler shift to a wavelength grid.
Args:
z (float): Object redshift.
wavelength (Float[Array, N_BINS_AXIS]): Wavelengths in Å.
Returns:
Float[Array, N_BINS_AXIS]: Doppler-shifted wavelengths in Å.
"""
# Calculate the cosmological Doppler shift of a wavelength
return (1 + z) * wavelength
[docs]
@jaxtyped(typechecker=typechecker)
def calculate_diff(
vec: Float[Array, "..."], pad_with_zero: bool = True
) -> Float[Array, "..."]:
"""Calculate consecutive differences along a vector.
Args:
vec (Float[Array, "..."]): Input grid.
pad_with_zero (bool, optional): If ``True`` prepend the first element
so the output matches the input length. Defaults to ``True``.
Returns:
Float[Array, "..."]: Finite differences of ``vec``.
"""
if pad_with_zero:
differences = jnp.diff(vec, prepend=vec[0])
else:
differences = jnp.diff(vec)
return differences
def _get_velocity_component_single(
vec: Float[Array, "..."],
direction: str,
) -> Float:
# Check that vec is of size 3
if not vec.size == 3:
raise ValueError(f"Expected vector of size 3, but got {vec.size}.")
if direction == "x":
return vec[0]
elif direction == "y":
return vec[1]
elif direction == "z":
return vec[2]
else:
raise ValueError(
f"{direction} is not a valid direction. Supported directions are "
f"'x', 'y', or 'z'."
)
def _get_velocity_component_multiple(
vecs: Float[Array, PARTICLE_MATRIX_AXES],
direction: str,
) -> Float[Array, N_PARTICLE_AXIS]:
# Check that vecs has shape (n_particles, 3)
if vecs.shape[1] != 3:
raise ValueError(
f"Expected vectors of shape (n_particles, 3), but got " f"{vecs.shape}."
)
if direction == "x":
return vecs[:, 0]
elif direction == "y":
return vecs[:, 1]
elif direction == "z":
return vecs[:, 2]
else:
raise ValueError(
f"{direction} is not a valid direction. Supported directions are "
f"'x', 'y', or 'z'."
)
[docs]
@jaxtyped(typechecker=typechecker)
def get_velocity_component(
vec: Float[Array, "..."], direction: str
) -> Float[Array, "..."]:
"""
This function returns the velocity component in a given direction.
Args:
vec (Float[Array, "..."]): The velocity vector.
direction (str): The direction in which to get the velocity component.
Supported directions are 'x', 'y', or 'z'.
Returns:
Float[Array, "..."]: Component extracted from ``vec``.
Raises:
ValueError: If ``vec`` does not have 1 or 2 dimensions or the
direction is invalid.
"""
if isinstance(vec, jax.Array) and vec.ndim == 2:
return _get_velocity_component_multiple(vec, direction)
elif isinstance(vec, jax.Array) and vec.ndim == 1:
return _get_velocity_component_single(vec, direction)
else:
raise ValueError(
f"Got wrong shapes. Expected vec.ndim =2 or vec.ndim=1, but got "
f"vec.ndim = {vec.ndim}"
)
def _velocity_doppler_shift_single(
wavelength: Float[Array, N_BINS_AXIS],
velocity: Float[Array, VELOCITY_AXIS],
direction: str = "y",
SPEED_OF_LIGHT: float = config["constants"]["SPEED_OF_LIGHT"],
) -> Float[Array, N_BINS_AXIS]:
"""Apply a velocity-induced Doppler shift for a single vector.
Args:
wavelength (Float[Array, N_BINS_AXIS]): Rest wavelengths in Å.
velocity (Float[Array, VELOCITY_AXIS]): Velocity components in km/s.
direction (str, optional): Component axis. Defaults to ``"y"``.
SPEED_OF_LIGHT (float, optional): Speed of light in km/s. Defaults to
``config["constants"]["SPEED_OF_LIGHT"]``.
Returns:
Float[Array, N_BINS_AXIS]: Doppler shifted wavelengths in Å.
"""
velocity = get_velocity_component(velocity, direction)
# Calculate the Doppler shift of a wavelength due to a velocity
# print(velocity/SPEED_OF_LIGHT)
# classic dopplershift, which is approximated 1 + v/c
return wavelength * jnp.exp(velocity / SPEED_OF_LIGHT)
# relativistic dopplershift
# return wavelength * jnp.sqrt(
# (1 + velocity / SPEED_OF_LIGHT)
# / (1 - velocity / SPEED_OF_LIGHT)
# )
# return wavelength
[docs]
@jaxtyped(typechecker=typechecker)
def velocity_doppler_shift(
wavelength: Float[Array, "..."],
velocity: Float[Array, ELLIPSIS_THREE_AXES],
direction: str = config["ifu"]["doppler"]["velocity_direction"],
SPEED_OF_LIGHT: float = config["constants"]["SPEED_OF_LIGHT"],
) -> Float[Array, "..."]:
"""Vectorized Doppler shift over multiple velocity vectors.
Args:
wavelength (Float[Array, "..."]): Rest wavelengths in Å.
velocity (Float[Array, ELLIPSIS_THREE_AXES]): Velocity components per
sample.
direction (str, optional): Axis to project onto. Defaults to
``config["ifu"]["doppler"]["velocity_direction"]``.
SPEED_OF_LIGHT (float, optional): Speed of light in km/s. Defaults to
``config["constants"]["SPEED_OF_LIGHT"]``.
Returns:
Float[Array, "..."]: Doppler shifted wavelengths per velocity entry.
"""
while velocity.shape[0] == 1:
velocity = jnp.squeeze(velocity, axis=0)
# if velocity.shape[0] == 1:
# velocity = jnp.squeeze(velocity, axis=0)
# Vmap the function to handle multiple velocities with the same wavelength
return jax.vmap(
lambda v: _velocity_doppler_shift_single(
wavelength, v, direction, SPEED_OF_LIGHT
)
)(velocity)
[docs]
@jaxtyped(typechecker=typechecker)
def resample_spectrum(
initial_spectrum: Float[Array, N_BINS_INITIAL_AXIS],
initial_wavelength: Float[Array, N_BINS_INITIAL_AXIS],
target_wavelength: Float[Array, N_BINS_TARGET_AXIS],
) -> Float[Array, N_BINS_TARGET_AXIS]:
"""Resample a spectrum onto a target wavelength grid.
Args:
initial_spectrum (Float[Array, N_BINS_INITIAL_AXIS]): Input spectrum.
initial_wavelength (Float[Array, N_BINS_INITIAL_AXIS]): Input grid in
Å.
target_wavelength (Float[Array, N_BINS_TARGET_AXIS]): Target grid in Å.
Returns:
Float[Array, N_BINS_TARGET_AXIS]: Flux conserved on the new grid.
"""
# Get wavelengths inside the telescope range
in_range_mask = (initial_wavelength >= jnp.min(target_wavelength)) & (
initial_wavelength <= jnp.max(target_wavelength)
)
intrinsic_wave_diff = calculate_diff(initial_wavelength) * in_range_mask
# Get total luminsoity within the wavelength range
total_lum = jnp.sum(initial_spectrum * intrinsic_wave_diff)
# Interpolate the wavelegnth to the telescope grid
particle_lum = jnp.interp(
target_wavelength,
initial_wavelength,
initial_spectrum,
)
# New total luminosity
new_total_lum = jnp.sum(particle_lum * calculate_diff(target_wavelength))
# Factor to conserve flux in the new spectrum
scale_factor = total_lum / new_total_lum
scale_factor = jnp.nan_to_num(
scale_factor, nan=0.0
) # Otherwise we get NaNs if new_total_lum is zero
lum = particle_lum * scale_factor
# jax.debug.print("total_lum: {}", total_lum)
# jax.debug.print("new_total_lum: {}", new_total_lum)
# jax.debug.print("scale_factor: {}", scale_factor)
# jax.debug.print("resampled spectrum: {}", lum)
# jax.debug.print("intrinsic_wave_diff: {}", intrinsic_wave_diff)
return lum
[docs]
@jaxtyped(typechecker=typechecker)
def calculate_cube(
spectra: Float[Array, STAR_WAVE_AXES],
spaxel_index: Int[Array, SPAXEL_INDEX_AXIS],
num_spaxels: int,
) -> Float[Array, SPAXEL_CUBE_AXES]:
"""Aggregate stellar spectra into a spatial data cube.
Args:
spectra (Float[Array, STAR_WAVE_AXES]): Individual spectra.
spaxel_index (Int[Array, SPAXEL_INDEX_AXIS]): Flat spaxel indices per
star.
num_spaxels (int): Number of spaxels per axis.
Returns:
Float[Array, SPAXEL_CUBE_AXES]: Summed cube.
"""
datacube = jax.ops.segment_sum(
spectra,
spaxel_index,
num_segments=num_spaxels**2,
)
datacube = datacube.reshape(num_spaxels, num_spaxels, spectra.shape[-1])
return datacube