Source code for rubix.pipeline.linear_pipeline
from copy import deepcopy
from typing import Any, Callable, Sequence
from jax.tree_util import Partial
from . import abstract_pipeline as apl
from .transformer import bound_transformer
[docs]
class LinearTransformerPipeline(apl.AbstractPipeline):
"""Minimal linear pipeline that chains transformer functions.
An implementation of a data transformation pipeline in the form
of a simple, 1-D chain of composed functions in which each function
uses the output of the function before it as arguments.
Note:
This does only set up all necessary things to build the pipeline.
The pipeline itself has to be created after registering transformers
to use by calling `assemble()`.
Args:
cfg (dict): Configuration describing the pipeline nodes and their
dependencies.
transformers (list[Callable[..., Any]]): Transformer functions that
will be registered before building the pipeline.
"""
def __init__(
self,
cfg: dict,
transformers: list[Callable[..., Any]],
) -> None:
super().__init__(cfg, transformers)
[docs]
def update_pipeline(self, current_name: str) -> None:
"""Adds a new pipeline node with name 'current_name' to
the pipeline, taking into account internal linear dependencies.
Mostly used internally for adding nodes one by one.
Args:
current_name (str): Name of the node to append.
Raises:
RuntimeError: When the requested node is missing from the config.
"""
if current_name not in self.config["Transformers"]:
raise RuntimeError(f"Node '{current_name}' not found in the config")
for key, node in self.config["Transformers"].items():
if current_name == node["depends_on"]:
func = bound_transformer(*node["args"], **node["kwargs"])(
self.transformers[node["name"]]
)
self._pipeline.append(func)
self._names.append(key)
self.update_pipeline(key)
[docs]
def build_pipeline(self) -> None:
"""Construct the pipeline from the stored configuration.
Note:
This only works when all transformers the pipeline is composed of
have been registered with it. Multiple different versions
(configurations) of the same transformer can be used in a pipeline.
Raises:
RuntimeError: When no transformers are registered.
ValueError: When the configuration contains invalid or branching
dependencies or when there are multiple starting points to the
pipeline.
"""
if len(self.transformers) == 0:
raise RuntimeError("No registered transformers present")
# sanity check: and make sure that dependencies are not there multiple
# times, as branching is not allowed either
# use a set for quick lookup. This check also captures
# multiple end points.
dependencies: set[str] = set()
for key, node in self.config["Transformers"].items():
if "name" not in node:
raise ValueError(
(
"Each node of a pipeline must have a config node "
"containing 'name'"
)
)
if "args" not in node:
raise ValueError("Config node must have a possibly empty args element")
if "kwargs" not in node:
raise ValueError(
"Config node must have a possible empty kwargs element"
)
if "depends_on" not in node:
raise ValueError(
(
"Config node must have a possibly 'null' valued node "
"depends_on"
)
)
dep = node["depends_on"]
if dep is None:
continue
if dep in dependencies:
raise ValueError(
(
"Dependencies must be unique in a linear pipeline as "
"branching is not allowed. "
f"Found {dep} at least twice"
)
)
else:
dependencies.add(dep)
# find the starting point
start_func = None
start_name = None
for key, node in self.config["Transformers"].items():
if node["depends_on"] is None and start_name is None:
start_func = bound_transformer(*node["args"], **node["kwargs"])(
self.transformers[node["name"]]
)
start_name = key
elif node["depends_on"] is None and start_name is not None:
raise ValueError("There can only be one starting point.")
else:
continue
self._pipeline = [start_func]
self._names = [start_name]
self.update_pipeline(start_name)
[docs]
def build_expression(self) -> None:
"""Compose the assembled pipeline into a single callable that has
the same signature as the first element of the pipeline."""
def expr(input, pipeline=[]):
res = input
for f in pipeline:
res = f(res)
return res
# deepcopy is needed to isolate the expr-function instance from the
# class, since in principle it's a closure that pulls in the
# surrounding scope, which includes `self`
self.expression = Partial(expr, pipeline=deepcopy(self._pipeline))
[docs]
def apply(
self,
*args: Any,
static_args: Sequence[int] | None = None,
static_kwargs: Sequence[str] | None = None,
**kwargs: Any,
) -> Any:
"""Apply the compiled pipeline to the provided positional arguments
`*args` and keyword arguments `**kwargs` that match the signature of
the first method in the pipeline with static keyword arguments that are
not traced by JAX. First applies the jax jit to the pre-assembled
pipeline, then applies the result to the provided arguments.
Args:
*args (Any): Positional arguments to feed to the pipeline's first
element.
static_args (Sequence[int] | None, optional): Positional indices
that should not be traced by JAX. Defaults to ``None``.
static_kwargs (Sequence[str] | None, optional): Keyword names
that should not be traced by JAX. Defaults to ``None``.
**kwargs (Any): Keyword arguments to forward to the pipeline.
Returns:
Any: The result of running the pipeline.
Raises:
ValueError: When no positional arguments are provided.
"""
if len(args) == 0:
raise ValueError("Cannot apply the pipeline to an empty list of arguments")
if self.compiled_expression is None:
self.compile_expression(
static_args=static_args,
static_kwargs=static_kwargs,
)
return self.compiled_expression(*args, **kwargs) # type: ignore