rubix.pipeline package#

Submodules#

rubix.pipeline.abstract_pipeline module#

class rubix.pipeline.abstract_pipeline.AbstractPipeline(cfg: dict, transformers: list[Callable[[...], Any]])[source]#

Bases: ABC

Abstract base class for data transformation pipelines.

Derived classes must implement build_pipeline, build_expression, and apply. These helpers build the pipeline, assemble it into a pure function, and apply it to input data.

Parameters:
  • cfg (dict) – Configuration dictionary defining the pipeline.

  • transformers (list[Callable[..., Any]]) – Transformers that will be registered with the pipeline.

assemble() None[source]#

Assemble the pipeline into a self-contained function.

compile_element(name: str, static_args: Sequence[int] | None = None, static_kwargs: Sequence[str] | None = None) Callable[[...], Any][source]#

Compile a specific pipeline element using jax.jit.

Parameters:
  • name (str) – Name of the element to compile.

  • static_args (Optional[Sequence[int]], optional) – Positional indices forwarded to jit as static_argnums. Defaults to None.

  • static_kwargs (Optional[Sequence[str]], optional) – Keyword names forwarded to jit as static_argnames. Defaults to None.

Raises:

RuntimeError – When compilation of the element fails.

Returns:

The compiled transformer.

Return type:

Callable[…, Any]

compile_expression(static_args: Sequence[int] | None = None, static_kwargs: Sequence[str] | None = None) Callable[[...], Any][source]#

Compile the pipeline expression using jax.jit.

Parameters:
  • static_args (Optional[Sequence[int]], optional) – Positional indices forwarded to jit as static_argnums. Defaults to None.

  • static_kwargs (Optional[Sequence[str]], optional) – Keyword names forwarded to jit as static_argnames. Defaults to None.

Raises:

RuntimeError – When compilation fails.

Returns:

Compiled pipeline function.

Return type:

Callable[…, Any]

get_jaxpr(*args: Any, static_args: Sequence[int] | None = None) Callable[[...], Any] | ClosedJaxpr[source]#

Return a JAX intermediate expression for the pipeline.

Note: Please note that this only works with static positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jitted function. You can use partial to fix keyword arguments before calling this method.

Parameters:
  • *args (Any) – Positional arguments forwarded to the expression whose intermediate representation should be produced.

  • static_args (Optional[Sequence[int]], optional) – Static positional indices forwarded to jax.make_jaxpr via static_argnums. Defaults to None.

Returns:

When *args is provided. Callable[…, Any]: When *args is empty.

Return type:

ClosedJaxpr

get_jaxpr_for_element(name: str, *args: Any, static_args: Sequence[int] | None = None) Callable[[...], Any] | ClosedJaxpr[source]#

get_jaxpr_for_element Create a jax intermediate expression for a given element of the pipeline named ‘name’ with static arguments ‘static_args and arguments *args. If no arguments are provided, a function is returned which will return the intermediate representation once it is called with arguments.

Parameters:
  • name (str) – Name of the element to inspect.

  • *args (Any) – Positional arguments forwarded to the element.

  • static_args (Optional[Sequence[int]], optional) – Static positional indices forwarded to expression_transformer. Defaults to None.

Raises:

RuntimeError – When the expression cannot be created.

Returns:

When *args is provided. Callable[…, Any]: When *args is empty.

Return type:

ClosedJaxpr

property pipeline: dict[str, Callable[[...], Any]]#
Return the registered pipeline elements as a

dictionary of name: function pairs.

Returns:

Mapping from name to function.

Return type:

dict[str, Callable[…, Any]]

register_transformer(cls: Callable[[...], Any]) None[source]#

Register a transformer function for later use to make it available to the calling pipeline object.

Note: The registered function must be a pure functional function in order to be transformable with jax. The registered transformers are used to build a pipeline.

Parameters:

cls (Callable[..., Any]) – Function to register.

Raises:

ValueError – When a transformer with the same name is already present.

rubix.pipeline.linear_pipeline module#

class rubix.pipeline.linear_pipeline.LinearTransformerPipeline(cfg: dict, transformers: list[Callable[[...], Any]])[source]#

