Concept of the pipeline#

Basic ideas#

  • rubix essentially implements a big data transformation pipeline.

  • a pipeline is composed of nodes that are ordered in a list ordered by execution order (or more generally a DAG [not supported currently]). Each node is called a transformer.

  • each step in this pipeline (i.e., each transformer) can ultimately be seen in itself as being composed of other, smaller transformers. This gives us a pattern that can be used to guide the implementation of transformers

  • simple implementation in rubix.pipeline

Restrictions#

  • jax is pure functional. Anything that needs to be transformed with jax has to be a pure function. Any stuff that comes from the environment must be explicitly copied into the function or be bound to it such that the internal state is of the function is self-contained.

  • It’s irrelevant what builds these pure functions. Therefore, we use a factory pattern to do all configuration work like reading files, pulling stuff from the net, providing any function arguments to be used in the pipeline and so on. A factory then produces a pure function that contains all the data we need as static arguments and retains only the stuff it computes on as tracable arguments.

  • we can leverage jax.tree_util.Partial for this, which works like functools.partial but is compatible with jax transformations. Note that stateful objects can still be used internally as long as no stuff from an outer scope (that may change over time) is read or written. This is the user’s responsibility

from copy import deepcopy
import jax
import jax.numpy as jnp
from jax import make_jaxpr
from jax.tree_util import Partial
from jax import jit, grad
from rubix.pipeline import linear_pipeline as ltp
from rubix.pipeline import transformer as rtr
from rubix.utils import read_yaml

Build some simple decorator for function configuration#

  • leverages jax.tree_util.Partial

  • builds a partial object to which jax transformations can be applied

  • three cases:

    • build the pure function object: you have to take care about static args/kwargs yourself upon calling jit. The decorator only builds the function object

    • jit it right away: the usual. here you can tell it which args/kwargs to trace or not with the static_args and static_kwargs keyword arguments

    • build expression: mainly to check what comes out of the thing at the end of for intermediate steps. can build a jax expression (wiht no arguments) or a jax core expression (when arguments are given as well). Note that for some reasone, jax.make_jaxpr does not have static_argnames like jit does.

  • With these, we can configure our pipeline transformers.

  • Not entirely sure right now which are useful or needed

  • these decorators/factory functions live in rubix.pipeline.transformer

simple transformer decorator that binds function to arguments

this shows the basic implementation, they are available under rubix.pipeline.transformer in the package.

def bound_transformer(*args, **kwargs):
    """
    bound_transformer  Create a jax.Partial function from an input
    function and given arguments and keyword arguments.
    The user must take care that arguments are bound starting from the first,
    i.e., leftmost. If specific arguments should be bound using keyword
    arguments may be advisable.
    """

    def transformer_wrap(kernel):
        return Partial(
            deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)
        )  # deepcopy to avoid context dependency

    return transformer_wrap
@bound_transformer(z=5, k=3.14)
def add(x, y, z: float = 0, k: float = 0):
    return x + y + z + k
type(add)
jax._src.tree_util.Partial
addjit = jax.jit(add)
x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)
addjit(x, x)
Array([14.14, 12.14, 10.14], dtype=float32)

Compiling transformer to jit individual elements and bind them to traceable partial arguments#

  • can be used for the final pipeline or for intermediate steps during debug or whatever

  • combines a Partial to bind arguments that is then jitted with static args and kwargs. However, bound args and kwargs can NOT be static at the same time. In principle, we would want a partial of a jit here, which kind of defeats the purpose of the jit because of overhead of the wrapper?

  • A solution to this would yield a configurable jit factory, essentially.

  • I am not entirely sure why the below works the way it does

  • not even entirely sure it is useful at all…

def compiled_transformer(
    *args,
    static_args: list = [],
    static_kwargs: list = [],
    **kwargs,
):

    def transformer_wrap(kernel):

        return jit(
            Partial(deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)),
            static_argnums=deepcopy(static_args),
            static_argnames=deepcopy(static_kwargs),
        )

    return transformer_wrap
@compiled_transformer(
    z=5,
    k=-3.14,
    static_kwargs=[
        "k",
    ],
)
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k
cond_add
<PjitFunction of Partial(<function cond_add at 0x7f0c9404e9e0>, z=5, k=-3.14)>
cond_add(x, x)
Array([7.8599997, 5.8599997, 3.86     ], dtype=float32)
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k

use on predefined functions without the decorator syntax

cond_add_plus = compiled_transformer(z=5, static_kwargs=["k"])(cond_add)
cond_add_plus
<PjitFunction of Partial(<function cond_add at 0x7f0c9404fac0>, z=5)>
cond_add_plus(x, x, k=-3.14)
Array([7.8599997, 5.8599997, 3.86     ], dtype=float32)

Problem: the compiled_transformer decorator cannot make args or kwargs static that are bound to the function, i.e., configured parameters are not static here. This only works if the entire pipeline is compiled after assembling it. Not sure how to fix that at the moment, if at all

Expression based decorator for getting out the intermediate jaxpr object for inspection**#

  • make_jaxpr does not support kwargs. god knows why?

