SSP interpolation#

Load SSP Grid#

# NBVAL_SKIP
from rubix.spectra.ssp.templates import BruzualCharlot2003

print(BruzualCharlot2003)

SSP lookup#

# NBVAL_SKIP
import matplotlib.pyplot as plt
from rubix.spectra.ssp.templates import BruzualCharlot2003
from jax import jit

ssp = BruzualCharlot2003
wave = ssp.wavelength


age_index = 0
met_index = 3

target_age = ssp.age[age_index] + 0.5*(ssp.age[age_index+1] - ssp.age[age_index])
print(target_age)
target_met = ssp.metallicity[met_index] + 0.5*(ssp.metallicity[met_index+1] - ssp.metallicity[met_index])

lookup = ssp.get_lookup_interpolation()

spec_calc = lookup(target_met, target_age)

spec_true = ssp.flux[met_index, age_index, :]

plt.plot(wave, spec_calc, label='calc')
plt.plot(wave, spec_true, label='true')

plt.legend()
plt.yscale('log')
# NBVAL_SKIP
# Check if it works with jit

spec_calc = jit(lookup)(target_met, target_age)

plt.plot(wave, spec_calc, label='calc jit')
# NBVAL_SKIP
from rubix.utils import load_galaxy_data

data, units = load_galaxy_data("output/rubix_galaxy.h5")
mass = data["particle_data"]["stars"]["mass"]
metallicity = data["particle_data"]["stars"]["metallicity"]
age = data["particle_data"]["stars"]["age"]

VMAP#

Vmap the lookup over the stellar particles

# NBVAL_SKIP
# Calculate spectra with vmap
from jax import vmap

lookup = ssp.get_lookup_interpolation()

subset = 1000

# Use only subset because it is too big to fit into gpu memory
met_subset = metallicity[:subset]
age_subset = age[:subset]


# Clip the metallicity and age values to the range of the SSP

met_subset = met_subset.clip(min(ssp.metallicity), max(ssp.metallicity))
age_subset = age_subset.clip(min(ssp.age), max(ssp.age))


spec_calc = vmap(lookup)(met_subset, age_subset)


spec_calc.shape
# NBVAL_SKIP
# check if it contains nan values
import jax.numpy as jnp
jnp.isnan(spec_calc).any()

Use configuration to load lookup function#

#NBVAL_SKIP
config ={ "ssp": {
    "template": {
      "name": "BruzualCharlot2003"
    },
    "method": "cubic"
  }
}
# NBVAL_SKIP
from rubix.core.ssp import get_lookup_interpolation

lookup = get_lookup_interpolation(config)
# NBVAL_SKIP
# Check how many particles are outside the range of the SSP
import numpy as np
np.sum(metallicity < ssp.metallicity[0]), np.sum(metallicity > ssp.metallicity[-1])
# NBVAL_SKIP
np.sum(age < ssp.age[0]), np.sum(age > ssp.age[-1])
# NBVAL_SKIP
len(metallicity)
# NBVAL_SKIP
# clip the metallicity and age values to the range of the SSP
met_subset = met_subset.clip(min(ssp.metallicity), max(ssp.metallicity))
age_subset = age_subset.clip(min(ssp.age), max(ssp.age))
lookup(met_subset, age_subset)