rubix.pipeline package#
Submodules#
rubix.pipeline.abstract_pipeline module#
- class rubix.pipeline.abstract_pipeline.AbstractPipeline(cfg: dict, transformers: list[Callable[[...], Any]])[source]#
Bases:
ABCAbstract base class for data transformation pipelines.
Derived classes must implement build_pipeline, build_expression, and apply. These helpers build the pipeline, assemble it into a pure function, and apply it to input data.
- Parameters:
cfg (dict) – Configuration dictionary defining the pipeline.
transformers (list[Callable[..., Any]]) – Transformers that will be registered with the pipeline.
- compile_element(name: str, static_args: Sequence[int] | None = None, static_kwargs: Sequence[str] | None = None) Callable[[...], Any][source]#
Compile a specific pipeline element using
jax.jit.- Parameters:
name (str) – Name of the element to compile.
static_args (Optional[Sequence[int]], optional) – Positional indices forwarded to
jitasstatic_argnums. Defaults toNone.static_kwargs (Optional[Sequence[str]], optional) – Keyword names forwarded to
jitasstatic_argnames. Defaults toNone.
- Raises:
RuntimeError – When compilation of the element fails.
- Returns:
The compiled transformer.
- Return type:
Callable[…, Any]
- compile_expression(static_args: Sequence[int] | None = None, static_kwargs: Sequence[str] | None = None) Callable[[...], Any][source]#
Compile the pipeline expression using
jax.jit.- Parameters:
static_args (Optional[Sequence[int]], optional) – Positional indices forwarded to
jitasstatic_argnums. Defaults toNone.static_kwargs (Optional[Sequence[str]], optional) – Keyword names forwarded to
jitasstatic_argnames. Defaults toNone.
- Raises:
RuntimeError – When compilation fails.
- Returns:
Compiled pipeline function.
- Return type:
Callable[…, Any]
- get_jaxpr(*args: Any, static_args: Sequence[int] | None = None) Callable[[...], Any] | ClosedJaxpr[source]#
Return a JAX intermediate expression for the pipeline.
Note: Please note that this only works with static positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jitted function. You can use partial to fix keyword arguments before calling this method.
- Parameters:
*args (Any) – Positional arguments forwarded to the expression whose intermediate representation should be produced.
static_args (Optional[Sequence[int]], optional) – Static positional indices forwarded to
jax.make_jaxprviastatic_argnums. Defaults toNone.
- Returns:
When
*argsis provided. Callable[…, Any]: When*argsis empty.- Return type:
ClosedJaxpr
- get_jaxpr_for_element(name: str, *args: Any, static_args: Sequence[int] | None = None) Callable[[...], Any] | ClosedJaxpr[source]#
get_jaxpr_for_element Create a jax intermediate expression for a given element of the pipeline named ‘name’ with static arguments ‘static_args and arguments *args. If no arguments are provided, a function is returned which will return the intermediate representation once it is called with arguments.
- Parameters:
name (str) – Name of the element to inspect.
*args (Any) – Positional arguments forwarded to the element.
static_args (Optional[Sequence[int]], optional) – Static positional indices forwarded to
expression_transformer. Defaults toNone.
- Raises:
RuntimeError – When the expression cannot be created.
- Returns:
When
*argsis provided. Callable[…, Any]: When*argsis empty.- Return type:
ClosedJaxpr
- property pipeline: dict[str, Callable[[...], Any]]#
- Return the registered pipeline elements as a
dictionary of name: function pairs.
- Returns:
Mapping from name to function.
- Return type:
dict[str, Callable[…, Any]]
- register_transformer(cls: Callable[[...], Any]) None[source]#
Register a transformer function for later use to make it available to the calling pipeline object.
Note: The registered function must be a pure functional function in order to be transformable with jax. The registered transformers are used to build a pipeline.
- Parameters:
cls (Callable[..., Any]) – Function to register.
- Raises:
ValueError – When a transformer with the same name is already present.
rubix.pipeline.linear_pipeline module#
- class rubix.pipeline.linear_pipeline.LinearTransformerPipeline(cfg: dict, transformers: list[Callable[[...], Any]])[source]#
Bases:
AbstractPipelineMinimal linear pipeline that chains transformer functions.
An implementation of a data transformation pipeline in the form of a simple, 1-D chain of composed functions in which each function uses the output of the function before it as arguments.
Note
This does only set up all necessary things to build the pipeline. The pipeline itself has to be created after registering transformers to use by calling assemble().
- Parameters:
cfg (dict) – Configuration describing the pipeline nodes and their dependencies.
transformers (list[Callable[..., Any]]) – Transformer functions that will be registered before building the pipeline.
- apply(*args: Any, static_args: Sequence[int] | None = None, static_kwargs: Sequence[str] | None = None, **kwargs: Any) Any[source]#
Apply the compiled pipeline to the provided positional arguments *args and keyword arguments **kwargs that match the signature of the first method in the pipeline with static keyword arguments that are not traced by JAX. First applies the jax jit to the pre-assembled pipeline, then applies the result to the provided arguments.
- Parameters:
*args (Any) – Positional arguments to feed to the pipeline’s first element.
static_args (Sequence[int] | None, optional) – Positional indices that should not be traced by JAX. Defaults to
None.static_kwargs (Sequence[str] | None, optional) – Keyword names that should not be traced by JAX. Defaults to
None.**kwargs (Any) – Keyword arguments to forward to the pipeline.
- Returns:
The result of running the pipeline.
- Return type:
Any
- Raises:
ValueError – When no positional arguments are provided.
- build_expression() None[source]#
Compose the assembled pipeline into a single callable that has the same signature as the first element of the pipeline.
- build_pipeline() None[source]#
Construct the pipeline from the stored configuration.
Note
This only works when all transformers the pipeline is composed of have been registered with it. Multiple different versions (configurations) of the same transformer can be used in a pipeline.
- Raises:
RuntimeError – When no transformers are registered.
ValueError – When the configuration contains invalid or branching dependencies or when there are multiple starting points to the pipeline.
- update_pipeline(current_name: str) None[source]#
Adds a new pipeline node with name ‘current_name’ to the pipeline, taking into account internal linear dependencies. Mostly used internally for adding nodes one by one.
- Parameters:
current_name (str) – Name of the node to append.
- Raises:
RuntimeError – When the requested node is missing from the config.
rubix.pipeline.transformer module#
- rubix.pipeline.transformer.bound_transformer(*args, **kwargs)[source]#
bound_transformer creates a jax.Partial function from an input function and given arguments and keyword arguments. The user must take care that arguments are bound starting from the first, i.e., leftmost. If specific arguments should be bound using keyword arguments may be advisable.
- rubix.pipeline.transformer.compiled_transformer(*args, static_args: list = [], static_kwargs: list = [], **kwargs)[source]#
compiled_transformer creates a precompiled function with jax with given arguments and keyword arguments that will be bound to the function, similar to using functools.partial with *args and **kwargs.
Note that any array args/kwargs will behave as dynamic arguments in the jax jit, while any non-array args/kwargs will behave as static. static_args and static_kwargs refer to the remaining arguments.
*args count from the first positional argument of the decorated function in order.
*argsand**kwargsare bound to the decorated function.- Parameters:
*args – Positional arguments to bind to the target function.
static_args (list, optional) – Indices of static (untraced) positional arguments of the bound function. Defaults to
[].static_kwargs (list, optional) – Names of static (untraced) keyword arguments of the bound function. Defaults to
[].**kwargs – Keyword arguments to bind to the target function.
- Returns:
A decorator that returns a jitted function with bound args.
- Return type:
callable
- rubix.pipeline.transformer.expression_transformer(*args, static_args: list = [])[source]#
expression_transformer creates a jax intermediate expression with given untraced arguments from a function. Please note that this only works with static positional arguments: JAX does currently not provide a way to have static keyword arguments when creating a jaxpr and not a jited function. :param *args: Positional arguments to bind for the expression generation. :param static_args: Indices of static (untraced) positional
arguments to the function. Defaults to
[].- Returns:
A function (or jaxpr) produced from the provided kernel.
- Return type:
callable