Gradient through RUBIX#

# NBVAL_SKIP
from jax import config
#config.update("jax_enable_x64", True)
config.update('jax_num_cpu_devices', 2)
#NBVAL_SKIP
import os

# Only make GPU 0 and GPU 1 visible to JAX:
#os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

import jax

# Now JAX will list two CpuDevice entries
print(jax.devices())
# → [CpuDevice(id=0), CpuDevice(id=1)]
[CpuDevice(id=0), CpuDevice(id=1)]
# NBVAL_SKIP
import os
os.environ['SPS_HOME'] = '/home/annalena_data/sps_fsps'

Load ssp template from FSPS#

# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")
2025-07-01 14:33:44,585 - rubix - INFO - 
   ___  __  _____  _____  __
  / _ \/ / / / _ )/  _/ |/_/
 / , _/ /_/ / _  |/ /_>  <
/_/|_|\____/____/___/_/|_|
2025-07-01 14:33:44,586 - rubix - INFO - Rubix version: 0.0.post465+g01a25a7.d20250701
2025-07-01 14:33:44,586 - rubix - INFO - JAX version: 0.6.0
2025-07-01 14:33:44,587 - rubix - INFO - Running on [CpuDevice(id=0), CpuDevice(id=1)] devices
# NBVAL_SKIP
age_values = ssp_fsps.age
print(age_values.shape)

metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)
(107,)
(12,)
# NBVAL_SKIP
index_age = 90
index_metallicity = 9

#initial_metallicity_index = 5
#initial_age_index = 70
initial_metallicity_index = 10
initial_age_index = 104

learning_all = 5e-2
tol = 1e-10

print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")
start age: 15.848933219909668, start metallicity: 0.025251565501093864
target age: 3.1622776985168457, target metallicity: 0.014199999161064625

Configure pipeline#

# NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline
import os
config = {
    "pipeline":{"name": "calc_gradient",},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 1.2},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
        "dust": {
                "extinction_model": "Cardelli89",
                "dust_to_gas_ratio": 0.01,
                "dust_to_metals_ratio": 0.4,
                "dust_grain_density": 3.5,
                "Rv": 3.1,
            },
    },        
}
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()
output = pipe.run_sharded(inputdata)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:45,112 - rubix - INFO - Getting rubix data...
2025-07-01 14:33:45,113 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
2025-07-01 14:33:45,153 - rubix - INFO - Centering stars particles
2025-07-01 14:33:45,894 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars
2025-07-01 14:33:45,896 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.
2025-07-01 14:33:45,896 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:33:45,897 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:33:45,897 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:33:45,900 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:45,923 - rubix - INFO - Getting cosmology...
2025-07-01 14:33:46,094 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:33:46,186 - rubix - INFO - Getting cosmology...
2025-07-01 14:33:46,231 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:46,323 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:46,334 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:46,390 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:46,597 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:33:46,598 - rubix - INFO - Compiling the expressions...
2025-07-01 14:33:46,598 - rubix - INFO - Number of devices: 2
2025-07-01 14:33:46,703 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:33:46,804 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:33:46,818 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:33:46,956 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:33:46,956 - rubix - INFO - Convolving with PSF...
2025-07-01 14:33:46,959 - rubix - INFO - Convolving with LSF...
2025-07-01 14:33:46,964 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:33:54,764 - rubix - INFO - Pipeline run completed in 8.87 seconds.

Set target values#

# NBVAL_SKIP
import jax.numpy as jnp

inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
targetdata = pipe.run_sharded(inputdata)
2025-07-01 14:33:54,818 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:33:54,819 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:33:54,820 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:33:54,822 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:54,833 - rubix - INFO - Getting cosmology...
2025-07-01 14:33:54,843 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:33:54,853 - rubix - INFO - Getting cosmology...
2025-07-01 14:33:54,887 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:54,923 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:54,962 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:55,009 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:33:55,050 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:33:55,051 - rubix - INFO - Compiling the expressions...
2025-07-01 14:33:55,052 - rubix - INFO - Number of devices: 2
2025-07-01 14:33:55,157 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:33:55,238 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:33:55,240 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:33:55,250 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:33:55,251 - rubix - INFO - Convolving with PSF...
2025-07-01 14:33:55,253 - rubix - INFO - Convolving with LSF...
2025-07-01 14:33:55,257 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:34:02,326 - rubix - INFO - Pipeline run completed in 7.51 seconds.
# NBVAL_SKIP
print(targetdata[0,0,:].shape)
(466,)

Set initial cube#

# NBVAL_SKIP
inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
initialdata = pipe.run_sharded(inputdata)
2025-07-01 14:34:02,383 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:34:02,384 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:34:02,385 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:34:02,387 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:02,399 - rubix - INFO - Getting cosmology...
2025-07-01 14:34:02,409 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:34:02,418 - rubix - INFO - Getting cosmology...
2025-07-01 14:34:02,452 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:02,508 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:02,521 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:02,557 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:02,587 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:34:02,588 - rubix - INFO - Compiling the expressions...
2025-07-01 14:34:02,589 - rubix - INFO - Number of devices: 2
2025-07-01 14:34:02,681 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:34:02,912 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:34:02,915 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:34:02,923 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:34:02,923 - rubix - INFO - Convolving with PSF...
2025-07-01 14:34:02,926 - rubix - INFO - Convolving with LSF...
2025-07-01 14:34:02,929 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:34:09,670 - rubix - INFO - Pipeline run completed in 7.29 seconds.

Adam optimizer#

# NBVAL_SKIP
from rubix.pipeline import linear_pipeline as pipeline

pipeline_instance = RubixPipeline(config)

pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
    pipeline_instance.pipeline_config, 
    pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:09,732 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:34:09,733 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:34:09,733 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:34:09,735 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:09,745 - rubix - INFO - Getting cosmology...
2025-07-01 14:34:09,755 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:34:09,765 - rubix - INFO - Getting cosmology...
2025-07-01 14:34:09,780 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:09,806 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:09,818 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:34:09,851 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
# NBVAL_SKIP
import optax

def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
        
        base_data.stars.age = age*20
        base_data.stars.metallicity = metallicity*0.05

        output = pipeline_instance.func(base_data)
        #loss = jnp.sum((output.stars.datacube - target) ** 2)
        #loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))
        #loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))
        loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
        
        return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import optax


def adam_optimization_multi(loss_fn, params_init, data, target, learning=learning_all, tol=1e-3, max_iter=500):
    """
    Optimizes both age and metallicity.

    Args:
        loss_fn: function with signature loss_fn(age, metallicity, data, target)
        params_init: dict with keys 'age' and 'metallicity', each a JAX array
        data: base data for the loss function
        target: target data for the loss function
        learning_rate: learning rate for Adam
        tol: tolerance for convergence (based on update norm)
        max_iter: maximum number of iterations

    Returns:
        params: final parameters (dict)
        params_history: list of parameter values for each iteration
        loss_history: list of loss values for each iteration
    """
    params = params_init  # e.g., {'age': jnp.array(...), 'metallicity': jnp.array(...)}
    optimizers = {
        'age': optax.adam(learning),
        'metallicity': optax.adam(learning)
    }
    # Create a parameter label pytree matching the structure of params
    param_labels = {'age': 'age', 'metallicity': 'metallicity'}
    
    # Combine the optimizers with multi_transform
    optimizer = optax.multi_transform(optimizers, param_labels)
    optimizer_state = optimizer.init(params)
    
    age_history = []
    metallicity_history = []
    loss_history = []
    
    for i in range(max_iter):
        # Compute loss and gradients with respect to both parameters
        loss, grads = jax.value_and_grad(lambda p: loss_fn(p['age'], p['metallicity'], data, target))(params)
        loss_history.append(float(loss))
        # Save current parameters (convert from JAX arrays to floats)
        age_history.append(float(params['age'][0]))
        metallicity_history.append(float(params['metallicity'][0]))
        #params_history.append({
        #    'age': params['age'],
        #    'metallicity': params['metallicity']
        #})
        
        # Compute updates and apply them
        updates, optimizer_state = optimizer.update(grads, optimizer_state)
        params = optax.apply_updates(params, updates)
        
        # Optionally clip the parameters to enforce physical constraints:
        #params['age'] = jnp.clip(params['age'], 0.0, 1.0)
        #params['metallicity'] = jnp.clip(params['metallicity'], 0.0, 1.0)
        # For metallicity, uncomment and adjust the limits as needed:
        # params['metallicity'] = jnp.clip(params['metallicity'], metallicity_lower_bound, metallicity_upper_bound)
        
        # Check convergence based on the global norm of updates
        if optax.global_norm(updates) < tol:
            print(f"Converged at iteration {i}")
            break

    return params, age_history, metallicity_history, loss_history
