from typing import Callable, Union
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jax import lax
from jaxtyping import Array, Float, jaxtyped
from rubix import config as rubix_config
from rubix.core.data import GasData, StarsData
from rubix.logger import get_logger
from rubix.spectra.dust.extinction_models import *
from rubix.spectra.ifu import (
_velocity_doppler_shift_single,
cosmological_doppler_shift,
resample_spectrum,
)
from .data import RubixData
from .ssp import get_lookup_interpolation, get_ssp
from .telescope import get_telescope
[docs]
@jaxtyped(typechecker=typechecker)
def get_calculate_datacube_particlewise(config: dict) -> Callable:
"""
Create a function that calculates the datacube for the stars component
of a RubixData object on a per-particle basis. First, it looks up the SSP
spectrum for each star based on its age and metallicity, scales it by the
star's mass, applies a Doppler shift based on the star's velocity, resamples
the spectrum onto the telescope's wavelength grid, and finally accumulates
the resulting spectra into the appropriate pixels of the datacube.
Args:
config (dict): Configuration dictionary containing telescope and galaxy
parameters.
Returns:
Callable: A function that takes a RubixData object and returns it with
the datacube calculated and added to the stars component.
"""
logger = get_logger(config.get("logger", None))
telescope = get_telescope(config)
ns = int(telescope.sbin)
nseg = ns * ns
target_wave = telescope.wave_seq # (n_wave_tel,)
# prepare SSP lookup
lookup_ssp = get_lookup_interpolation(config)
# prepare Doppler machinery
velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"]
z_obs = config["galaxy"]["dist_z"]
ssp_model = get_ssp(config)
ssp_wave0 = cosmological_doppler_shift(
z=z_obs, wavelength=ssp_model.wavelength
) # (n_wave_ssp,)
@jaxtyped(typechecker=typechecker)
def calculate_datacube_particlewise(rubixdata: RubixData) -> RubixData:
logger.info("Calculating Data Cube (combined per‐particle)…")
stars = rubixdata.stars
ages = stars.age # (n_stars,)
metallicity = stars.metallicity # (n_stars,)
masses = stars.mass # (n_stars,)
velocities = stars.velocity # (n_stars,)
pix_idx = stars.pixel_assignment # (n_stars,)
nstar = ages.shape[0]
# init flat cube: (nseg, n_wave_tel)
init_cube = jnp.zeros((nseg, target_wave.shape[-1]))
def body(cube, i):
age_i = ages[i] # scalar
Z_i = metallicity[i] # scalar
m_i = masses[i] # scalar
v_i = velocities[i] # scalar or vector
pix_i = pix_idx[i].astype(jnp.int32)
# 1) SSP lookup
spec_ssp = lookup_ssp(Z_i, age_i) # (n_wave_ssp,)
# 2) scale by mass
spec_mass = spec_ssp * m_i # (n_wave_ssp,)
# 3) Doppler‐shift wavelengths
shifted_wave = _velocity_doppler_shift_single(
wavelength=ssp_wave0,
velocity=v_i,
direction=velocity_direction,
) # (n_wave_ssp,)
# 4) resample onto telescope grid
spec_tel = resample_spectrum(
initial_spectrum=spec_mass,
initial_wavelength=shifted_wave,
target_wavelength=target_wave,
) # (n_wave_tel,)
# 5) accumulate
cube = cube.at[pix_i].add(spec_tel)
return cube, None
cube_flat, _ = lax.scan(body, init_cube, jnp.arange(nstar, dtype=jnp.int32))
cube_3d = cube_flat.reshape(ns, ns, -1)
setattr(rubixdata.stars, "datacube", cube_3d)
logger.debug(f"Datacube shape: {cube_3d.shape}")
return rubixdata
return calculate_datacube_particlewise
[docs]
@jaxtyped(typechecker=typechecker)
def get_calculate_dusty_datacube_particlewise(config: dict) -> Callable:
"""
Create a function that calculates the datacube for the stars component
of a RubixData object on a per-particle basis. First, it looks up the SSP
spectrum for each star based on its age and metallicity, scales it by the
star's mass, applies a Doppler shift based on the star's velocity, resamples
the spectrum onto the telescope's wavelength grid, and finally accumulates
the resulting spectra into the appropriate pixels of the datacube.
Args:
config (dict): Configuration dictionary containing telescope and galaxy
parameters.
Returns:
Callable: A function that takes a RubixData object and returns it with
the datacube calculated and added to the stars component.
"""
logger = get_logger(config.get("logger", None))
telescope = get_telescope(config)
ns = int(telescope.sbin)
nseg = ns * ns
target_wave = telescope.wave_seq # (n_wave_tel,)
# prepare SSP lookup
lookup_ssp = get_lookup_interpolation(config)
# prepare Doppler machinery
velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"]
z_obs = config["galaxy"]["dist_z"]
ssp_model = get_ssp(config)
ssp_wave0 = cosmological_doppler_shift(
z=z_obs, wavelength=ssp_model.wavelength
) # (n_wave_ssp,)
@jaxtyped(typechecker=typechecker)
def calculate_dusty_datacube_particlewise(rubixdata: RubixData) -> RubixData:
logger.info("Calculating Data Cube (combined per‐particle)…")
stars = rubixdata.stars
ages = stars.age # (n_stars,)
metallicity = stars.metallicity # (n_stars,)
masses = stars.mass # (n_stars,)
velocities = stars.velocity # (n_stars,)
pix_idx = stars.pixel_assignment # (n_stars,)
Av_array = stars.extinction # (n_stars, n_wave_ssp)
nstar = ages.shape[0]
# dust model
ext_model = config["ssp"]["dust"]["extinction_model"]
Rv = config["ssp"]["dust"]["Rv"]
# Dynamically choose the extinction model based on the string name
if ext_model not in RV_MODELS:
raise ValueError(
f"Extinction model '{ext_model}' is not available. Choose from {RV_MODELS}."
)
ext_model_class = Rv_model_dict[ext_model]
ext = ext_model_class(Rv=Rv)
# init flat cube: (nseg, n_wave_tel)
init_cube = jnp.zeros((nseg, target_wave.shape[-1]))
def body(cube, i):
age_i = ages[i] # scalar
Z_i = metallicity[i] # scalar
m_i = masses[i] # scalar
v_i = velocities[i] # scalar or vector
pix_i = pix_idx[i].astype(jnp.int32)
av_i = Av_array[i] # (n_wave_ssp,)
# 1) SSP lookup
spec_ssp = lookup_ssp(Z_i, age_i) # (n_wave_ssp,)
# 2) scale by mass
spec_mass = spec_ssp * m_i # (n_wave_ssp,)
# 3) Doppler‐shift wavelengths
shifted_wave = _velocity_doppler_shift_single(
wavelength=ssp_wave0,
velocity=v_i,
direction=velocity_direction,
) # (n_wave_ssp,)
# 4) resample onto telescope grid
spec_tel = resample_spectrum(
initial_spectrum=spec_mass,
initial_wavelength=shifted_wave,
target_wavelength=target_wave,
) # (n_wave_tel,)
# apply extinction
extinction = ext.extinguish(target_wave / 1e4, av_i)
spec_extincted = spec_tel * extinction # (n_wave_tel,)
# 5) accumulate
cube = cube.at[pix_i].add(spec_extincted)
return cube, None
cube_flat, _ = lax.scan(body, init_cube, jnp.arange(nstar, dtype=jnp.int32))
cube_3d = cube_flat.reshape(ns, ns, -1)
setattr(rubixdata.stars, "datacube", cube_3d)
logger.debug(f"Datacube shape: {cube_3d.shape}")
return rubixdata
return calculate_dusty_datacube_particlewise