Source code for rubix.pipeline.linear_pipeline
from . import abstract_pipeline as apl
from .transformer import bound_transformer, compiled_transformer
from jax.tree_util import Partial
from copy import deepcopy
import warnings
[docs]
class LinearTransformerPipeline(apl.AbstractPipeline):
"""
LinearTransformerPipeline An implementation of a data transformation
pipeline in the form of a simple, 1-D chain of composed functions in which
each function uses the output of the function before it as arguments.
Args:
apl: Abstract base class for all pipeline implementations
"""
def __init__(self, cfg: dict, transformers: list):
"""
__init__ Build a new LinearTransformerPipeline instance
This does only set up all the necessary things to build the pipeline.
The pipeline itself has to be created after registering transformers
to use and calling `assemble`.
Parameters
----------
cfg : dict Read config file defining the pipeline
transformers : list Transformer functions to use
"""
super().__init__(cfg, transformers)
[docs]
def update_pipeline(self, current_name):
"""
update_pipeline adds a new pipeline node with name 'current_name' to
the pipeline, taking into account internal linear dependencies.
Mostly used internally for adding nodes one by one.
Args:
current_name: Name of the node to add
"""
if current_name not in self.config["Transformers"]:
raise RuntimeError(f"Node '{current_name}' not found in the config")
for key, node in self.config["Transformers"].items():
if current_name == node["depends_on"]:
func = bound_transformer(*node["args"], **node["kwargs"])(
self.transformers[node["name"]]
)
self._pipeline.append(func)
self._names.append(key)
self.update_pipeline(key)
[docs]
def build_pipeline(self):
"""
build_pipeline builds up the pipeline from the internally stored
configuration.
This only works when all transformers the pipeline is composed of have
been registered with it. Multiple different versions (configurations) of
the same transformer can be used in a pipeline.
Raises
------
RuntimeError
When there are no transformers to build the pipeline out of.
ValueError
When there are multiple starting points to the pipeline.
ValueError
When branching occurs in the pipeline.
ValueError
When a config node is present that does not have a 'name' attribute.
"""
if len(self.transformers) == 0:
raise RuntimeError("No registered transformers present")
# sanity check: and make sure that dependencies are not there multiple times, as branching is not allowed either
# use a set for quick and easy lookup. This check also captures multiple end points.
dependencies = set()
for key, node in self.config["Transformers"].items():
if "name" not in node:
raise ValueError(
"Each node of a pipeline must have a config node containing 'name'"
)
if "args" not in node:
raise ValueError("Config node must have a possibly empty args element")
if "kwargs" not in node:
raise ValueError(
"Config node must have a possible empty kwargs element"
)
if "depends_on" not in node:
raise ValueError(
"Config node must have a possibly 'null' valued node depends_on"
)
dep = node["depends_on"]
if dep is None:
continue
if dep in dependencies:
raise ValueError(
f"Dependencies must be unique in a linear pipeline as branching is not allowed. Found {dep} at least twice"
)
else:
dependencies.add(node["depends_on"])
# find the starting point
start_func = None
start_name = None
for key, node in self.config["Transformers"].items():
if node["depends_on"] is None and start_name is None:
start_func = bound_transformer(*node["args"], **node["kwargs"])(
self.transformers[node["name"]]
)
start_name = key
elif node["depends_on"] is None and start_name is not None:
raise ValueError("There can only be one starting point.")
else:
continue
self._pipeline = [
start_func,
]
self._names = [
start_name,
]
self.update_pipeline(start_name)
[docs]
def build_expression(self):
"""
build_expression Compose the assembled pipeline into a single
expression that has the same signature as the first element of the
pipeline.
"""
def expr(input, pipeline=[]):
res = input
for f in pipeline:
res = f(res)
return res
# deepcopy is needed to isolate the expr-function instance from the class,
# since in principle it's a closure that pulls in the surrounding scope,
# which includes `self`
self.expression = Partial(expr, pipeline=deepcopy(self._pipeline))
[docs]
def apply(self, *args, static_args=[], static_kwargs=[], **kwargs):
"""
apply Apply the pipeline to a set of input positional arguments *args
and keyword arguments **kwargs that match the
signature of the first method in the pipeline with static
(keyword) arguments that are not traced. First applies the jax jit
to the pre-assembled pipeline, then applies the result to the
arguments.
Args:
*args: Positional arguments to apply the pipeline to
static_args (list, optional): Positional arguments that should not be traced, by default [].
static_kwargs (list, optional): Keyword arguments that should not be traced, by default [].
**kwargs: Keyword arguments to apply the pipeline to
Returns:
Result of the application of the pipeline to the provided input as object.
Raises:
ValueError
_description_
"""
print("Arguments: ", *args)
if len(args) == 0:
raise ValueError("Cannot apply the pipeline to an empty list of arguments")
if self.compiled_expression is None:
self.compile_expression(
static_args=static_args, static_kwargs=static_kwargs
)
return self.compiled_expression(*args, **kwargs) # type: ignore