# NBVAL_SKIP
loss_only_wrt_age_metallicity(inputdata.stars.age, inputdata.stars.metallicity, inputdata, targetdata)
2025-07-01 14:34:09,976 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:34:10,063 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:34:10,079 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:34:10,242 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:34:10,243 - rubix - INFO - Convolving with PSF...
2025-07-01 14:34:10,246 - rubix - INFO - Convolving with LSF...
2025-07-01 14:34:10,250 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
Array(nan, dtype=float32)
# NBVAL_SKIP
data = inputdata  # Replace with your actual data if needed
target_value = targetdata  # Replace with your actual target

# Define initial guesses for both age and metallicity.
# Adjust the initialization as needed for your problem.
age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])
metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])


# Pack both initial parameters into a dictionary.
params_init = {'age': age_init, 'metallicity': metallicity_init}
print(f"Initial parameters: {params_init}")

# Call the new optimizer function that handles both parameters.
optimized_params, age_history, metallicity_history, loss_history = adam_optimization_multi(
    loss_only_wrt_age_metallicity,
    params_init,
    data,
    target_value,
    learning=learning_all,
    tol=tol,
    max_iter=5000,
)

print(f"Optimized Age: {optimized_params['age']}")
print(f"Optimized Metallicity: {optimized_params['metallicity']}")
Initial parameters: {'age': Array([0.7924467, 0.7924467], dtype=float32), 'metallicity': Array([0.5050313, 0.5050313], dtype=float32)}
2025-07-01 14:34:27,504 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:34:27,768 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:34:27,770 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:34:27,874 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:34:27,874 - rubix - INFO - Convolving with PSF...
2025-07-01 14:34:27,878 - rubix - INFO - Convolving with LSF...
2025-07-01 14:34:27,881 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[18], line 16
     13 print(f"Initial parameters: {params_init}")
     15 # Call the new optimizer function that handles both parameters.
---> 16 optimized_params, age_history, metallicity_history, loss_history = adam_optimization_multi(
     17     loss_only_wrt_age_metallicity,
     18     params_init,
     19     data,
     20     target_value,
     21     learning=learning_all,
     22     tol=tol,
     23     max_iter=5000,
     24 )
     26 print(f"Optimized Age: {optimized_params['age']}")
     27 print(f"Optimized Metallicity: {optimized_params['metallicity']}")

Cell In[16], line 54, in adam_optimization_multi(loss_fn, params_init, data, target, learning, tol, max_iter)
     47 metallicity_history.append(float(params['metallicity'][0]))
     48 #params_history.append({
     49 #    'age': params['age'],
     50 #    'metallicity': params['metallicity']
     51 #})
     52 
     53 # Compute updates and apply them
