SSP interpolation#
Load SSP Grid#
# NBVAL_SKIP
from rubix.spectra.ssp.templates import BruzualCharlot2003
print(BruzualCharlot2003)
SSP lookup#
# NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.spectra.ssp.templates import BruzualCharlot2003
from jax import jit
ssp = BruzualCharlot2003
wave = ssp.wavelength
age_index = 0
met_index = 3
target_age = ssp.age[age_index] + 0.5*(ssp.age[age_index+1] - ssp.age[age_index])
print(target_age)
target_met = ssp.metallicity[met_index] + 0.5*(ssp.metallicity[met_index+1] - ssp.metallicity[met_index])
lookup = ssp.get_lookup_interpolation()
spec_calc = lookup(target_met, target_age)
spec_true = ssp.flux[met_index, age_index, :]
plt.plot(wave, spec_calc, label='calc')
plt.plot(wave, spec_true, label='true')
plt.legend()
plt.yscale('log')
# NBVAL_SKIP
# Check if it works with jit
spec_calc = jit(lookup)(target_met, target_age)
plt.plot(wave, spec_calc, label='calc jit')
# NBVAL_SKIP
from rubix.utils import load_galaxy_data
data, units = load_galaxy_data("output/rubix_galaxy.h5")
mass = data["particle_data"]["stars"]["mass"]
metallicity = data["particle_data"]["stars"]["metallicity"]
age = data["particle_data"]["stars"]["age"]
VMAP#
Vmap the lookup over the stellar particles
# NBVAL_SKIP
# Calculate spectra with vmap
from jax import vmap
lookup = ssp.get_lookup_interpolation()
subset = 1000
# Use only subset because it is too big to fit into gpu memory
met_subset = metallicity[:subset]
age_subset = age[:subset]
# Clip the metallicity and age values to the range of the SSP
met_subset = met_subset.clip(min(ssp.metallicity), max(ssp.metallicity))
age_subset = age_subset.clip(min(ssp.age), max(ssp.age))
spec_calc = vmap(lookup)(met_subset, age_subset)
spec_calc.shape
# NBVAL_SKIP
# check if it contains nan values
import jax.numpy as jnp
jnp.isnan(spec_calc).any()
Use configuration to load lookup function#
#NBVAL_SKIP
config ={ "ssp": {
"template": {
"name": "BruzualCharlot2003"
},
"method": "cubic"
}
}
# NBVAL_SKIP
from rubix.core.ssp import get_lookup_interpolation
lookup = get_lookup_interpolation(config)
# NBVAL_SKIP
# Check how many particles are outside the range of the SSP
import numpy as np
np.sum(metallicity < ssp.metallicity[0]), np.sum(metallicity > ssp.metallicity[-1])
# NBVAL_SKIP
np.sum(age < ssp.age[0]), np.sum(age > ssp.age[-1])
# NBVAL_SKIP
len(metallicity)
# NBVAL_SKIP
# clip the metallicity and age values to the range of the SSP
met_subset = met_subset.clip(min(ssp.metallicity), max(ssp.metallicity))
age_subset = age_subset.clip(min(ssp.age), max(ssp.age))
lookup(met_subset, age_subset)