Bases: AbstractPipeline

Minimal linear pipeline that chains transformer functions.

An implementation of a data transformation pipeline in the form of a simple, 1-D chain of composed functions in which each function uses the output of the function before it as arguments.

Note

This does only set up all necessary things to build the pipeline. The pipeline itself has to be created after registering transformers to use by calling assemble().

Parameters:
  • cfg (dict) – Configuration describing the pipeline nodes and their dependencies.

  • transformers (list[Callable[..., Any]]) – Transformer functions that will be registered before building the pipeline.

apply(*args: Any, static_args: Sequence[int] | None = None, static_kwargs: Sequence[str] | None = None, **kwargs: Any) Any[source]#

Apply the compiled pipeline to the provided positional arguments *args and keyword arguments **kwargs that match the signature of the first method in the pipeline with static keyword arguments that are not traced by JAX. First applies the jax jit to the pre-assembled pipeline, then applies the result to the provided arguments.

Parameters:
  • *args (Any) – Positional arguments to feed to the pipeline’s first element.

  • static_args (Sequence[int] | None, optional) – Positional indices that should not be traced by JAX. Defaults to None.

  • static_kwargs (Sequence[str] | None, optional) – Keyword names that should not be traced by JAX. Defaults to None.

  • **kwargs (Any) – Keyword arguments to forward to the pipeline.

Returns:

The result of running the pipeline.

Return type:

Any

Raises:

ValueError – When no positional arguments are provided.

build_expression() None[source]#

Compose the assembled pipeline into a single callable that has the same signature as the first element of the pipeline.

build_pipeline() None[source]#

Construct the pipeline from the stored configuration.

Note

This only works when all transformers the pipeline is composed of have been registered with it. Multiple different versions (configurations) of the same transformer can be used in a pipeline.

Raises:
  • RuntimeError – When no transformers are registered.

  • ValueError – When the configuration contains invalid or branching dependencies or when there are multiple starting points to the pipeline.

update_pipeline(current_name: str) None[source]#

Adds a new pipeline node with name ‘current_name’ to the pipeline, taking into account internal linear dependencies. Mostly used internally for adding nodes one by one.

Parameters:

current_name (str) – Name of the node to append.

Raises:

RuntimeError – When the requested node is missing from the config.

rubix.pipeline.transformer module#

rubix.pipeline.transformer.bound_transformer(*args, **kwargs)[source]#

bound_transformer creates a jax.Partial function from an input function and given arguments and keyword arguments. The user must take care that arguments are bound starting from the first, i.e., leftmost. If specific arguments should be bound using keyword arguments may be advisable.

rubix.pipeline.transformer.compiled_transformer(*args, static_args: list = [], static_kwargs: list = [], **kwargs)[source]#

compiled_transformer creates a precompiled function with jax with given arguments and keyword arguments that will be bound to the function, similar to using functools.partial with *args and **kwargs.

Note that any array args/kwargs will behave as dynamic arguments in the jax jit, while any non-array args/kwargs will behave as static. static_args and static_kwargs refer to the remaining arguments.

*args count from the first positional argument of the decorated function in order. *args and **kwargs are bound to the decorated function.

Parameters:
  • *args – Positional arguments to bind to the target function.

  • static_args (list, optional) – Indices of static (untraced) positional arguments of the bound function. Defaults to [].

  • static_kwargs (list, optional) – Names of static (untraced) keyword arguments of the bound function. Defaults to [].

  • **kwargs – Keyword arguments to bind to the target function.

Returns:

A decorator that returns a jitted function with bound args.

Return type:

callable

rubix.pipeline.transformer.expression_transformer(*args, static_args: list = [])[source]#

expression_transformer creates a jax intermediate expression with given untraced arguments from a function. Please note that this only works with static positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jited function. :param *args: Positional arguments to bind for the expression generation. :param static_args: Indices of static (untraced) positional

arguments to the function. Defaults to [].

Returns:

A function (or jaxpr) produced from the provided kernel.

Return type:

callable

Module contents#