---> 54 updates, optimizer_state = optimizer.update(grads, optimizer_state)
     55 params = optax.apply_updates(params, updates)
     57 # Optionally clip the parameters to enforce physical constraints:
     58 #params['age'] = jnp.clip(params['age'], 0.0, 1.0)
     59 #params['metallicity'] = jnp.clip(params['metallicity'], 0.0, 1.0)
   (...)     62 
     63 # Check convergence based on the global norm of updates

File ~/.conda/envs/rubix/lib/python3.12/site-packages/optax/transforms/_combining.py:262, in partition.<locals>.update_fn(updates, state, params, **extra_args)
    256 for group, tx in transforms.items():
    257   masked_tx = wrappers.masked(
    258       tx,
    259       make_mask(labels, group),
    260       mask_compatible_extra_args=mask_compatible_extra_args,
    261   )
--> 262   updates, new_inner_state[group] = masked_tx.update(
    263       updates, state.inner_states[group], params, **extra_args
    264   )
    265 return updates, PartitionState(new_inner_state)

File ~/.conda/envs/rubix/lib/python3.12/site-packages/optax/transforms/_masking.py:137, in masked.<locals>.update_fn(updates, state, params, **extra_args)
    134 masked_updates = mask_pytree(updates, mask_tree)
    135 masked_params = None if params is None else mask_pytree(params, mask_tree)
--> 137 new_masked_updates, new_inner_state = inner.update(
    138     masked_updates, state.inner_state, masked_params, **masked_extra_args
    139 )
    141 new_updates = jax.tree.map(
    142     lambda m, new_u, old_u: new_u if m else old_u,
    143     mask_tree,
    144     new_masked_updates,
    145     updates,
    146 )
    147 return new_updates, MaskedState(inner_state=new_inner_state)

File ~/.conda/envs/rubix/lib/python3.12/site-packages/optax/transforms/_combining.py:75, in chain.<locals>.update_fn(updates, state, params, **extra_args)
     73 new_state = []
     74 for s, fn in zip(state, update_fns):
---> 75   updates, new_s = fn(updates, s, params, **extra_args)
     76   new_state.append(new_s)
     77 return updates, tuple(new_state)

File ~/.conda/envs/rubix/lib/python3.12/site-packages/optax/_src/base.py:333, in with_extra_args_support.<locals>.update(***failed resolving arguments***)
    331 def update(updates, state, params=None, **extra_args):
    332   del extra_args
--> 333   return tx.update(updates, state, params)

File ~/.conda/envs/rubix/lib/python3.12/site-packages/optax/_src/transform.py:298, in scale_by_adam.<locals>.update_fn(***failed resolving arguments***)
    294   mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
    295 # Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ
    296 # Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is
    297 # unclear why. Other Nadam implementations also omit the extra b2 factor.
--> 298 nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
    299 updates = jax.tree.map(
    300     lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps),
    301     mu_hat,
    302     nu_hat,
    303     is_leaf=lambda x: x is None,
    304 )
    305 mu = otu.tree_cast(mu, mu_dtype)

File <string>:1, in <lambda>(_cls)

KeyboardInterrupt: 

Loss history#

# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# Convert histories to NumPy arrays if needed
loss_history_np = np.array(loss_history)
age_history_np = np.array(age_history)
metallicity_history_np = np.array(metallicity_history)

# Create an x-axis based on the number of iterations (assumed same for all)
iterations = np.arange(len(loss_history_np))
print(f"Number of iterations: {len(iterations)}")

# Create a figure with three subplots in one row and shared x-axis.
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharex=True)

# Plot the loss history (convert log-loss back to loss if needed)
axs[0].plot(iterations, 10**loss_history_np, marker='o', linestyle='-')
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[0].set_title("Loss History")
axs[0].grid(True)

