Gradient through RUBIX#

# NBVAL_SKIP
from jax import config
#config.update("jax_enable_x64", True)
config.update('jax_num_cpu_devices', 2)
#NBVAL_SKIP
import os

# Only make GPU 0 and GPU 1 visible to JAX:
#os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

import jax

# Now JAX will list two CpuDevice entries
print(jax.devices())
# → [CpuDevice(id=0), CpuDevice(id=1)]
[CpuDevice(id=0), CpuDevice(id=1)]
# NBVAL_SKIP
import os
os.environ['SPS_HOME'] = '/home/annalena_data/sps_fsps'

Load ssp template from FSPS#

# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")
2025-07-01 14:19:12,434 - rubix - INFO - 
   ___  __  _____  _____  __
  / _ \/ / / / _ )/  _/ |/_/
 / , _/ /_/ / _  |/ /_>  <
/_/|_|\____/____/___/_/|_|


2025-07-01 14:19:12,435 - rubix - INFO - Rubix version: 0.0.post465+g01a25a7.d20250701
2025-07-01 14:19:12,435 - rubix - INFO - JAX version: 0.6.0
2025-07-01 14:19:12,436 - rubix - INFO - Running on [CpuDevice(id=0), CpuDevice(id=1)] devices
# NBVAL_SKIP
age_values = ssp_fsps.age
print(age_values.shape)

metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)
(107,)
(12,)
# NBVAL_SKIP
index_age = 90
index_metallicity = 9

#initial_metallicity_index = 5
#initial_age_index = 70
initial_metallicity_index = 10
initial_age_index = 104

learning_all = 5e-2
tol = 1e-10

print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")
start age: 15.848933219909668, start metallicity: 0.025251565501093864
target age: 3.1622776985168457, target metallicity: 0.014199999161064625

Configure pipeline#

# NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline
import os
config = {
    "pipeline":{"name": "calc_gradient",},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 1.2},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
        "dust": {
                "extinction_model": "Cardelli89",
                "dust_to_gas_ratio": 0.01,
                "dust_to_metals_ratio": 0.4,
                "dust_grain_density": 3.5,
                "Rv": 3.1,
            },
    },        
}
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()
output = pipe.run_sharded(inputdata)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:13,008 - rubix - INFO - Getting rubix data...
2025-07-01 14:19:13,009 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
2025-07-01 14:19:13,045 - rubix - INFO - Centering stars particles
2025-07-01 14:19:13,713 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars
2025-07-01 14:19:13,714 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.
2025-07-01 14:19:13,715 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:13,715 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:13,715 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:13,718 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:13,741 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:13,902 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:13,911 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:13,951 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:14,038 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:14,057 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:14,145 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:14,403 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:19:14,403 - rubix - INFO - Compiling the expressions...
2025-07-01 14:19:14,404 - rubix - INFO - Number of devices: 2
2025-07-01 14:19:14,508 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:14,613 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:14,629 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:14,772 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:14,772 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:14,776 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:14,781 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:19:22,678 - rubix - INFO - Pipeline run completed in 8.96 seconds.

Set target values#

# NBVAL_SKIP
import jax.numpy as jnp

inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
targetdata = pipe.run_sharded(inputdata)
2025-07-01 14:19:22,773 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:22,774 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:22,775 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:22,776 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:22,787 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:22,797 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:22,807 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:22,827 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:22,855 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:22,867 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:22,907 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:22,939 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:19:22,940 - rubix - INFO - Compiling the expressions...
2025-07-01 14:19:22,941 - rubix - INFO - Number of devices: 2
2025-07-01 14:19:23,052 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:23,131 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:23,134 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:23,303 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:23,304 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:23,307 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:23,310 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:19:31,497 - rubix - INFO - Pipeline run completed in 8.72 seconds.
# NBVAL_SKIP
print(targetdata[0,0,:].shape)
(466,)

Set initial cube#

# NBVAL_SKIP
inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
initialdata = pipe.run_sharded(inputdata)
2025-07-01 14:19:31,563 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:31,564 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:31,565 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:31,566 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:31,577 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:31,587 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:31,596 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:31,610 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:31,638 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:31,650 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:31,684 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:31,716 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:19:31,716 - rubix - INFO - Compiling the expressions...
2025-07-01 14:19:31,717 - rubix - INFO - Number of devices: 2
2025-07-01 14:19:31,804 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:31,882 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:31,884 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:31,892 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:31,892 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:31,895 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:31,898 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:19:38.155820: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %gather.185 = f32[105,1,12,5994]{3,2,1,0} gather(%constant.3104, %iota.28), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,12,5994}, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jit(shmap_body)/jit(<unnamed wrapped function>)/while/body/jit(interp2d)/jit(fun)/jit(_take)/gather" source_file="/tmp/ipykernel_1868132/3725520555.py" source_line=4}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-07-01 14:19:38.206841: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.051189572s
Constant folding an instruction is taking > 1s:

  %gather.185 = f32[105,1,12,5994]{3,2,1,0} gather(%constant.3104, %iota.28), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,12,5994}, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jit(shmap_body)/jit(<unnamed wrapped function>)/while/body/jit(interp2d)/jit(fun)/jit(_take)/gather" source_file="/tmp/ipykernel_1868132/3725520555.py" source_line=4}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-07-01 14:19:39,618 - rubix - INFO - Pipeline run completed in 8.05 seconds.

