Source code for rubix.core.pipeline
import time
from typing import Union
import jax
import jax.numpy as jnp
import sys
from rubix.logger import get_logger
from rubix.pipeline import linear_pipeline as pipeline
from rubix.pipeline import transformer as transformer
from rubix.utils import get_config, get_pipeline_config
from .data import get_reshape_data, get_rubix_data
from .ifu import (
get_calculate_spectra,
get_doppler_shift_and_resampling,
get_scale_spectrum_by_mass,
get_calculate_datacube,
)
from .rotation import get_galaxy_rotation
from .ssp import get_ssp
from .telescope import get_spaxel_assignment, get_telescope, get_filter_particles
from .psf import get_convolve_psf
from .lsf import get_convolve_lsf
from .noise import get_apply_noise
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker
[docs]
class RubixPipeline:
"""
RubixPipeline is responsible for setting up and running the data processing pipeline.
Args:
user_config (dict or str): Parsed user configuration for the pipeline.
pipeline_config (dict): Configuration for the pipeline.
logger(Logger) : Logger instance for logging messages.
ssp(object) : Stellar population synthesis model.
telescope(object) : Telescope configuration.
data (dict): Dictionary containing particle data.
func (callable): Compiled pipeline function to process data.
Example
--------
>>> from rubix.core.pipeline import RubixPipeline
>>> config = "path/to/config.yml"
>>> pipeline = RubixPipeline(config)
>>> output = pipeline.run()
>>> ssp_model = pipeline.ssp
>>> telescope = pipeline.telescope
"""
def __init__(self, user_config: Union[dict, str]):
self.user_config = get_config(user_config)
self.pipeline_config = get_pipeline_config(self.user_config["pipeline"]["name"])
self.logger = get_logger(self.user_config["logger"])
self.ssp = get_ssp(self.user_config)
self.telescope = get_telescope(self.user_config)
self.data = self._prepare_data()
self.func = None
def _prepare_data(self):
"""
Prepares and loads the data for the pipeline.
Returns:
Dictionary containing particle data with keys:
'n_particles', 'coords', 'velocities', 'metallicity', 'mass', and 'age'.
"""
# Get the data
self.logger.info("Getting rubix data...")
rubixdata = get_rubix_data(self.user_config)
star_count = (
len(rubixdata.stars.coords) if rubixdata.stars.coords is not None else 0
)
gas_count = len(rubixdata.gas.coords) if rubixdata.gas.coords is not None else 0
self.logger.info(
f"Data loaded with {star_count} star particles and {gas_count} gas particles."
)
self.logger.info(f"Data loaded with {sys.getsizeof(rubixdata)} properties.")
# Setup the data dictionary
# TODO: This is a temporary solution, we need to figure out a better way to handle the data
# This works, because JAX can trace through the data dictionary
# Other option may be named tuples or data classes to have fixed keys
self.logger.debug("Data: %s", rubixdata)
# self.logger.debug(
# "Data Shape: %s",
# {k: v.shape for k, v in rubixdata.items() if hasattr(v, "shape")},
# )
return rubixdata
@jaxtyped(typechecker=typechecker)
def _get_pipeline_functions(self) -> list:
"""
Sets up the pipeline functions.
Returns:
List of functions to be used in the pipeline.
"""
self.logger.info("Setting up the pipeline...")
self.logger.debug("Pipeline Configuration: %s", self.pipeline_config)
# TODO: maybe there is a nicer way to load the functions from the yaml config?
rotate_galaxy = get_galaxy_rotation(self.user_config)
filter_particles = get_filter_particles(self.user_config)
spaxel_assignment = get_spaxel_assignment(self.user_config)
calculate_spectra = get_calculate_spectra(self.user_config)
reshape_data = get_reshape_data(self.user_config)
scale_spectrum_by_mass = get_scale_spectrum_by_mass(self.user_config)
doppler_shift_and_resampling = get_doppler_shift_and_resampling(
self.user_config
)
calculate_datacube = get_calculate_datacube(self.user_config)
convolve_psf = get_convolve_psf(self.user_config)
convolve_lsf = get_convolve_lsf(self.user_config)
apply_noise = get_apply_noise(self.user_config)
functions = [
rotate_galaxy,
filter_particles,
spaxel_assignment,
calculate_spectra,
reshape_data,
scale_spectrum_by_mass,
doppler_shift_and_resampling,
calculate_datacube,
convolve_psf,
convolve_lsf,
apply_noise,
]
return functions
# TODO: currently returns dict, but later should return only the IFU cube
[docs]
def run(self):
"""
Runs the data processing pipeline.
Returns
-------
dict
Output of the pipeline after processing the input data.
"""
# Create the pipeline
time_start = time.time()
functions = self._get_pipeline_functions()
self._pipeline = pipeline.LinearTransformerPipeline(
self.pipeline_config, functions
)
# Assembling the pipeline
self.logger.info("Assembling the pipeline...")
self._pipeline.assemble()
# Compiling the expressions
self.logger.info("Compiling the expressions...")
self.func = self._pipeline.compile_expression()
# Running the pipeline
self.logger.info("Running the pipeline on the input data...")
output = self.func(self.data)
jax.block_until_ready(output)
time_end = time.time()
self.logger.info(
"Pipeline run completed in %.2f seconds.", time_end - time_start
)
return output
# TODO: implement gradient calculation
[docs]
def gradient(self):
"""
This function will calculate the gradient of the pipeline, but is yet not implemented.
"""
raise NotImplementedError("Gradient calculation is not implemented yet")