Source code for rubix.core.telescope
import jax.numpy as jnp
from rubix.telescope.utils import (
calculate_spatial_bin_edges,
square_spaxel_assignment,
mask_particles_outside_aperture,
)
from rubix.telescope.base import BaseTelescope
from rubix.telescope.factory import TelescopeFactory
from rubix.logger import get_logger
from .cosmology import get_cosmology
from typing import Callable
from typing import Union
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker
from unittest.mock import patch, MagicMock
[docs]
@jaxtyped(typechecker=typechecker)
def get_telescope(config: Union[str, dict]) -> BaseTelescope:
"""
Get the telescope object based on the configuration.
Args:
config (dict): Configuration dictionary.
Returns:
The telescope object.
Example
-------
>>> from rubix.core.telescope import get_telescope
>>> config = {
... "telescope":
... {"name": "MUSE"},
... }
>>> telescope = get_telescope(config)
>>> print(telescope)
"""
# TODO: this currently only loads telescope that are supported.
# add support for custom telescopes
factory = TelescopeFactory()
telescope = factory.create_telescope(config["telescope"]["name"])
if not isinstance(telescope, BaseTelescope):
raise TypeError(f"Expected type BaseTelescope, but got {type(telescope)}")
return telescope
[docs]
@jaxtyped(typechecker=typechecker)
def get_spatial_bin_edges(config: dict) -> jnp.ndarray:
"""
Get the spatial bin edges based on the configuration.
Args:
config (dict): Configuration dictionary.
Returns:
The spatial bin edges.
"""
logger = get_logger(config.get("logger", None))
logger.info("Calculating spatial bin edges...")
telescope = get_telescope(config)
galaxy_dist_z = config["galaxy"]["dist_z"]
cosmology = get_cosmology(config)
# Calculate the spatial bin edges
# TODO: check if we need the spatial bin size somewhere? For now we dont use it
spatial_bin_edges, spatial_bin_size = calculate_spatial_bin_edges(
fov=telescope.fov,
spatial_bins=telescope.sbin,
dist_z=galaxy_dist_z,
cosmology=cosmology,
)
return spatial_bin_edges
[docs]
@jaxtyped(typechecker=typechecker)
def get_spaxel_assignment(config: dict) -> Callable:
"""
Get the spaxel assignment function based on the configuration.
Args:
config (dict): Configuration dictionary.
Returns:
The spaxel assignment function.
Example
-------
>>> from rubix.core.telescope import get_spaxel_assignment
>>> bin_particles = get_spaxel_assignment(config)
>>> rubixdata = bin_particles(rubixdata)
>>> print(rubixdata.stars.pixel_assignment)
>>> print(rubixdata.stars.spatial_bin_edges)
"""
logger = get_logger(config.get("logger", None))
telescope = get_telescope(config)
if telescope.pixel_type not in ["square"]:
raise ValueError(f"Pixel type {telescope.pixel_type} not supported")
spatial_bin_edges = get_spatial_bin_edges(config)
def spaxel_assignment(rubixdata: object) -> object:
logger.info("Assigning particles to spaxels...")
if rubixdata.stars.coords is not None:
pixel_assignment = square_spaxel_assignment(
rubixdata.stars.coords, spatial_bin_edges
)
rubixdata.stars.pixel_assignment = pixel_assignment
rubixdata.stars.spatial_bin_edges = spatial_bin_edges
if rubixdata.gas.coords is not None:
pixel_assignment = square_spaxel_assignment(
rubixdata.gas.coords, spatial_bin_edges
)
rubixdata.gas.pixel_assignment = pixel_assignment
rubixdata.gas.spatial_bin_edges = spatial_bin_edges
return rubixdata
return spaxel_assignment
[docs]
@jaxtyped(typechecker=typechecker)
def get_filter_particles(config: dict) -> Callable:
"""
Get the function to filter particles outside the aperture.
Args:
config (dict): Configuration dictionary.
Returns:
The filter particles function
Example
-------
>>> from rubix.core.telescope import get_filter_particles
>>> filter_particles = get_filter_particles(config)
>>> rubixdata = filter_particles(rubixdata)
"""
logger = get_logger(config.get("logger", None))
spatial_bin_edges = get_spatial_bin_edges(config)
def filter_particles(rubixdata: object) -> object:
logger.info("Filtering particles outside the aperture...")
if "stars" in config["data"]["args"]["particle_type"]:
# if rubixdata.stars.coords is not None:
mask = mask_particles_outside_aperture(
rubixdata.stars.coords, spatial_bin_edges
)
attributes = [
attr
for attr in dir(rubixdata.stars)
if not attr.startswith("__")
and not callable(getattr(rubixdata.stars, attr))
and attr not in ("coords", "velocity")
]
for attr in attributes:
current_attr_value = getattr(rubixdata.stars, attr)
# Apply mask only if current_attr_value is an ndarray
if isinstance(current_attr_value, jnp.ndarray):
setattr(
rubixdata.stars, attr, jnp.where(mask, current_attr_value, 0)
)
mask_jax = jnp.array(mask)
setattr(rubixdata.stars, "mask", mask_jax)
# rubixdata.stars.mask = mask
if "gas" in config["data"]["args"]["particle_type"]:
mask = mask_particles_outside_aperture(
rubixdata.gas.coords, spatial_bin_edges
)
attributes = [
attr
for attr in dir(rubixdata.gas)
if not attr.startswith("__")
and not callable(getattr(rubixdata.gas, attr))
and attr not in ("coords", "velocity")
]
for attr in attributes:
current_attr_value = getattr(rubixdata.gas, attr)
if isinstance(current_attr_value, jnp.ndarray):
setattr(rubixdata.gas, attr, jnp.where(mask, current_attr_value, 0))
# rubixdata.gas.__setattr__(attr, jnp.where(mask, rubixdata.gas.__getattribute__(attr), 0))
mask_jax = jnp.array(mask)
setattr(rubixdata.gas, "mask", mask_jax)
# rubixdata.gas.mask = mask
return rubixdata
return filter_particles