def expression_transformer(
    *args,
    static_args: list = [],
):

    def transformer_wrap(kernel):
        if len(args) > 0:
            return jax.make_jaxpr(kernel, static_argnums=static_args)(*args)
        else:
            return jax.make_jaxpr(kernel, static_argnums=static_args)

    return transformer_wrap
@expression_transformer(x, x, 5, 3.14, static_args=[2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k
cond_add
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = add a b
    d:f32[3] = add c 5.0
    e:f32[3] = add d 6.28000020980835
  in (e,) }

make sure to use the right static_args when doing control flow, or use jax/lax primitives

@expression_transformer(x, x, 5, -3.14, static_args=[3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k
cond_add
{ lambda ; a:f32[3] b:f32[3] c:i32[]. let
    d:f32[3] = add a b
    e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
    f:f32[3] = add d e
    g:f32[3] = add f -3.140000104904175
  in (g,) }

without giving arguments you get out a function that produces an expression when arguments are added

@expression_transformer(static_args=[2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0:
        return x + y + z + k
    else:
        return x + y + z + 2 * k
cond_add
<function jax.make_jaxpr(cond_add)(x, y, z: float = 0, k: float = 0)>
cond_add(x, x, 3, 2.71)
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = add a b
    d:f32[3] = add c 3.0
    e:f32[3] = add d 5.420000076293945
  in (e,) }

Define a number of simple, dump transformers#

  • we pretend that their second value is something we want to configure from the start and hence it should not be traced

  • we can use the above decorators to bind their second arg to something we know

def add(x, s: float):
    return x + s


def mult(x, m: float):
    return x * m


def div(x, d: float):
    return x / d


def sub(x, s: float):
    return x - s

Configuration files and pipeline building#

General remarks about yaml#

  • yaml format: dictionary

  • inside the dictionary one can arbitrarily nest lists, dicts.

  • yaml is customizable for node formats that are not provided by default, or for reading in types. Look up yaml-tags for more.

  • there’s yaml libraries for pretty much all common languages

Here, we use yaml in the following way:#

  • the config file builds an adjacency list of a DAG essentially, but currently the design is limited to only one child per node => linear pipelines only, no branching

  • consequently, the build algorithm is limited to linear pipelines for the moment. Both must evolve together.

  • while a more general abstract base class is provided, we only implement a linear pipeline class LinearTransformerPipeline at the moment

  • the essential part is the Transformers node of the config. this is the actual DAG adjacency list. This needs to adhere to the format outlined below.

  • you can add other nodes to configure other parts of your system: data directories and so on.

  • The starting point is always defined by a node that does not depend on another node.

  • The stop point is just the last element in the pipeline

Config node structure:#

name_of_pipeline_step:

    name: name_of_function to use for this step

    depends_on: name_of_step_immediatelly_prior_in_pipeline 
    
    args: # arguments to the transformer functions that should be bound to it

        argument1: value1 

        argument2: value2 

        argumentN: valueN

the arguments in args will be used to create the Partial object, using the transformer decorators above.

Example

  B: 
    name: sub
    depends_on: C
    args:
      s: 2
  C:
    name: add
    depends_on: null
    args:
      s: 4

Here, C is the starting node, i.e., the first function in the pipeline. Whatever you do before that with your data does not concern the pipeline and hence has no influence on differentiability etc.

For a full example, see the demo.yml file.

Read yaml and build pipeline#

read_yaml is available from rubix.utils.py and is very simple

read_cfg = read_yaml("./demo.yml")  # implemented in utils
read_cfg
{'Transformers': {'A': {'name': 'add',
   'depends_on': 'B',
   'args': [],
   'kwargs': {'s': 3.0}},
  'X': {'name': 'mult', 'depends_on': 'A', 'args': [], 'kwargs': {'m': 3}},
  'Z': {'name': 'div', 'depends_on': 'X', 'args': [], 'kwargs': {'d': 4}},
  'B': {'name': 'sub', 'depends_on': 'C', 'args': [], 'kwargs': {'s': 2}},
  'C': {'name': 'add', 'depends_on': None, 'args': [], 'kwargs': {'s': 4}}}}
type(read_cfg)
dict
read_cfg["Transformers"]
{'A': {'name': 'add', 'depends_on': 'B', 'args': [], 'kwargs': {'s': 3.0}},
 'X': {'name': 'mult', 'depends_on': 'A', 'args': [], 'kwargs': {'m': 3}},
 'Z': {'name': 'div', 'depends_on': 'X', 'args': [], 'kwargs': {'d': 4}},
 'B': {'name': 'sub', 'depends_on': 'C', 'args': [], 'kwargs': {'s': 2}},
 'C': {'name': 'add', 'depends_on': None, 'args': [], 'kwargs': {'s': 4}}}
type(read_cfg["Transformers"])
dict

Transformers need to be registered upon creation. If you have fixed ones or many of them, maybe it makes sense to write a factory function.

tp = ltp.LinearTransformerPipeline(read_cfg, [add, mult, div, sub])
tp.transformers
{'add': <function __main__.add(x, s: float)>,
 'mult': <function __main__.mult(x, m: float)>,
 'div': <function __main__.div(x, d: float)>,
 'sub': <function __main__.sub(x, s: float)>}

The transformers member gives us a dict of name: function pairs for the transformers This currently has to be done before the assembly of the pipeline, or the pipeline will not know what to assemble it from

tp.assemble()
tp.pipeline
{'C': Partial(<function add at 0x7f0c900581f0>, s=4),
 'B': Partial(<function sub at 0x7f0c90058310>, s=2),
 'A': Partial(<function add at 0x7f0c900581f0>, s=3.0),
 'X': Partial(<function mult at 0x7f0c90058820>, m=3),
 'Z': Partial(<function div at 0x7f0c90058940>, d=4)}

Now we have a list of jax Partials to which we can apply, assuming the individual elements are well behaved, all jax transformations in principle. If this is true for the elements, then it is true for the composition as long as the function we use for composition is pure functional itself

x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)

The expression that a pipeline builds is a partial object that is bound to the pipeline

tp.expression
Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x7f0c90058700>, pipeline=[Partial(<function add at 0x7f0c900581f0>, s=4), Partial(<function sub at 0x7f0c90058310>, s=2), Partial(<function add at 0x7f0c900581f0>, s=3.0), Partial(<function mult at 0x7f0c90058820>, m=3), Partial(<function div at 0x7f0c90058940>, d=4)])

… it has the same signature as the first function in the pipeline.

func = tp.compile_expression()
func
<PjitFunction of Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x7f0c90058700>, pipeline=[Partial(<function add at 0x7f0c900581f0>, s=4), Partial(<function sub at 0x7f0c90058310>, s=2), Partial(<function add at 0x7f0c900581f0>, s=3.0), Partial(<function mult at 0x7f0c90058820>, m=3), Partial(<function div at 0x7f0c90058940>, d=4)])>
func(x)
Array([6.  , 5.25, 4.5 ], dtype=float32)
div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)
Array([6.  , 5.25, 4.5 ], dtype=float32)

