import jax.numpy as jnp
from jaxtyping import Float, Array
from typing import Tuple, Union
from jax.scipy.spatial.transform import Rotation
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker
[docs]
@jaxtyped(typechecker=typechecker)
def center_particles(rubixdata: object, key: str) -> object:
"""
Center the stellar particles around the galaxy center.
Args:
rubixdata (object): The RubixData object.
key (str): The key to the particle data.
stellar_coordinates (jnp.ndarray): The coordinates of the particles.
stellar_velocities (jnp.ndarray): The velocities of the particles.
galaxy_center (jnp.ndarray): The center of the galaxy.
Returns:
The RubixData object with the centered particles, which contain of a new set
of coordinates and velocities as jnp.ndarray.
Example
-------
>>> from rubix.galaxy.alignment import center_particles
>>> rubixdata = center_particles(rubixdata, "stars")
"""
if key == "stars":
stellar_coordinates = rubixdata.stars.coords
stellar_velocities = rubixdata.stars.velocity
elif key == "gas":
stellar_coordinates = rubixdata.gas.coords
stellar_velocities = rubixdata.gas.velocity
galaxy_center = rubixdata.galaxy.center
# Check if Center is within bounds
check_bounds = (
(galaxy_center[0] >= jnp.min(stellar_coordinates[:, 0]))
& (galaxy_center[0] <= jnp.max(stellar_coordinates[:, 0]))
& (galaxy_center[1] >= jnp.min(stellar_coordinates[:, 1]))
& (galaxy_center[1] <= jnp.max(stellar_coordinates[:, 1]))
& (galaxy_center[2] >= jnp.min(stellar_coordinates[:, 2]))
& (galaxy_center[2] <= jnp.max(stellar_coordinates[:, 2]))
)
if not check_bounds:
raise ValueError("Center is not within the bounds of the galaxy")
# Calculate Central Velocity from median velocities within 10kpc of center
mask = jnp.linalg.norm(stellar_coordinates - galaxy_center, axis=1) < 10
# TODO this should be a median
central_velocity = jnp.median(stellar_velocities[mask], axis=0)
if key == "stars":
rubixdata.stars.coords = stellar_coordinates - galaxy_center
rubixdata.stars.velocity = stellar_velocities - central_velocity
elif key == "gas":
rubixdata.gas.coords = stellar_coordinates - galaxy_center
rubixdata.gas.velocity = stellar_velocities - central_velocity
return rubixdata
[docs]
@jaxtyped(typechecker=typechecker)
def moment_of_inertia_tensor(
positions: Float[Array, "..."],
masses: Float[Array, "..."],
halfmass_radius: Union[Float[Array, "..."], float],
) -> Float[Array, "..."]:
"""
Calculate the moment of inertia tensor for a given set of positions and masses within the half-light radius.
Assumes the galaxy is already centered.
Args:
positions (jnp.ndarray): The positions of the particles.
masses (jnp.ndarray): The masses of the particles.
half_light_radius (float): The half-light radius of the galaxy.
Returns:
The moment of inertia tensor as a jnp.ndarray.
Example
-------
>>> from rubix.galaxy.alignment import moment_of_inertia_tensor
>>> I = moment_of_inertia_tensor(rubixdata.stars.coords, rubixdata.stars.mass, rubixdata.galaxy.half_light_radius)
"""
distances = jnp.sqrt(
jnp.sum(positions**2, axis=1)
) # Direct calculation since positions are already centered
within_halfmass_radius = distances <= halfmass_radius
# Ensure within_halfmass_radius is concrete
concrete_indices = jnp.where(
within_halfmass_radius, size=within_halfmass_radius.shape[0]
)[0]
filtered_positions = positions[concrete_indices]
filtered_masses = masses[concrete_indices]
I = jnp.zeros((3, 3))
for i in range(3):
for j in range(3):
if i == j:
I = I.at[i, j].set(
jnp.sum(
filtered_masses * jnp.sum(filtered_positions**2, axis=1)
- filtered_masses * filtered_positions[:, i] ** 2
)
)
else:
I = I.at[i, j].set(
-jnp.sum(
filtered_masses
* filtered_positions[:, i]
* filtered_positions[:, j]
)
)
return I
[docs]
@jaxtyped(typechecker=typechecker)
def rotation_matrix_from_inertia_tensor(I: Float[Array, "..."]) -> Float[Array, "..."]:
"""
Calculate 3x3 rotation matrix by diagonalization of the moment of inertia tensor.
Args:
I (jnp.ndarray): The moment of inertia tensor.
Returns:
The rotation matrix as a jnp.ndarray.
"""
eigen_values, eigen_vectors = jnp.linalg.eigh(I)
order = jnp.argsort(eigen_values)
rotation_matrix = eigen_vectors[:, order]
return rotation_matrix
[docs]
@jaxtyped(typechecker=typechecker)
def apply_init_rotation(
positions: Float[Array, "..."], rotation_matrix: Float[Array, "..."]
) -> Float[Array, "..."]:
"""
Apply a rotation matrix to a set of positions.
Args:
positions (jnp.ndarray): The positions of the particles.
rotation_matrix (jnp.ndarray): The rotation matrix.
Returns:
The rotated positions as a jnp.ndarray.
"""
return jnp.dot(positions, rotation_matrix)
[docs]
@jaxtyped(typechecker=typechecker)
# def euler_rotation_matrix(alpha: Float[jnp.ndarray, ""], beta: Float[jnp.ndarray, ""], gamma: Float[jnp.ndarray, ""]) -> Float[jnp.ndarray, "3 3"]:
def euler_rotation_matrix(
alpha: float, beta: float, gamma: float
) -> Float[Array, "3 3"]:
"""
Create a 3x3 rotation matrix given Euler angles (in degrees)
Args:
alpha (float): Rotation around the x-axis in degrees
beta (float): Rotation around the y-axis in degrees
gamma (float): Rotation around the z-axis in degrees
Returns:
The rotation matrix as a jnp.ndarray.
"""
# alpha = alpha/180*jnp.pi
# beta = beta/180*jnp.pi
# gamma = gamma/180*jnp.pi
# Rotation around the x-axis
# R_x = jnp.array([
# [1, 0, 0],
# [0, jnp.cos(alpha), -jnp.sin(alpha)],
# [0, jnp.sin(alpha), jnp.cos(alpha)]
# ])
R_x = Rotation.from_euler("x", alpha, degrees=True)
# Rotation around the y-axis (pitch)
# R_y = jnp.array([
# [jnp.cos(beta), 0, jnp.sin(beta)],
# [0, 1, 0],
# [-jnp.sin(beta), 0, jnp.cos(beta)]
# ])
R_y = Rotation.from_euler("y", beta, degrees=True)
# Rotation around the z-axis (yaw)
# R_z = jnp.array([
# [jnp.cos(gamma), -jnp.sin(gamma), 0],
# [jnp.sin(gamma), jnp.cos(gamma), 0],
# [0, 0, 1]
# ])
R_z = Rotation.from_euler("z", gamma, degrees=True)
# Combine the rotations by matrix multiplication: R = R_z * R_y * R_x
R = R_z * R_y * R_x
return R.as_matrix()
# @jaxtyped(typechecker=typechecker)
[docs]
def apply_rotation(
positions: Float[Array, "* 3"], alpha: float, beta: float, gamma: float
) -> Float[Array, "* 3"]:
"""
Apply a rotation to a set of positions given Euler angles.
Args:
positions (jnp.ndarray): The positions of the particles.
alpha (float): Rotation around the x-axis in degrees
beta (float): Rotation around the y-axis in degrees
gamma (float): Rotation around the z-axis in degrees
Returns:
The rotated positions as a jnp.ndarray.
"""
R = euler_rotation_matrix(alpha, beta, gamma)
return jnp.dot(positions, R)
# @jaxtyped(typechecker=typechecker)
[docs]
def rotate_galaxy(
positions: Float[Array, "* 3"],
velocities: Float[Array, "* 3"],
masses: Float[Array, "..."],
halfmass_radius: Float[Array, "..."],
alpha: float,
beta: float,
gamma: float,
) -> Tuple[Float[Array, "* 3"], Float[Array, "* 3"]]:
"""
Orientate the galaxy by applying a rotation matrix to the positions of the particles.
Args:
positions (jnp.ndarray): The positions of the particles.
velocities (jnp.ndarray): The velocities of the particles.
masses (jnp.ndarray): The masses of the particles.
halfmass_radius (float): The half-mass radius of the galaxy.
alpha (float): Rotation around the x-axis in degrees
beta (float): Rotation around the y-axis in degrees
gamma (float): Rotation around the z-axis in degrees
Returns:
The rotated positions and velocities as a jnp.ndarray.
"""
I = moment_of_inertia_tensor(positions, masses, halfmass_radius)
R = rotation_matrix_from_inertia_tensor(I)
pos_rot = apply_init_rotation(positions, R)
vel_rot = apply_init_rotation(velocities, R)
pos_final = apply_rotation(pos_rot, alpha, beta, gamma)
vel_final = apply_rotation(vel_rot, alpha, beta, gamma)
return pos_final, vel_final