Gradient vs finite difference#
# NBVAL_SKIP
from jax import config
import os
import jax
print(jax.devices())
# NBVAL_SKIP
import os
os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'
Load ssp template from FSPS#
# NBVAL_SKIP
from rubix.spectra.ssp.factory import get_ssp_template
ssp_fsps = get_ssp_template("FSPS")
# NBVAL_SKIP
age_values = ssp_fsps.age
print(age_values.shape)
metallicity_values = ssp_fsps.metallicity
print(metallicity_values.shape)
# NBVAL_SKIP
index_age = 90
index_metallicity = 9
#initial_metallicity_index = 5
#initial_age_index = 70
initial_metallicity_index = 10
initial_age_index = 104
learning_all = 1e-2
tol = 1e-10
print(f"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}")
print(f"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}")
Configure pipeline#
# NBVAL_SKIP
from rubix.core.pipeline import RubixPipeline
import os
config = {
"pipeline":{"name": "calc_gradient",},
"logger": {
"log_level": "DEBUG",
"log_file_path": None,
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
"data": {
"name": "IllustrisAPI",
"args": {
"api_key": os.environ.get("ILLUSTRIS_API_KEY"),
"particle_type": ["stars"],
"simulation": "TNG50-1",
"snapshot": 99,
"save_data_path": "data",
},
"load_galaxy_args": {
"id": 14,
"reuse": True,
},
"subset": {
"use_subset": True,
"subset_size": 2,
},
},
"simulation": {
"name": "IllustrisTNG",
"args": {
"path": "data/galaxy-id-14.hdf5",
},
},
"output_path": "output",
"telescope":
{"name": "TESTGRADIENT",
"psf": {"name": "gaussian", "size": 5, "sigma": 0.6},
"lsf": {"sigma": 1.2},
"noise": {"signal_to_noise": 100,"noise_distribution": "normal"},
},
"cosmology":
{"name": "PLANCK15"},
"galaxy":
{"dist_z": 0.1,
"rotation": {"type": "edge-on"},
},
"ssp": {
"template": {
"name": "FSPS"
},
"dust": {
"extinction_model": "Cardelli89",
"dust_to_gas_ratio": 0.01,
"dust_to_metals_ratio": 0.4,
"dust_grain_density": 3.5,
"Rv": 3.1,
},
},
}
# NBVAL_SKIP
pipe = RubixPipeline(config)
inputdata = pipe.prepare_data()
output = pipe.run_sharded(inputdata)
Set target values#
# NBVAL_SKIP
import jax.numpy as jnp
inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])
inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
targetdata = pipe.run_sharded(inputdata)
# NBVAL_SKIP
print(targetdata[0,0,:].shape)
Set initial datracube#
# NBVAL_SKIP
inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])
inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])
inputdata.stars.mass = jnp.array([[1.0], [1.0]])
inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
# NBVAL_SKIP
initialdata = pipe.run_sharded(inputdata)
Adam optimizer#
# NBVAL_SKIP
from rubix.pipeline import linear_pipeline as pipeline
pipeline_instance = RubixPipeline(config)
pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(
pipeline_instance.pipeline_config,
pipeline_instance._get_pipeline_functions()
)
pipeline_instance._pipeline.assemble()
pipeline_instance.func = pipeline_instance._pipeline.compile_expression()
# NBVAL_SKIP
import optax
def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):
base_data.stars.age = age*20
base_data.stars.metallicity = metallicity*0.05
output = pipeline_instance.func(base_data)
#loss = jnp.sum((output.stars.datacube - target) ** 2)
#loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))
#loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))
loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))
return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)
#NBVAL_SKIP
import jax
def compute_gradient(age, metallicity, base_data, target):
loss, grad_fn = jax.value_and_grad(loss_only_wrt_age_metallicity, argnums=(0,1))
grads = grad_fn(age, metallicity, base_data, target)
return grads, loss
#NBVAL_SKIP
#calculate gradient with jax
age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])
metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])
# Pack both initial parameters into a dictionary.
params_init = {'age': age_init, 'metallicity': metallicity_init}
print(f"Initial parameters: {params_init}")
data = inputdata
target_value = targetdata
loss, grads = jax.value_and_grad(lambda p: loss_only_wrt_age_metallicity(p['age'], p['metallicity'], data, target_value))(params_init)
print("grads:", grads)
print("loss:", loss)
#NBVAL_SKIP
#calculate finite differnce
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
# 1) Skalares Loss über das ganze Param-PyTree
f = lambda p: loss_only_wrt_age_metallicity(p['age'], p['metallicity'], data, target_value)
# 2) Finite-Difference-Gradient (zentral) für beliebiges PyTree
def finite_diff_grad(f, params, eps=1e-5):
flat, unravel = ravel_pytree(params)
def f_flat(x): return f(unravel(x))
def fd_i(i):
e_i = jnp.zeros_like(flat).at[i].set(1.0)
return (f_flat(flat + eps*e_i) - f_flat(flat - eps*e_i)) / (2*eps)
g_flat = jax.vmap(fd_i)(jnp.arange(flat.size))
return unravel(g_flat)
# 3) Anwenden: JAX-Grad + FD-Grad berechnen und vergleichen
grads_fd = finite_diff_grad(f, params_init, eps=1e-2)
print("grads_fd:", grads_fd)
# NBVAL_SKIP
import matplotlib.pyplot as plt
# eps-Werte, über die wir scannen
eps_values = jnp.logspace(-6, -1, 20) # von 1e-6 bis 1e-1
age_fd_values = []
metal_fd_values = []
for eps in eps_values:
g_fd = finite_diff_grad(f, params_init, eps=float(eps))
# g_fd hat die gleiche Struktur wie params_init:
# {'age': array([..,..]), 'metallicity': array([..,..])}
# Beispiel: nimm hier den Mittelwert pro Array
age_fd_values.append(float(jnp.mean(g_fd['age'])))
metal_fd_values.append(float(jnp.mean(g_fd['metallicity'])))
plt.figure(figsize=(7,5))
plt.semilogx(eps_values, age_fd_values, 'o-', label="age grad (FD)")
plt.semilogx(eps_values, metal_fd_values, 's-', label="metallicity grad (FD)")
# horizontale Linien = JAX-Gradient
plt.axhline(float(grads['age'][0]), color='C0', linestyle='--', label="age grad (JAX)")
plt.axhline(float(grads['metallicity'][0]), color='C1', linestyle='--', label="metalicity grad (JAX)")
plt.xlabel("Step size")
plt.ylabel("Derivation")
# plt.title("Gradient vs finite difference step size")
plt.legend()
plt.grid(True)
plt.savefig("output/optimisation_finite_diff.jpg", dpi=1000)
plt.show()