Gradient through RUBIX#
# NBVAL_SKIP
from jax import config
#config.update("jax_enable_x64", True)
config.update('jax_num_cpu_devices', 2)
#NBVAL_SKIP
import os
# Only make GPU 0 and GPU 1 visible to JAX:
#os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'
import jax
# Now JAX will list two CpuDevice entries
print(jax.devices())
# → [CpuDevice(id=0), CpuDevice(id=1)]
[CpuDevice(id=0), CpuDevice(id=1)]
# NBVAL_SKIP
import os
os.environ['SPS_HOME'] = '/home/annalena_data/sps_fsps'
Load ssp template from FSPS#
# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")
2025-07-01 14:19:12,434 - rubix - INFO -
___ __ _____ _____ __
/ _ \/ / / / _ )/ _/ |/_/
/ , _/ /_/ / _ |/ /_> <
/_/|_|\____/____/___/_/|_|
2025-07-01 14:19:12,435 - rubix - INFO - Rubix version: 0.0.post465+g01a25a7.d20250701
2025-07-01 14:19:12,435 - rubix - INFO - JAX version: 0.6.0
2025-07-01 14:19:12,436 - rubix - INFO - Running on [CpuDevice(id=0), CpuDevice(id=1)] devices
# NBVAL_SKIP
age_values = ssp_fsps.age
print(age_values.shape)
metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)
(107,)
(12,)
# NBVAL_SKIP
index_age = 90
index_metallicity = 9
#initial_metallicity_index = 5
#initial_age_index = 70
initial_metallicity_index = 10
initial_age_index = 104
learning_all = 5e-2
tol = 1e-10
print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")
start age: 15.848933219909668, start metallicity: 0.025251565501093864
target age: 3.1622776985168457, target metallicity: 0.014199999161064625
Configure pipeline#
# NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline
import os
config = {
"pipeline":{"name": "calc_gradient",},
"logger": {
"log_level": "DEBUG",
"log_file_path": None,
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
"data": {
"name": "IllustrisAPI",
"args": {
"api_key": os.environ.get("ILLUSTRIS_API_KEY"),
"particle_type": ["stars"],
"simulation": "TNG50-1",
"snapshot": 99,
"save_data_path": "data",
},
"load_galaxy_args": {
"id": 14,
"reuse": True,
},
"subset": {
"use_subset": True,
"subset_size": 2,
},
},
"simulation": {
"name": "IllustrisTNG",
"args": {
"path": "data/galaxy-id-14.hdf5",
},
},
"output_path": "output",
"telescope":
{"name": "TESTGRADIENT",
"psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
"lsf": {"sigma": 1.2},
"noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
},
"cosmology":
{"name": "PLANCK15"},
"galaxy":
{"dist_z": 0.1,
"rotation": {"type": "edge-on"},
},
"ssp": {
"template": {
"name": "FSPS"
},
"dust": {
"extinction_model": "Cardelli89",
"dust_to_gas_ratio": 0.01,
"dust_to_metals_ratio": 0.4,
"dust_grain_density": 3.5,
"Rv": 3.1,
},
},
}
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()
output = pipe.run_sharded(inputdata)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:13,008 - rubix - INFO - Getting rubix data...
2025-07-01 14:19:13,009 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
2025-07-01 14:19:13,045 - rubix - INFO - Centering stars particles
2025-07-01 14:19:13,713 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars
2025-07-01 14:19:13,714 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.
2025-07-01 14:19:13,715 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:13,715 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:13,715 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:13,718 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:13,741 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:13,902 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:13,911 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:13,951 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:14,038 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:14,057 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:14,145 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:14,403 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:19:14,403 - rubix - INFO - Compiling the expressions...
2025-07-01 14:19:14,404 - rubix - INFO - Number of devices: 2
2025-07-01 14:19:14,508 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:14,613 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:14,629 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:14,772 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:14,772 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:14,776 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:14,781 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:19:22,678 - rubix - INFO - Pipeline run completed in 8.96 seconds.
Set target values#
# NBVAL_SKIP
import jax.numpy as jnp
inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
targetdata = pipe.run_sharded(inputdata)
2025-07-01 14:19:22,773 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:22,774 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:22,775 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:22,776 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:22,787 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:22,797 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:22,807 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:22,827 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:22,855 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:22,867 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:22,907 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:22,939 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:19:22,940 - rubix - INFO - Compiling the expressions...
2025-07-01 14:19:22,941 - rubix - INFO - Number of devices: 2
2025-07-01 14:19:23,052 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:23,131 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:23,134 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:23,303 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:23,304 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:23,307 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:23,310 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:19:31,497 - rubix - INFO - Pipeline run completed in 8.72 seconds.
# NBVAL_SKIP
print(targetdata[0,0,:].shape)
(466,)
Set initial cube#
# NBVAL_SKIP
inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
initialdata = pipe.run_sharded(inputdata)
2025-07-01 14:19:31,563 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:31,564 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:31,565 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:31,566 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:31,577 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:31,587 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:31,596 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:31,610 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:31,638 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:31,650 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:31,684 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:31,716 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:19:31,716 - rubix - INFO - Compiling the expressions...
2025-07-01 14:19:31,717 - rubix - INFO - Number of devices: 2
2025-07-01 14:19:31,804 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:31,882 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:31,884 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:31,892 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:31,892 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:31,895 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:31,898 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:19:38.155820: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:
%gather.185 = f32[105,1,12,5994]{3,2,1,0} gather(%constant.3104, %iota.28), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,12,5994}, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jit(shmap_body)/jit(<unnamed wrapped function>)/while/body/jit(interp2d)/jit(fun)/jit(_take)/gather" source_file="/tmp/ipykernel_1868132/3725520555.py" source_line=4}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-07-01 14:19:38.206841: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.051189572s
Constant folding an instruction is taking > 1s:
%gather.185 = f32[105,1,12,5994]{3,2,1,0} gather(%constant.3104, %iota.28), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,12,5994}, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/jit(shmap_body)/jit(<unnamed wrapped function>)/while/body/jit(interp2d)/jit(fun)/jit(_take)/gather" source_file="/tmp/ipykernel_1868132/3725520555.py" source_line=4}
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-07-01 14:19:39,618 - rubix - INFO - Pipeline run completed in 8.05 seconds.
Adam optimizer#
# NBVAL_SKIP
from rubix.pipeline import linear_pipeline as pipeline
pipeline_instance = RubixPipeline(config)
pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
pipeline_instance.pipeline_config,
pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:39,666 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:19:39,666 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:19:39,667 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:19:39,669 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:39,679 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:39,689 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:19:39,699 - rubix - INFO - Getting cosmology...
2025-07-01 14:19:39,726 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:39,768 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:39,780 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:19:39,826 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
# NBVAL_SKIP
import optax
def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
base_data.stars.age = age*20
base_data.stars.metallicity = metallicity*0.05
output = pipeline_instance.func(base_data)
#loss = jnp.sum((output.stars.datacube - target) ** 2)
#loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))
#loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))
loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import optax
def adam_optimization_multi(loss_fn, params_init, data, target, learning=learning_all, tol=1e-3, max_iter=500):
"""
Optimizes both age and metallicity.
Args:
loss_fn: function with signature loss_fn(age, metallicity, data, target)
params_init: dict with keys 'age' and 'metallicity', each a JAX array
data: base data for the loss function
target: target data for the loss function
learning_rate: learning rate for Adam
tol: tolerance for convergence (based on update norm)
max_iter: maximum number of iterations
Returns:
params: final parameters (dict)
params_history: list of parameter values for each iteration
loss_history: list of loss values for each iteration
"""
params = params_init # e.g., {'age': jnp.array(...), 'metallicity': jnp.array(...)}
optimizers = {
'age': optax.adam(learning),
'metallicity': optax.adam(learning)
}
# Create a parameter label pytree matching the structure of params
param_labels = {'age': 'age', 'metallicity': 'metallicity'}
# Combine the optimizers with multi_transform
optimizer = optax.multi_transform(optimizers, param_labels)
optimizer_state = optimizer.init(params)
age_history = []
metallicity_history = []
loss_history = []
for i in range(max_iter):
# Compute loss and gradients with respect to both parameters
loss, grads = jax.value_and_grad(lambda p: loss_fn(p['age'], p['metallicity'], data, target))(params)
loss_history.append(float(loss))
# Save current parameters (convert from JAX arrays to floats)
age_history.append(float(params['age'][0]))
metallicity_history.append(float(params['metallicity'][0]))
#params_history.append({
# 'age': params['age'],
# 'metallicity': params['metallicity']
#})
# Compute updates and apply them
updates, optimizer_state = optimizer.update(grads, optimizer_state)
params = optax.apply_updates(params, updates)
# Optionally clip the parameters to enforce physical constraints:
#params['age'] = jnp.clip(params['age'], 0.0, 1.0)
#params['metallicity'] = jnp.clip(params['metallicity'], 0.0, 1.0)
# For metallicity, uncomment and adjust the limits as needed:
# params['metallicity'] = jnp.clip(params['metallicity'], metallicity_lower_bound, metallicity_upper_bound)
# Check convergence based on the global norm of updates
if optax.global_norm(updates) < tol:
print(f"Converged at iteration {i}")
break
return params, age_history, metallicity_history, loss_history
# NBVAL_SKIP
loss_only_wrt_age_metallicity(inputdata.stars.age, inputdata.stars.metallicity, inputdata, targetdata)
2025-07-01 14:19:39,970 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:40,063 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:40,081 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:40,253 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:40,254 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:40,257 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:40,263 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
Array(nan, dtype=float64)
# NBVAL_SKIP
data = inputdata # Replace with your actual data if needed
target_value = targetdata # Replace with your actual target
# Define initial guesses for both age and metallicity.
# Adjust the initialization as needed for your problem.
age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])
metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])
# Pack both initial parameters into a dictionary.
params_init = {'age': age_init, 'metallicity': metallicity_init}
print(f"Initial parameters: {params_init}")
# Call the new optimizer function that handles both parameters.
optimized_params, age_history, metallicity_history, loss_history = adam_optimization_multi(
loss_only_wrt_age_metallicity,
params_init,
data,
target_value,
learning=learning_all,
tol=tol,
max_iter=5000,
)
print(f"Optimized Age: {optimized_params['age']}")
print(f"Optimized Metallicity: {optimized_params['metallicity']}")
Initial parameters: {'age': Array([0.7924467, 0.7924467], dtype=float32), 'metallicity': Array([0.5050313, 0.5050313], dtype=float32)}
2025-07-01 14:19:58,613 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:19:58,727 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:19:58,729 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:19:58,834 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:19:58,835 - rubix - INFO - Convolving with PSF...
2025-07-01 14:19:58,838 - rubix - INFO - Convolving with LSF...
2025-07-01 14:19:58,842 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
Optimized Age: [nan nan]
Optimized Metallicity: [nan nan]
Loss history#
# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
# Convert histories to NumPy arrays if needed
loss_history_np = np.array(loss_history)
age_history_np = np.array(age_history)
metallicity_history_np = np.array(metallicity_history)
# Create an x-axis based on the number of iterations (assumed same for all)
iterations = np.arange(len(loss_history_np))
print(f"Number of iterations: {len(iterations)}")
# Create a figure with three subplots in one row and shared x-axis.
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharex=True)
# Plot the loss history (convert log-loss back to loss if needed)
axs[0].plot(iterations, 10**loss_history_np, marker='o', linestyle='-')
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[0].set_title("Loss History")
axs[0].grid(True)
# Plot the age history, multiplying by 20 for the physical scale.
axs[1].plot(iterations, age_history_np * 20, marker='o', linestyle='-')
# Draw a horizontal line for the target age
axs[1].hlines(y=age_values[index_age], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[1].set_xlabel("Iteration")
axs[1].set_ylabel("Age")
axs[1].set_title("Age History")
axs[1].grid(True)
# Plot the metallicity history, multiplying by 0.05 for the physical scale.
axs[2].plot(iterations, metallicity_history_np *0.05, marker='o', linestyle='-')
# Draw a horizontal line for the target metallicity
axs[2].hlines(y=metallicity_values[index_metallicity], xmin=0, xmax=iterations[-1], color='r', linestyle='-')
axs[2].set_xlabel("Iteration")
axs[2].set_ylabel("Metallicity")
axs[2].set_title("Metallicity History")
axs[2].grid(True)
axs[0].set_xlim(-5, 900)
axs[1].set_xlim(-5, 900)
axs[2].set_xlim(-5, 900)
plt.tight_layout()
plt.savefig(f"output/optimisation_history.jpg", dpi=1000)
plt.show()
Number of iterations: 5000
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 0
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)
#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq
spectra_target = targetdata
spectra_optimitzed = rubixdata
print(rubixdata.shape)
plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:21:55,601 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:21:55,602 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:21:55,603 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:21:55,605 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:21:55,615 - rubix - INFO - Getting cosmology...
2025-07-01 14:21:55,631 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:21:55,640 - rubix - INFO - Getting cosmology...
2025-07-01 14:21:55,659 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:21:55,691 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:21:55,704 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:21:55,743 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:21:55,793 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:21:55,794 - rubix - INFO - Compiling the expressions...
2025-07-01 14:21:55,794 - rubix - INFO - Number of devices: 2
2025-07-01 14:21:55,905 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:21:55,985 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:21:55,987 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:21:55,998 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:21:55,999 - rubix - INFO - Convolving with PSF...
2025-07-01 14:21:56,001 - rubix - INFO - Convolving with LSF...
2025-07-01 14:21:56,005 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:22:03,786 - rubix - INFO - Pipeline run completed in 8.18 seconds.
(1, 1, 466)
# NBVAL_SKIP
#run the pipeline with the optimized age
#rubixdata.stars.age = optimized_age
i = 850
inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])
inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
pipe = RubixPipeline(config)
rubixdata = pipe.run_sharded(inputdata)
#plot the target and the optimized spectra
import matplotlib.pyplot as plt
wave = pipe.telescope.wave_seq
spectra_target = targetdata #.stars.datacube
spectra_optimitzed = rubixdata #.stars.datacube
plt.plot(wave, spectra_target[0,0,:], label=f"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}")
plt.plot(wave, spectra_optimitzed[0,0,:], label=f"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}")
plt.xlabel("Wavelength [Å]")
plt.ylabel("Luminosity [L/Å]")
plt.title("Difference between target and optimized spectra")
#plt.title(f"Loss {loss_history[i]:.2e}")
plt.legend()
#plt.ylim(0.00003, 0.00008)
plt.grid(True)
plt.show()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:22:03,947 - rubix - INFO - Setting up the pipeline...
2025-07-01 14:22:03,948 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-07-01 14:22:03,948 - rubix - DEBUG - Roataion Type found: edge-on
2025-07-01 14:22:03,950 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:22:03,961 - rubix - INFO - Getting cosmology...
2025-07-01 14:22:03,971 - rubix - INFO - Calculating spatial bin edges...
2025-07-01 14:22:03,980 - rubix - INFO - Getting cosmology...
2025-07-01 14:22:03,999 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:22:04,034 - rubix - DEBUG - SSP Wave: (5994,)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:22:04,045 - rubix - INFO - Getting cosmology...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:22:04,083 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
warnings.warn(
2025-07-01 14:22:04,125 - rubix - INFO - Assembling the pipeline...
2025-07-01 14:22:04,126 - rubix - INFO - Compiling the expressions...
2025-07-01 14:22:04,127 - rubix - INFO - Number of devices: 2
2025-07-01 14:22:04,228 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:22:04,308 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:22:04,310 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:22:04,321 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:22:04,322 - rubix - INFO - Convolving with PSF...
2025-07-01 14:22:04,324 - rubix - INFO - Convolving with LSF...
2025-07-01 14:22:04,327 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-07-01 14:22:11,227 - rubix - INFO - Pipeline run completed in 7.28 seconds.
# NBVAL_SKIP
import matplotlib as mpl
import matplotlib.pyplot as plt
# Create a figure with two subplots, sharing the x-axis.
fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True, gridspec_kw={'height_ratios': [4, 1]}, figsize=(7, 5))
# Plot target and optimized spectra in the upper subplot.
ax1.plot(wave, spectra_target[0, 0, :], label=f"Target age = {age_values[index_age]:.2f}, metallicity = {metallicity_values[index_metallicity]:.4f}")
ax1.plot(wave, spectra_optimitzed[0, 0, :], label=f"Optimized age = {age_history[i]*20:.2f}, metallicity = {metallicity_history[i]*0.05:.4f}")
ax1.set_ylabel("Luminosity [L/Å]")
#ax1.set_title("Target vs Optimized Spectra")
ax1.legend()
ax1.grid(True)
# Compute the residual (difference between target and optimized spectra).
residual = (spectra_target[0, 0, :] - spectra_optimitzed[0, 0, :]) #/spectra_target[0, 0, :]
# Plot the residual in the lower subplot.
ax2.plot(wave, residual, 'k-')
ax2.set_xlabel("Wavelength [Å]")
ax2.set_ylabel("Residual")
ax2.grid(True)
plt.tight_layout()
plt.savefig(f"output/optimisation_spectra.jpg", dpi=1000)
plt.show()
Calculate loss landscape#
# NBVAL_SKIP
import optax
def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
# Create 2D arrays for age and metallicity.
# For example, if there are two stars, you might do:
base_data.stars.age = jnp.array([age*20, age*20])
base_data.stars.metallicity = jnp.array([metallicity*0.05, metallicity*0.05])
output = pipeline_instance.func(base_data)
#loss = jnp.sum((output.stars.datacube - target) ** 2)
loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
return loss
# NBVAL_SKIP
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Number of grid points
num_steps = 100
# Define physical ranges
physical_ages = jnp.linspace(0, 1, num_steps) # Age from 0 to 10
physical_metals = jnp.linspace(0, 1, num_steps) # Metallicity from 1e-4 to 0.05
loss_map = []
for age in physical_ages:
row = []
for metal in physical_metals:
loss = loss_only_wrt_age_metallicity(age, metal, inputdata, targetdata)
row.append(loss)
loss_map.append(jnp.stack(row))
loss_map = jnp.stack(loss_map)
2025-07-01 14:22:12,033 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-07-01 14:22:12,100 - rubix - INFO - Assigning particles to spaxels...
2025-07-01 14:22:12,102 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-07-01 14:22:12,448 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-07-01 14:22:12,449 - rubix - INFO - Convolving with PSF...
2025-07-01 14:22:12,451 - rubix - INFO - Convolving with LSF...
2025-07-01 14:22:12,454 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
# NBVAL_SKIP
# Plot the loss landscape using imshow.
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.figure(figsize=(5, 4))
plt.imshow(loss_map, origin='lower', extent=[0,1,0,1], aspect='auto', norm=colors.LogNorm())#, vmin=-3.5, vmax=-2.5)#extent=[1e-4, 0.05, 0, 10]
plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')
# Plot a red dot at the desired coordinates.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)
#plt.plot(metallicity_history[::100], age_history[::100], 'bx', markersize=8)
plt.plot(metallicity_values[index_metallicity]/0.05, age_values[index_age]/20, 'ro', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index]/0.05, age_values[initial_age_index]/20, 'ro', markersize=8)
plt.savefig(f"output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()
# NBVAL_SKIP
metallicity_history = np.array(metallicity_history)*0.05
age_history = np.array(age_history)*20
# NBVAL_SKIP
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.figure(figsize=(6, 5))
# Update the extent to the physical values: metallicity from 0 to 0.05 and age from 0 to 20.
plt.imshow(loss_map, origin='lower', extent=[0, 0.05, 0, 20], aspect='auto', norm=colors.LogNorm())
plt.xlabel('Metallicity')
plt.ylabel('Age')
plt.title('Loss Landscape')
plt.colorbar(label='loss')
# Plot the history in physical coordinates by multiplying the normalized values.
plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)
# Plot the red dots in physical coordinates
plt.plot(metallicity_values[index_metallicity], age_values[index_age], marker='o', color='orange', markersize=8)
plt.plot(metallicity_values[initial_metallicity_index], age_values[initial_age_index], 'wo', markersize=8)
plt.savefig("output/optimisation_losslandscape.jpg", dpi=1000)
plt.show()