Source code for rubix.cosmology.utils

from jax import jit
from jax.lax import scan
from typing import Union
import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker


# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L247C1-L256C1
@jaxtyped(typechecker=typechecker)
@jit
def _cumtrapz_scan_func(carryover, el):
    """
    Integral helper function, which uses the formula for trapezoidal integration.

    Args:
        carryover (tuple): Tuple of (a, fa, cumtrapz)
        a: current value of x-coordinate
        fa: current value of function at a
        cumtrapz: cumulative sum of trapezoidal integration so far
        el (tuple): Tuple of (b, fb)
        b: next value of x-coordinate
        fb: next value of function at b

    Returns:
        The carryover tuple, which contain (b, fb, cumtrapz)

        The accumulated integral value
    """
    b, fb = el
    a, fa, cumtrapz = carryover
    cumtrapz = cumtrapz + (b - a) * (fb + fa) / 2.0
    carryover = b, fb, cumtrapz
    accumulated = cumtrapz
    return carryover, accumulated


# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L278C1-L298C1
[docs] @jaxtyped(typechecker=typechecker) @jit def trapz( xarr: Union[jnp.ndarray, Float[Array, "n"]], yarr: Union[jnp.ndarray, Float[Array, "n"]], ) -> jnp.ndarray: """ The function performs the trapezoidal integration using the ``_cumtrapz_scan_func`` helper function. Args: xarr (ndarray): The x-coordinates of the data points in shape (n, ). yarr (ndarray): The y-coordinates of the data points in shape (n, ). Returns: The result of the trapezoidal integration. Example ------- >>> from rubix.cosmology.utils import trapz >>> import jax.numpy as jnp >>> x = jnp.array([0, 1, 2, 3]) >>> y = jnp.array([0, 1, 4, 9]) >>> print(trapz(x, y)) """ res_init = xarr[0], yarr[0], 0.0 scan_data = xarr, yarr cumtrapz = scan(_cumtrapz_scan_func, res_init, scan_data)[1] return cumtrapz[-1]