import equinox as eqx
import jax.numpy as jnp
from astropy import units as u
from astropy.io import fits
import os
import h5py
import requests
from rubix import config as rubix_config
from rubix.logger import get_logger
from interpax import interp2d
from jax.tree_util import Partial
from dataclasses import dataclass, fields
from typing import List, Tuple, Union
from jaxtyping import Int, Array, Float, jaxtyped
from beartype import beartype as typechecker
SSP_UNITS = rubix_config["ssp"]["units"]
[docs]
@dataclass
class SSPGrid:
"""
Base class for all SSP models.
"""
age: Float[Array, " age_bins"]
metallicity: Float[Array, " metallicity_bins"]
wavelength: Float[Array, " wavelength_bins"]
flux: Float[Array, "metallicity_bins age_bins wavelength_bins"]
# This does not work with jax.jit, gives error that str is not valid Jax type
# units: Dict[str, str] = eqx.field(default_factory=dict)
def __init__(self, age, metallicity, wavelength, flux, _logger=None):
self.age = jnp.asarray(age)
self.metallicity = jnp.asarray(metallicity)
self.wavelength = jnp.asarray(wavelength)
self.flux = jnp.asarray(flux)
# self.units = SSP_UNITS
[docs]
@jaxtyped(typechecker=typechecker)
def keys(self) -> List[str]:
"""
Returns the keys of the dataclass.
Args:
None
Returns:
List of keys of the dataclass.
"""
return [f.name for f in fields(self)]
def __iter__(self):
yield from (getattr(self, field.name) for field in fields(self))
[docs]
@jaxtyped(typechecker=typechecker)
def get_lookup_interpolation(
self, method: str = "cubic", extrap: int = 0
) -> Partial:
"""
Returns a 2D interpolation function for the SSP grid.
The function can be called with metallicity and age as arguments to get the flux at that metallicity and age.
Parameters
----------
method : str
The method to use for interpolation. Default is "cubic".
extrap: float, bool or tuple
The value to return for points outside the interpolation domain. Default is 0.
See https://interpax.readthedocs.io/en/latest/_api/interpax.Interpolator2D.html#interpax.Interpolator2D
Returns:
The 2D interpolation function Ìnterp2D`.
Example 1
----------
>>> grid = SSPGrid(...)
>>> lookup = grid.get_lookup_interpolation()
>>> metallicity = 0.02
>>> age = 1e9
>>> flux = lookup(metallicity, age)
Example 2
----------
>>> import matplotlib.pyplot as plt
>>> from rubix.spectra.ssp.templates import BruzualCharlot2003
>>> from jax import jit
>>> ssp = BruzualCharlot2003
>>> wave = ssp.wavelength
>>> age_index = 0
>>> met_index = 3
>>> target_age = ssp.age[age_index] + 0.5*(ssp.age[age_index+1] - ssp.age[age_index])
>>> target_met = ssp.metallicity[met_index] + 0.5*(ssp.metallicity[met_index+1] - ssp.metallicity[met_index])
>>> lookup = ssp.get_lookup_interpolation()
>>> spec_calc = lookup(target_met, target_age)
>>> spec_true = ssp.flux[met_index, age_index, :]
>>> plt.plot(wave, spec_calc, label='calc')
>>> plt.plot(wave, spec_true, label='true')
>>> plt.legend()
>>> plt.yscale('log')
"""
# Bind the SSP grid to the interpolation function
interp = Partial(
interp2d,
method=method,
x=self.metallicity,
y=self.age,
f=self.flux,
extrap=extrap,
)
interp.__doc__ = (
"Interpolation function for SSP grid, args: f(metallicity, age)"
)
return interp
[docs]
@jaxtyped(typechecker=typechecker)
@staticmethod
def convert_units(
data: Union[Float[Array, "..."], Int[Array, "..."]],
from_units: str,
to_units: str,
) -> Float[Array, "..."]:
"""
Convert the units of the data from `from_units` to `to_units`.
Args:
data (array-like): The data to convert.
from_units (str): The units of the data.
to_units (str): The units to convert to.
Returns:
The data converted to the new units.
"""
quantity = u.Quantity(data, from_units)
return jnp.array(quantity.to(to_units).value, dtype=jnp.float32)
[docs]
@jaxtyped(typechecker=typechecker)
@staticmethod
def checkout_SSP_template(config: dict, file_location: str) -> str:
"""
Check if the SSP template exists on disk, if not download it
from the given URL in the configuration dictionary.
Parameters
----------
config : dict
Configuration dictionary.
file_location : str
Location to save the template file.
Returns:
The path to the file as str.
"""
_logger = get_logger()
file_path = os.path.join(file_location, config["file_name"])
source = config["source"]
if not config["source"].endswith("/"):
source += "/"
if not os.path.exists(file_path):
_logger.info(
f'[SSPModels] File {file_path} not found. Downloading it from {config["source"]}'
)
try:
response = requests.get(source + config["file_name"])
except requests.exceptions.SSLError as SSLerr:
_logger.warning(f"[SSPModels] Error: {SSLerr}")
_logger.warning(
f"[SSPModels] Trying to download file {config['file_name']} from url {source} without SSL verification."
)
try:
response = requests.get(source + config["file_name"], verify=False)
except requests.exceptions.RequestException as err:
_logger.error(f"[SSPModels] Error: {err}")
# except requests.exceptions.HTTPError as errh:
# print("Http Error:",errh)
raise FileNotFoundError(
f"Could not download file {config['file_name']} from url {source}."
)
except requests.exceptions.RequestException as err:
_logger.error(f"[SSPModels] Error: {err}")
# except requests.exceptions.HTTPError as errh:
# print("Http Error:",errh)
raise FileNotFoundError(
f"Could not download file {config['file_name']} from url {source}."
)
response.raise_for_status()
if response.status_code == 200:
with open(file_path, "wb") as f:
f.write(response.content)
_logger.info(
f'[SSPModels] File {config["file_name"]} downloaded successfully!'
)
return file_path
else:
raise FileNotFoundError(
f"Could not download file {config['file_name']} from url {source}."
)
else:
return file_path
[docs]
@jaxtyped(typechecker=typechecker)
@classmethod
def from_file(cls, config: dict, file_location: str):
"""
Template function to load a SSP grid from a file.
Args:
config (dict): Configuration dictionary.
file_location (str): Location of the file.
Returns:
The SSP grid SSPGrid in the correct units.
"""
# Initialize an empty zero length array for each field
# in the SSP configuration.
# Actual loading of templates needs to be implemented in the subclasses.
ssp_data = {}
for field_name, field_info in config["fields"].items():
ssp_data[field_info["name"]] = jnp.empty(0)
grid = cls(**ssp_data)
grid.__class__.__name__ = config["name"]
return grid
[docs]
class HDF5SSPGrid(SSPGrid):
"""
Class for SSP models stored in HDF5 format.
Mainly used for custom collection of Bruzual & Charlot 2003 models and MILES models.
Example
-------
>>> config = {
... "name": "Bruzual & Charlot (2003)",
... "format": "HDF5",
... "source": "https://www.bruzual.org/bc03/",
... "file_name": "BC03lr.h5",
... "fields": {
... "age": {
... "name": "age",
... "units": "Gyr",
... "in_log": False
... },
... "metallicity": {
... "name": "metallicity",
... "units": "",
... "in_log": False
... },
... "wavelength": {
... "name": "wavelength",
... "units": "Angstrom",
... "in_log": False
... },
... "flux": {
... "name": "flux",
... "units": "Lsun/Angstrom",
... "in_log": False
... }
... }
... }
>>> from rubix.spectra.ssp.grid import HDF5SSPGrid
>>> ssp = HDF5SSPGrid.from_file(config, file_location="../rubix/spectra/ssp/templates")
>>> ssp.age.shape
>>> ssp.metallicity.shape
>>> ssp.wavelength.shape
>>> ssp.flux.shape
"""
# Do we need this again or is this taken care of by inheriting from SSPGrid?
age: Float[Array, " age_bins"]
metallicity: Float[Array, " metallicity_bins"]
wavelength: Float[Array, " wavelength_bins"]
flux: Float[Array, "metallicity_bins age_bins wavelength_bins"]
# This does not work with jax.jit, gives error that str is not valid Jax type
# units: Dict[str, str] = eqx.field(default_factory=dict)
def __init__(self, age, metallicity, wavelength, flux):
super().__init__(age, metallicity, wavelength, flux)
[docs]
@jaxtyped(typechecker=typechecker)
@classmethod
def from_file(cls, config: dict, file_location: str) -> SSPGrid:
"""
Load a SSP grid from a HDF5 file.
Args:
config (dict): Configuration dictionary.
Returns:
The SSP grid `SSPGrid` in the correct units.
"""
if config.get("format", "").lower() not in ["hdf5", "fsps"]:
raise ValueError("Configured file format is not HDF5.")
file_path = cls.checkout_SSP_template(config, file_location)
ssp_data = {}
with h5py.File(file_path, "r") as f:
for field_name, field_info in config["fields"].items():
data = f[field_info["name"]][:] # type: ignore
data = jnp.power(10, data) if field_info["in_log"] else data # type: ignore
data = jnp.array(data, dtype=jnp.float32)
data = cls.convert_units(
data, field_info["units"], SSP_UNITS[field_name]
)
ssp_data[field_name] = data
grid = cls(**ssp_data)
grid.__class__.__name__ = config["name"]
return grid
[docs]
class pyPipe3DSSPGrid(SSPGrid):
"""
Class for all SSP models supported by the pyPipe3D project.
See http://ifs.astroscu.unam.mx/pyPipe3D/templates/ for more information.
Example
-------
>>> config = {
... "name": "Mastar Charlot & Bruzual (2019)",
... "format": "pyPipe3D",
... "source": "https://ifs.astroscu.unam.mx/pyPipe3D/templates/",
... "file_name": "MaStar_CB19.slog_1_5.fits.gz",
... "fields": {
... "age": {
... "name": "age",
... "units": "Gyr",
... "in_log": False
... },
... "metallicity": {
... "name": "metallicity",
... "units": "",
... "in_log": False
... },
... "wavelength": {
... "name": "wavelength",
... "units": "Angstrom",
... "in_log": False
... },
... "flux": {
... "name": "flux",
... "units": "Lsun/Angstrom",
... "in_log": False
... }
... }
... }
>>> from rubix.spectra.ssp.grid import pyPipe3DSSPGrid
>>> ssp = pyPipe3DSSPGrid.from_file(config, file_location="../rubix/spectra/ssp/templates")
"""
age: Float[Array, " age_bins"]
metallicity: Float[Array, " metallicity_bins"]
wavelength: Float[Array, " wavelength_bins"]
flux: Float[Array, "metallicity_bins age_bins wavelength_bins"]
# This does not work with jax.jit, gives error that str is not valid Jax type
# units: Dict[str, str] = eqx.field(default_factory=dict)
def __init__(self, age, metallicity, wavelength, flux):
super().__init__(age, metallicity, wavelength, flux)
# @staticmethod
# def get_normalization_wavelength(header, wavelength, flux_models, n_models):
# """
# Search for the normalization wavelength at the FITS header.
# If the key WAVENORM does not exists in the header, sweeps all the
# models looking for the wavelengths where the flux is closer to 1,
# calculates the median of those wavelengths and returns it.
#
# TODO: defines a better normalization wavelength if it's not present
# in the header.
#
# adapted from https://github.com/reginasar/TNG_MaNGA_mocks/blob/3229dd47b441aef380ef7dbfdf110f39e5c5a77c/sin_ifu_clean.py#L1466
#
# Parameters
# ----------
# header : :class:`astropy.io.fits.header.Header`
# FITS header with spectral data.
#
# wavelength : array like, wavelength of the model SSPs.
#
# flux_models : array like, flux of the model SSPs.
#
# n_models : int, number of models in the SSP grid.
#
# Returns
# -------
# float
# The normalization wavelength.
# """
# try:
# wave_norm = header['WAVENORM']
# except Exception as ex:
# _closer = 1e-6
# probable_wavenorms = jnp.hstack([wavelength[(jnp.abs(flux_models[i] - 1) < _closer)]
# for i in range(n_models)])
# wave_norm = jnp.median(probable_wavenorms)
# print(f'[SSPModels] {ex}')
# print(f'[SSPModels] setting normalization wavelength to {wave_norm} A')
# return wave_norm
[docs]
@jaxtyped(typechecker=typechecker)
@staticmethod
def get_tZ_models(
header, n_models: int
) -> Tuple[Float[Array, "..."], Float[Array, "..."], Float[Array, "..."]]:
"""
Reads the values of age, metallicity and mass-to-light at the
normalization flux from the SSP models FITS file.
adapted from https://github.com/reginasar/TNG_MaNGA_mocks/blob/3229dd47b441aef380ef7dbfdf110f39e5c5a77c/sin_ifu_clean.py#L1466
Parameters
----------
header : :class:`astropy.io.fits.header.Header`
FITS header with spectral data.
n_models : int, number of models in the SSP grid.
Returns
-------
array like
Ages, in Gyr, in the sequence as they appear in FITS data.
array like
Metallicities in the sequence as they appear in FITS data.
array like
Mass-to-light value at the normalization wavelength.
"""
ages = jnp.zeros(n_models, dtype=jnp.float32)
Zs = jnp.zeros(n_models, dtype=jnp.float32)
mtol = jnp.zeros(n_models, dtype=jnp.float32)
for i in range(n_models):
mult = {"Gyr": 1, "Myr": 1 / 1000}
name_read_split = header[f"NAME{i}"].split("_")
# removes 'spec_ssp_' from the name
name_read_split = name_read_split[2:]
_age = name_read_split[0]
if "yr" in _age:
mult = mult[_age[-3:]] # Gyr or Myr
_age = _age[:-3]
else:
mult = 1 # Gyr
age = mult * jnp.float32(_age)
_Z = name_read_split[1].split(".")[0]
Z = jnp.float32(_Z.replace("z", "0."))
ages = ages.at[i].set(age)
Zs = Zs.at[i].set(Z)
if jnp.float32(header[f"NORM{i}"]) != 0:
mtol = mtol.at[i].set(1 / jnp.float32(header[f"NORM{i}"]))
else:
mtol = mtol.at[i].set(1)
return jnp.unique(ages), jnp.unique(Zs), mtol
[docs]
@jaxtyped(typechecker=typechecker)
@classmethod
def from_file(cls, config: dict, file_location: str) -> SSPGrid:
"""
Load a SSP grid from a fits file in pyPipe3D format.
Args:
config (dict): Configuration dictionary.
file_location (str): Location of the file.
Returns:
The SSP grid SSPGrid in the correct units.
"""
if config.get("format", "").lower() != "pypipe3d":
raise ValueError("Configured file format is not fits.")
file_path = cls.checkout_SSP_template(config, file_location)
ssp_data = {}
with fits.open(file_path) as f:
_header = f[0].header
# n_wave = _header['NAXIS1']
n_models = _header["NAXIS2"]
# pyPIPE3D uses the key WAVENORM to store the normalization wavelength
# not sure what this is actually used for in the end.
# Here we enable reading it, but we should make sure we understand what it is used for.
# normalization_wavelength = get_normalization_wavelength(_header, wavelength, flux_models, n_models)
ages, metallicities, m2l = cls.get_tZ_models(_header, n_models)
wavelength = cls.get_wavelength_from_header(_header)
# read in the flux of the models and multiply by the mass-to-light ratio to get the flux in Lsun/Msun
# see also eq. A1 here https://arxiv.org/pdf/1811.04856.pdf
template_flux = jnp.array(f[0].data, dtype=jnp.float32) / m2l[:, None]
# reshape and bring into the correct order of metallcity, age, wavelength
# to conform with the SSPGrid dataclass
flux_models = template_flux.reshape(
len(metallicities), len(ages), len(wavelength)
)
flux_models = template_flux.reshape(
len(metallicities), len(ages), len(wavelength)
)
for field_name, field_info in config["fields"].items():
if field_name == "flux":
data = flux_models
elif field_name == "wavelength":
data = wavelength
elif field_name == "age":
data = ages
elif field_name == "metallicity":
data = metallicities
else:
raise ValueError(f"Field {field_name} not recognized")
data = jnp.power(10, data) if field_info["in_log"] else data # type: ignore
data = cls.convert_units(
data, field_info["units"], SSP_UNITS[field_name]
)
ssp_data[field_name] = data
grid = cls(**ssp_data)
grid.__class__.__name__ = config["name"]
return grid
# TODO: build another class that handles eMILES, sMILES templates that are also used by the GECKOS survey.
# those will also have alpha enhancement and not only metallicity dependence. might need some changes to the
# interpolation function further down the pipeline...