Gradient vs finite difference#

# NBVAL_SKIP
from jax import config
import os
import jax

print(jax.devices())
[CpuDevice(id=0)]
# NBVAL_SKIP
import os
os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'

Load ssp template from FSPS#

# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")
2025-11-10 17:11:57,608 - rubix - INFO - 
   ___  __  _____  _____  __
  / _ \/ / / / _ )/  _/ |/_/
 / , _/ /_/ / _  |/ /_>  <
/_/|_|\____/____/___/_/|_|


2025-11-10 17:11:57,608 - rubix - INFO - Rubix version: 0.0.post626+g42b4b7505.d20251110
2025-11-10 17:11:57,609 - rubix - INFO - JAX version: 0.7.2
2025-11-10 17:11:57,609 - rubix - INFO - Running on [CpuDevice(id=0)] devices
2025-11-10 17:11:57,609 - rubix - WARNING - python-fsps is not installed. Please install it to use this function. Install using pip install fsps and check the installation page: https://dfm.io/python-fsps/current/installation/ for more details. Especially, make sure to set all necessary environment variables.
# NBVAL_SKIP
age_values = ssp_fsps.age
print(age_values.shape)

metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)
(107,)
(12,)
# NBVAL_SKIP
index_age = 90
index_metallicity = 9

#initial_metallicity_index = 5
#initial_age_index = 70
initial_metallicity_index = 10
initial_age_index = 104

learning_all = 1e-2
tol = 1e-10

print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")
start age: 15.848933219909668, start metallicity: 0.025251565501093864
target age: 3.1622776985168457, target metallicity: 0.014199999161064625

Configure pipeline#

# NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline
import os
config = {
    "pipeline":{"name": "calc_gradient",},
    
    "logger": {
        "log_level": "DEBUG",
        "log_file_path": None,
        "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    },
    "data": {
        "name": "IllustrisAPI",
        "args": {
            "api_key": os.environ.get("ILLUSTRIS_API_KEY"),
            "particle_type": ["stars"],
            "simulation": "TNG50-1",
            "snapshot": 99,
            "save_data_path": "data",
        },
        
        "load_galaxy_args": {
        "id": 14,
        "reuse": True,
        },
        
        "subset": {
            "use_subset": True,
            "subset_size": 2,
        },
    },
    "simulation": {
        "name": "IllustrisTNG",
        "args": {
            "path": "data/galaxy-id-14.hdf5",
        },
    
    },
    "output_path": "output",

    "telescope":
        {"name": "TESTGRADIENT",
         "psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
         "lsf": {"sigma": 1.2},
         "noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
         },
    "cosmology":
        {"name": "PLANCK15"},
        
    "galaxy":
        {"dist_z": 0.1,
         "rotation": {"type": "edge-on"},
        },
        
    "ssp": {
        "template": {
            "name": "FSPS"
        },
        "dust": {
                "extinction_model": "Cardelli89",
                "dust_to_gas_ratio": 0.01,
                "dust_to_metals_ratio": 0.4,
                "dust_grain_density": 3.5,
                "Rv": 3.1,
            },
    },        
}
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()
output = pipe.run_sharded(inputdata)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:11:58,256 - rubix - INFO - Getting rubix data...
2025-11-10 17:11:58,257 - rubix - INFO - Rubix galaxy file already exists, skipping conversion
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/jax/_src/numpy/scalar_types.py:50: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/core/data.py:491: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  rubixdata.galaxy.center = jnp.array(data["subhalo_center"], dtype=jnp.float64)
2025-11-10 17:11:58,318 - rubix - INFO - Centering stars particles
2025-11-10 17:11:59,305 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars
2025-11-10 17:11:59,305 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.
2025-11-10 17:11:59,306 - rubix - INFO - Data preparation completed in 1.05 seconds.
2025-11-10 17:11:59,306 - rubix - INFO - Setting up the pipeline...
2025-11-10 17:11:59,307 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-10 17:11:59,307 - rubix - DEBUG - Rotation Type found: edge-on
2025-11-10 17:11:59,310 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:11:59,337 - rubix - INFO - Getting cosmology...
2025-11-10 17:11:59,547 - rubix - INFO - Calculating spatial bin edges...
2025-11-10 17:11:59,556 - rubix - INFO - Getting cosmology...
2025-11-10 17:11:59,567 - rubix - INFO - Getting cosmology...
2025-11-10 17:11:59,652 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:11:59,807 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:00,094 - rubix - INFO - Assembling the pipeline...
2025-11-10 17:12:00,095 - rubix - INFO - Compiling the expressions...
2025-11-10 17:12:00,096 - rubix - INFO - Number of devices: 1
2025-11-10 17:12:00,180 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-10 17:12:00,181 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-10 17:12:00,181 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-10 17:12:00,286 - rubix - INFO - Assigning particles to spaxels...
2025-11-10 17:12:00,318 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-10 17:12:00,481 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-10 17:12:00,481 - rubix - INFO - Convolving with PSF...
2025-11-10 17:12:00,486 - rubix - INFO - Convolving with LSF...
2025-11-10 17:12:00,493 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-10 17:12:06,414 - rubix - INFO - Total time for sharded pipeline run: 7.11 seconds.

Set target values#

# NBVAL_SKIP
import jax.numpy as jnp

inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
targetdata = pipe.run_sharded(inputdata)
2025-11-10 17:12:06,474 - rubix - INFO - Setting up the pipeline...
2025-11-10 17:12:06,475 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-10 17:12:06,476 - rubix - DEBUG - Rotation Type found: edge-on
2025-11-10 17:12:06,479 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:06,491 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:06,502 - rubix - INFO - Calculating spatial bin edges...
2025-11-10 17:12:06,624 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:06,635 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:06,681 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:06,757 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:06,804 - rubix - INFO - Assembling the pipeline...
2025-11-10 17:12:06,804 - rubix - INFO - Compiling the expressions...
2025-11-10 17:12:06,806 - rubix - INFO - Number of devices: 1
2025-11-10 17:12:06,888 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-10 17:12:06,889 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-10 17:12:06,889 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-10 17:12:06,961 - rubix - INFO - Assigning particles to spaxels...
2025-11-10 17:12:06,963 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-10 17:12:06,971 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-10 17:12:06,972 - rubix - INFO - Convolving with PSF...
2025-11-10 17:12:06,974 - rubix - INFO - Convolving with LSF...
2025-11-10 17:12:06,977 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-10 17:12:12,833 - rubix - INFO - Total time for sharded pipeline run: 6.36 seconds.
# NBVAL_SKIP
print(targetdata[0,0,:].shape)
(466,)

Set initial datracube#

# NBVAL_SKIP
inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
initialdata = pipe.run_sharded(inputdata)
2025-11-10 17:12:12,913 - rubix - INFO - Setting up the pipeline...
2025-11-10 17:12:12,914 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-10 17:12:12,914 - rubix - DEBUG - Rotation Type found: edge-on
2025-11-10 17:12:12,916 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:12,929 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:12,941 - rubix - INFO - Calculating spatial bin edges...
2025-11-10 17:12:12,951 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:12,961 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:13,008 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:13,099 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:13,159 - rubix - INFO - Assembling the pipeline...
2025-11-10 17:12:13,160 - rubix - INFO - Compiling the expressions...
2025-11-10 17:12:13,164 - rubix - INFO - Number of devices: 1
2025-11-10 17:12:13,243 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-10 17:12:13,244 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-10 17:12:13,244 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-10 17:12:13,315 - rubix - INFO - Assigning particles to spaxels...
2025-11-10 17:12:13,317 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-10 17:12:13,328 - rubix - DEBUG - Datacube shape: (1, 1, 466)
2025-11-10 17:12:13,328 - rubix - INFO - Convolving with PSF...
2025-11-10 17:12:13,331 - rubix - INFO - Convolving with LSF...
2025-11-10 17:12:13,334 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
2025-11-10 17:12:19,690 - rubix - INFO - Total time for sharded pipeline run: 6.78 seconds.

Adam optimizer#

# NBVAL_SKIP
from rubix.pipeline import linear_pipeline as pipeline

pipeline_instance = RubixPipeline(config)

pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
    pipeline_instance.pipeline_config, 
    pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:19,786 - rubix - INFO - Setting up the pipeline...
2025-11-10 17:12:19,787 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}
2025-11-10 17:12:19,787 - rubix - DEBUG - Rotation Type found: edge-on
2025-11-10 17:12:19,789 - rubix - INFO - Calculating spatial bin edges...
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:19,798 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:19,808 - rubix - INFO - Calculating spatial bin edges...
2025-11-10 17:12:19,818 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:19,828 - rubix - INFO - Getting cosmology...
2025-11-10 17:12:19,871 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
2025-11-10 17:12:19,938 - rubix - DEBUG - Method not defined, using default method: cubic
/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml
  warnings.warn(
# NBVAL_SKIP
import optax

def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
        
        base_data.stars.age = age*20
        base_data.stars.metallicity = metallicity*0.05

        output = pipeline_instance.func(base_data)
        #loss = jnp.sum((output.stars.datacube - target) ** 2)
        #loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))
        #loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))
        loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
        
        return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)