# Plot the age history, multiplying by 20 for the physical scale.
axs[1].plot(iterations, age_history_np * 20, marker='o', linestyle='-')
# Draw a horizontal line for the target age
axs[1].hlines(y=age_values[index_age], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[1].set_xlabel("Iteration")
axs[1].set_ylabel("Age")
axs[1].set_title("Age History")
axs[1].grid(True)

# Plot the metallicity history, multiplying by 0.05 for the physical scale.
axs[2].plot(iterations, metallicity_history_np *0.05, marker='o', linestyle='-')
# Draw a horizontal line for the target metallicity
axs[2].hlines(y=metallicity_values[index_metallicity], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[2].set_xlabel("Iteration")
axs[2].set_ylabel("Metallicity")
axs[2].set_title("Metallicity History")
axs[2].grid(True)

axs[0].set_xlim(-5, 900)
axs[1].set_xlim(-5, 900)
axs[2].set_xlim(-5, 900)
plt.tight_layout()
plt.savefig(f"output/optimisation_history.jpg", dpi=1000)
plt.show()
Number of iterations: 5000
../_images/6133ff8625f2add75bf1ef30b0758a2b4eb5a405d071f494569846f2f932bb81.png
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 0
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)

#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq

spectra_target = targetdata
spectra_optimitzed = rubixdata
print(rubixdata.shape)


plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,601 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:21:55,602 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:21:55,603 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:21:55,605 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,615 - rubix - INFO - Getting cosmology...
2025-07-01 14:21:55,631 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:21:55,640 - rubix - INFO - Getting cosmology...
2025-07-01 14:21:55,659 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,691 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,704 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,743 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,793 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:21:55,794 - rubix - INFO - Compiling the expressions...
2025-07-01 14:21:55,794 - rubix - INFO - Number of devices: 2
2025-07-01 14:21:55,905 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:21:55,985 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:21:55,987 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:21:55,998 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:21:55,999 - rubix - INFO - Convolving with PSF...
2025-07-01 14:21:56,001 - rubix - INFO - Convolving with LSF...
2025-07-01 14:21:56,005 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:22:03,786 - rubix - INFO - Pipeline run completed in 8.18 seconds.
(1, 1, 466)
../_images/0dd87d23b6c1851fe279dc7dec8af30abac5e769698219e956e438e2d04e1e00.png
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 850
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)

#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq

spectra_target = targetdata #.stars.datacube
spectra_optimitzed = rubixdata #.stars.datacube

plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:03,947 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:22:03,948 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:22:03,948 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:22:03,950 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:03,961 - rubix - INFO - Getting cosmology...
2025-07-01 14:22:03,971 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:22:03,980 - rubix - INFO - Getting cosmology...
2025-07-01 14:22:03,999 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,034 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,045 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,083 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,125 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:22:04,126 - rubix - INFO - Compiling the expressions...
2025-07-01 14:22:04,127 - rubix - INFO - Number of devices: 2
2025-07-01 14:22:04,228 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:22:04,308 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:22:04,310 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:22:04,321 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:22:04,322 - rubix - INFO - Convolving with PSF...
2025-07-01 14:22:04,324 - rubix - INFO - Convolving with LSF...
2025-07-01 14:22:04,327 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:22:11,227 - rubix - INFO - Pipeline run completed in 7.28 seconds.
../_images/217431e942a9d8f0dbeaf3528df39b1da3bf5343bc8778c735285ae7b5ac4835.png
# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt

# Create a figure with two subplots, sharing the x-axis.
fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True, gridspec_kw={'height_ratios': [4, 1]}, figsize=(7, 5))

# Plot target and optimized spectra in the upper subplot.
ax1.plot(wave, spectra_target[0, 0, :], label=f"Target age = {age_values[index_age]:.2f}, metallicity = {metallicity_values[index_metallicity]:.4f}")
ax1.plot(wave, spectra_optimitzed[0, 0, :], label=f"Optimized age = {age_history[i]*20:.2f}, metallicity = {metallicity_history[i]*0.05:.4f}")
ax1.set_ylabel("Luminosity [L/Å]")
#ax1.set_title("Target vs Optimized Spectra")
ax1.legend()
ax1.grid(True)