Adam optimizer#

# NBVAL_SKIP
from rubix.pipeline import linear_pipeline as pipeline

pipeline_instance = RubixPipeline(config)

pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
    pipeline_instance.pipeline_config, 
    pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:39,666 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:39,666 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:39,667 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:39,669 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:39,679 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:39,689 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:39,699 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:39,726 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:39,768 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:39,780 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:19:39,826 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
# NBVAL_SKIP
import optax

def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
        
        base_data.stars.age = age*20
        base_data.stars.metallicity = metallicity*0.05

        output = pipeline_instance.func(base_data)
        #loss = jnp.sum((output.stars.datacube - target) ** 2)
        #loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))
        #loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))
        loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
        
        return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import optax


def adam_optimization_multi(loss_fn, params_init, data, target, learning=learning_all, tol=1e-3, max_iter=500):
    """
    Optimizes both age and metallicity.

    Args:
        loss_fn: function with signature loss_fn(age, metallicity, data, target)
        params_init: dict with keys 'age' and 'metallicity', each a JAX array
        data: base data for the loss function
        target: target data for the loss function
        learning_rate: learning rate for Adam
        tol: tolerance for convergence (based on update norm)
        max_iter: maximum number of iterations

    Returns:
        params: final parameters (dict)
        params_history: list of parameter values for each iteration
        loss_history: list of loss values for each iteration
    """
    params = params_init  # e.g., {'age': jnp.array(...), 'metallicity': jnp.array(...)}
    optimizers = {
        'age': optax.adam(learning),
        'metallicity': optax.adam(learning)
    }
    # Create a parameter label pytree matching the structure of params
    param_labels = {'age': 'age', 'metallicity': 'metallicity'}
    
    # Combine the optimizers with multi_transform
    optimizer = optax.multi_transform(optimizers, param_labels)
    optimizer_state = optimizer.init(params)
    
    age_history = []
    metallicity_history = []
    loss_history = []
    
    for i in range(max_iter):
        # Compute loss and gradients with respect to both parameters
        loss, grads = jax.value_and_grad(lambda p: loss_fn(p['age'], p['metallicity'], data, target))(params)
        loss_history.append(float(loss))
        # Save current parameters (convert from JAX arrays to floats)
        age_history.append(float(params['age'][0]))
        metallicity_history.append(float(params['metallicity'][0]))
        #params_history.append({
        #    'age': params['age'],
        #    'metallicity': params['metallicity']
        #})
        
        # Compute updates and apply them
        updates, optimizer_state = optimizer.update(grads, optimizer_state)
        params = optax.apply_updates(params, updates)
        
        # Optionally clip the parameters to enforce physical constraints:
        #params['age'] = jnp.clip(params['age'], 0.0, 1.0)
        #params['metallicity'] = jnp.clip(params['metallicity'], 0.0, 1.0)
        # For metallicity, uncomment and adjust the limits as needed:
        # params['metallicity'] = jnp.clip(params['metallicity'], metallicity_lower_bound, metallicity_upper_bound)
        
        # Check convergence based on the global norm of updates
        if optax.global_norm(updates) < tol:
            print(f"Converged at iteration {i}")
            break

    return params, age_history, metallicity_history, loss_history
# NBVAL_SKIP
loss_only_wrt_age_metallicity(inputdata.stars.age, inputdata.stars.metallicity, inputdata, targetdata)
2025-07-01 14:19:39,970 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:40,063 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:40,081 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:40,253 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:40,254 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:40,257 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:40,263 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
Array(nan, dtype=float64)
# NBVAL_SKIP
data = inputdata  # Replace with your actual data if needed
target_value = targetdata  # Replace with your actual target

# Define initial guesses for both age and metallicity.
# Adjust the initialization as needed for your problem.
age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])
metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])


# Pack both initial parameters into a dictionary.
params_init = {'age': age_init, 'metallicity': metallicity_init}
print(f"Initial parameters: {params_init}")

