rubix.pipeline package#
Submodules#
rubix.pipeline.abstract_pipeline module#
- class rubix.pipeline.abstract_pipeline.AbstractPipeline(cfg: dict, transformers: list)[source]#
Bases:
ABC
AbstractPipeline Abstract baseclass for data transformation pipelines. Provides methods build_pipeline, build_expression and apply which must be implemented by every derived class and which are responsible for building up the pipeline, for assembling it into a self-contained pure functional function and for applyin the latter to data, respectively.
- assemble()[source]#
assemble Assemble the pipeline into a self-contained function with the same signature as the pipeline’s first element. Can only run if all functions that make up the pipeline are registered with it by calling register_transformer.
- Raises:
RuntimeError – When no transformers are registered to build the pipeline out of.
- compile_element(name: str, static_args=[], static_kwargs=[])[source]#
compile_element Compile an element of the pipeline named ‘name’ with the jax jit with the provided static_args and static kwargs.
- Parameters:
name (str) – Name of the element to be compiled
static_args (list, optional) – static positional argument indices. Will be forwarded to the jit static_argnums argument., by default []
static_kwargs (list, optional) – Names of the static keyword arguments. Will be forwarded to the jit static_argnames argument, by default []
- Returns:
_description_
- Return type:
_type_
- compile_expression(static_args=[], static_kwargs=[])[source]#
- compile_expression Compile the function that represents an application
of this pipeline to input data using jax jit.
- Parameters:
static_args (list, optional) – static poisitional arguments that should not be traced by jit, by default []
static_kwargs (list, optional) – statiuc keyword arguments that should not be traced by jit, by default []
Returns – Compiled pipeline function as PjitFunction
- get_jaxpr(*args, static_args: list = [])[source]#
get_jaxpr Get a jax intermediate expression for the function that represents an application of this pipeline to input data. Please note that this only works with tatic positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jited function. You can use partial to fix keyword arguments before calling this method.
- Parameters:
static_args (list, optional) – Static argument indices. Will be forwarded to the static_argnums argument of jax.make_jaxpr, by default []
- Returns:
- get_jaxpr_for_element(name: str, *args, static_args: list = [])[source]#
get_jaxpr_for_element Create a jax intermediate expression for a given element of the pipeline named ‘name’ with static arguments ‘static_args and arguments *args. If no arguments are provided, a function is returned which will return the intermediate representation once it is called with arguments.
- Parameters:
name (str) – Name of the element to be retrieved
static_args (list, optional) – static positional argument indices, by default []
- Returns:
- property pipeline: dict#
pipeline Get the sequence of functions that make up the pipeline as a dictionary of name: function pairs.
- Returns:
function pairs as dict.
- Return type:
Description of the pipeline as name
- register_transformer(cls)[source]#
register_transformer Make a functtion available to the calling pipeline object. The registered function must be a pure functional function in order to be transformable with jax. The registered transformers are used to build a pipeline.
- Parameters:
cls – function object to register.
- Raises:
ValueError – When the function is already registered with the pipeline
rubix.pipeline.linear_pipeline module#
- class rubix.pipeline.linear_pipeline.LinearTransformerPipeline(cfg: dict, transformers: list)[source]#
Bases:
AbstractPipeline
LinearTransformerPipeline An implementation of a data transformation pipeline in the form of a simple, 1-D chain of composed functions in which each function uses the output of the function before it as arguments.
- Parameters:
apl – Abstract base class for all pipeline implementations
- apply(*args, static_args=[], static_kwargs=[], **kwargs)[source]#
apply Apply the pipeline to a set of input positional arguments *args and keyword arguments **kwargs that match the signature of the first method in the pipeline with static (keyword) arguments that are not traced. First applies the jax jit to the pre-assembled pipeline, then applies the result to the arguments.
- Parameters:
*args – Positional arguments to apply the pipeline to
static_args (list, optional) – Positional arguments that should not be traced, by default [].
static_kwargs (list, optional) – Keyword arguments that should not be traced, by default [].
**kwargs – Keyword arguments to apply the pipeline to
- Returns:
Result of the application of the pipeline to the provided input as object.
- Raises:
ValueError – _description_
- build_expression()[source]#
build_expression Compose the assembled pipeline into a single expression that has the same signature as the first element of the pipeline.
- build_pipeline()[source]#
build_pipeline builds up the pipeline from the internally stored configuration. This only works when all transformers the pipeline is composed of have been registered with it. Multiple different versions (configurations) of the same transformer can be used in a pipeline.
- Raises:
RuntimeError – When there are no transformers to build the pipeline out of.
ValueError – When there are multiple starting points to the pipeline.
ValueError – When branching occurs in the pipeline.
ValueError – When a config node is present that does not have a ‘name’ attribute.
rubix.pipeline.transformer module#
- rubix.pipeline.transformer.bound_transformer(*args, **kwargs)[source]#
bound_transformer creates a jax.Partial function from an input function and given arguments and keyword arguments. The user must take care that arguments are bound starting from the first, i.e., leftmost. If specific arguments should be bound using keyword arguments may be advisable.
- rubix.pipeline.transformer.compiled_transformer(*args, static_args: list = [], static_kwargs: list = [], **kwargs)[source]#
compiled_transformer creates a precompiled function with jax with given arguments and keyword arguments that will be bound to the function, similar to using functools.partial with *args and **kwargs.
Note that any array args/kwargs will behave as dynamic arguments in the jax jit, while any non-array args/kwargs will behave as static. static_args and static_kwargs refer to the remaining arguments.
*args count from the first positional argument of the decorated function in order. *args and **kwargs are bound to the decorated function
- Parameters:
static_args (list, optional) – Indices of static, i.e., untraced arguments of the bound function, by default [].
static_kwargs (list, optional) – Names of static, i.e., untraced, keyword arguments of the bound function, by default {}.
- rubix.pipeline.transformer.expression_transformer(*args, static_args: list = [])[source]#
expression_transformer creates a jax intermediate expression with given untraced arguments from a function. Please note that this only works with static positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jited function.
- Parameters:
static_args (list, optional) – Indices of static, i.e., untraced arguments to the function, by default [].