Concept of the pipeline#
Basic ideas#
rubix essentially implements a big data transformation pipeline.
a pipeline is composed of nodes that are ordered in a list ordered by execution order (or more generally a DAG [not supported currently]). Each node is called a transformer.
each step in this pipeline (i.e., each transformer) can ultimately be seen in itself as being composed of other, smaller transformers. This gives us a pattern that can be used to guide the implementation of transformers
simple implementation in rubix.pipeline
Restrictions#
jax is pure functional. Anything that needs to be transformed with jax has to be a pure function. Any stuff that comes from the environment must be explicitly copied into the function or be bound to it such that the internal state is of the function is self-contained.
It’s irrelevant what builds these pure functions. Therefore, we use a factory pattern to do all configuration work like reading files, pulling stuff from the net, providing any function arguments to be used in the pipeline and so on. A factory then produces a pure function that contains all the data we need as static arguments and retains only the stuff it computes on as tracable arguments.
we can leverage
jax.tree_util.Partial
for this, which works likefunctools.partial
but is compatible with jax transformations. Note that stateful objects can still be used internally as long as no stuff from an outer scope (that may change over time) is read or written. This is the user’s responsibility
from copy import deepcopy
import jax
import jax.numpy as jnp
from jax import make_jaxpr
from jax.tree_util import Partial
from jax import jit, grad
from rubix.pipeline import linear_pipeline as ltp
from rubix.pipeline import transformer as rtr
from rubix.utils import read_yaml
Build some simple decorator for function configuration#
leverages jax.tree_util.Partial
builds a partial object to which jax transformations can be applied
three cases:
build the pure function object: you have to take care about static args/kwargs yourself upon calling jit. The decorator only builds the function object
jit it right away: the usual. here you can tell it which args/kwargs to trace or not with the
static_args
andstatic_kwargs
keyword argumentsbuild expression: mainly to check what comes out of the thing at the end of for intermediate steps. can build a jax expression (wiht no arguments) or a jax core expression (when arguments are given as well). Note that for some reasone,
jax.make_jaxpr
does not havestatic_argnames
likejit
does.
With these, we can configure our pipeline transformers.
Not entirely sure right now which are useful or needed
these decorators/factory functions live in
rubix.pipeline.transformer
simple transformer decorator that binds function to arguments
this shows the basic implementation, they are available under rubix.pipeline.transformer
in the package.
def bound_transformer(*args, **kwargs):
"""
bound_transformer Create a jax.Partial function from an input
function and given arguments and keyword arguments.
The user must take care that arguments are bound starting from the first,
i.e., leftmost. If specific arguments should be bound using keyword
arguments may be advisable.
"""
def transformer_wrap(kernel):
return Partial(
deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)
) # deepcopy to avoid context dependency
return transformer_wrap
@bound_transformer(z=5, k=3.14)
def add(x, y, z: float = 0, k: float = 0):
return x + y + z + k
type(add)
jax.tree_util.Partial
addjit = jax.jit(add)
x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)
addjit(x, x)
Array([14.14, 12.14, 10.14], dtype=float32)
Compiling transformer to jit individual elements and bind them to traceable partial arguments#
can be used for the final pipeline or for intermediate steps during debug or whatever
combines a
Partial
to bind arguments that is then jitted with static args and kwargs. However, bound args and kwargs can NOT be static at the same time. In principle, we would want a partial of a jit here, which kind of defeats the purpose of the jit because of overhead of the wrapper?A solution to this would yield a configurable jit factory, essentially.
I am not entirely sure why the below works the way it does
not even entirely sure it is useful at all…
def compiled_transformer(
*args,
static_args: list = [],
static_kwargs: list = [],
**kwargs,
):
def transformer_wrap(kernel):
return jit(
Partial(deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)),
static_argnums=deepcopy(static_args),
static_argnames=deepcopy(static_kwargs),
)
return transformer_wrap
@compiled_transformer(
z=5,
k=-3.14,
static_kwargs=[
"k",
],
)
def cond_add(x, y, z: float = 0, k: float = 0):
if k < 0:
return x + y + z + k
else:
return x + y + z + 2 * k
cond_add
<PjitFunction of Partial(<function cond_add at 0x7f87f830aef0>, z=5, k=-3.14)>
cond_add(x, x)
Array([7.8599997, 5.8599997, 3.86 ], dtype=float32)
def cond_add(x, y, z: float = 0, k: float = 0):
if k < 0:
return x + y + z + k
else:
return x + y + z + 2 * k
use on predefined functions without the decorator syntax
cond_add_plus = compiled_transformer(z=5, static_kwargs=["k"])(cond_add)
cond_add_plus
<PjitFunction of Partial(<function cond_add at 0x7f87f830b400>, z=5)>
cond_add_plus(x, x, k=-3.14)
Array([7.8599997, 5.8599997, 3.86 ], dtype=float32)
Problem: the compiled_transformer
decorator cannot make args or kwargs static that are bound to the function, i.e., configured parameters are not static here. This only works if the entire pipeline is compiled after assembling it. Not sure how to fix that at the moment, if at all
Expression based decorator for getting out the intermediate jaxpr
object for inspection**#
make_jaxpr
does not support kwargs. god knows why?
def expression_transformer(
*args,
static_args: list = [],
):
def transformer_wrap(kernel):
if len(args) > 0:
return jax.make_jaxpr(kernel, static_argnums=static_args)(*args)
else:
return jax.make_jaxpr(kernel, static_argnums=static_args)
return transformer_wrap
@expression_transformer(x, x, 5, 3.14, static_args=[2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
if k < 0:
return x + y + z + k
else:
return x + y + z + 2 * k
cond_add
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = add a b
d:f32[3] = add c 5.0
e:f32[3] = add d 6.28000020980835
in (e,) }
make sure to use the right static_args
when doing control flow, or use jax/lax primitives
@expression_transformer(x, x, 5, -3.14, static_args=[3])
def cond_add(x, y, z: float = 0, k: float = 0):
if k < 0:
return x + y + z + k
else:
return x + y + z + 2 * k
cond_add
{ lambda ; a:f32[3] b:f32[3] c:i32[]. let
d:f32[3] = add a b
e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
f:f32[3] = add d e
g:f32[3] = add f -3.140000104904175
in (g,) }
without giving arguments you get out a function that produces an expression when arguments are added
@expression_transformer(static_args=[2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
if k < 0:
return x + y + z + k
else:
return x + y + z + 2 * k
cond_add
<function jax.make_jaxpr(cond_add)(x, y, z: float = 0, k: float = 0)>
cond_add(x, x, 3, 2.71)
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = add a b
d:f32[3] = add c 3.0
e:f32[3] = add d 5.420000076293945
in (e,) }
Define a number of simple, dump transformers#
we pretend that their second value is something we want to configure from the start and hence it should not be traced
we can use the above decorators to bind their second arg to something we know
def add(x, s: float):
return x + s
def mult(x, m: float):
return x * m
def div(x, d: float):
return x / d
def sub(x, s: float):
return x - s
Configuration files and pipeline building#
General remarks about yaml#
yaml format: dictionary
inside the dictionary one can arbitrarily nest lists, dicts.
yaml is customizable for node formats that are not provided by default, or for reading in types. Look up yaml-tags for more.
there’s yaml libraries for pretty much all common languages
Here, we use yaml in the following way:#
the config file builds an adjacency list of a DAG essentially, but currently the design is limited to only one child per node => linear pipelines only, no branching
consequently, the build algorithm is limited to linear pipelines for the moment. Both must evolve together.
while a more general abstract base class is provided, we only implement a linear pipeline class
LinearTransformerPipeline
at the momentthe essential part is the
Transformers
node of the config. this is the actual DAG adjacency list. This needs to adhere to the format outlined below.you can add other nodes to configure other parts of your system: data directories and so on.
The starting point is always defined by a node that does not depend on another node.
The stop point is just the last element in the pipeline
Config node structure:#
name_of_pipeline_step:
name: name_of_function to use for this step
depends_on: name_of_step_immediatelly_prior_in_pipeline
args: # arguments to the transformer functions that should be bound to it
argument1: value1
argument2: value2
argumentN: valueN
the arguments in args
will be used to create the Partial
object, using the transformer decorators above.
Example
B:
name: sub
depends_on: C
args:
s: 2
C:
name: add
depends_on: null
args:
s: 4
Here, C
is the starting node, i.e., the first function in the pipeline.
Whatever you do before that with your data does not concern the pipeline and hence has no influence on differentiability etc.
For a full example, see the demo.yml
file.
Read yaml and build pipeline#
read_yaml
is available from rubix.utils.py
and is very simple
read_cfg = read_yaml("./demo.yml") # implemented in utils
read_cfg
{'Transformers': {'A': {'name': 'add',
'depends_on': 'B',
'args': [],
'kwargs': {'s': 3.0}},
'X': {'name': 'mult', 'depends_on': 'A', 'args': [], 'kwargs': {'m': 3}},
'Z': {'name': 'div', 'depends_on': 'X', 'args': [], 'kwargs': {'d': 4}},
'B': {'name': 'sub', 'depends_on': 'C', 'args': [], 'kwargs': {'s': 2}},
'C': {'name': 'add', 'depends_on': None, 'args': [], 'kwargs': {'s': 4}}}}
type(read_cfg)
dict
read_cfg["Transformers"]
{'A': {'name': 'add', 'depends_on': 'B', 'args': [], 'kwargs': {'s': 3.0}},
'X': {'name': 'mult', 'depends_on': 'A', 'args': [], 'kwargs': {'m': 3}},
'Z': {'name': 'div', 'depends_on': 'X', 'args': [], 'kwargs': {'d': 4}},
'B': {'name': 'sub', 'depends_on': 'C', 'args': [], 'kwargs': {'s': 2}},
'C': {'name': 'add', 'depends_on': None, 'args': [], 'kwargs': {'s': 4}}}
type(read_cfg["Transformers"])
dict
Transformers need to be registered upon creation. If you have fixed ones or many of them, maybe it makes sense to write a factory function.
tp = ltp.LinearTransformerPipeline(read_cfg, [add, mult, div, sub])
tp.transformers
{'add': <function __main__.add(x, s: float)>,
'mult': <function __main__.mult(x, m: float)>,
'div': <function __main__.div(x, d: float)>,
'sub': <function __main__.sub(x, s: float)>}
The transformers
member gives us a dict of name: function
pairs for the transformers
This currently has to be done before the assembly of the pipeline, or the pipeline will not know what to assemble it from
tp.assemble()
tp.pipeline
{'C': Partial(<function add at 0x7f87f837c0d0>, s=4),
'B': Partial(<function sub at 0x7f87f837c040>, s=2),
'A': Partial(<function add at 0x7f87f837c0d0>, s=3.0),
'X': Partial(<function mult at 0x7f87f837c310>, m=3),
'Z': Partial(<function div at 0x7f87f837c160>, d=4)}
Now we have a list of jax Partial
s to which we can apply, assuming the individual elements are well behaved, all jax transformations in principle. If this is true for the elements, then it is true for the composition as long as the function we use for composition is pure functional itself
x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)
The expression that a pipeline builds is a partial object that is bound to the pipeline
tp.expression
Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x7f87fc008790>, pipeline=[Partial(<function add at 0x7f87f837c0d0>, s=4), Partial(<function sub at 0x7f87f837c040>, s=2), Partial(<function add at 0x7f87f837c0d0>, s=3.0), Partial(<function mult at 0x7f87f837c310>, m=3), Partial(<function div at 0x7f87f837c160>, d=4)])
… it has the same signature as the first function in the pipeline.
func = tp.compile_expression()
func
<PjitFunction of Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x7f87fc008790>, pipeline=[Partial(<function add at 0x7f87f837c0d0>, s=4), Partial(<function sub at 0x7f87f837c040>, s=2), Partial(<function add at 0x7f87f837c0d0>, s=3.0), Partial(<function mult at 0x7f87f837c310>, m=3), Partial(<function div at 0x7f87f837c160>, d=4)])>
func(x)
Array([6. , 5.25, 4.5 ], dtype=float32)
div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)
Array([6. , 5.25, 4.5 ], dtype=float32)
… output’s the same. yay :)
expr = tp.get_jaxpr()(x)
expr
{ lambda ; a:f32[3]. let
b:f32[3] = add a 4.0
c:f32[3] = sub b 2.0
d:f32[3] = add c 3.0
e:f32[3] = mul d 3.0
f:f32[3] = div e 4.0
in (f,) }
type(expr)
jax._src.core.ClosedJaxpr
def func_manual(x):
return div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)
make_jaxpr(func_manual)(x)
{ lambda ; a:f32[3]. let
b:f32[3] = add a 4.0
c:f32[3] = sub b 2.0
d:f32[3] = add c 3.0
e:f32[3] = mul d 3.0
f:f32[3] = div e 4.0
in (f,) }
… expressions are too, because JAX is smart enough to trace across loops and we don’t have to mess with expression composition ourselves. We hence should end up with something that’s jax transformable if its elements are jax transformable. yay :)
just for completeness, we can mess a bit more with the expression stuff
jax.jacfwd(tp.compile_expression())(x)
Array([[0.75, 0. , 0. ],
[0. , 0.75, 0. ],
[0. , 0. , 0.75]], dtype=float32)
jax.jacrev(tp.compile_expression())(x)
Array([[0.75, 0. , 0. ],
[0. , 0.75, 0. ],
[0. , 0. , 0.75]], dtype=float32)
jax.hessian(tp.compile_expression())(x)
Array([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]], dtype=float32)
Query individual elements#
this is mainly useful for debugging
compile or get expressions for individual elements
tp.compile_element("A")
<PjitFunction of Partial(_HashableCallableShim(Partial(<function add at 0x7f87f837c0d0>, s=3.0)))>
tp.get_jaxpr_for_element("A", x)
{ lambda ; a:f32[3]. let b:f32[3] = add a 3.0 in (b,) }
when building an expression with no arguments, a function is returned that creates an expression once args are added
f = tp.get_jaxpr_for_element("A")
f
<function jax.<unnamed function>(x, *, s: float = 3.0)>
f(x)
{ lambda ; a:f32[3]. let b:f32[3] = add a 3.0 in (b,) }
Alternative structures that allow for more complex systems#
allow to inject new data at intermediate steps: multiple starting points: transforms the pipeline into an inverted tree.
allow for a step to depend on multiple other steps: transforms the pipeline into a directed acyclic graph. Common structure in more general data processing systems.
=> if possible use something simple like Partial
to accomplish this
Tentative best practices#
think in small steps: a more granular pipeline is easier to write in a pure functional style, easier to reason about and probably also better to optimize.
A more granular system also is easier to test and extend
ideally write the pipeline such that it can be compiled all at once with
compile_expression
.
Summary#
pipeline produces same jax code as handwritten stuff. This seems encouraging.
at which points do we still need to ensure pure functional behavior?
how will we enforce transformer compatibility
this is a pathologically simple case, hence not representative for real-world scenarios
when does it break?
what use cases are not covered?
what else do you need?