… output’s the same. yay :)

expr = tp.get_jaxpr()(x)
expr
{ lambda ; a:f32[3]. let
    b:f32[3] = add a 4.0
    c:f32[3] = sub b 2.0
    d:f32[3] = add c 3.0
    e:f32[3] = mul d 3.0
    f:f32[3] = div e 4.0
  in (f,) }
type(expr)
jax._src.core.ClosedJaxpr
def func_manual(x):
    return div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)
make_jaxpr(func_manual)(x)
{ lambda ; a:f32[3]. let
    b:f32[3] = add a 4.0
    c:f32[3] = sub b 2.0
    d:f32[3] = add c 3.0
    e:f32[3] = mul d 3.0
    f:f32[3] = div e 4.0
  in (f,) }

… expressions are too, because JAX is smart enough to trace across loops and we don’t have to mess with expression composition ourselves. We hence should end up with something that’s jax transformable if its elements are jax transformable. yay :)

just for completeness, we can mess a bit more with the expression stuff

jax.jacfwd(tp.compile_expression())(x)
Array([[0.75, 0.  , 0.  ],
       [0.  , 0.75, 0.  ],
       [0.  , 0.  , 0.75]], dtype=float32)
jax.jacrev(tp.compile_expression())(x)
Array([[0.75, 0.  , 0.  ],
       [0.  , 0.75, 0.  ],
       [0.  , 0.  , 0.75]], dtype=float32)
jax.hessian(tp.compile_expression())(x)
Array([[[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]], dtype=float32)

Query individual elements#

this is mainly useful for debugging

compile or get expressions for individual elements

tp.compile_element("A")
<PjitFunction of Partial(_HashableCallableShim(Partial(<function add at 0x7f0c900581f0>, s=3.0)))>
tp.get_jaxpr_for_element("A", x)
{ lambda ; a:f32[3]. let b:f32[3] = add a 3.0 in (b,) }

when building an expression with no arguments, a function is returned that creates an expression once args are added

f = tp.get_jaxpr_for_element("A")
f
<function jax.<unnamed function>(x, *, s: float = 3.0)>
f(x)
{ lambda ; a:f32[3]. let b:f32[3] = add a 3.0 in (b,) }

Alternative structures that allow for more complex systems#

  • allow to inject new data at intermediate steps: multiple starting points: transforms the pipeline into an inverted tree.

  • allow for a step to depend on multiple other steps: transforms the pipeline into a directed acyclic graph. Common structure in more general data processing systems.

=> if possible use something simple like Partial to accomplish this

Tentative best practices#

  • think in small steps: a more granular pipeline is easier to write in a pure functional style, easier to reason about and probably also better to optimize.

  • A more granular system also is easier to test and extend

  • ideally write the pipeline such that it can be compiled all at once with compile_expression.

Summary#

  • pipeline produces same jax code as handwritten stuff. This seems encouraging.

  • at which points do we still need to ensure pure functional behavior?

  • how will we enforce transformer compatibility

  • this is a pathologically simple case, hence not representative for real-world scenarios

  • when does it break?

  • what use cases are not covered?

  • what else do you need?