Source code for rubix.telescope.apertures

""" This class defines the aperture mask for the observation of a galaxy.

"""

import numpy as np
from jaxtyping import Array, Float
import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker

__all__ = ["HEXAGONAL_APERTURE", "SQUARE_APERTURE", "CIRCULAR_APERTURE"]


[docs] @jaxtyped(typechecker=typechecker) def HEXAGONAL_APERTURE(sbin: np.int64) -> Float[Array, "..."]: """ Creates a hexagonal aperture mask for the observation of a galaxy. Args: sbin (int): The size of the spatial bin in each direction for the aperture mask. Returns: A jnp.ndarray 1D array of the aperture mask. """ sbin = int(sbin) # Ensure that the input is an integer ap_region = jnp.zeros((sbin, sbin)) # Empty matrix for aperture mask xcentre, ycentre = sbin / 2 + 0.5, sbin / 2 + 0.5 for x in range(1, sbin + 1): for y in range(1, sbin + 1): xx = x - xcentre yy = y - ycentre rr = ( (2 * (sbin / 4) * (sbin * jnp.sqrt(3) / 4)) - ((sbin / 4) * jnp.abs(yy)) - ((sbin * jnp.sqrt(3) / 4) * jnp.abs(xx)) ) if ( (rr >= 0) and (jnp.abs(xx) < sbin / 2) and (jnp.abs(yy) < (sbin * jnp.sqrt(3) / 4)) ): ap_region = ap_region.at[x - 1, y - 1].set(1) return ap_region.flatten()
[docs] @jaxtyped(typechecker=typechecker) def SQUARE_APERTURE(sbin: np.int64) -> Float[Array, "..."]: """Creates a square aperture mask for the observation of a galaxy. Args: sbin (int): The size of the spatial bin in each direction for the aperture mask. Returns: A jnp.ndarray 1D array of the aperture mask. """ sbin = int(sbin) return jnp.ones((sbin, sbin)).flatten()
[docs] @jaxtyped(typechecker=typechecker) def CIRCULAR_APERTURE(sbin: np.int64) -> Float[Array, "..."]: """Creates a circular aperture mask for the observation of a galaxy. Args: sbin (int): The size of the spatial bin in each direction for the aperture mask. Returns: A jnp.ndarray 1D array of the aperture mask. """ sbin = int(sbin) aperture = jnp.zeros((sbin, sbin)) # Empty matrix for aperture mask xcentre, ycentre = sbin / 2 + 0.5, sbin / 2 + 0.5 x = jnp.tile(jnp.arange(1, sbin + 1), (sbin, 1)) y = jnp.tile(jnp.arange(sbin, 0, -1), (sbin, 1)).T xx, yy = x - xcentre, y - ycentre rr = jnp.sqrt(xx**2 + yy**2) aperture = aperture.at[rr <= sbin / 2].set(1) return aperture.flatten()