# Compute the residual (difference between target and optimized spectra).
residual = (spectra_target[0, 0, :] - spectra_optimitzed[0, 0, :]) #/spectra_target[0, 0, :]

# Plot the residual in the lower subplot.
ax2.plot(wave, residual, 'k-')
ax2.set_xlabel("Wavelength [Å]")
ax2.set_ylabel("Residual")
ax2.grid(True)

plt.tight_layout()
plt.savefig(f"output/optimisation_spectra.jpg", dpi=1000)
plt.show()
../_images/3f649625e33b85dcadbe1eeb42adf03440739fab64e3907efa2af258bcc3971a.png

Calculate loss landscape#

# NBVAL_SKIP
import optax

def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):

    # Create 2D arrays for age and metallicity.
    # For example, if there are two stars, you might do:
    base_data.stars.age = jnp.array([age*20, age*20])
    base_data.stars.metallicity = jnp.array([metallicity*0.05, metallicity*0.05])

    output = pipeline_instance.func(base_data)
    #loss = jnp.sum((output.stars.datacube - target) ** 2)
    loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
    return loss
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Number of grid points
num_steps = 100

# Define physical ranges
physical_ages = jnp.linspace(0, 1, num_steps)         # Age from 0 to 10
physical_metals = jnp.linspace(0, 1, num_steps)    # Metallicity from 1e-4 to 0.05

loss_map = []

for age in physical_ages:
    row = []
    for metal in physical_metals:
        loss = loss_only_wrt_age_metallicity(age, metal, inputdata, targetdata)
        row.append(loss)
    loss_map.append(jnp.stack(row))

loss_map = jnp.stack(loss_map)
2025-07-01 14:22:12,033 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:22:12,100 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:22:12,102 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:22:12,448 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:22:12,449 - rubix - INFO - Convolving with PSF...
2025-07-01 14:22:12,451 - rubix - INFO - Convolving with LSF...
2025-07-01 14:22:12,454 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
# NBVAL_SKIP
# Plot the loss landscape using imshow.
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.figure(figsize=(5, 4))
plt.imshow(loss_map, origin='lower', extent=[0,1,0,1], aspect='auto', norm=colors.LogNorm())#, vmin=-3.5, vmax=-2.5)#extent=[1e-4, 0.05, 0, 10]
plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')
# Plot a red dot at the desired coordinates.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)
#plt.plot(metallicity_history[::100], age_history[::100], 'bx', markersize=8)
plt.plot(metallicity_values[index_metallicity]/0.05, age_values[index_age]/20, 'ro', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index]/0.05, age_values[initial_age_index]/20, 'ro', markersize=8)
plt.savefig(f"output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()
../_images/4a22fa48097686bd009f16175584324d2982483538139b38d5a28bc75fe2dc65.png
# NBVAL_SKIP
metallicity_history = np.array(metallicity_history)*0.05
age_history = np.array(age_history)*20
# NBVAL_SKIP
import matplotlib.pyplot as plt
import matplotlib.colors as colors

plt.figure(figsize=(6, 5))

# Update the extent to the physical values: metallicity from 0 to 0.05 and age from 0 to 20.
plt.imshow(loss_map, origin='lower', extent=[0, 0.05, 0, 20], aspect='auto', norm=colors.LogNorm())

plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')

# Plot the history in physical coordinates by multiplying the normalized values.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)

# Plot the red dots in physical coordinates
plt.plot(metallicity_values[index_metallicity], age_values[index_age], marker='o', color='orange', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index], age_values[initial_age_index], 'wo', markersize=8)

plt.savefig("output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()
../_images/1a97fb8f74253f805458df4257ca41b5215efa2c94b9c2801be06cc2bcecd8eb.png