{ "cells": [ { "cell_type": "markdown", "id": "3350bd62", "metadata": {}, "source": [ "# Gradient through RUBIX" ] }, { "cell_type": "code", "execution_count": null, "id": "b108191d", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "from jax import config\n", "#config.update(\"jax_enable_x64\", True)\n", "config.update('jax_num_cpu_devices', 2)" ] }, { "cell_type": "code", "execution_count": 2, "id": "7a728a04", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[CpuDevice(id=0), CpuDevice(id=1)]\n" ] } ], "source": [ "#NBVAL_SKIP\n", "import os\n", "\n", "# Only make GPU 0 and GPU 1 visible to JAX:\n", "#os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'\n", "\n", "import jax\n", "\n", "# Now JAX will list two CpuDevice entries\n", "print(jax.devices())\n", "# → [CpuDevice(id=0), CpuDevice(id=1)]" ] }, { "cell_type": "code", "execution_count": 3, "id": "839c9ebd", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import os\n", "os.environ['SPS_HOME'] = '/home/annalena_data/sps_fsps'" ] }, { "cell_type": "markdown", "id": "dcb32323", "metadata": {}, "source": [ "# Load ssp template from FSPS" ] }, { "cell_type": "code", "execution_count": 4, "id": "ef5f7772", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-07-01 14:19:12,434 - rubix - INFO - \n", " ___ __ _____ _____ __\n", " / _ \\/ / / / _ )/ _/ |/_/\n", " / , _/ /_/ / _ |/ /_> <\n", "/_/|_|\\____/____/___/_/|_|\n", "\n", "\n", "2025-07-01 14:19:12,435 - rubix - INFO - Rubix version: 0.0.post465+g01a25a7.d20250701\n", "2025-07-01 14:19:12,435 - rubix - INFO - JAX version: 0.6.0\n", "2025-07-01 14:19:12,436 - rubix - INFO - Running on [CpuDevice(id=0), CpuDevice(id=1)] devices\n" ] } ], "source": [ "# NBVAL_SKIP\n", "from rubix.spectra.ssp.factory import get_ssp_template\n", "ssp_fsps = get_ssp_template(\"FSPS\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "eb0688b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(107,)\n", "(12,)\n" ] } ], "source": [ "# NBVAL_SKIP\n", "age_values = ssp_fsps.age\n", "print(age_values.shape)\n", "\n", "metallicity_values = ssp_fsps.metallicity\n", "print(metallicity_values.shape)" ] }, { "cell_type": "code", "execution_count": 6, "id": "9fbf0de6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "start age: 15.848933219909668, start metallicity: 0.025251565501093864\n", "target age: 3.1622776985168457, target metallicity: 0.014199999161064625\n" ] } ], "source": [ "# NBVAL_SKIP\n", "index_age = 90\n", "index_metallicity = 9\n", "\n", "#initial_metallicity_index = 5\n", "#initial_age_index = 70\n", "initial_metallicity_index = 10\n", "initial_age_index = 104\n", "\n", "learning_all = 5e-2\n", "tol = 1e-10\n", "\n", "print(f\"start age: {age_values[initial_age_index]}, start metallicity: {metallicity_values[initial_metallicity_index]}\")\n", "print(f\"target age: {age_values[index_age]}, target metallicity: {metallicity_values[index_metallicity]}\")" ] }, { "cell_type": "markdown", "id": "f9b479b9", "metadata": {}, "source": [ "# Configure pipeline" ] }, { "cell_type": "code", "execution_count": 7, "id": "b01e145b", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "from rubix.core.pipeline import RubixPipeline\n", "import os\n", "config = {\n", " \"pipeline\":{\"name\": \"calc_gradient\",},\n", " \n", " \"logger\": {\n", " \"log_level\": \"DEBUG\",\n", " \"log_file_path\": None,\n", " \"format\": \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\",\n", " },\n", " \"data\": {\n", " \"name\": \"IllustrisAPI\",\n", " \"args\": {\n", " \"api_key\": os.environ.get(\"ILLUSTRIS_API_KEY\"),\n", " \"particle_type\": [\"stars\"],\n", " \"simulation\": \"TNG50-1\",\n", " \"snapshot\": 99,\n", " \"save_data_path\": \"data\",\n", " },\n", " \n", " \"load_galaxy_args\": {\n", " \"id\": 14,\n", " \"reuse\": True,\n", " },\n", " \n", " \"subset\": {\n", " \"use_subset\": True,\n", " \"subset_size\": 2,\n", " },\n", " },\n", " \"simulation\": {\n", " \"name\": \"IllustrisTNG\",\n", " \"args\": {\n", " \"path\": \"data/galaxy-id-14.hdf5\",\n", " },\n", " \n", " },\n", " \"output_path\": \"output\",\n", "\n", " \"telescope\":\n", " {\"name\": \"TESTGRADIENT\",\n", " \"psf\": {\"name\": \"gaussian\", \"size\": 5, \"sigma\": 0.6},\n", " \"lsf\": {\"sigma\": 1.2},\n", " \"noise\": {\"signal_to_noise\": 100,\"noise_distribution\": \"normal\"},\n", " },\n", " \"cosmology\":\n", " {\"name\": \"PLANCK15\"},\n", " \n", " \"galaxy\":\n", " {\"dist_z\": 0.1,\n", " \"rotation\": {\"type\": \"edge-on\"},\n", " },\n", " \n", " \"ssp\": {\n", " \"template\": {\n", " \"name\": \"FSPS\"\n", " },\n", " \"dust\": {\n", " \"extinction_model\": \"Cardelli89\",\n", " \"dust_to_gas_ratio\": 0.01,\n", " \"dust_to_metals_ratio\": 0.4,\n", " \"dust_grain_density\": 3.5,\n", " \"Rv\": 3.1,\n", " },\n", " }, \n", "}" ] }, { "cell_type": "code", "execution_count": 8, "id": "d084be52", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:13,008 - rubix - INFO - Getting rubix data...\n", "2025-07-01 14:19:13,009 - rubix - INFO - Rubix galaxy file already exists, skipping conversion\n", "2025-07-01 14:19:13,045 - rubix - INFO - Centering stars particles\n", "2025-07-01 14:19:13,713 - rubix - WARNING - The Subset value is set in config. Using only subset of size 2 for stars\n", "2025-07-01 14:19:13,714 - rubix - INFO - Data loaded with 2 star particles and 0 gas particles.\n", "2025-07-01 14:19:13,715 - rubix - INFO - Setting up the pipeline...\n", "2025-07-01 14:19:13,715 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-07-01 14:19:13,715 - rubix - DEBUG - Roataion Type found: edge-on\n", "2025-07-01 14:19:13,718 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:13,741 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:13,902 - rubix - INFO - Calculating spatial bin edges...\n", "2025-07-01 14:19:13,911 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:13,951 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:14,038 - rubix - DEBUG - SSP Wave: (5994,)\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:14,057 - rubix - INFO - Getting cosmology...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:14,145 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:14,403 - rubix - INFO - Assembling the pipeline...\n", "2025-07-01 14:19:14,403 - rubix - INFO - Compiling the expressions...\n", "2025-07-01 14:19:14,404 - rubix - INFO - Number of devices: 2\n", "2025-07-01 14:19:14,508 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:19:14,613 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:19:14,629 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:19:14,772 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:19:14,772 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:19:14,776 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:19:14,781 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-07-01 14:19:22,678 - rubix - INFO - Pipeline run completed in 8.96 seconds.\n" ] } ], "source": [ "# NBVAL_SKIP\n", "pipe = RubixPipeline(config)\n", "inputdata = pipe.prepare_data()\n", "output = pipe.run_sharded(inputdata)" ] }, { "cell_type": "markdown", "id": "cfb4c047", "metadata": {}, "source": [ "# Set target values" ] }, { "cell_type": "code", "execution_count": 9, "id": "ea23c03b", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import jax.numpy as jnp\n", "\n", "inputdata.stars.age = jnp.array([age_values[index_age], age_values[index_age]])\n", "inputdata.stars.metallicity = jnp.array([metallicity_values[index_metallicity], metallicity_values[index_metallicity]])\n", "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", "inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])" ] }, { "cell_type": "code", "execution_count": 10, "id": "2e911a08", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-07-01 14:19:22,773 - rubix - INFO - Setting up the pipeline...\n", "2025-07-01 14:19:22,774 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-07-01 14:19:22,775 - rubix - DEBUG - Roataion Type found: edge-on\n", "2025-07-01 14:19:22,776 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:22,787 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:22,797 - rubix - INFO - Calculating spatial bin edges...\n", "2025-07-01 14:19:22,807 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:22,827 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:22,855 - rubix - DEBUG - SSP Wave: (5994,)\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:22,867 - rubix - INFO - Getting cosmology...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:22,907 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:22,939 - rubix - INFO - Assembling the pipeline...\n", "2025-07-01 14:19:22,940 - rubix - INFO - Compiling the expressions...\n", "2025-07-01 14:19:22,941 - rubix - INFO - Number of devices: 2\n", "2025-07-01 14:19:23,052 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:19:23,131 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:19:23,134 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:19:23,303 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:19:23,304 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:19:23,307 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:19:23,310 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-07-01 14:19:31,497 - rubix - INFO - Pipeline run completed in 8.72 seconds.\n" ] } ], "source": [ "# NBVAL_SKIP\n", "targetdata = pipe.run_sharded(inputdata)" ] }, { "cell_type": "code", "execution_count": 11, "id": "1fd0e182", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(466,)\n" ] } ], "source": [ "# NBVAL_SKIP\n", "print(targetdata[0,0,:].shape)" ] }, { "cell_type": "markdown", "id": "801a48a3", "metadata": {}, "source": [ "# Set initial cube" ] }, { "cell_type": "code", "execution_count": 12, "id": "e2cfc7ca", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "inputdata.stars.age = jnp.array([age_values[initial_age_index], age_values[initial_age_index]])\n", "inputdata.stars.metallicity = jnp.array([metallicity_values[initial_metallicity_index], metallicity_values[initial_metallicity_index]])\n", "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", "inputdata.stars.coords = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])" ] }, { "cell_type": "code", "execution_count": 13, "id": "ef032a2f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-07-01 14:19:31,563 - rubix - INFO - Setting up the pipeline...\n", "2025-07-01 14:19:31,564 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-07-01 14:19:31,565 - rubix - DEBUG - Roataion Type found: edge-on\n", "2025-07-01 14:19:31,566 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:31,577 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:31,587 - rubix - INFO - Calculating spatial bin edges...\n", "2025-07-01 14:19:31,596 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:31,610 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:31,638 - rubix - DEBUG - SSP Wave: (5994,)\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:31,650 - rubix - INFO - Getting cosmology...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:31,684 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:31,716 - rubix - INFO - Assembling the pipeline...\n", "2025-07-01 14:19:31,716 - rubix - INFO - Compiling the expressions...\n", "2025-07-01 14:19:31,717 - rubix - INFO - Number of devices: 2\n", "2025-07-01 14:19:31,804 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:19:31,882 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:19:31,884 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:19:31,892 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:19:31,892 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:19:31,895 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:19:31,898 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-07-01 14:19:38.155820: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:\n", "\n", " %gather.185 = f32[105,1,12,5994]{3,2,1,0} gather(%constant.3104, %iota.28), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,12,5994}, metadata={op_name=\"jit()/jit(main)/jit(shmap_body)/jit()/while/body/jit(interp2d)/jit(fun)/jit(_take)/gather\" source_file=\"/tmp/ipykernel_1868132/3725520555.py\" source_line=4}\n", "\n", "This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.\n", "\n", "If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.\n", "2025-07-01 14:19:38.206841: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.051189572s\n", "Constant folding an instruction is taking > 1s:\n", "\n", " %gather.185 = f32[105,1,12,5994]{3,2,1,0} gather(%constant.3104, %iota.28), offset_dims={1,2,3}, collapsed_slice_dims={}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,12,5994}, metadata={op_name=\"jit()/jit(main)/jit(shmap_body)/jit()/while/body/jit(interp2d)/jit(fun)/jit(_take)/gather\" source_file=\"/tmp/ipykernel_1868132/3725520555.py\" source_line=4}\n", "\n", "This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.\n", "\n", "If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.\n", "2025-07-01 14:19:39,618 - rubix - INFO - Pipeline run completed in 8.05 seconds.\n" ] } ], "source": [ "# NBVAL_SKIP\n", "initialdata = pipe.run_sharded(inputdata)" ] }, { "cell_type": "markdown", "id": "089d273f", "metadata": {}, "source": [ "# Adam optimizer" ] }, { "cell_type": "code", "execution_count": 14, "id": "af72af79", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:39,666 - rubix - INFO - Setting up the pipeline...\n", "2025-07-01 14:19:39,666 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-07-01 14:19:39,667 - rubix - DEBUG - Roataion Type found: edge-on\n", "2025-07-01 14:19:39,669 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:39,679 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:39,689 - rubix - INFO - Calculating spatial bin edges...\n", "2025-07-01 14:19:39,699 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:19:39,726 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:39,768 - rubix - DEBUG - SSP Wave: (5994,)\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:39,780 - rubix - INFO - Getting cosmology...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:19:39,826 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n" ] } ], "source": [ "# NBVAL_SKIP\n", "from rubix.pipeline import linear_pipeline as pipeline\n", "\n", "pipeline_instance = RubixPipeline(config)\n", "\n", "pipeline_instance._pipeline = pipeline.LinearTransformerPipeline(\n", " pipeline_instance.pipeline_config, \n", " pipeline_instance._get_pipeline_functions()\n", ")\n", "pipeline_instance._pipeline.assemble()\n", "pipeline_instance.func = pipeline_instance._pipeline.compile_expression()" ] }, { "cell_type": "code", "execution_count": 15, "id": "18641767", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import optax\n", "\n", "def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):\n", " \n", " base_data.stars.age = age*20\n", " base_data.stars.metallicity = metallicity*0.05\n", "\n", " output = pipeline_instance.func(base_data)\n", " #loss = jnp.sum((output.stars.datacube - target) ** 2)\n", " #loss = jnp.sum(optax.l2_loss(output.stars.datacube, target.stars.datacube))\n", " #loss = jnp.sum(optax.huber_loss(output.stars.datacube, target.stars.datacube))\n", " loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))\n", " \n", " return jnp.log10(loss) #loss#/0.03 #jnp.log10(loss #/5e-5)" ] }, { "cell_type": "code", "execution_count": 16, "id": "a624d440", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "\n", "\n", "def adam_optimization_multi(loss_fn, params_init, data, target, learning=learning_all, tol=1e-3, max_iter=500):\n", " \"\"\"\n", " Optimizes both age and metallicity.\n", "\n", " Args:\n", " loss_fn: function with signature loss_fn(age, metallicity, data, target)\n", " params_init: dict with keys 'age' and 'metallicity', each a JAX array\n", " data: base data for the loss function\n", " target: target data for the loss function\n", " learning_rate: learning rate for Adam\n", " tol: tolerance for convergence (based on update norm)\n", " max_iter: maximum number of iterations\n", "\n", " Returns:\n", " params: final parameters (dict)\n", " params_history: list of parameter values for each iteration\n", " loss_history: list of loss values for each iteration\n", " \"\"\"\n", " params = params_init # e.g., {'age': jnp.array(...), 'metallicity': jnp.array(...)}\n", " optimizers = {\n", " 'age': optax.adam(learning),\n", " 'metallicity': optax.adam(learning)\n", " }\n", " # Create a parameter label pytree matching the structure of params\n", " param_labels = {'age': 'age', 'metallicity': 'metallicity'}\n", " \n", " # Combine the optimizers with multi_transform\n", " optimizer = optax.multi_transform(optimizers, param_labels)\n", " optimizer_state = optimizer.init(params)\n", " \n", " age_history = []\n", " metallicity_history = []\n", " loss_history = []\n", " \n", " for i in range(max_iter):\n", " # Compute loss and gradients with respect to both parameters\n", " loss, grads = jax.value_and_grad(lambda p: loss_fn(p['age'], p['metallicity'], data, target))(params)\n", " loss_history.append(float(loss))\n", " # Save current parameters (convert from JAX arrays to floats)\n", " age_history.append(float(params['age'][0]))\n", " metallicity_history.append(float(params['metallicity'][0]))\n", " #params_history.append({\n", " # 'age': params['age'],\n", " # 'metallicity': params['metallicity']\n", " #})\n", " \n", " # Compute updates and apply them\n", " updates, optimizer_state = optimizer.update(grads, optimizer_state)\n", " params = optax.apply_updates(params, updates)\n", " \n", " # Optionally clip the parameters to enforce physical constraints:\n", " #params['age'] = jnp.clip(params['age'], 0.0, 1.0)\n", " #params['metallicity'] = jnp.clip(params['metallicity'], 0.0, 1.0)\n", " # For metallicity, uncomment and adjust the limits as needed:\n", " # params['metallicity'] = jnp.clip(params['metallicity'], metallicity_lower_bound, metallicity_upper_bound)\n", " \n", " # Check convergence based on the global norm of updates\n", " if optax.global_norm(updates) < tol:\n", " print(f\"Converged at iteration {i}\")\n", " break\n", "\n", " return params, age_history, metallicity_history, loss_history" ] }, { "cell_type": "code", "execution_count": 17, "id": "8bef00b9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-07-01 14:19:39,970 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:19:40,063 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:19:40,081 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:19:40,253 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:19:40,254 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:19:40,257 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:19:40,263 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n" ] }, { "data": { "text/plain": [ "Array(nan, dtype=float64)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# NBVAL_SKIP\n", "loss_only_wrt_age_metallicity(inputdata.stars.age, inputdata.stars.metallicity, inputdata, targetdata)" ] }, { "cell_type": "code", "execution_count": 18, "id": "dd3467df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial parameters: {'age': Array([0.7924467, 0.7924467], dtype=float32), 'metallicity': Array([0.5050313, 0.5050313], dtype=float32)}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-07-01 14:19:58,613 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:19:58,727 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:19:58,729 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:19:58,834 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:19:58,835 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:19:58,838 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:19:58,842 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Optimized Age: [nan nan]\n", "Optimized Metallicity: [nan nan]\n" ] } ], "source": [ "# NBVAL_SKIP\n", "data = inputdata # Replace with your actual data if needed\n", "target_value = targetdata # Replace with your actual target\n", "\n", "# Define initial guesses for both age and metallicity.\n", "# Adjust the initialization as needed for your problem.\n", "age_init = jnp.array([age_values[initial_age_index]/20, age_values[initial_age_index]/20])\n", "metallicity_init = jnp.array([metallicity_values[initial_metallicity_index]/0.05, metallicity_values[initial_metallicity_index]/0.05])\n", "\n", "\n", "# Pack both initial parameters into a dictionary.\n", "params_init = {'age': age_init, 'metallicity': metallicity_init}\n", "print(f\"Initial parameters: {params_init}\")\n", "\n", "# Call the new optimizer function that handles both parameters.\n", "optimized_params, age_history, metallicity_history, loss_history = adam_optimization_multi(\n", " loss_only_wrt_age_metallicity,\n", " params_init,\n", " data,\n", " target_value,\n", " learning=learning_all,\n", " tol=tol,\n", " max_iter=5000,\n", ")\n", "\n", "print(f\"Optimized Age: {optimized_params['age']}\")\n", "print(f\"Optimized Metallicity: {optimized_params['metallicity']}\")\n" ] }, { "cell_type": "markdown", "id": "1ca494b4", "metadata": {}, "source": [ "# Loss history" ] }, { "cell_type": "code", "execution_count": 19, "id": "aecb28f6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of iterations: 5000\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# NBVAL_SKIP\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# Convert histories to NumPy arrays if needed\n", "loss_history_np = np.array(loss_history)\n", "age_history_np = np.array(age_history)\n", "metallicity_history_np = np.array(metallicity_history)\n", "\n", "# Create an x-axis based on the number of iterations (assumed same for all)\n", "iterations = np.arange(len(loss_history_np))\n", "print(f\"Number of iterations: {len(iterations)}\")\n", "\n", "# Create a figure with three subplots in one row and shared x-axis.\n", "fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharex=True)\n", "\n", "# Plot the loss history (convert log-loss back to loss if needed)\n", "axs[0].plot(iterations, 10**loss_history_np, marker='o', linestyle='-')\n", "axs[0].set_xlabel(\"Iteration\")\n", "axs[0].set_ylabel(\"Loss\")\n", "axs[0].set_title(\"Loss History\")\n", "axs[0].grid(True)\n", "\n", "# Plot the age history, multiplying by 20 for the physical scale.\n", "axs[1].plot(iterations, age_history_np * 20, marker='o', linestyle='-')\n", "# Draw a horizontal line for the target age\n", "axs[1].hlines(y=age_values[index_age], xmin=0, xmax=iterations[-1], color='r', linestyle='-')\n", "axs[1].set_xlabel(\"Iteration\")\n", "axs[1].set_ylabel(\"Age\")\n", "axs[1].set_title(\"Age History\")\n", "axs[1].grid(True)\n", "\n", "# Plot the metallicity history, multiplying by 0.05 for the physical scale.\n", "axs[2].plot(iterations, metallicity_history_np *0.05, marker='o', linestyle='-')\n", "# Draw a horizontal line for the target metallicity\n", "axs[2].hlines(y=metallicity_values[index_metallicity], xmin=0, xmax=iterations[-1], color='r', linestyle='-')\n", "axs[2].set_xlabel(\"Iteration\")\n", "axs[2].set_ylabel(\"Metallicity\")\n", "axs[2].set_title(\"Metallicity History\")\n", "axs[2].grid(True)\n", "\n", "axs[0].set_xlim(-5, 900)\n", "axs[1].set_xlim(-5, 900)\n", "axs[2].set_xlim(-5, 900)\n", "plt.tight_layout()\n", "plt.savefig(f\"output/optimisation_history.jpg\", dpi=1000)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 20, "id": "25aeafb6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:21:55,601 - rubix - INFO - Setting up the pipeline...\n", "2025-07-01 14:21:55,602 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-07-01 14:21:55,603 - rubix - DEBUG - Roataion Type found: edge-on\n", "2025-07-01 14:21:55,605 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:21:55,615 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:21:55,631 - rubix - INFO - Calculating spatial bin edges...\n", "2025-07-01 14:21:55,640 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:21:55,659 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:21:55,691 - rubix - DEBUG - SSP Wave: (5994,)\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:21:55,704 - rubix - INFO - Getting cosmology...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:21:55,743 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:21:55,793 - rubix - INFO - Assembling the pipeline...\n", "2025-07-01 14:21:55,794 - rubix - INFO - Compiling the expressions...\n", "2025-07-01 14:21:55,794 - rubix - INFO - Number of devices: 2\n", "2025-07-01 14:21:55,905 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:21:55,985 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:21:55,987 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:21:55,998 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:21:55,999 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:21:56,001 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:21:56,005 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-07-01 14:22:03,786 - rubix - INFO - Pipeline run completed in 8.18 seconds.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(1, 1, 466)\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# NBVAL_SKIP\n", "#run the pipeline with the optimized age\n", "#rubixdata.stars.age = optimized_age\n", "i = 0\n", "inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])\n", "inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])\n", "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", "\n", "pipe = RubixPipeline(config)\n", "rubixdata = pipe.run_sharded(inputdata)\n", "\n", "#plot the target and the optimized spectra\n", "import matplotlib.pyplot as plt\n", "wave = pipe.telescope.wave_seq\n", "\n", "spectra_target = targetdata\n", "spectra_optimitzed = rubixdata\n", "print(rubixdata.shape)\n", "\n", "\n", "plt.plot(wave, spectra_target[0,0,:], label=f\"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}\")\n", "plt.plot(wave, spectra_optimitzed[0,0,:], label=f\"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}\")\n", "plt.xlabel(\"Wavelength [Å]\")\n", "plt.ylabel(\"Luminosity [L/Å]\")\n", "plt.title(\"Difference between target and optimized spectra\")\n", "#plt.title(f\"Loss {loss_history[i]:.2e}\")\n", "plt.legend()\n", "#plt.ylim(0.00003, 0.00008)\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "id": "e54c9842", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:22:03,947 - rubix - INFO - Setting up the pipeline...\n", "2025-07-01 14:22:03,948 - rubix - DEBUG - Pipeline Configuration: {'Transformers': {'rotate_galaxy': {'name': 'rotate_galaxy', 'depends_on': None, 'args': [], 'kwargs': {}}, 'spaxel_assignment': {'name': 'spaxel_assignment', 'depends_on': 'rotate_galaxy', 'args': [], 'kwargs': {}}, 'calculate_datacube_particlewise': {'name': 'calculate_datacube_particlewise', 'depends_on': 'spaxel_assignment', 'args': [], 'kwargs': {}}, 'convolve_psf': {'name': 'convolve_psf', 'depends_on': 'calculate_datacube_particlewise', 'args': [], 'kwargs': {}}, 'convolve_lsf': {'name': 'convolve_lsf', 'depends_on': 'convolve_psf', 'args': [], 'kwargs': {}}, 'apply_noise': {'name': 'apply_noise', 'depends_on': 'convolve_lsf', 'args': [], 'kwargs': {}}}}\n", "2025-07-01 14:22:03,948 - rubix - DEBUG - Roataion Type found: edge-on\n", "2025-07-01 14:22:03,950 - rubix - INFO - Calculating spatial bin edges...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:22:03,961 - rubix - INFO - Getting cosmology...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-07-01 14:22:03,971 - rubix - INFO - Calculating spatial bin edges...\n", "2025-07-01 14:22:03,980 - rubix - INFO - Getting cosmology...\n", "2025-07-01 14:22:03,999 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:22:04,034 - rubix - DEBUG - SSP Wave: (5994,)\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:22:04,045 - rubix - INFO - Getting cosmology...\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:22:04,083 - rubix - DEBUG - Method not defined, using default method: cubic\n", "/home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/factory.py:26: UserWarning: No telescope config provided, using default stored in /home/annalena/.conda/envs/rubix/lib/python3.12/site-packages/rubix/telescope/telescopes.yaml\n", " warnings.warn(\n", "2025-07-01 14:22:04,125 - rubix - INFO - Assembling the pipeline...\n", "2025-07-01 14:22:04,126 - rubix - INFO - Compiling the expressions...\n", "2025-07-01 14:22:04,127 - rubix - INFO - Number of devices: 2\n", "2025-07-01 14:22:04,228 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:22:04,308 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:22:04,310 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:22:04,321 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:22:04,322 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:22:04,324 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:22:04,327 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n", "2025-07-01 14:22:11,227 - rubix - INFO - Pipeline run completed in 7.28 seconds.\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# NBVAL_SKIP\n", "#run the pipeline with the optimized age\n", "#rubixdata.stars.age = optimized_age\n", "i = 850\n", "inputdata.stars.age = jnp.array([age_history[i]*20, age_history[i]*20])\n", "inputdata.stars.metallicity = jnp.array([metallicity_history[i]*0.05, metallicity_history[i]*0.05])\n", "inputdata.stars.mass = jnp.array([[1.0], [1.0]])\n", "inputdata.stars.velocity = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])\n", "\n", "pipe = RubixPipeline(config)\n", "rubixdata = pipe.run_sharded(inputdata)\n", "\n", "#plot the target and the optimized spectra\n", "import matplotlib.pyplot as plt\n", "wave = pipe.telescope.wave_seq\n", "\n", "spectra_target = targetdata #.stars.datacube\n", "spectra_optimitzed = rubixdata #.stars.datacube\n", "\n", "plt.plot(wave, spectra_target[0,0,:], label=f\"Target age = {age_values[index_age]:.2f}, metal. = {metallicity_values[index_metallicity]:.4f}\")\n", "plt.plot(wave, spectra_optimitzed[0,0,:], label=f\"Optimized age = {age_history[i]*20:.2f}, metal. = {metallicity_history[i]*0.05:.4f}\")\n", "plt.xlabel(\"Wavelength [Å]\")\n", "plt.ylabel(\"Luminosity [L/Å]\")\n", "plt.title(\"Difference between target and optimized spectra\")\n", "#plt.title(f\"Loss {loss_history[i]:.2e}\")\n", "plt.legend()\n", "#plt.ylim(0.00003, 0.00008)\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 22, "id": "280ae7ad", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# NBVAL_SKIP\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "\n", "# Create a figure with two subplots, sharing the x-axis.\n", "fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True, gridspec_kw={'height_ratios': [4, 1]}, figsize=(7, 5))\n", "\n", "# Plot target and optimized spectra in the upper subplot.\n", "ax1.plot(wave, spectra_target[0, 0, :], label=f\"Target age = {age_values[index_age]:.2f}, metallicity = {metallicity_values[index_metallicity]:.4f}\")\n", "ax1.plot(wave, spectra_optimitzed[0, 0, :], label=f\"Optimized age = {age_history[i]*20:.2f}, metallicity = {metallicity_history[i]*0.05:.4f}\")\n", "ax1.set_ylabel(\"Luminosity [L/Å]\")\n", "#ax1.set_title(\"Target vs Optimized Spectra\")\n", "ax1.legend()\n", "ax1.grid(True)\n", "\n", "# Compute the residual (difference between target and optimized spectra).\n", "residual = (spectra_target[0, 0, :] - spectra_optimitzed[0, 0, :]) #/spectra_target[0, 0, :]\n", "\n", "# Plot the residual in the lower subplot.\n", "ax2.plot(wave, residual, 'k-')\n", "ax2.set_xlabel(\"Wavelength [Å]\")\n", "ax2.set_ylabel(\"Residual\")\n", "ax2.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.savefig(f\"output/optimisation_spectra.jpg\", dpi=1000)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "dfe98919", "metadata": {}, "source": [ "# Calculate loss landscape" ] }, { "cell_type": "code", "execution_count": 23, "id": "77873f8d", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "import optax\n", "\n", "def loss_only_wrt_age_metallicity(age, metallicity, base_data, target):\n", "\n", " # Create 2D arrays for age and metallicity.\n", " # For example, if there are two stars, you might do:\n", " base_data.stars.age = jnp.array([age*20, age*20])\n", " base_data.stars.metallicity = jnp.array([metallicity*0.05, metallicity*0.05])\n", "\n", " output = pipeline_instance.func(base_data)\n", " #loss = jnp.sum((output.stars.datacube - target) ** 2)\n", " loss = jnp.sum(optax.cosine_distance(output.stars.datacube, target))\n", " return loss" ] }, { "cell_type": "code", "execution_count": 24, "id": "3cd1cac4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-07-01 14:22:12,033 - rubix - INFO - Rotating galaxy with alpha=90.0, beta=0.0, gamma=0.0\n", "2025-07-01 14:22:12,100 - rubix - INFO - Assigning particles to spaxels...\n", "2025-07-01 14:22:12,102 - rubix - INFO - Calculating Data Cube (combined per‐particle)…\n", "2025-07-01 14:22:12,448 - rubix - DEBUG - Datacube shape: (1, 1, 466)\n", "2025-07-01 14:22:12,449 - rubix - INFO - Convolving with PSF...\n", "2025-07-01 14:22:12,451 - rubix - INFO - Convolving with LSF...\n", "2025-07-01 14:22:12,454 - rubix - INFO - Applying noise to datacube with signal to noise ratio: 100 and noise distribution: normal\n" ] } ], "source": [ "# NBVAL_SKIP\n", "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "\n", "# Number of grid points\n", "num_steps = 100\n", "\n", "# Define physical ranges\n", "physical_ages = jnp.linspace(0, 1, num_steps) # Age from 0 to 10\n", "physical_metals = jnp.linspace(0, 1, num_steps) # Metallicity from 1e-4 to 0.05\n", "\n", "loss_map = []\n", "\n", "for age in physical_ages:\n", " row = []\n", " for metal in physical_metals:\n", " loss = loss_only_wrt_age_metallicity(age, metal, inputdata, targetdata)\n", " row.append(loss)\n", " loss_map.append(jnp.stack(row))\n", "\n", "loss_map = jnp.stack(loss_map)" ] }, { "cell_type": "code", "execution_count": 25, "id": "b8591401", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# NBVAL_SKIP\n", "# Plot the loss landscape using imshow.\n", "import matplotlib.pyplot as plt\n", "import matplotlib.colors as colors\n", "plt.figure(figsize=(5, 4))\n", "plt.imshow(loss_map, origin='lower', extent=[0,1,0,1], aspect='auto', norm=colors.LogNorm())#, vmin=-3.5, vmax=-2.5)#extent=[1e-4, 0.05, 0, 10]\n", "plt.xlabel('Metallicity')\n", "plt.ylabel('Age')\n", "plt.title('Loss Landscape')\n", "plt.colorbar(label='loss')\n", "# Plot a red dot at the desired coordinates.\n", "plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)\n", "#plt.plot(metallicity_history[::100], age_history[::100], 'bx', markersize=8)\n", "plt.plot(metallicity_values[index_metallicity]/0.05, age_values[index_age]/20, 'ro', markersize=8)\n", "plt.plot(metallicity_values[initial_metallicity_index]/0.05, age_values[initial_age_index]/20, 'ro', markersize=8)\n", "plt.savefig(f\"output/optimisation_losslandscape.jpg\", dpi=1000)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 26, "id": "93071f12", "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "metallicity_history = np.array(metallicity_history)*0.05\n", "age_history = np.array(age_history)*20" ] }, { "cell_type": "code", "execution_count": 27, "id": "bb006589", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# NBVAL_SKIP\n", "import matplotlib.pyplot as plt\n", "import matplotlib.colors as colors\n", "\n", "plt.figure(figsize=(6, 5))\n", "\n", "# Update the extent to the physical values: metallicity from 0 to 0.05 and age from 0 to 20.\n", "plt.imshow(loss_map, origin='lower', extent=[0, 0.05, 0, 20], aspect='auto', norm=colors.LogNorm())\n", "\n", "plt.xlabel('Metallicity')\n", "plt.ylabel('Age')\n", "plt.title('Loss Landscape')\n", "plt.colorbar(label='loss')\n", "\n", "# Plot the history in physical coordinates by multiplying the normalized values.\n", "plt.plot(metallicity_history[:], age_history[:])#, 'bx', markersize=8)\n", "\n", "# Plot the red dots in physical coordinates\n", "plt.plot(metallicity_values[index_metallicity], age_values[index_age], marker='o', color='orange', markersize=8)\n", "plt.plot(metallicity_values[initial_metallicity_index], age_values[initial_age_index], 'wo', markersize=8)\n", "\n", "plt.savefig(\"output/optimisation_losslandscape.jpg\", dpi=1000)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "rubix", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }