Source code for rubix.core.ifu

from typing import Callable

import jax
import jax.numpy as jnp

from rubix import config as rubix_config
from rubix.logger import get_logger
from rubix.spectra.ifu import (
    cosmological_doppler_shift,
    resample_spectrum,
    velocity_doppler_shift,
    calculate_cube,
)

from .ssp import get_lookup_interpolation_pmap, get_ssp
from .telescope import get_telescope

from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker


[docs] @jaxtyped(typechecker=typechecker) def get_calculate_spectra(config: dict) -> Callable: """ The function gets the lookup function that performs the lookup to the SSP model, and parallelizes the funciton across all GPUs. Args: config (dict): The configuration dictionary Returns: The function that calculates the spectra of the stars. Example ------- >>> config = { ... "ssp": { ... "template": { ... "name": "BruzualCharlot2003" ... }, ... }, ... } >>> from rubix.core.ifu import get_calculate_spectra >>> calcultae_spectra = get_calculate_spectra(config) >>> rubixdata = calcultae_spectra(rubixdata) >>> # Access the spectra of the stars >>> rubixdata.stars.spectra """ logger = get_logger(config.get("logger", None)) lookup_interpolation_pmap = get_lookup_interpolation_pmap(config) @jaxtyped(typechecker=typechecker) def calculate_spectra(rubixdata: object) -> object: logger.info("Calculating IFU cube...") logger.debug( f"Input shapes: Metallicity: {len(rubixdata.stars.metallicity)}, Age: {len(rubixdata.stars.age)}" ) # Ensure metallicity and age are arrays and reshape them to be at least 1-dimensional age_data = jax.device_get(rubixdata.stars.age) metallicity_data = jax.device_get(rubixdata.stars.metallicity) # Ensure they are not scalars or empty; convert to 1D arrays if necessary age = jnp.atleast_1d(age_data) metallicity = jnp.atleast_1d(metallicity_data) spectra = lookup_interpolation_pmap( # rubixdata.stars.metallicity, rubixdata.stars.age metallicity, age, ) # * inputs["mass"] logger.debug(f"Calculation Finished! Spectra shape: {spectra.shape}") spectra_jax = jnp.array(spectra) rubixdata.stars.spectra = spectra_jax # setattr(rubixdata.gas, "spectra", spectra) # jax.debug.print("Calculate Spectra: Spectra {}", spectra) return rubixdata return calculate_spectra
[docs] @jaxtyped(typechecker=typechecker) def get_scale_spectrum_by_mass(config: dict) -> Callable: """ The spectra of the stellar particles are scaled by the mass of the stars. Args: config (dict): The configuration dictionary Returns: The function that scales the spectra by the mass of the stars. Example ------- >>> from rubix.core.ifu import get_scale_spectrum_by_mass >>> scale_spectrum_by_mass = get_scale_spectrum_by_mass(config) >>> rubixdata = scale_spectrum_by_mass(rubixdata) >>> # Access the spectra of the stars, which is now scaled by the stellar mass >>> rubixdata.stars.spectra """ logger = get_logger(config.get("logger", None)) @jaxtyped(typechecker=typechecker) def scale_spectrum_by_mass(rubixdata: object) -> object: logger.info("Scaling Spectra by Mass...") mass = jnp.expand_dims(rubixdata.stars.mass, axis=-1) # rubixdata.stars.spectra = rubixdata.stars.spectra * mass spectra_mass = rubixdata.stars.spectra * mass setattr(rubixdata.stars, "spectra", spectra_mass) # jax.debug.print("mass mult: Spectra {}", inputs["spectra"]) return rubixdata return scale_spectrum_by_mass
# Vectorize the resample_spectrum function
[docs] @jaxtyped(typechecker=typechecker) def get_resample_spectrum_vmap(target_wavelength) -> Callable: """ The spectra of the stars are resampled to the telescope wavelength grid. Args: target_wavelength (jax.Array): The telescope wavelength grid Returns: The function that resamples the spectra to the telescope wavelength grid. """ @jaxtyped(typechecker=typechecker) def resample_spectrum_vmap(initial_spectrum, initial_wavelength): return resample_spectrum( initial_spectrum=initial_spectrum, initial_wavelength=initial_wavelength, target_wavelength=target_wavelength, ) return jax.vmap(resample_spectrum_vmap, in_axes=(0, 0))
# Parallelize the vectorized function across devices
[docs] @jaxtyped(typechecker=typechecker) def get_resample_spectrum_pmap(target_wavelength) -> Callable: """ Pmap the function that resamples the spectra of the stars to the telescope wavelength grid. Args: target_wavelength (jax.Array): The telescope wavelength grid Returns: The function that resamples the spectra to the telescope wavelength grid. """ vmapped_resample_spectrum = get_resample_spectrum_vmap(target_wavelength) return jax.pmap(vmapped_resample_spectrum)
[docs] @jaxtyped(typechecker=typechecker) def get_velocities_doppler_shift_vmap( ssp_wave: Float[Array, "..."], velocity_direction: str ) -> Callable: """ The function doppler shifts the wavelength based on the velocity of the stars. Args: ssp_wave (jax.Array): The wavelength of the SSP grid velocity_direction (str): The velocity component of the stars that is used to doppler shift the wavelength Returns: The function that doppler shifts the wavelength based on the velocity of the stars. """ def func(velocity): return velocity_doppler_shift( wavelength=ssp_wave, velocity=velocity, direction=velocity_direction ) return jax.vmap(func, in_axes=0)
[docs] @jaxtyped(typechecker=typechecker) def get_doppler_shift_and_resampling(config: dict) -> Callable: """ The function doppler shifts the wavelength based on the velocity of the stars and resamples the spectra to the telescope wavelength grid. Args: config (dict): The configuration dictionary Returns: The function that doppler shifts the wavelength based on the velocity of the stars and resamples the spectra to the telescope wavelength grid. Example ------- >>> from rubix.core.ifu import get_doppler_shift_and_resampling >>> doppler_shift_and_resampling = get_doppler_shift_and_resampling(config) >>> rubixdata = doppler_shift_and_resampling(rubixdata) >>> # Access the spectra of the stars, which is now doppler shifted and resampled to the telescope wavelength grid >>> rubixdata.stars.spectra """ logger = get_logger(config.get("logger", None)) # The velocity component of the stars that is used to doppler shift the wavelength velocity_direction = rubix_config["ifu"]["doppler"]["velocity_direction"] # The redshift at which the user wants to observe the galaxy galaxy_redshift = config["galaxy"]["dist_z"] # Get the telescope wavelength bins telescope = get_telescope(config) telescope_wavelenght = telescope.wave_seq # Get the SSP grid to doppler shift the wavelengths ssp = get_ssp(config) # Doppler shift the SSP wavelenght based on the cosmological distance of the observed galaxy ssp_wave = cosmological_doppler_shift(z=galaxy_redshift, wavelength=ssp.wavelength) logger.debug(f"SSP Wave: {ssp_wave.shape}") # Function to Doppler shift the wavelength based on the velocity of the stars particles # This binds the velocity direction, such that later we only need the velocity during the pipeline doppler_shift = get_velocities_doppler_shift_vmap(ssp_wave, velocity_direction) @jaxtyped(typechecker=typechecker) def doppler_shift_and_resampling(rubixdata: object) -> object: if rubixdata.stars.spectra is not None: # Doppler shift the SSP Wavelengths based on the velocity of the stars doppler_shifted_ssp_wave = doppler_shift(rubixdata.stars.velocity) logger.info("Doppler shifting and resampling stellar spectra...") logger.debug(f"Doppler Shifted SSP Wave: {doppler_shifted_ssp_wave.shape}") logger.debug(f"Telescope Wave Seq: {telescope.wave_seq.shape}") # Function to resample the spectrum to the telescope wavelength grid resample_spectrum_pmap = get_resample_spectrum_pmap(telescope_wavelenght) # jax.debug.print("doppler shifted ssp wave {}", doppler_shifted_ssp_wave) # jax.debug.print("Spectra before resampling {}", inputs["spectra"]) spectrum_resampled = resample_spectrum_pmap( rubixdata.stars.spectra, doppler_shifted_ssp_wave ) # rubixdata.stars.spectra = spectrum_resampled setattr(rubixdata.stars, "spectra", spectrum_resampled) # jax.debug.print("doppler shift and resampl: Spectra {}", inputs["spectra"]) if rubixdata.gas.spectra is not None: # Doppler shift the SSP Wavelengths based on the velocity of the gas particles doppler_shifted_ssp_wave = doppler_shift(rubixdata.gas.velocity) logger.info("Doppler shifting and resampling gas spectra...") logger.debug(f"Doppler Shifted SSP Wave: {doppler_shifted_ssp_wave.shape}") logger.debug(f"Telescope Wave Seq: {telescope.wave_seq.shape}") # Function to resample the spectrum to the telescope wavelength grid resample_spectrum_pmap = get_resample_spectrum_pmap(telescope_wavelenght) spectrum_resampled = resample_spectrum_pmap( rubixdata.gas.spectra, doppler_shifted_ssp_wave ) return rubixdata return doppler_shift_and_resampling
[docs] @jaxtyped(typechecker=typechecker) def get_calculate_datacube(config: dict) -> Callable: """ The function returns the function that calculates the datacube of the stars. Args: config (dict): The configuration dictionary Returns: The function that calculates the datacube of the stars. Example ------- >>> from rubix.core.ifu import get_calculate_datacube >>> calculate_datacube = get_calculate_datacube(config) >>> rubixdata = calculate_datacube(rubixdata) >>> # Access the datacube of the stars >>> rubixdata.stars.datacube """ logger = get_logger(config.get("logger", None)) telescope = get_telescope(config) num_spaxels = int(telescope.sbin) # Bind the num_spaxels to the function calculate_cube_fn = jax.tree_util.Partial(calculate_cube, num_spaxels=num_spaxels) calculate_cube_pmap = jax.pmap(calculate_cube_fn) @jaxtyped(typechecker=typechecker) def calculate_datacube(rubixdata: object) -> object: logger.info("Calculating Data Cube...") ifu_cubes = calculate_cube_pmap( spectra=rubixdata.stars.spectra, spaxel_index=rubixdata.stars.pixel_assignment, ) datacube = jnp.sum(ifu_cubes, axis=0) logger.debug(f"Datacube Shape: {datacube.shape}") # logger.debug(f"This is the datacube: {datacube}") datacube_jax = jnp.array(datacube) setattr(rubixdata.stars, "datacube", datacube_jax) # rubixdata.stars.datacube = datacube return rubixdata return calculate_datacube