rubix.pipeline package#

Submodules#

rubix.pipeline.abstract_pipeline module#

class rubix.pipeline.abstract_pipeline.AbstractPipeline(cfg: dict, transformers: list)[source]#

Bases: ABC

AbstractPipeline Abstract baseclass for data transformation pipelines. Provides methods build_pipeline, build_expression and apply which must be implemented by every derived class and which are responsible for building up the pipeline, for assembling it into a self-contained pure functional function and for applyin the latter to data, respectively.

assemble()[source]#

assemble Assemble the pipeline into a self-contained function with the same signature as the pipeline’s first element. Can only run if all functions that make up the pipeline are registered with it by calling register_transformer.

Raises:

RuntimeError – When no transformers are registered to build the pipeline out of.

compile_element(name: str, static_args=[], static_kwargs=[])[source]#

compile_element Compile an element of the pipeline named ‘name’ with the jax jit with the provided static_args and static kwargs.

Parameters:
  • name (str) – Name of the element to be compiled

  • static_args (list, optional) – static positional argument indices. Will be forwarded to the jit static_argnums argument., by default []

  • static_kwargs (list, optional) – Names of the static keyword arguments. Will be forwarded to the jit static_argnames argument, by default []

Returns:

_description_

Return type:

_type_

compile_expression(static_args=[], static_kwargs=[])[source]#
compile_expression Compile the function that represents an application

of this pipeline to input data using jax jit.

Parameters:
  • static_args (list, optional) – static poisitional arguments that should not be traced by jit, by default []

  • static_kwargs (list, optional) – statiuc keyword arguments that should not be traced by jit, by default []

  • Returns – Compiled pipeline function as PjitFunction

get_jaxpr(*args, static_args: list = [])[source]#

get_jaxpr Get a jax intermediate expression for the function that represents an application of this pipeline to input data. Please note that this only works with tatic positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jited function. You can use partial to fix keyword arguments before calling this method.

Parameters:

static_args (list, optional) – Static argument indices. Will be forwarded to the static_argnums argument of jax.make_jaxpr, by default []

Returns:

  • jax.ClosedJaxpr – If *args is not empty: A jax intermediate representation that results from applying the calling pipeline to the provided arguments.

  • Callable – if *args is empty. A function that will result in a jax intermediate expression if called with desired arguments.

get_jaxpr_for_element(name: str, *args, static_args: list = [])[source]#

get_jaxpr_for_element Create a jax intermediate expression for a given element of the pipeline named ‘name’ with static arguments ‘static_args and arguments *args. If no arguments are provided, a function is returned which will return the intermediate representation once it is called with arguments.

Parameters:
  • name (str) – Name of the element to be retrieved

  • static_args (list, optional) – static positional argument indices, by default []

Returns:

  • jax.ClosedJaxpr – If *args is not empty: Intermediate expression respresenting the computation that is carried out when calling the element with the given arguments.

  • Callable – If *args is empty: Function that returns a jax.ClosedJaxpr once called with appropriate arguments.

property pipeline: dict#

pipeline Get the sequence of functions that make up the pipeline as a dictionary of name: function pairs.

Returns:

function pairs as dict.

Return type:

Description of the pipeline as name

register_transformer(cls)[source]#

register_transformer Make a functtion available to the calling pipeline object. The registered function must be a pure functional function in order to be transformable with jax. The registered transformers are used to build a pipeline.

Parameters:

cls – function object to register.

Raises:

ValueError – When the function is already registered with the pipeline

rubix.pipeline.linear_pipeline module#

class rubix.pipeline.linear_pipeline.LinearTransformerPipeline(cfg: dict, transformers: list)[source]#

Bases: AbstractPipeline

LinearTransformerPipeline An implementation of a data transformation pipeline in the form of a simple, 1-D chain of composed functions in which each function uses the output of the function before it as arguments.

Parameters:

apl – Abstract base class for all pipeline implementations

apply(*args, static_args=[], static_kwargs=[], **kwargs)[source]#

apply Apply the pipeline to a set of input positional arguments *args and keyword arguments **kwargs that match the signature of the first method in the pipeline with static (keyword) arguments that are not traced. First applies the jax jit to the pre-assembled pipeline, then applies the result to the arguments.

Parameters:
  • *args – Positional arguments to apply the pipeline to

  • static_args (list, optional) – Positional arguments that should not be traced, by default [].

  • static_kwargs (list, optional) – Keyword arguments that should not be traced, by default [].

  • **kwargs – Keyword arguments to apply the pipeline to

Returns:

Result of the application of the pipeline to the provided input as object.

Raises:

ValueError – _description_

build_expression()[source]#

build_expression Compose the assembled pipeline into a single expression that has the same signature as the first element of the pipeline.

build_pipeline()[source]#

build_pipeline builds up the pipeline from the internally stored configuration. This only works when all transformers the pipeline is composed of have been registered with it. Multiple different versions (configurations) of the same transformer can be used in a pipeline.

Raises:
  • RuntimeError – When there are no transformers to build the pipeline out of.

  • ValueError – When there are multiple starting points to the pipeline.

  • ValueError – When branching occurs in the pipeline.

  • ValueError – When a config node is present that does not have a ‘name’ attribute.

update_pipeline(current_name)[source]#

update_pipeline adds a new pipeline node with name ‘current_name’ to the pipeline, taking into account internal linear dependencies. Mostly used internally for adding nodes one by one.

Parameters:

current_name – Name of the node to add

rubix.pipeline.transformer module#

rubix.pipeline.transformer.bound_transformer(*args, **kwargs)[source]#

bound_transformer creates a jax.Partial function from an input function and given arguments and keyword arguments. The user must take care that arguments are bound starting from the first, i.e., leftmost. If specific arguments should be bound using keyword arguments may be advisable.

rubix.pipeline.transformer.compiled_transformer(*args, static_args: list = [], static_kwargs: list = [], **kwargs)[source]#

compiled_transformer creates a precompiled function with jax with given arguments and keyword arguments that will be bound to the function, similar to using functools.partial with *args and **kwargs.

Note that any array args/kwargs will behave as dynamic arguments in the jax jit, while any non-array args/kwargs will behave as static. static_args and static_kwargs refer to the remaining arguments.

*args count from the first positional argument of the decorated function in order. *args and **kwargs are bound to the decorated function

Parameters:
  • static_args (list, optional) – Indices of static, i.e., untraced arguments of the bound function, by default [].

  • static_kwargs (list, optional) – Names of static, i.e., untraced, keyword arguments of the bound function, by default {}.

rubix.pipeline.transformer.expression_transformer(*args, static_args: list = [])[source]#

expression_transformer creates a jax intermediate expression with given untraced arguments from a function. Please note that this only works with static positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jited function.

Parameters:

static_args (list, optional) – Indices of static, i.e., untraced arguments to the function, by default [].

Module contents#