# Call the new optimizer function that handles both parameters.
optimized_params, age_history, metallicity_history, loss_history = adam_optimization_multi(
    loss_only_wrt_age_metallicity,
    params_init,
    data,
    target_value,
    learning=learning_all,
    tol=tol,
    max_iter=5000,
)

print(f"Optimized Age: {optimized_params['age']}")
print(f"Optimized Metallicity: {optimized_params['metallicity']}")
Initial parameters: {'age': Array([0.7924467, 0.7924467], dtype=float32), 'metallicity': Array([0.5050313, 0.5050313], dtype=float32)}
2025-07-01 14:19:58,613 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:58,727 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:58,729 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:58,834 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:58,835 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:58,838 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:58,842 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
Optimized Age: [nan nan]
Optimized Metallicity: [nan nan]

Loss history#

# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# Convert histories to NumPy arrays if needed
loss_history_np = np.array(loss_history)
age_history_np = np.array(age_history)
metallicity_history_np = np.array(metallicity_history)

# Create an x-axis based on the number of iterations (assumed same for all)
iterations = np.arange(len(loss_history_np))
print(f"Number of iterations: {len(iterations)}")

# Create a figure with three subplots in one row and shared x-axis.
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharex=True)

# Plot the loss history (convert log-loss back to loss if needed)
axs[0].plot(iterations, 10**loss_history_np, marker='o', linestyle='-')
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[0].set_title("Loss History")
axs[0].grid(True)

# Plot the age history, multiplying by 20 for the physical scale.
axs[1].plot(iterations, age_history_np * 20, marker='o', linestyle='-')
# Draw a horizontal line for the target age
axs[1].hlines(y=age_values[index_age], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[1].set_xlabel("Iteration")
axs[1].set_ylabel("Age")
axs[1].set_title("Age History")
axs[1].grid(True)

# Plot the metallicity history, multiplying by 0.05 for the physical scale.
axs[2].plot(iterations, metallicity_history_np *0.05, marker='o', linestyle='-')
# Draw a horizontal line for the target metallicity
axs[2].hlines(y=metallicity_values[index_metallicity], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[2].set_xlabel("Iteration")
axs[2].set_ylabel("Metallicity")
axs[2].set_title("Metallicity History")
axs[2].grid(True)

axs[0].set_xlim(-5, 900)
axs[1].set_xlim(-5, 900)
axs[2].set_xlim(-5, 900)
plt.tight_layout()
plt.savefig(f"output/optimisation_history.jpg", dpi=1000)
plt.show()
Number of iterations: 5000
../_images/6133ff8625f2add75bf1ef30b0758a2b4eb5a405d071f494569846f2f932bb81.png
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 0
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)

#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq

spectra_target = targetdata
spectra_optimitzed = rubixdata
print(rubixdata.shape)


plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,601 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:21:55,602 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:21:55,603 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:21:55,605 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,615 - rubix - INFO - Getting cosmology...
2025-07-01 14:21:55,631 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:21:55,640 - rubix - INFO - Getting cosmology...
2025-07-01 14:21:55,659 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,691 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,704 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,743 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:21:55,793 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:21:55,794 - rubix - INFO - Compiling the expressions...
2025-07-01 14:21:55,794 - rubix - INFO - Number of devices: 2
2025-07-01 14:21:55,905 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:21:55,985 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:21:55,987 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:21:55,998 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:21:55,999 - rubix - INFO - Convolving with PSF...
2025-07-01 14:21:56,001 - rubix - INFO - Convolving with LSF...
2025-07-01 14:21:56,005 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:22:03,786 - rubix - INFO - Pipeline run completed in 8.18 seconds.
(1, 1, 466)
../_images/0dd87d23b6c1851fe279dc7dec8af30abac5e769698219e956e438e2d04e1e00.png
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 850
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])

pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)

#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq

spectra_target = targetdata #.stars.datacube
spectra_optimitzed = rubixdata #.stars.datacube

plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:03,947 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:22:03,948 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:22:03,948 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:22:03,950 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:03,961 - rubix - INFO - Getting cosmology...
2025-07-01 14:22:03,971 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:22:03,980 - rubix - INFO - Getting cosmology...
2025-07-01 14:22:03,999 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,034 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,045 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,083 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-07-01 14:22:04,125 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:22:04,126 - rubix - INFO - Compiling the expressions...
2025-07-01 14:22:04,127 - rubix - INFO - Number of devices: 2
2025-07-01 14:22:04,228 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:22:04,308 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:22:04,310 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:22:04,321 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:22:04,322 - rubix - INFO - Convolving with PSF...
2025-07-01 14:22:04,324 - rubix - INFO - Convolving with LSF...
2025-07-01 14:22:04,327 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:22:11,227 - rubix - INFO - Pipeline run completed in 7.28 seconds.
../_images/217431e942a9d8f0dbeaf3528df39b1da3bf5343bc8778c735285ae7b5ac4835.png
# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt

