Source code for rubix.core.rotation
from typing import Dict
import jax
from rubix.logger import get_logger
from rubix.galaxy.alignment import rotate_galaxy as rotate_galaxy_core
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker
[docs]
@jaxtyped(typechecker=typechecker)
def get_galaxy_rotation(config: dict):
"""
Get the function to rotate the galaxy based on the configuration.
Args:
config (dict): Configuration dictionary.
Returns:
The function to rotate the galaxy.
Example
--------
>>> config = {
... ...
... "galaxy":
... {"dist_z": 0.1,
... "rotation": {"type": "edge-on"},
... },
... ...
... }
>>> from rubix.core.rotation import get_galaxy_rotation
>>> rotate_galaxy = get_galaxy_rotation(config)
>>> rubixdata = rotate_galaxy(rubixdata)
"""
# Check if rotation information is provided under galaxy config
if "rotation" not in config["galaxy"]:
raise ValueError("Rotation information not provided in galaxy config")
logger = get_logger()
# Check if type is provided
if "type" in config["galaxy"]["rotation"]:
# Check if type is valid: face-on or edge-on
if config["galaxy"]["rotation"]["type"] not in ["face-on", "edge-on"]:
raise ValueError("Invalid type provided in rotation information")
# if type is face on, alpha = beta = gamma = 0
# if type is edge on, alpha = 90, beta = gamma = 0
if config["galaxy"]["rotation"]["type"] == "face-on":
logger.debug("Roataion Type found: Face-on")
alpha = 0.0
beta = 0.0
gamma = 0.0
else:
# type is edge-on
logger.debug("Roataion Type found: edge-on")
alpha = 90.0
beta = 0.0
gamma = 0.0
else:
# If type is not provided, then alpha, beta, gamma should be set
# Check if alpha, beta, gamma are provided
for key in ["alpha", "beta", "gamma"]:
if key not in config["galaxy"]["rotation"]:
raise ValueError(f"{key} not provided in rotation information")
# Get the rotation angles from the user config
alpha = config["galaxy"]["rotation"]["alpha"]
beta = config["galaxy"]["rotation"]["beta"]
gamma = config["galaxy"]["rotation"]["gamma"]
@jaxtyped(typechecker=typechecker)
def rotate_galaxy(rubixdata: object, type: str = "face-on") -> object:
logger.info(f"Rotating galaxy with alpha={alpha}, beta={beta}, gamma={gamma}")
if "stars" in config["data"]["args"]["particle_type"]:
# Get the inputs
coords = rubixdata.stars.coords
velocities = rubixdata.stars.velocity
masses = rubixdata.stars.mass
halfmass_radius = rubixdata.galaxy.halfmassrad_stars
# Rotate the galaxy
coords, velocities = rotate_galaxy_core(
positions=coords,
velocities=velocities,
masses=masses,
halfmass_radius=halfmass_radius,
alpha=alpha,
beta=beta,
gamma=gamma,
)
# Update the inputs
# rubixdata.stars.coords = coords
# rubixdata.stars.velocity = velocities
setattr(rubixdata.stars, "coords", coords)
setattr(rubixdata.stars, "velocity", velocities)
if "gas" in config["data"]["args"]["particle_type"]:
# Get the inputs
coords = rubixdata.gas.coords
velocities = rubixdata.gas.velocity
masses = rubixdata.gas.mass
halfmass_radius = rubixdata.galaxy.halfmassrad_stars
# Rotate the galaxy
coords, velocities = rotate_galaxy_core(
positions=coords,
velocities=velocities,
masses=masses,
halfmass_radius=halfmass_radius,
alpha=alpha,
beta=beta,
gamma=gamma,
)
# Update the inputs
# rubixdata.gas.coords = coords
# rubixdata.gas.velocity = velocities
setattr(rubixdata.gas, "coords", coords)
setattr(rubixdata.gas, "velocity", velocities)
return rubixdata
return rotate_galaxy