Source code for rubix.core.ifu

import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from beartype.typing import Callable
from jax import lax
from jaxtyping import Array, Float, jaxtyped

from rubix import config as rubix_config
from rubix.core.data import GasData, StarsData
from rubix.logger import get_logger
from rubix.spectra.dust.extinction_models import RV_MODELS, Rv_model_dict
from rubix.spectra.ifu import (
    _velocity_doppler_shift_single,
    cosmological_doppler_shift,
    resample_spectrum,
)

from .data import RubixData
from .ssp import get_lookup_interpolation, get_ssp
from .telescope import get_telescope


[docs] @jaxtyped(typechecker=typechecker) def get_calculate_datacube_particlewise(config: dict) -> Callable: """Prepare a per-particle datacube builder for the star component. The returned callable performs an SSP lookup, scales by mass, applies the Doppler shift, resamples onto the telescope wavelength grid, and aggregates the flux into spatial pixels. First, it looks up the SSP spectrum for each star based on its age and metallicity, scales it by the star's mass, applies a Doppler shift based on the star's velocity, resamples the spectrum onto the telescope's wavelength grid, and finally accumulates the resulting spectra into the appropriate pixels of the datacube. Args: config (dict): Configuration dictionary containing telescope and galaxy parameters. Returns: Callable[[RubixData], RubixData]: Function that computes ``stars.datacube``. """ logger = get_logger(config.get("logger", None)) telescope = get_telescope(config) ns = int(telescope.sbin) nseg = ns * ns target_wave = telescope.wave_seq # (n_wave_tel,) # prepare SSP lookup lookup_ssp = get_lookup_interpolation(config) # prepare Doppler machinery velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"] z_obs = config["galaxy"]["dist_z"] ssp_model = get_ssp(config) ssp_wave0 = cosmological_doppler_shift( z=z_obs, wavelength=ssp_model.wavelength ) # (n_wave_ssp,) @jaxtyped(typechecker=typechecker) def calculate_datacube_particlewise(rubixdata: RubixData) -> RubixData: """Compute the star datacube for a single RubixData batch. Args: rubixdata (RubixData): Particle data with star attributes populated. Returns: RubixData: Same RubixData with ``stars.datacube`` populated. """ logger.info("Calculating Data Cube (combined per‐particle)…") stars = rubixdata.stars ages = stars.age # (n_stars,) metallicity = stars.metallicity # (n_stars,) masses = stars.mass # (n_stars,) velocities = stars.velocity # (n_stars,) pix_idx = stars.pixel_assignment # (n_stars,) nstar = ages.shape[0] # init flat cube: (nseg, n_wave_tel) init_cube = jnp.zeros((nseg, target_wave.shape[-1])) def body(cube, i): age_i = ages[i] # scalar Z_i = metallicity[i] # scalar m_i = masses[i] # scalar v_i = velocities[i] # scalar or vector pix_i = pix_idx[i].astype(jnp.int32) # 1) SSP lookup spec_ssp = lookup_ssp(Z_i, age_i) # (n_wave_ssp,) # 2) scale by mass spec_mass = spec_ssp * m_i # (n_wave_ssp,) # 3) Doppler‐shift wavelengths shifted_wave = _velocity_doppler_shift_single( wavelength=ssp_wave0, velocity=v_i, direction=velocity_direction, ) # (n_wave_ssp,) # 4) resample onto telescope grid spec_tel = resample_spectrum( initial_spectrum=spec_mass, initial_wavelength=shifted_wave, target_wavelength=target_wave, ) # (n_wave_tel,) # 5) accumulate cube = cube.at[pix_i].add(spec_tel) return cube, None cube_flat, _ = lax.scan( body, init_cube, jnp.arange(nstar, dtype=jnp.int32), ) cube_3d = cube_flat.reshape(ns, ns, -1) setattr(rubixdata.stars, "datacube", cube_3d) logger.debug(f"Datacube shape: {cube_3d.shape}") return rubixdata return calculate_datacube_particlewise
[docs] @jaxtyped(typechecker=typechecker) def get_calculate_dusty_datacube_particlewise(config: dict) -> Callable: """Prepare a dusty per-particle datacube builder for the star component. The returned callable is similar to :func:`get_calculate_datacube_particlewise` but applies wavelength-dependent extinction using the configured dust model. First, it looks up the SSP spectrum for each star based on its age and metallicity, scales it by the star's mass, applies a Doppler shift based on the star's velocity, resamples the spectrum onto the telescope's wavelength grid, and finally accumulates the resulting spectra into the appropriate pixels of the datacube. Args: config (dict): Configuration dictionary containing telescope and galaxy parameters as well as ``ssp.dust`` settings. Returns: Callable[[RubixData], RubixData]: Function that computes ``stars.datacube`` with extinction. """ logger = get_logger(config.get("logger", None)) telescope = get_telescope(config) ns = int(telescope.sbin) nseg = ns * ns target_wave = telescope.wave_seq # (n_wave_tel,) # prepare SSP lookup lookup_ssp = get_lookup_interpolation(config) # prepare Doppler machinery velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"] z_obs = config["galaxy"]["dist_z"] ssp_model = get_ssp(config) ssp_wave0 = cosmological_doppler_shift( z=z_obs, wavelength=ssp_model.wavelength ) # (n_wave_ssp,) @jaxtyped(typechecker=typechecker) def calculate_dusty_datacube_particlewise( rubixdata: RubixData, ) -> RubixData: """Apply SSP spectra, Doppler shifts, and extinction per particle. Args: rubixdata (RubixData): Particle data with dust extinction arrays. Returns: RubixData: Input data updated with ``stars.datacube``. Raises: ValueError: If the configured extinction model is unavailable. """ logger.info("Calculating Data Cube (combined per‐particle)…") stars = rubixdata.stars ages = stars.age # (n_stars,) metallicity = stars.metallicity # (n_stars,) masses = stars.mass # (n_stars,) velocities = stars.velocity # (n_stars,) pix_idx = stars.pixel_assignment # (n_stars,) Av_array = stars.extinction # (n_stars, n_wave_ssp) nstar = ages.shape[0] # dust model ext_model = config["ssp"]["dust"]["extinction_model"] Rv = config["ssp"]["dust"]["Rv"] # Dynamically choose the extinction model based on the string name if ext_model not in RV_MODELS: # pragma: no cover raise ValueError( "Extinction model '{ext_model}' is not available. " f"Choose from {RV_MODELS}." ) ext_model_class = Rv_model_dict[ext_model] ext = ext_model_class(Rv=Rv) # init flat cube: (nseg, n_wave_tel) init_cube = jnp.zeros((nseg, target_wave.shape[-1])) def body(cube, i): age_i = ages[i] # scalar Z_i = metallicity[i] # scalar m_i = masses[i] # scalar v_i = velocities[i] # scalar or vector pix_i = pix_idx[i].astype(jnp.int32) av_i = Av_array[i] # (n_wave_ssp,) # 1) SSP lookup spec_ssp = lookup_ssp(Z_i, age_i) # (n_wave_ssp,) # 2) scale by mass spec_mass = spec_ssp * m_i # (n_wave_ssp,) # 3) Doppler‐shift wavelengths shifted_wave = _velocity_doppler_shift_single( wavelength=ssp_wave0, velocity=v_i, direction=velocity_direction, ) # (n_wave_ssp,) # 4) resample onto telescope grid spec_tel = resample_spectrum( initial_spectrum=spec_mass, initial_wavelength=shifted_wave, target_wavelength=target_wave, ) # (n_wave_tel,) # apply extinction extinction = ext.extinguish(target_wave / 1e4, av_i) spec_extincted = spec_tel * extinction # (n_wave_tel,) # 5) accumulate cube = cube.at[pix_i].add(spec_extincted) return cube, None cube_flat, _ = lax.scan( body, init_cube, jnp.arange(nstar, dtype=jnp.int32), ) cube_3d = cube_flat.reshape(ns, ns, -1) setattr(rubixdata.stars, "datacube", cube_3d) logger.debug(f"Datacube shape: {cube_3d.shape}") return rubixdata return calculate_dusty_datacube_particlewise