# Create a figure with two subplots, sharing the x-axis.
fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True, gridspec_kw={'height_ratios': [4, 1]}, figsize=(7, 5))

# Plot target and optimized spectra in the upper subplot.
ax1.plot(wave, spectra_target[0, 0, :], label=f"Target age = {age_values[index_age]:.2f}, metallicity = {metallicity_values[index_metallicity]:.4f}")
ax1.plot(wave, spectra_optimitzed[0, 0, :], label=f"Optimized age = {age_history[i]*20:.2f}, metallicity = {metallicity_history[i]*0.05:.4f}")
ax1.set_ylabel("Luminosity [L/Å]")
#ax1.set_title("Target vs Optimized Spectra")
ax1.legend()
ax1.grid(True)

# Compute the residual (difference between target and optimized spectra).
residual = (spectra_target[0, 0, :] - spectra_optimitzed[0, 0, :]) #/spectra_target[0, 0, :]

# Plot the residual in the lower subplot.
ax2.plot(wave, residual, 'k-')
ax2.set_xlabel("Wavelength [Å]")
ax2.set_ylabel("Residual")
ax2.grid(True)

plt.tight_layout()
plt.savefig(f"output/optimisation_spectra.jpg", dpi=1000)
plt.show()
../_images/3f649625e33b85dcadbe1eeb42adf03440739fab64e3907efa2af258bcc3971a.png

Calculate loss landscape#

# NBVAL_SKIP
import optax

def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):

    # Create 2D arrays for age and metallicity.
    # For example, if there are two stars, you might do:
    base_data.stars.age = jnp.array([age*20, age*20])
    base_data.stars.metallicity = jnp.array([metallicity*0.05, metallicity*0.05])

    output = pipeline_instance.func(base_data)
    #loss = jnp.sum((output.stars.datacube - target) ** 2)
    loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
    return loss
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Number of grid points
num_steps = 100

# Define physical ranges
physical_ages = jnp.linspace(0, 1, num_steps)         # Age from 0 to 10
physical_metals = jnp.linspace(0, 1, num_steps)    # Metallicity from 1e-4 to 0.05

loss_map = []

for age in physical_ages:
    row = []
    for metal in physical_metals:
        loss = loss_only_wrt_age_metallicity(age, metal, inputdata, targetdata)
        row.append(loss)
    loss_map.append(jnp.stack(row))

loss_map = jnp.stack(loss_map)
2025-07-01 14:22:12,033 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:22:12,100 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:22:12,102 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:22:12,448 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:22:12,449 - rubix - INFO - Convolving with PSF...
2025-07-01 14:22:12,451 - rubix - INFO - Convolving with LSF...
2025-07-01 14:22:12,454 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
# NBVAL_SKIP
# Plot the loss landscape using imshow.
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.figure(figsize=(5, 4))
plt.imshow(loss_map, origin='lower', extent=[0,1,0,1], aspect='auto', norm=colors.LogNorm())#, vmin=-3.5, vmax=-2.5)#extent=[1e-4, 0.05, 0, 10]
plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')
# Plot a red dot at the desired coordinates.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)
#plt.plot(metallicity_history[::100], age_history[::100], 'bx', markersize=8)
plt.plot(metallicity_values[index_metallicity]/0.05, age_values[index_age]/20, 'ro', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index]/0.05, age_values[initial_age_index]/20, 'ro', markersize=8)
plt.savefig(f"output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()
../_images/4a22fa48097686bd009f16175584324d2982483538139b38d5a28bc75fe2dc65.png
# NBVAL_SKIP
metallicity_history = np.array(metallicity_history)*0.05
age_history = np.array(age_history)*20
# NBVAL_SKIP
import matplotlib.pyplot as plt
import matplotlib.colors as colors

plt.figure(figsize=(6, 5))

# Update the extent to the physical values: metallicity from 0 to 0.05 and age from 0 to 20.
plt.imshow(loss_map, origin='lower', extent=[0, 0.05, 0, 20], aspect='auto', norm=colors.LogNorm())

plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')

# Plot the history in physical coordinates by multiplying the normalized values.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)

# Plot the red dots in physical coordinates
plt.plot(metallicity_values[index_metallicity], age_values[index_age], marker='o', color='orange', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index], age_values[initial_age_index], 'wo', markersize=8)

plt.savefig("output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()
../_images/1a97fb8f74253f805458df4257ca41b5215efa2c94b9c2801be06cc2bcecd8eb.png