Source code for rubix.core.ssp

from typing import Callable

import jax

from rubix.logger import get_logger
from rubix.spectra.ssp.factory import get_ssp_template

from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker


[docs] @jaxtyped(typechecker=typechecker) def get_ssp(config: dict) -> object: """ This function loads the simple stellar population (SSP) template defined in the configuration. Args: config (dict): Configuration dictionary. Returns: SSP template """ # Check if field exists if "ssp" not in config: raise ValueError("Configuration does not contain 'ssp' field") # Check if template exists if "template" not in config["ssp"]: raise ValueError("Configuration does not contain 'template' field") # Check if name exists if "name" not in config["ssp"]["template"]: raise ValueError("Configuration does not contain 'name' field") ssp = get_ssp_template(config["ssp"]["template"]["name"]) return ssp
[docs] @jaxtyped(typechecker=typechecker) def get_lookup_interpolation(config: dict) -> Callable: """ Loads the SSP template defined in the configuration and returns the lookup function for the template. The lookup function is a function that takes in the metallicity and age of a star and returns the spectrum of the star. This is later used to vmap over the stars metallicities and ages, and pmap over multiple GPUs. Args: config (dict): Configuration dictionary. Returns: Lookup function for the SSP template. """ logger_config = config.get("logger", None) logger = get_logger(logger_config) ssp = get_ssp(config) # Check if method is defined if "method" not in config["ssp"]: logger.debug("Method not defined, using default method: cubic") method = "cubic" else: logger.debug(f"Using method defined in config: {config['ssp']['method']}") method = config["ssp"]["method"] lookup = ssp.get_lookup_interpolation(method=method) return lookup
[docs] @jaxtyped(typechecker=typechecker) def get_lookup_interpolation_vmap(config: dict) -> Callable: """ This function loads the SSP template defined in the configuration and returns the lookup function for the template, vmapped over the stars metallicities and ages. Args: config (dict): Configuration dictionary. Returns: vmapped lookup function for the SSP template. """ lookup = get_lookup_interpolation(config) lookup_vmap = jax.vmap(lookup, in_axes=(0, 0)) return lookup_vmap
[docs] @jaxtyped(typechecker=typechecker) def get_lookup_interpolation_pmap(config: dict) -> Callable: """ Get the pmap version of the lookup function for the SSP template defined in the configuration. Args: config (dict): Configuration dictionary. Returns: pmapped lookup function for the SSP template. """ lookup_vmap = get_lookup_interpolation_vmap(config) lookup_pmap = jax.pmap(lookup_vmap, in_axes=(0, 0)) # type: ignore return lookup_pmap