Gradient based optimization (Adam)#
# NBVAL_SKIP
from jax import config
#config.update("jax_enable_x64", True)
#config.update('jax_num_cpu_devices', 2)
#NBVAL_SKIP
import os
# Tell XLA to fake 2 host CPU devices
#os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=3'
# Only make GPU 0 and GPU 1 visible to JAX:
#os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax
# Now JAX will list two CpuDevice entries
print(jax.devices())
# → [CpuDevice(id=0), CpuDevice(id=1)]
[CpuDevice(id=0)]
# NBVAL_SKIP
import os
#os.environ['SPS_HOME'] = '/mnt/storage/annalena_data/sps_fsps'
#os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'
os.environ['SPS_HOME'] = '/Users/annalena/Documents/GitHub/fsps'
#os.environ['SPS_HOME'] = '/export/home/aschaibl/fsps'
Load ssp template from FSPS#
# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")
2025-11-11 10:14:51,323 - rubix - INFO -
___ __ _____ _____ __
/ _ \/ / / / _ )/ _/ |/_/
/ , _/ /_/ / _ |/ /_> <
/_/|_|\____/____/___/_/|_|
2025-11-11 10:14:51,324 - rubix - INFO - Rubix version: 0.0.post507+g27646941c
2025-11-11 10:14:51,325 - rubix - INFO - JAX version: 0.4.38
2025-11-11 10:14:51,325 - rubix - INFO - Running on [CpuDevice(id=0)] devices
2025-11-11 10:14:51,326 - rubix - WARNING - python-fsps is not installed. Please install it to use this function. Install using pip install fsps and check the installation page: https://dfm.io/python-fsps/current/installation/ for more details. Especially, make sure to set all necessary environment variables.
# NBVAL_SKIP
age_values = ssp_fsps.age
print(age_values.shape)
metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)
(107,)
(12,)
# NBVAL_SKIP
index_age = 90
index_metallicity = 9
#initial_metallicity_index = 5
#initial_age_index = 70
initial_metallicity_index = 10
initial_age_index = 104
initial_age_index2 = 90
initial_metallicity_index2 = 6
initial_age_index3 = 99
initial_metallicity_index3 = 11
learning_all = 5e-3
tol = 1e-10
print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")
start age: 15.848933219909668, start metallicity: 0.025251565501093864
target age: 3.1622776985168457, target metallicity: 0.014199999161064625
Configure pipeline#
# NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline
import os
config = {
"pipeline":{"name": "calc_gradient",},
"logger": {
"log_level": "DEBUG",
"log_file_path": None,
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
"data": {
"name": "IllustrisAPI",
"args": {
"api_key": os.environ.get("ILLUSTRIS_API_KEY"),
"particle_type": ["stars"],
"simulation": "TNG50-1",
"snapshot": 99,
"save_data_path": "data",
},
"load_galaxy_args": {
"id": 14,
"reuse": True,
},
"subset": {
"use_subset": True,
"subset_size": 2,
},
},
"simulation": {
"name": "IllustrisTNG",
"args": {
"path": "data/galaxy-id-14.hdf5",
},
},
"output_path": "output",
"telescope":
{"name": "TESTGRADIENT",
"psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
"lsf": {"sigma": 1.2},
"noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
},
"cosmology":
{"name": "PLANCK15"},
"galaxy":
{"dist_z": 0.1,
"rotation": {"type": "edge-on"},
},
"ssp": {
"template": {
"name": "FSPS"
},
"dust": {
"extinction_model": "Cardelli89",
"dust_to_gas_ratio": 0.01,
"dust_to_metals_ratio": 0.4,
"dust_grain_density": 3.5,
"Rv": 3.1,
},
},
}
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()
output = pipe.run_sharded(inputdata)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:52,007 - rubix - INFO - Getting rubix data...
2025-11-11 10:14:52,008 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:188: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
return asarray(x, dtype=self.dtype)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/core/data.py:537: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float64)
2025-11-11 10:14:52,081 - rubix - INFO - Centering stars particles
2025-11-11 10:14:53,068 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars
2025-11-11 10:14:53,071 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.
2025-11-11 10:14:53,072 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:14:53,073 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:14:53,074 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:14:53,078 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:53,102 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:53,278 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:14:53,288 - rubix - INFO - Getting cosmology...
2025-11-11 10:14:53,341 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:53,447 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:53,465 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:53,533 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:53,742 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:14:53,743 - rubix - INFO - Compiling the expressions...
2025-11-11 10:14:53,744 - rubix - INFO - Number of devices: 1
2025-11-11 10:14:53,813 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:14:53,814 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:14:53,814 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:14:53,904 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:14:53,921 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:14:54,059 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:14:54,060 - rubix - INFO - Convolving with PSF...
2025-11-11 10:14:54,063 - rubix - INFO - Convolving with LSF...
2025-11-11 10:14:54,068 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:14:57,373 - rubix - INFO - Pipeline run completed in 4.30 seconds.
Set target values#
# NBVAL_SKIP
import jax.numpy as jnp
inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
targetdata = pipe.run_sharded(inputdata)
2025-11-11 10:14:57,494 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:14:57,495 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:14:57,496 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:14:57,498 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:57,511 - rubix - INFO - Getting cosmology...
2025-11-11 10:14:57,523 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:14:57,533 - rubix - INFO - Getting cosmology...
2025-11-11 10:14:57,577 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:57,617 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:57,630 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:57,684 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:14:57,731 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:14:57,732 - rubix - INFO - Compiling the expressions...
2025-11-11 10:14:57,733 - rubix - INFO - Number of devices: 1
2025-11-11 10:14:57,809 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:14:57,810 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:14:57,810 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:14:57,871 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:14:57,873 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:14:57,884 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:14:57,884 - rubix - INFO - Convolving with PSF...
2025-11-11 10:14:57,887 - rubix - INFO - Convolving with LSF...
2025-11-11 10:14:57,889 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:15:00,996 - rubix - INFO - Pipeline run completed in 3.50 seconds.
# NBVAL_SKIP
print(targetdata[0,0,:].shape)
(466,)
Set initial datracube#
# NBVAL_SKIP
inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
initialdata = pipe.run_sharded(inputdata)
2025-11-11 10:15:01,085 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:15:01,086 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:15:01,086 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:15:01,088 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:01,100 - rubix - INFO - Getting cosmology...
2025-11-11 10:15:01,110 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:15:01,119 - rubix - INFO - Getting cosmology...
2025-11-11 10:15:01,152 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:01,183 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:01,195 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:01,258 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:01,317 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:15:01,318 - rubix - INFO - Compiling the expressions...
2025-11-11 10:15:01,319 - rubix - INFO - Number of devices: 1
2025-11-11 10:15:01,395 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:15:01,396 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:15:01,397 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:15:01,456 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:15:01,458 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:15:01,468 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:15:01,469 - rubix - INFO - Convolving with PSF...
2025-11-11 10:15:01,471 - rubix - INFO - Convolving with LSF...
2025-11-11 10:15:01,474 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:15:04,414 - rubix - INFO - Pipeline run completed in 3.33 seconds.
Adam optimizer#
# NBVAL_SKIP
from rubix.pipeline import linear_pipeline as pipeline
pipeline_instance = RubixPipeline(config)
pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
pipeline_instance.pipeline_config,
pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:04,493 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:15:04,493 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:15:04,494 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:15:04,496 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:04,505 - rubix - INFO - Getting cosmology...
2025-11-11 10:15:04,516 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:15:04,525 - rubix - INFO - Getting cosmology...
2025-11-11 10:15:04,569 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:04,602 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:04,615 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:15:04,654 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
# NBVAL_SKIP
import optax
def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
base_data.stars.age = age*20
base_data.stars.metallicity = metallicity*0.05
output = pipeline_instance.func(base_data)
#loss = jnp.sum((output.stars.datacube - target) ** 2)
#loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))
#loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))
loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import optax
def adam_optimization_multi(loss_fn, params_init, data, target, learning=learning_all, tol=1e-3, max_iter=500):
"""
Optimizes both age and metallicity.
Args:
loss_fn: function with signature loss_fn(age, metallicity, data, target)
params_init: dict with keys 'age' and 'metallicity', each a JAX array
data: base data for the loss function
target: target data for the loss function
learning_rate: learning rate for Adam
tol: tolerance for convergence (based on update norm)
max_iter: maximum number of iterations
Returns:
params: final parameters (dict)
params_history: list of parameter values for each iteration
loss_history: list of loss values for each iteration
"""
params = params_init # e.g., {'age': jnp.array(...), 'metallicity': jnp.array(...)}
optimizers = {
'age': optax.adam(learning),
'metallicity': optax.adam(learning)
}
# Create a parameter label pytree matching the structure of params
param_labels = {'age': 'age', 'metallicity': 'metallicity'}
# Combine the optimizers with multi_transform
optimizer = optax.multi_transform(optimizers, param_labels)
optimizer_state = optimizer.init(params)
age_history = []
metallicity_history = []
loss_history = []
for i in range(max_iter):
# Compute loss and gradients with respect to both parameters
loss, grads = jax.value_and_grad(lambda p: loss_fn(p['age'], p['metallicity'], data, target))(params)
loss_history.append(float(loss))
# Save current parameters (convert from JAX arrays to floats)
age_history.append(float(params['age'][0]))
metallicity_history.append(float(params['metallicity'][0]))
#params_history.append({
# 'age': params['age'],
# 'metallicity': params['metallicity']
#})
# Compute updates and apply them
updates, optimizer_state = optimizer.update(grads, optimizer_state)
params = optax.apply_updates(params, updates)
# Optionally clip the parameters to enforce physical constraints:
#params['age'] = jnp.clip(params['age'], 0.0, 1.0)
#params['metallicity'] = jnp.clip(params['metallicity'], 0.0, 1.0)
# For metallicity, uncomment and adjust the limits as needed:
# params['metallicity'] = jnp.clip(params['metallicity'], metallicity_lower_bound, metallicity_upper_bound)
# Check convergence based on the global norm of updates
if optax.global_norm(updates) < tol:
print(f"Converged at iteration {i}")
break
return params, age_history, metallicity_history, loss_history
# NBVAL_SKIP
loss_only_wrt_age_metallicity(inputdata.stars.age, inputdata.stars.metallicity, inputdata, targetdata)
2025-11-11 10:15:04,803 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:15:04,804 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:15:04,804 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:15:04,888 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:15:04,903 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:15:05,278 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:15:05,279 - rubix - INFO - Convolving with PSF...
2025-11-11 10:15:05,282 - rubix - INFO - Convolving with LSF...
2025-11-11 10:15:05,287 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
Array(nan, dtype=float32)
# NBVAL_SKIP
data = inputdata # Replace with your actual data if needed
target_value = targetdata # Replace with your actual target
# Define initial guesses for both age and metallicity.
# Adjust the initialization as needed for your problem.
age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])
metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])
# Pack both initial parameters into a dictionary.
params_init = {'age': age_init, 'metallicity': metallicity_init}
print(f"Initial parameters: {params_init}")
# Call the new optimizer function that handles both parameters.
optimized_params, age_history, metallicity_history, loss_history = adam_optimization_multi(
loss_only_wrt_age_metallicity,
params_init,
data,
target_value,
learning=learning_all,
tol=tol,
max_iter=5000,
)
print(f"Optimized Age: {optimized_params['age']}")
print(f"Optimized Metallicity: {optimized_params['metallicity']}")
Initial parameters: {'age': Array([0.7924467, 0.7924467], dtype=float32), 'metallicity': Array([0.5050313, 0.5050313], dtype=float32)}
Optimized Age: [nan nan]
Optimized Metallicity: [nan nan]
# NBVAL_SKIP
inputdata2 = pipe.prepare_data()
inputdata2.stars.age = jnp.array([age_values[initial_age_index2], age_values[initial_age_index2]])
inputdata2.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index2], metallicity_values[initial_metallicity_index2]])
inputdata2.stars.mass = jnp.array([[1.0], [1.0]])
inputdata2.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata2.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
initialdata2 = pipe.run_sharded(inputdata2)
data2 = inputdata2 # Replace with your actual data if needed
target_value = targetdata # Replace with your actual target
# Define initial guesses for both age and metallicity.
# Adjust the initialization as needed for your problem.
age_init2 = jnp.array([age_values[initial_age_index2]/20, age_values[initial_age_index2]/20])
metallicity_init2 = jnp.array([metallicity_values[initial_metallicity_index2]/0.05, metallicity_values[initial_metallicity_index2]/0.05])
# Pack both initial parameters into a dictionary.
params_init2 = {'age': age_init2, 'metallicity': metallicity_init2}
print(f"Initial parameters: {params_init2}")
# Call the new optimizer function that handles both parameters.
optimized_params2, age_history2, metallicity_history2, loss_history2 = adam_optimization_multi(
loss_only_wrt_age_metallicity,
params_init2,
data2,
target_value,
learning=learning_all,
tol=tol,
max_iter=5000,
)
2025-11-11 10:17:53,403 - rubix - INFO - Getting rubix data...
2025-11-11 10:17:53,405 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:188: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
return asarray(x, dtype=self.dtype)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/core/data.py:537: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float64)
2025-11-11 10:17:53,446 - rubix - INFO - Centering stars particles
2025-11-11 10:17:53,785 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars
2025-11-11 10:17:53,786 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.
2025-11-11 10:17:53,801 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:17:53,802 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:17:53,803 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:17:53,806 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:17:53,840 - rubix - INFO - Getting cosmology...
2025-11-11 10:17:53,855 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:17:53,865 - rubix - INFO - Getting cosmology...
2025-11-11 10:17:53,906 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:17:53,960 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:17:53,977 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:17:54,039 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:17:54,094 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:17:54,095 - rubix - INFO - Compiling the expressions...
2025-11-11 10:17:54,096 - rubix - INFO - Number of devices: 1
2025-11-11 10:17:54,177 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:17:54,177 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:17:54,178 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:17:54,236 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:17:54,238 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:17:54,251 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:17:54,251 - rubix - INFO - Convolving with PSF...
2025-11-11 10:17:54,253 - rubix - INFO - Convolving with LSF...
2025-11-11 10:17:54,256 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:17:57,635 - rubix - INFO - Pipeline run completed in 3.83 seconds.
Initial parameters: {'age': Array([0.15811388, 0.15811388], dtype=float32), 'metallicity': Array([0.05050315, 0.05050315], dtype=float32)}
#NBVAL_SKIP
inputdata3 = pipe.prepare_data()
inputdata3.stars.age = jnp.array([age_values[initial_age_index3], age_values[initial_age_index3]])
inputdata3.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index3], metallicity_values[initial_metallicity_index3]])
inputdata3.stars.mass = jnp.array([[1.0], [1.0]])
inputdata3.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata3.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
initialdata3 = pipe.run_sharded(inputdata3)
data3 = inputdata3 # Replace with your actual data if needed
target_value = targetdata # Replace with your actual target
age_init3 = jnp.array([age_values[initial_age_index3]/20, age_values[initial_age_index3]/20])
metallicity_init3 = jnp.array([metallicity_values[initial_metallicity_index3]/0.05, metallicity_values[initial_metallicity_index3]/0.05])
params_init3 = {'age': age_init3, 'metallicity': metallicity_init3}
print(f"Initial parameters: {params_init3}")
optimized_params3, age_history3, metallicity_history3, loss_history3 = adam_optimization_multi(
loss_only_wrt_age_metallicity,
params_init3,
data3,
target_value,
learning=learning_all,
tol=tol,
max_iter=5000,
)
2025-11-11 10:20:28,016 - rubix - INFO - Getting rubix data...
2025-11-11 10:20:28,019 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:188: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
return asarray(x, dtype=self.dtype)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/core/data.py:537: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float64)
2025-11-11 10:20:28,056 - rubix - INFO - Centering stars particles
2025-11-11 10:20:28,372 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars
2025-11-11 10:20:28,373 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.
2025-11-11 10:20:28,386 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:20:28,387 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:20:28,388 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:20:28,390 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:20:28,424 - rubix - INFO - Getting cosmology...
2025-11-11 10:20:28,440 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:20:28,450 - rubix - INFO - Getting cosmology...
2025-11-11 10:20:28,537 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:20:28,594 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:20:28,606 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:20:28,674 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:20:28,731 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:20:28,733 - rubix - INFO - Compiling the expressions...
2025-11-11 10:20:28,734 - rubix - INFO - Number of devices: 1
2025-11-11 10:20:28,815 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:20:28,816 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:20:28,817 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:20:28,875 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:20:28,877 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:20:28,889 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:20:28,890 - rubix - INFO - Convolving with PSF...
2025-11-11 10:20:28,892 - rubix - INFO - Convolving with LSF...
2025-11-11 10:20:28,894 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:20:32,060 - rubix - INFO - Pipeline run completed in 3.67 seconds.
Initial parameters: {'age': Array([0.44562545, 0.44562545], dtype=float32), 'metallicity': Array([0.8980868, 0.8980868], dtype=float32)}
Loss history#
# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
# Configure matplotlib to use LaTeX for all text
#mpl.rcParams.update({
# "text.usetex": True, # Use LaTeX for text rendering
# "font.family": "serif", # Use serif fonts
# Here "txfonts" is not directly available as a font in matplotlib,
# but you can set the serif list to a font that closely resembles it.
# Alternatively, you can try using:
# "font.serif": ["Times", "Palatino", "New Century Schoolbook"],
# "font.size": 16, # Set the base font size (adjust to match your document)
# "text.latex.preamble": r"\usepackage{txfonts}", # Use txfonts to match your Overleaf document
#})
# Convert histories to NumPy arrays if needed
loss_history_np = np.array(loss_history)
age_history_np = np.array(age_history)
metallicity_history_np = np.array(metallicity_history)
# Create an x-axis based on the number of iterations (assumed same for all)
iterations = np.arange(len(loss_history_np))
print(f"Number of iterations: {len(iterations)}")
# Create a figure with three subplots in one row and shared x-axis.
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharex=True)
# Plot the loss history (convert log-loss back to loss if needed)
axs[0].plot(iterations, 10**loss_history_np, marker='o', linestyle='-')
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[0].set_title("Loss History")
axs[0].grid(True)
# Plot the age history, multiplying by 20 for the physical scale.
axs[1].plot(iterations, age_history_np * 20, marker='o', linestyle='-')
# Draw a horizontal line for the target age
axs[1].hlines(y=age_values[index_age], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[1].set_xlabel("Iteration")
axs[1].set_ylabel("Age")
axs[1].set_title("Age History")
axs[1].grid(True)
# Plot the metallicity history, multiplying by 0.05 for the physical scale.
axs[2].plot(iterations, metallicity_history_np *0.05, marker='o', linestyle='-')
# Draw a horizontal line for the target metallicity
axs[2].hlines(y=metallicity_values[index_metallicity], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[2].set_xlabel("Iteration")
axs[2].set_ylabel("Metallicity")
axs[2].set_title("Metallicity History")
axs[2].grid(True)
axs[0].set_xlim(-5, 900)
axs[1].set_xlim(-5, 900)
axs[2].set_xlim(-5, 900)
plt.tight_layout()
plt.savefig(f"output/optimisation_history.jpg", dpi=1000)
plt.show()
Number of iterations: 5000
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 0
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)
#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq
spectra_target = targetdata
spectra_optimitzed = rubixdata
print(rubixdata.shape)
plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:09,750 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:23:09,751 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:23:09,752 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:23:09,753 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:09,764 - rubix - INFO - Getting cosmology...
2025-11-11 10:23:09,779 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:23:09,789 - rubix - INFO - Getting cosmology...
2025-11-11 10:23:09,842 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:09,890 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:09,902 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:09,954 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:10,022 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:23:10,023 - rubix - INFO - Compiling the expressions...
2025-11-11 10:23:10,024 - rubix - INFO - Number of devices: 1
2025-11-11 10:23:10,108 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:23:10,109 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:23:10,109 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:23:10,168 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:23:10,170 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:23:10,182 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:23:10,183 - rubix - INFO - Convolving with PSF...
2025-11-11 10:23:10,185 - rubix - INFO - Convolving with LSF...
2025-11-11 10:23:10,188 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:23:13,509 - rubix - INFO - Pipeline run completed in 3.76 seconds.
(1, 1, 466)
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 850
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)
#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq
spectra_target = targetdata #.stars.datacube
spectra_optimitzed = rubixdata #.stars.datacube
plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:13,726 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:23:13,728 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:23:13,729 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:23:13,730 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:13,740 - rubix - INFO - Getting cosmology...
2025-11-11 10:23:13,750 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:23:13,759 - rubix - INFO - Getting cosmology...
2025-11-11 10:23:13,789 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:13,848 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:13,866 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:13,916 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:23:13,982 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:23:13,982 - rubix - INFO - Compiling the expressions...
2025-11-11 10:23:13,983 - rubix - INFO - Number of devices: 1
2025-11-11 10:23:14,061 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:23:14,062 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:23:14,062 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:23:14,119 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:23:14,121 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:23:14,134 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:23:14,135 - rubix - INFO - Convolving with PSF...
2025-11-11 10:23:14,137 - rubix - INFO - Convolving with LSF...
2025-11-11 10:23:14,139 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:23:17,643 - rubix - INFO - Pipeline run completed in 3.92 seconds.
# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt
# Create a figure with two subplots, sharing the x-axis.
fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True, gridspec_kw={'height_ratios': [4, 1]}, figsize=(7, 5))
# Plot target and optimized spectra in the upper subplot.
ax1.plot(wave, spectra_target[0, 0, :], label=f"Target age = {age_values[index_age]:.2f}, metallicity = {metallicity_values[index_metallicity]:.4f}")
ax1.plot(wave, spectra_optimitzed[0, 0, :], label=f"Optimized age = {age_history[i]*20:.2f}, metallicity = {metallicity_history[i]*0.05:.4f}")
ax1.set_ylabel("Luminosity [L/Å]")
#ax1.set_title("Target vs Optimized Spectra")
ax1.legend()
ax1.grid(True)
# Compute the residual (difference between target and optimized spectra).
residual = (spectra_target[0, 0, :] - spectra_optimitzed[0, 0, :]) #/spectra_target[0, 0, :]
# Plot the residual in the lower subplot.
ax2.plot(wave, residual, 'k-')
ax2.set_xlabel("Wavelength [Å]")
ax2.set_ylabel("Residual")
ax2.grid(True)
plt.tight_layout()
plt.savefig(f"output/optimisation_spectra.jpg", dpi=1000)
plt.show()
Calculate loss landscape#
# NBVAL_SKIP
import optax
def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
# Create 2D arrays for age and metallicity.
# For example, if there are two stars, you might do:
base_data.stars.age = jnp.array([age*20, age*20])
base_data.stars.metallicity = jnp.array([metallicity*0.05, metallicity*0.05])
output = pipeline_instance.func(base_data)
#loss = jnp.sum((output.stars.datacube - target) ** 2)
loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
return loss
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Number of grid points
num_steps = 100
# Define physical ranges
physical_ages = jnp.linspace(0, 1, num_steps) # Age from 0 to 10
physical_metals = jnp.linspace(0, 1, num_steps) # Metallicity from 1e-4 to 0.05
# Use nested vmap to compute the loss at every grid point.
# Note: loss_only_wrt_age_metallicity takes physical values directly.
#vectorized_loss = jax.vmap(
# lambda age: jax.vmap(
# lambda metal: loss_only_wrt_age_metallicity(age, metal, inputdata, targetdata)
# )(physical_metals)
#)(physical_ages)
# Convert the result to a NumPy array for plotting
#loss_map = jnp.array(vectorized_loss)
loss_map = []
for age in physical_ages:
row = []
for metal in physical_metals:
loss = loss_only_wrt_age_metallicity(age, metal, inputdata, targetdata)
row.append(loss)
loss_map.append(jnp.stack(row))
loss_map = jnp.stack(loss_map)
# NBVAL_SKIP
# Plot the loss landscape using imshow.
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.figure(figsize=(5, 4))
plt.imshow(loss_map, origin='lower', extent=[0,1,0,1], aspect='auto', norm=colors.LogNorm())#, vmin=-3.5, vmax=-2.5)#extent=[1e-4, 0.05, 0, 10]
plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')
# Plot a red dot at the desired coordinates.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)
#plt.plot(metallicity_history[::100], age_history[::100], 'bx', markersize=8)
plt.plot(metallicity_values[index_metallicity]/0.05, age_values[index_age]/20, 'ro', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index]/0.05, age_values[initial_age_index]/20, 'ro', markersize=8)
plt.savefig(f"output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()
# NBVAL_SKIP
metallicity_history = np.array(metallicity_history)*0.05
age_history = np.array(age_history)*20
metallicity_history2 = np.array(metallicity_history2)*0.05
age_history2 = np.array(age_history2)*20
metallicity_history3 = np.array(metallicity_history3)*0.05
age_history3 = np.array(age_history3)*20
# NBVAL_SKIP
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.figure(figsize=(6, 5))
# Update the extent to the physical values: metallicity from 0 to 0.05 and age from 0 to 20.
plt.imshow(loss_map, origin='lower', extent=[0, 0.05, 0, 20], aspect='auto', norm=colors.LogNorm())
plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')
# Plot the history in physical coordinates by multiplying the normalized values.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)
plt.plot(metallicity_history2[:], age_history2[:])#, 'gx', markersize=8
plt.plot(metallicity_history3[:], age_history3[:])#, 'mx', markersize=8)
# Plot the red dots in physical coordinates
plt.plot(metallicity_values[index_metallicity], age_values[index_age], marker='*', color='yellow', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index], age_values[initial_age_index], 'wo', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index2], age_values[initial_age_index2], 'wo', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index3], age_values[initial_age_index3], 'wo', markersize=8)
plt.savefig("output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()
#NBVAL_SKIP
# plot loss history for all three runs
loss_history_np = np.array(loss_history)
loss_history2 = np.array(loss_history2)
loss_history3 = np.array(loss_history3)
iterations = np.arange(len(loss_history_np))
plt.figure(figsize=(6, 4))
plt.plot(iterations, loss_history_np, label='Run 1')
plt.plot(iterations, loss_history2, label='Run 2')
plt.plot(iterations, loss_history3, label='Run 3')
#plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('log(Loss)')
plt.title('Loss History for Three Runs')
plt.legend()
plt.grid(True)
plt.savefig("output/optimisation_loglosshistory.jpg", dpi=1000)
plt.show()
#NBVAL_SKIP
# plot loss history for all three runs
loss_history_np = np.array(loss_history)
loss_history2 = np.array(loss_history2)
loss_history3 = np.array(loss_history3)
iterations = np.arange(len(loss_history_np))
plt.figure(figsize=(6, 4))
plt.plot(iterations, 10**loss_history_np, label='Run 1')
plt.plot(iterations, 10**loss_history2, label='Run 2')
plt.plot(iterations, 10**loss_history3, label='Run 3')
#plt.yscale('log')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Loss History for Three Runs')
plt.legend()
plt.grid(True)
plt.savefig("output/optimisation_losshistory.jpg", dpi=1000)
plt.show()
#NBVAL_SKIP
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
# Prepare loss histories
loss_history_np = np.array(loss_history)
loss_history2 = np.array(loss_history2)
loss_history3 = np.array(loss_history3)
iterations = np.arange(len(loss_history_np))
fig, axs = plt.subplots(1, 2, figsize=(8, 3))
# --- Left: Loss Landscape ---
im = axs[0].imshow(
loss_map,
origin='lower',
extent=[0, 0.05, 0, 20],
aspect='auto',
norm=colors.LogNorm()
)
axs[0].set_xlabel('Metallicity')
axs[0].set_ylabel('Age (Gyrs)')
axs[0].set_xlim(0, 0.045)
#axs[0].set_title('Loss Landscape')
fig.colorbar(im, ax=axs[0], label='log(loss)')
# Plot the history in physical coordinates
axs[0].plot(metallicity_history[:], age_history[:], color='orange')
axs[0].plot(metallicity_history2[:], age_history2[:], color='purple')
axs[0].plot(metallicity_history3[:], age_history3[:], color='red')
# Plot the red dots in physical coordinates
axs[0].plot(metallicity_values[index_metallicity], age_values[index_age], marker='*', color='yellow', markersize=8)
axs[0].plot(metallicity_values[initial_metallicity_index], age_values[initial_age_index], 'wo', markersize=8)
axs[0].plot(metallicity_values[initial_metallicity_index2], age_values[initial_age_index2], 'wo', markersize=8)
axs[0].plot(metallicity_values[initial_metallicity_index3], age_values[initial_age_index3], 'wo', markersize=8)
# --- Right: Loss History ---
axs[1].plot(iterations, 10**loss_history_np, label='Run 1', color='orange')
axs[1].plot(iterations, 10**loss_history2, label='Run 2', color='purple')
axs[1].plot(iterations, 10**loss_history3, label='Run 3', color='red')
axs[1].set_xlabel('Iteration')
axs[1].set_ylabel('Loss')
#axs[1].set_title('Loss History for Three Runs')
axs[1].legend()
axs[1].grid(True)
plt.tight_layout()
plt.savefig("output/optimisation_landscape_and_history.jpg", dpi=1000)
plt.show()
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 200
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)
#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq
spectra_target = targetdata
spectra_optimitzed = rubixdata
print(rubixdata.shape)
# Create a figure with two subplots, sharing the x-axis.
fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True, gridspec_kw={'height_ratios': [4, 1]}, figsize=(7, 5))
# Plot target and optimized spectra in the upper subplot.
ax1.plot(wave, spectra_target[0, 0, :], label=f"Target age = {age_values[index_age]:.2f}, metallicity = {metallicity_values[index_metallicity]:.4f}")
ax1.plot(wave, spectra_optimitzed[0, 0, :], label=f"Optimized age = {age_history[i]*20:.2f}, metallicity = {metallicity_history[i]*0.05:.4f}")
ax1.set_ylabel("Luminosity [L/Å]")
#ax1.set_title("Target vs Optimized Spectra")
ax1.legend()
ax1.grid(True)
# Compute the residual (difference between target and optimized spectra).
residual = (spectra_target[0, 0, :] - spectra_optimitzed[0, 0, :]) #/spectra_target[0, 0, :]
# Plot the residual in the lower subplot.
ax2.plot(wave, residual, 'k-')
ax2.set_xlabel("Wavelength [Å]")
ax2.set_ylabel("Residual")
ax2.grid(True)
plt.tight_layout()
plt.savefig(f"output/optimisation_spectra.jpg", dpi=1000)
plt.show()
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:28:16,917 - rubix - INFO - Setting up the pipeline...
2025-11-11 10:28:16,918 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-11 10:28:16,919 - rubix - DEBUG - Roataion Type found: edge-on
2025-11-11 10:28:16,920 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:28:16,930 - rubix - INFO - Getting cosmology...
2025-11-11 10:28:16,940 - rubix - INFO - Calculating spatial bin edges...
2025-11-11 10:28:16,949 - rubix - INFO - Getting cosmology...
2025-11-11 10:28:16,981 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:28:17,040 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:28:17,060 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:28:17,114 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubixcpu2/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-11-11 10:28:17,176 - rubix - INFO - Assembling the pipeline...
2025-11-11 10:28:17,177 - rubix - INFO - Compiling the expressions...
2025-11-11 10:28:17,178 - rubix - INFO - Number of devices: 1
2025-11-11 10:28:17,251 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-11 10:28:17,252 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-11 10:28:17,252 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-11 10:28:17,309 - rubix - INFO - Assigning particles to spaxels...
2025-11-11 10:28:17,311 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-11 10:28:17,322 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-11 10:28:17,323 - rubix - INFO - Convolving with PSF...
2025-11-11 10:28:17,325 - rubix - INFO - Convolving with LSF...
2025-11-11 10:28:17,328 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-11 10:28:20,803 - rubix - INFO - Pipeline run completed in 3.89 seconds.
(1, 1, 466)