{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient vs finite difference" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[CpuDevice(id=0)]\n" ] } ], "source": [ "# NBVAL_SKIP\n", "from jax import config\n", "import os\n", "import jax\n", "\n", "print(jax.devices())" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import os\n", "os.environ['SPS_HOME'] = '/home/annalena/sps_fsps'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load ssp template from FSPS" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-11-10 17:11:57,608 - rubix - INFO - \n", " ___ __ _____ _____ __\n", " / _ \\/ / / / _ )/ _/ |/_/\n", " / , _/ /_/ / _ |/ /_> <\n", "/_/|_|\\____/____/___/_/|_|\n", "\n", "\n", "2025-11-10 17:11:57,608 - rubix - INFO - Rubix version: 0.0.post626+g42b4b7505.d20251110\n", "2025-11-10 17:11:57,609 - rubix - INFO - JAX version: 0.7.2\n", "2025-11-10 17:11:57,609 - rubix - INFO - Running on [CpuDevice(id=0)] devices\n", "2025-11-10 17:11:57,609 - rubix - WARNING - python-fsps is not installed. Please install it to use this function. Install using pip install fsps and check the installation page: https://dfm.io/python-fsps/current/installation/ for more details. Especially, make sure to set all necessary environment variables.\n" ] } ], "source": [ "# NBVAL_SKIP\n", "from rubix.spectra.ssp.factory import get_ssp_template\n", "ssp_fsps = get_ssp_template(\"FSPS\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(107,)\n", "(12,)\n" ] } ], "source": [ "# NBVAL_SKIP\n", "age_values = ssp_fsps.age\n", "print(age_values.shape)\n", "\n", "metallicity_values = ssp_fsps.metallicity\n", "print(metallicity_values.shape)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "start age: 15.848933219909668, start metallicity: 0.025251565501093864\n", "target age: 3.1622776985168457, target metallicity: 0.014199999161064625\n" ] } ], "source": [ "# NBVAL_SKIP\n", "index_age = 90\n", "index_metallicity = 9\n", "\n", "#initial_metallicity_index = 5\n", "#initial_age_index = 70\n", "initial_metallicity_index = 10\n", "initial_age_index = 104\n", "\n", "learning_all = 1e-2\n", "tol = 1e-10\n", "\n", "print(f\"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}\")\n", "print(f\"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Configure pipeline" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "from rubix.core.pipeline import RubixPipeline\n", "import os\n", "config = {\n", " \"pipeline\":{\"name\": \"calc_gradient\",},\n", " \n", " \"logger\": {\n", " \"log_level\": \"DEBUG\",\n", " \"log_file_path\": None,\n", " \"format\": \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", " },\n", " \"data\": {\n", " \"name\": \"IllustrisAPI\",\n", " \"args\": {\n", " \"api_key\": os.environ.get(\"ILLUSTRIS_API_KEY\"),\n", " \"particle_type\": [\"stars\"],\n", " \"simulation\": \"TNG50-1\",\n", " \"snapshot\": 99,\n", " \"save_data_path\": \"data\",\n", " },\n", " \n", " \"load_galaxy_args\": {\n", " \"id\": 14,\n", " \"reuse\": True,\n", " },\n", " \n", " \"subset\": {\n", " \"use_subset\": True,\n", " \"subset_size\": 2,\n", " },\n", " },\n", " \"simulation\": {\n", " \"name\": \"IllustrisTNG\",\n", " \"args\": {\n", " \"path\": \"data/galaxy-id-14.hdf5\",\n", " },\n", " \n", " },\n", " \"output_path\": \"output\",\n", "\n", " \"telescope\":\n", " {\"name\": \"TESTGRADIENT\",\n", " \"psf\": {\"name\": \"gaussian\", \"size\": 5, \"sigma\": 0.6},\n", " \"lsf\": {\"sigma\": 1.2},\n", " \"noise\": {\"signal_to_noise\": 100,\"noise_distribution\": \"normal\"},\n", " },\n", " \"cosmology\":\n", " {\"name\": \"PLANCK15\"},\n", " \n", " \"galaxy\":\n", " {\"dist_z\": 0.1,\n", " \"rotation\": {\"type\": \"edge-on\"},\n", " },\n", " \n", " \"ssp\": {\n", " \"template\": {\n", " \"name\": \"FSPS\"\n", " },\n", " \"dust\": {\n", " \"extinction_model\": \"Cardelli89\",\n", " \"dust_to_gas_ratio\": 0.01,\n", " \"dust_to_metals_ratio\": 0.4,\n", " \"dust_grain_density\": 3.5,\n", " \"Rv\": 3.1,\n", " },\n", " }, \n", "}" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:11:58,256 - rubix - INFO - Getting rubix data...\n", "2025-11-10 17:11:58,257 - rubix - INFO - Rubix galaxy file already exists, skipping conversion\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/jax/_src/numpy/scalar_types.py:50: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n", " return asarray(x, dtype=self.dtype)\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/core/data.py:491: UserWarning: Explicitly requested dtype requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n", " rubixdata.galaxy.center = jnp.array(data[\"subhalo_center\"], dtype=jnp.float64)\n", "2025-11-10 17:11:58,318 - rubix - INFO - Centering stars particles\n", "2025-11-10 17:11:59,305 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars\n", "2025-11-10 17:11:59,305 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.\n", "2025-11-10 17:11:59,306 - rubix - INFO - Data preparation completed in 1.05 seconds.\n", "2025-11-10 17:11:59,306 - rubix - INFO - Setting up the pipeline...\n", "2025-11-10 17:11:59,307 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-11-10 17:11:59,307 - rubix - DEBUG - Rotation Type found: edge-on\n", "2025-11-10 17:11:59,310 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:11:59,337 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:11:59,547 - rubix - INFO - Calculating spatial bin edges...\n", "2025-11-10 17:11:59,556 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:11:59,567 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:11:59,652 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:11:59,807 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:00,094 - rubix - INFO - Assembling the pipeline...\n", "2025-11-10 17:12:00,095 - rubix - INFO - Compiling the expressions...\n", "2025-11-10 17:12:00,096 - rubix - INFO - Number of devices: 1\n", "2025-11-10 17:12:00,180 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-11-10 17:12:00,181 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG\n", "2025-11-10 17:12:00,181 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.\n", "2025-11-10 17:12:00,286 - rubix - INFO - Assigning particles to spaxels...\n", "2025-11-10 17:12:00,318 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-11-10 17:12:00,481 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-11-10 17:12:00,481 - rubix - INFO - Convolving with PSF...\n", "2025-11-10 17:12:00,486 - rubix - INFO - Convolving with LSF...\n", "2025-11-10 17:12:00,493 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-11-10 17:12:06,414 - rubix - INFO - Total time for sharded pipeline run: 7.11 seconds.\n" ] } ], "source": [ "# NBVAL_SKIP\n", "pipe = RubixPipeline(config)\n", "inputdata = pipe.prepare_data()\n", "output = pipe.run_sharded(inputdata)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set target values" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import jax.numpy as jnp\n", "\n", "inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])\n", "inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])\n", "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", "inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-11-10 17:12:06,474 - rubix - INFO - Setting up the pipeline...\n", "2025-11-10 17:12:06,475 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-11-10 17:12:06,476 - rubix - DEBUG - Rotation Type found: edge-on\n", "2025-11-10 17:12:06,479 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:06,491 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:06,502 - rubix - INFO - Calculating spatial bin edges...\n", "2025-11-10 17:12:06,624 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:06,635 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:06,681 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:06,757 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:06,804 - rubix - INFO - Assembling the pipeline...\n", "2025-11-10 17:12:06,804 - rubix - INFO - Compiling the expressions...\n", "2025-11-10 17:12:06,806 - rubix - INFO - Number of devices: 1\n", "2025-11-10 17:12:06,888 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-11-10 17:12:06,889 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG\n", "2025-11-10 17:12:06,889 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.\n", "2025-11-10 17:12:06,961 - rubix - INFO - Assigning particles to spaxels...\n", "2025-11-10 17:12:06,963 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-11-10 17:12:06,971 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-11-10 17:12:06,972 - rubix - INFO - Convolving with PSF...\n", "2025-11-10 17:12:06,974 - rubix - INFO - Convolving with LSF...\n", "2025-11-10 17:12:06,977 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-11-10 17:12:12,833 - rubix - INFO - Total time for sharded pipeline run: 6.36 seconds.\n" ] } ], "source": [ "# NBVAL_SKIP\n", "targetdata = pipe.run_sharded(inputdata)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(466,)\n" ] } ], "source": [ "# NBVAL_SKIP\n", "print(targetdata[0,0,:].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set initial datracube" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])\n", "inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])\n", "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", "inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-11-10 17:12:12,913 - rubix - INFO - Setting up the pipeline...\n", "2025-11-10 17:12:12,914 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-11-10 17:12:12,914 - rubix - DEBUG - Rotation Type found: edge-on\n", "2025-11-10 17:12:12,916 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:12,929 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:12,941 - rubix - INFO - Calculating spatial bin edges...\n", "2025-11-10 17:12:12,951 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:12,961 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:13,008 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:13,099 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:13,159 - rubix - INFO - Assembling the pipeline...\n", "2025-11-10 17:12:13,160 - rubix - INFO - Compiling the expressions...\n", "2025-11-10 17:12:13,164 - rubix - INFO - Number of devices: 1\n", "2025-11-10 17:12:13,243 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-11-10 17:12:13,244 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG\n", "2025-11-10 17:12:13,244 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.\n", "2025-11-10 17:12:13,315 - rubix - INFO - Assigning particles to spaxels...\n", "2025-11-10 17:12:13,317 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-11-10 17:12:13,328 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-11-10 17:12:13,328 - rubix - INFO - Convolving with PSF...\n", "2025-11-10 17:12:13,331 - rubix - INFO - Convolving with LSF...\n", "2025-11-10 17:12:13,334 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-11-10 17:12:19,690 - rubix - INFO - Total time for sharded pipeline run: 6.78 seconds.\n" ] } ], "source": [ "# NBVAL_SKIP\n", "initialdata = pipe.run_sharded(inputdata)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Adam optimizer" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:19,786 - rubix - INFO - Setting up the pipeline...\n", "2025-11-10 17:12:19,787 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-11-10 17:12:19,787 - rubix - DEBUG - Rotation Type found: edge-on\n", "2025-11-10 17:12:19,789 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:19,798 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:19,808 - rubix - INFO - Calculating spatial bin edges...\n", "2025-11-10 17:12:19,818 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:19,828 - rubix - INFO - Getting cosmology...\n", "2025-11-10 17:12:19,871 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-11-10 17:12:19,938 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n" ] } ], "source": [ "# NBVAL_SKIP\n", "from rubix.pipeline import linear_pipeline as pipeline\n", "\n", "pipeline_instance = RubixPipeline(config)\n", "\n", "pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(\n", " pipeline_instance.pipeline_config, \n", " pipeline_instance._get_pipeline_functions()\n", ")\n", "pipeline_instance._pipeline.assemble()\n", "pipeline_instance.func = pipeline_instance._pipeline.compile_expression()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import optax\n", "\n", "def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):\n", " \n", " base_data.stars.age = age*20\n", " base_data.stars.metallicity = metallicity*0.05\n", "\n", " output = pipeline_instance.func(base_data)\n", " #loss = jnp.sum((output.stars.datacube - target) ** 2)\n", " #loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))\n", " #loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))\n", " loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))\n", " \n", " return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "#NBVAL_SKIP\n", "import jax\n", "\n", "def compute_gradient(age, metallicity, base_data, target):\n", " loss, grad_fn = jax.value_and_grad(loss_only_wrt_age_metallicity, argnums=(0,1))\n", " grads = grad_fn(age, metallicity, base_data, target)\n", " return grads, loss" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-11-10 17:12:20,133 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-11-10 17:12:20,134 - rubix - INFO - Rotating galaxy for simulation: IllustrisTNG\n", "2025-11-10 17:12:20,134 - rubix - WARNING - Gas not found in particle_type, only rotating stellar component.\n", "2025-11-10 17:12:20,223 - rubix - INFO - Assigning particles to spaxels...\n", "2025-11-10 17:12:20,239 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-11-10 17:12:20,375 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Initial parameters: {'age': Array([0.7924467, 0.7924467], dtype=float32), 'metallicity': Array([0.5050313, 0.5050313], dtype=float32)}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-11-10 17:12:20,375 - rubix - INFO - Convolving with PSF...\n", "2025-11-10 17:12:20,378 - rubix - INFO - Convolving with LSF...\n", "2025-11-10 17:12:20,383 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "grads: {'age': Array([5.885172, 5.885172], dtype=float32), 'metallicity': Array([0.27147812, 0.27147812], dtype=float32)}\n", "loss: -2.057193\n" ] } ], "source": [ "#NBVAL_SKIP\n", "#calculate gradient with jax\n", "age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])\n", "metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])\n", "\n", "\n", "# Pack both initial parameters into a dictionary.\n", "params_init = {'age': age_init, 'metallicity': metallicity_init}\n", "print(f\"Initial parameters: {params_init}\")\n", "\n", "data = inputdata\n", "target_value = targetdata\n", "\n", "loss, grads = jax.value_and_grad(lambda p: loss_only_wrt_age_metallicity(p['age'], p['metallicity'], data, target_value))(params_init)\n", "\n", "print(\"grads:\", grads)\n", "print(\"loss:\", loss)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "grads_fd: {'age': Array([0.35352707, 0.35352707], dtype=float32), 'metallicity': Array([0.25906563, 0.25891066], dtype=float32)}\n" ] } ], "source": [ "#NBVAL_SKIP\n", "#calculate finite differnce\n", "import jax\n", "import jax.numpy as jnp\n", "from jax.flatten_util import ravel_pytree\n", "\n", "# 1) Skalares Loss über das ganze Param-PyTree\n", "f = lambda p: loss_only_wrt_age_metallicity(p['age'], p['metallicity'], data, target_value)\n", "\n", "# 2) Finite-Difference-Gradient (zentral) für beliebiges PyTree\n", "def finite_diff_grad(f, params, eps=1e-5):\n", " flat, unravel = ravel_pytree(params)\n", " def f_flat(x): return f(unravel(x))\n", "\n", " def fd_i(i):\n", " e_i = jnp.zeros_like(flat).at[i].set(1.0)\n", " return (f_flat(flat + eps*e_i) - f_flat(flat - eps*e_i)) / (2*eps)\n", "\n", " g_flat = jax.vmap(fd_i)(jnp.arange(flat.size))\n", " return unravel(g_flat)\n", "\n", "# 3) Anwenden: JAX-Grad + FD-Grad berechnen und vergleichen\n", "grads_fd = finite_diff_grad(f, params_init, eps=1e-2)\n", "print(\"grads_fd:\", grads_fd)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# NBVAL_SKIP\n", "import matplotlib.pyplot as plt\n", "\n", "# eps-Werte, über die wir scannen\n", "eps_values = jnp.logspace(-6, -1, 20) # von 1e-6 bis 1e-1\n", "\n", "age_fd_values = []\n", "metal_fd_values = []\n", "\n", "for eps in eps_values:\n", " g_fd = finite_diff_grad(f, params_init, eps=float(eps))\n", " # g_fd hat die gleiche Struktur wie params_init:\n", " # {'age': array([..,..]), 'metallicity': array([..,..])}\n", " # Beispiel: nimm hier den Mittelwert pro Array\n", " age_fd_values.append(float(jnp.mean(g_fd['age'])))\n", " metal_fd_values.append(float(jnp.mean(g_fd['metallicity'])))\n", "\n", "plt.figure(figsize=(7,5))\n", "plt.semilogx(eps_values, age_fd_values, 'o-', label=\"age grad (FD)\")\n", "plt.semilogx(eps_values, metal_fd_values, 's-', label=\"metallicity grad (FD)\")\n", "\n", "# horizontale Linien = JAX-Gradient\n", "plt.axhline(float(grads['age'][0]), color='C0', linestyle='--', label=\"age grad (JAX)\")\n", "plt.axhline(float(grads['metallicity'][0]), color='C1', linestyle='--', label=\"metalicity grad (JAX)\")\n", "\n", "plt.xlabel(\"Step size\")\n", "plt.ylabel(\"Derivation\")\n", "# plt.title(\"Gradient vs finite difference step size\")\n", "plt.legend()\n", "plt.grid(True)\n", "plt.savefig(\"output/optimisation_finite_diff.jpg\", dpi=1000)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "rubix", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 2 }