Source code for rubix.galaxy.alignment

import jax.numpy as jnp
from beartype import beartype as typechecker
from beartype.typing import Tuple, Union
from jax.scipy.spatial.transform import Rotation
from jaxtyping import Array, Float, jaxtyped


[docs] @jaxtyped(typechecker=typechecker) def center_particles(rubixdata: object, key: str) -> object: """ Center the stellar particles around the galaxy center. Args: rubixdata (object): The RubixData object to update. key (str): Particle key, e.g. "stars" or "gas". Returns: object: The same RubixData object with centered coordinates and velocities. Raises: ValueError: If the galaxy center lies outside the particle bounds. Example: >>> from rubix.galaxy.alignment import center_particles >>> rubixdata = center_particles(rubixdata, "stars") """ if key == "stars": particle_coordinates = rubixdata.stars.coords particle_velocities = rubixdata.stars.velocity elif key == "gas": particle_coordinates = rubixdata.gas.coords particle_velocities = rubixdata.gas.velocity galaxy_center = rubixdata.galaxy.center # Check if Center is within bounds check_bounds = ( (galaxy_center[0] >= jnp.min(particle_coordinates[:, 0])) & (galaxy_center[0] <= jnp.max(particle_coordinates[:, 0])) & (galaxy_center[1] >= jnp.min(particle_coordinates[:, 1])) & (galaxy_center[1] <= jnp.max(particle_coordinates[:, 1])) & (galaxy_center[2] >= jnp.min(particle_coordinates[:, 2])) & (galaxy_center[2] <= jnp.max(particle_coordinates[:, 2])) ) if not check_bounds: raise ValueError("Center is not within the bounds of the galaxy") # Calculate Central Velocity from median velocities within 10kpc of center mask = jnp.linalg.norm(particle_coordinates - galaxy_center, axis=1) < 10 # TODO this should be a median central_velocity = jnp.median(particle_velocities[mask], axis=0) if key == "stars": rubixdata.stars.coords = particle_coordinates - galaxy_center rubixdata.stars.velocity = particle_velocities - central_velocity elif key == "gas": rubixdata.gas.coords = particle_coordinates - galaxy_center rubixdata.gas.velocity = particle_velocities - central_velocity return rubixdata
[docs] @jaxtyped(typechecker=typechecker) def moment_of_inertia_tensor( positions: Float[Array, "..."], masses: Float[Array, "..."], halfmass_radius: Union[Float[Array, "..."], float], ) -> Float[Array, "..."]: """ Calculate the moment of inertia tensor for particles within the half-mass radius. Assumes the galaxy is already centered. Args: positions (Float[Array, "..."]): Particle positions. masses (Float[Array, "..."]): Corresponding masses. halfmass_radius (Union[Float[Array, "..."], float]): The half-mass radius of the galaxy used to filter particles. Returns: Float[Array, "..."]: Moment of inertia tensor. Example: >>> from rubix.galaxy.alignment import moment_of_inertia_tensor >>> I = moment_of_inertia_tensor( ... rubixdata.stars.coords, ... rubixdata.stars.mass, ... rubixdata.galaxy.halfmassrad_stars, ... ) """ distances = jnp.sqrt( jnp.sum(positions**2, axis=1) ) # Direct calculation since positions are already centered within_halfmass_radius = distances <= halfmass_radius # Ensure within_halfmass_radius is concrete concrete_indices = jnp.where( within_halfmass_radius, size=within_halfmass_radius.shape[0] )[0] filtered_positions = positions[concrete_indices] filtered_masses = masses[concrete_indices] I = jnp.zeros((3, 3)) for i in range(3): for j in range(3): if i == j: I = I.at[i, j].set( jnp.sum( filtered_masses * jnp.sum(filtered_positions**2, axis=1) - filtered_masses * filtered_positions[:, i] ** 2 ) ) else: I = I.at[i, j].set( -jnp.sum( filtered_masses * filtered_positions[:, i] * filtered_positions[:, j] ) ) return I
[docs] @jaxtyped(typechecker=typechecker) def rotation_matrix_from_inertia_tensor(I: Float[Array, "..."]) -> Float[Array, "..."]: """ Calculate the 3x3 rotation matrix by diagonalizing the moment of inertia tensor. Args: I (Float[Array, "..."]): The moment of inertia tensor. Returns: Float[Array, "..."]: The rotation matrix. """ eigen_values, eigen_vectors = jnp.linalg.eigh(I) order = jnp.argsort(eigen_values) rotation_matrix = eigen_vectors[:, order] return rotation_matrix
[docs] @jaxtyped(typechecker=typechecker) def apply_init_rotation( positions: Float[Array, "* 3"], rotation_matrix: Float[Array, "3 3"] ) -> Float[Array, "* 3"]: """ Apply a rotation matrix to a particle positions array. Args: positions (Float[Array, "* 3"]): The particle positions. rotation_matrix (Float[Array, "3 3"]): The rotation matrix to apply. Returns: Float[Array, "* 3"]: The rotated positions. """ return jnp.dot(positions, rotation_matrix)
[docs] @jaxtyped(typechecker=typechecker) def euler_rotation_matrix( alpha: float, beta: float, gamma: float ) -> Float[Array, "3 3"]: """ Create a 3x3 rotation matrix given Euler angles (in degrees) Args: alpha (float): Rotation around the x-axis in degrees beta (float): Rotation around the y-axis in degrees gamma (float): Rotation around the z-axis in degrees Returns: The rotation matrix as a jnp.ndarray. """ # alpha = alpha/180*jnp.pi # beta = beta/180*jnp.pi # gamma = gamma/180*jnp.pi # Rotation around the x-axis # R_x = jnp.array([ # [1, 0, 0], # [0, jnp.cos(alpha), -jnp.sin(alpha)], # [0, jnp.sin(alpha), jnp.cos(alpha)] # ]) R_x = Rotation.from_euler("x", alpha, degrees=True) # Rotation around the y-axis (pitch) # R_y = jnp.array([ # [jnp.cos(beta), 0, jnp.sin(beta)], # [0, 1, 0], # [-jnp.sin(beta), 0, jnp.cos(beta)] # ]) R_y = Rotation.from_euler("y", beta, degrees=True) # Rotation around the z-axis (yaw) # R_z = jnp.array([ # [jnp.cos(gamma), -jnp.sin(gamma), 0], # [jnp.sin(gamma), jnp.cos(gamma), 0], # [0, 0, 1] # ]) R_z = Rotation.from_euler("z", gamma, degrees=True) # Combine the rotations by matrix multiplication: R = R_z * R_y * R_x R = R_z * R_y * R_x return R.as_matrix()
[docs] @jaxtyped(typechecker=typechecker) def apply_rotation( positions: Float[Array, "* 3"], alpha: float, beta: float, gamma: float ) -> Float[Array, "* 3"]: """ Apply an Euler-angle rotation using the combined rotation matrix. Args: positions (Float[Array, "* 3"]): The positions of the particles. alpha (float): Rotation around the x-axis in degrees. beta (float): Rotation around the y-axis in degrees. gamma (float): Rotation around the z-axis in degrees. Returns: Float[Array, "* 3"]: The rotated positions. """ R = euler_rotation_matrix(alpha, beta, gamma) return jnp.dot(positions, R)
[docs] @jaxtyped(typechecker=typechecker) def rotate_galaxy( positions: Float[Array, "..."], velocities: Float[Array, "..."], positions_stars: Float[Array, "..."], masses_stars: Float[Array, "..."], halfmass_radius: Union[Float[Array, "..."], float], alpha: float, beta: float, gamma: float, key: str, ) -> Tuple[Float[Array, "* 3"], Float[Array, "* 3"]]: """ Orientate the galaxy by rotating the particle coordinates by Euler angles. Args: positions (Float[Array, "..."]): Particle positions. velocities (Float[Array, "..."]): Particle velocities. positions_stars (Float[Array, "..."]): Star particle positions. masses_stars (Float[Array, "..."]): Star particle masses. halfmass_radius (Union[Float[Array, "..."], float]): Radius used for the moment of inertia calculation. alpha (float): Rotation around the x-axis in degrees. beta (float): Rotation around the y-axis in degrees. gamma (float): Rotation around the z-axis in degrees. key (str): Dataset key ("IllustrisTNG" or "NIHAO"). Returns: Tuple[Float[Array, "* 3"], Float[Array, "* 3"]]: Rotated positions and velocities. Raises: ValueError: If `key` is not supported. """ # we have to distinguis between IllustrisTNG and NIHAO. # The nihao galaxies are already oriented face-on in the pynbody input handler. # The IllustrisTNG galaxies are not oriented face-on, so we have to calculate the moment of inertia tensor # and apply the rotation matrix to the positions and velocities. # After that the simulations can be treated in the same way. # Then the user specific rotation is applied to the positions and velocities. if key == "IllustrisTNG": I = moment_of_inertia_tensor(positions_stars, masses_stars, halfmass_radius) R = rotation_matrix_from_inertia_tensor(I) pos_rot = apply_init_rotation(positions, R) vel_rot = apply_init_rotation(velocities, R) pos_final = apply_rotation(pos_rot, alpha, beta, gamma) vel_final = apply_rotation(vel_rot, alpha, beta, gamma) elif key == "NIHAO": pos_final = apply_rotation(positions, alpha, beta, gamma) vel_final = apply_rotation(velocities, alpha, beta, gamma) else: raise ValueError( f"Unknown key: {key} for the rotation. Supported keys are 'IllustrisTNG' and 'NIHAO'." ) return pos_final, vel_final