Source code for rubix.pipeline.transformer
from copy import deepcopy
from jax import jit, make_jaxpr
from jax.tree_util import Partial
[docs]
def bound_transformer(*args, **kwargs):
"""
bound_transformer creates a jax.Partial function from an input
function and given arguments and keyword arguments.
The user must take care that arguments are bound starting from the first,
i.e., leftmost. If specific arguments should be bound using keyword
arguments may be advisable.
"""
def transformer_wrap(kernel):
return Partial(
deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)
) # deepcopy to avoid context dependency
return transformer_wrap
[docs]
def compiled_transformer(
*args,
static_args: list = [],
static_kwargs: list = [],
**kwargs,
):
"""
compiled_transformer creates a precompiled function with jax with given arguments and keyword arguments that will be bound to the function, similar
to using functools.partial with *args and **kwargs.
Note that any array args/kwargs will behave as dynamic arguments in the jax jit, while any non-array args/kwargs will behave as static.
static_args and static_kwargs refer to the remaining arguments.
*args count from the first positional argument of the decorated function in
order. ``*args`` and ``**kwargs`` are bound to the decorated function.
Args:
*args: Positional arguments to bind to the target function.
static_args (list, optional): Indices of static (untraced) positional
arguments of the bound function. Defaults to ``[]``.
static_kwargs (list, optional): Names of static (untraced) keyword
arguments of the bound function. Defaults to ``[]``.
**kwargs: Keyword arguments to bind to the target function.
Returns:
callable: A decorator that returns a jitted function with bound args.
"""
def transformer_wrap(kernel):
return jit(
Partial(deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)),
static_argnums=deepcopy(static_args),
static_argnames=deepcopy(static_kwargs),
)
return transformer_wrap
[docs]
def expression_transformer(
*args,
static_args: list = [],
):
"""
expression_transformer creates a jax intermediate expression with given
untraced arguments from a function. Please note that this only works with
static positional arguments: JAX does currently not provide a way to have
static keyword arguments when creating a jaxpr and not a jited function.
Args:
*args: Positional arguments to bind for the expression generation.
static_args (list, optional): Indices of static (untraced) positional
arguments to the function. Defaults to ``[]``.
Returns:
callable: A function (or jaxpr) produced from the provided kernel.
"""
def transformer_wrap(kernel):
if len(args) > 0:
return make_jaxpr(deepcopy(kernel), static_argnums=static_args)(
*deepcopy(args)
)
else:
return make_jaxpr(deepcopy(kernel), static_argnums=static_args)
return transformer_wrap