#NBVAL_SKIP
import jax

def compute_gradient(age, metallicity, base_data, target):
    loss, grad_fn = jax.value_and_grad(loss_only_wrt_age_metallicity, argnums=(0,1))
    grads = grad_fn(age, metallicity, base_data, target)
    return grads, loss
#NBVAL_SKIP
#calculate gradient with jax
age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])
metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])


# Pack both initial parameters into a dictionary.
params_init = {'age': age_init, 'metallicity': metallicity_init}
print(f"Initial parameters: {params_init}")

data = inputdata
target_value = targetdata

loss, grads = jax.value_and_grad(lambda p: loss_only_wrt_age_metallicity(p['age'], p['metallicity'], data, target_value))(params_init)

print("grads:", grads)
print("loss:", loss)
2025-11-10 17:12:20,133 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0
2025-11-10 17:12:20,134 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG
2025-11-10 17:12:20,134 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.
2025-11-10 17:12:20,223 - rubix - INFO - Assigning particles to spaxels...
2025-11-10 17:12:20,239 - rubix - INFO - Calculating Data Cube (combined per‐particle)…
2025-11-10 17:12:20,375 - rubix - DEBUG - Datacube shape: (1, 1, 466)
Initial parameters: {'age': Array([0.7924467, 0.7924467], dtype=float32), 'metallicity': Array([0.5050313, 0.5050313], dtype=float32)}
2025-11-10 17:12:20,375 - rubix - INFO - Convolving with PSF...
2025-11-10 17:12:20,378 - rubix - INFO - Convolving with LSF...
2025-11-10 17:12:20,383 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal
grads: {'age': Array([5.885172, 5.885172], dtype=float32), 'metallicity': Array([0.27147812, 0.27147812], dtype=float32)}
loss: -2.057193
#NBVAL_SKIP
#calculate finite differnce
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

# 1) Skalares Loss über das ganze Param-PyTree
f = lambda p: loss_only_wrt_age_metallicity(p['age'], p['metallicity'], data, target_value)

# 2) Finite-Difference-Gradient (zentral) für beliebiges PyTree
def finite_diff_grad(f, params, eps=1e-5):
    flat, unravel = ravel_pytree(params)
    def f_flat(x): return f(unravel(x))

    def fd_i(i):
        e_i = jnp.zeros_like(flat).at[i].set(1.0)
        return (f_flat(flat + eps*e_i) - f_flat(flat - eps*e_i)) / (2*eps)

    g_flat = jax.vmap(fd_i)(jnp.arange(flat.size))
    return unravel(g_flat)

# 3) Anwenden: JAX-Grad + FD-Grad berechnen und vergleichen
grads_fd = finite_diff_grad(f, params_init, eps=1e-2)
print("grads_fd:", grads_fd)
grads_fd: {'age': Array([0.35352707, 0.35352707], dtype=float32), 'metallicity': Array([0.25906563, 0.25891066], dtype=float32)}
# NBVAL_SKIP
import matplotlib.pyplot as plt

# eps-Werte, über die wir scannen
eps_values = jnp.logspace(-6, -1, 20)  # von 1e-6 bis 1e-1

age_fd_values = []
metal_fd_values = []

for eps in eps_values:
    g_fd = finite_diff_grad(f, params_init, eps=float(eps))
    # g_fd hat die gleiche Struktur wie params_init:
    # {'age': array([..,..]), 'metallicity': array([..,..])}
    # Beispiel: nimm hier den Mittelwert pro Array
    age_fd_values.append(float(jnp.mean(g_fd['age'])))
    metal_fd_values.append(float(jnp.mean(g_fd['metallicity'])))

plt.figure(figsize=(7,5))
plt.semilogx(eps_values, age_fd_values, 'o-', label="age grad (FD)")
plt.semilogx(eps_values, metal_fd_values, 's-', label="metallicity grad (FD)")

# horizontale Linien = JAX-Gradient
plt.axhline(float(grads['age'][0]), color='C0', linestyle='--', label="age grad (JAX)")
plt.axhline(float(grads['metallicity'][0]), color='C1', linestyle='--', label="metalicity grad (JAX)")

plt.xlabel("Step size")
plt.ylabel("Derivation")
# plt.title("Gradient vs finite difference step size")
plt.legend()
plt.grid(True)
plt.savefig("output/optimisation_finite_diff.jpg", dpi=1000)
plt.show()
../_images/32129ca04e13f0a62b428f5b98dfec5949ba22f28778fc6a987c2699f079e2a2.png