{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Concept of the pipeline\n", "## Basic ideas\n", "- rubix essentially implements a big data transformation pipeline. \n", "\n", "- a pipeline is composed of nodes that are ordered in a list ordered by execution order (or more generally a DAG [not supported currently]). Each node is called a transformer. \n", "\n", "- each step in this pipeline (i.e., each transformer) can ultimately be seen in itself as being composed of other, smaller transformers. This gives us a pattern that can be used to guide the implementation of transformers\n", "\n", "- simple implementation in rubix.pipeline\n", "\n", "## Restrictions\n", "- jax is pure functional. Anything that needs to be transformed with jax has to be a pure function. \n", "Any stuff that comes from the environment must be explicitly copied into the function or be bound to it such that the internal state is of the function is self-contained. \n", "\n", "- It's irrelevant what builds these pure functions. Therefore, we use a factory pattern to do all configuration work like reading files, pulling stuff from the net, providing any function arguments to be used in the pipeline and so on. A factory then produces a pure function that contains all the data we need as static arguments and retains only the stuff it computes on as tracable arguments. \n", "\n", "- we can leverage [`jax.tree_util.Partial`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.Partial.html) for this, which works like `functools.partial` but is compatible with jax transformations. Note that stateful objects can still be used internally as long as no stuff from an outer scope (that may change over time) is read or written. This is the user's responsibility \n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from copy import deepcopy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from jax import make_jaxpr\n", "from jax.tree_util import Partial\n", "from jax import jit, grad" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from rubix.pipeline import linear_pipeline as ltp\n", "from rubix.pipeline import transformer as rtr\n", "from rubix.utils import read_yaml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build some simple decorator for function configuration\n", "- leverages jax.tree_util.Partial\n", "- builds a partial object to which jax transformations can be applied \n", "- three cases: \n", " - build the pure function object: you have to take care about static args/kwargs yourself upon calling jit. The decorator only builds the function object\n", " - jit it right away: the usual. here you can tell it which args/kwargs to trace or not with the `static_args` and `static_kwargs` keyword arguments\n", " - build expression: mainly to check what comes out of the thing at the end of for intermediate steps. can build a jax expression (wiht no arguments) or a jax core expression (when arguments are given as well). Note that for some reasone, `jax.make_jaxpr` does not have `static_argnames` like `jit` does. \n", "- With these, we can configure our pipeline transformers. \n", "- Not entirely sure right now which are useful or needed\n", "- these decorators/factory functions live in `rubix.pipeline.transformer`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**simple transformer decorator that binds function to arguments** \n", "\n", "this shows the basic implementation, they are available under `rubix.pipeline.transformer` in the package.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def bound_transformer(*args, **kwargs):\n", " \"\"\"\n", " bound_transformer Create a jax.Partial function from an input\n", " function and given arguments and keyword arguments.\n", " The user must take care that arguments are bound starting from the first,\n", " i.e., leftmost. If specific arguments should be bound using keyword\n", " arguments may be advisable.\n", " \"\"\"\n", "\n", " def transformer_wrap(kernel):\n", " return Partial(\n", " deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)\n", " ) # deepcopy to avoid context dependency\n", "\n", " return transformer_wrap" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@bound_transformer(z=5, k=3.14)\n", "def add(x, y, z: float = 0, k: float = 0):\n", " return x + y + z + k" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(add)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "addjit = jax.jit(add)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "addjit(x, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Compiling transformer to jit individual elements and bind them to traceable partial arguments\n", "\n", "- can be used for the final pipeline or for intermediate steps during debug or whatever\n", "- combines a `Partial` to bind arguments that is then jitted with static args and kwargs. However, bound args and kwargs can **NOT** be static at the same time. In principle, we would want a partial of a jit here, which kind of defeats the purpose of the jit because of overhead of the wrapper? \n", "- A solution to this would yield a configurable jit factory, essentially. \n", "- I am not entirely sure why the below works the way it does\n", "- not even entirely sure it is useful at all... " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compiled_transformer(\n", " *args,\n", " static_args: list = [],\n", " static_kwargs: list = [],\n", " **kwargs,\n", "):\n", "\n", " def transformer_wrap(kernel):\n", "\n", " return jit(\n", " Partial(deepcopy(kernel), *deepcopy(args), **deepcopy(kwargs)),\n", " static_argnums=deepcopy(static_args),\n", " static_argnames=deepcopy(static_kwargs),\n", " )\n", "\n", " return transformer_wrap" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@compiled_transformer(\n", " z=5,\n", " k=-3.14,\n", " static_kwargs=[\n", " \"k\",\n", " ],\n", ")\n", "def cond_add(x, y, z: float = 0, k: float = 0):\n", " if k < 0:\n", " return x + y + z + k\n", " else:\n", " return x + y + z + 2 * k" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add(x, x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def cond_add(x, y, z: float = 0, k: float = 0):\n", " if k < 0:\n", " return x + y + z + k\n", " else:\n", " return x + y + z + 2 * k" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "use on predefined functions without the decorator syntax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add_plus = compiled_transformer(z=5, static_kwargs=[\"k\"])(cond_add)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add_plus" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add_plus(x, x, k=-3.14)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Problem**: the `compiled_transformer` decorator cannot make args or kwargs static that are bound to the function, i.e., configured parameters are not static here. This only works if the entire pipeline is compiled after assembling it. Not sure how to fix that at the moment, if at all" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Expression based decorator for getting out the intermediate `jaxpr` object for inspection** \n", "- `make_jaxpr` does not support kwargs. god knows why?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def expression_transformer(\n", " *args,\n", " static_args: list = [],\n", "):\n", "\n", " def transformer_wrap(kernel):\n", " if len(args) > 0:\n", " return jax.make_jaxpr(kernel, static_argnums=static_args)(*args)\n", " else:\n", " return jax.make_jaxpr(kernel, static_argnums=static_args)\n", "\n", " return transformer_wrap" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@expression_transformer(x, x, 5, 3.14, static_args=[2, 3])\n", "def cond_add(x, y, z: float = 0, k: float = 0):\n", " if k < 0:\n", " return x + y + z + k\n", " else:\n", " return x + y + z + 2 * k" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "make sure to use the right `static_args` when doing control flow, or use jax/lax primitives" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@expression_transformer(x, x, 5, -3.14, static_args=[3])\n", "def cond_add(x, y, z: float = 0, k: float = 0):\n", " if k < 0:\n", " return x + y + z + k\n", " else:\n", " return x + y + z + 2 * k" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "without giving arguments you get out a function that produces an expression when arguments are added" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@expression_transformer(static_args=[2, 3])\n", "def cond_add(x, y, z: float = 0, k: float = 0):\n", " if k < 0:\n", " return x + y + z + k\n", " else:\n", " return x + y + z + 2 * k" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cond_add(x, x, 3, 2.71)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define a number of simple, dump transformers\n", "- we pretend that their second value is something we want to configure from the start and hence it should not be traced\n", "\n", "- we can use the above decorators to bind their second arg to something we know" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def add(x, s: float):\n", " return x + s\n", "\n", "\n", "def mult(x, m: float):\n", " return x * m\n", "\n", "\n", "def div(x, d: float):\n", " return x / d\n", "\n", "\n", "def sub(x, s: float):\n", " return x - s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configuration files and pipeline building\n", "\n", "### General remarks about yaml\n", "- yaml format: dictionary \n", "- inside the dictionary one can arbitrarily nest lists, dicts. \n", "- yaml is customizable for node formats that are not provided by default, or for reading in types. Look up yaml-tags for more. \n", "- there's yaml libraries for pretty much all common languages\n", "\n", "### Here, we use yaml in the following way: \n", "- the config file builds an adjacency list of a DAG essentially, but currently the design is limited to only one child per node => linear pipelines only, no branching\n", "- consequently, the build algorithm is limited to linear pipelines for the moment. Both must evolve together.\n", "- while a more general abstract base class is provided, we only implement a linear pipeline class `LinearTransformerPipeline` at the moment\n", "- the essential part is the `Transformers` node of the config. this is the actual DAG adjacency list. This needs to adhere to the format outlined below.\n", "- you can add other nodes to configure other parts of your system: data directories and so on.\n", "- The starting point is always defined by a node that does not depend on another node. \n", "- The stop point is just the last element in the pipeline\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Config node structure: \n", "\n", "\n", "```yaml\n", "name_of_pipeline_step:\n", "\n", " name: name_of_function to use for this step\n", "\n", " depends_on: name_of_step_immediatelly_prior_in_pipeline \n", " \n", " args: # arguments to the transformer functions that should be bound to it\n", "\n", " argument1: value1 \n", "\n", " argument2: value2 \n", "\n", " argumentN: valueN\n", "\n", "```\n", "\n", "the arguments in `args` will be used to create the `Partial` object, using the transformer decorators above.\n", "\n", "\n", "**Example** \n", "```yaml \n", " B: \n", " name: sub\n", " depends_on: C\n", " args:\n", " s: 2\n", " C:\n", " name: add\n", " depends_on: null\n", " args:\n", " s: 4\n", "\n", "```\n", "Here, `C` is the starting node, i.e., the first function in the pipeline. \n", "Whatever you do before that with your data does not concern the pipeline and hence has no influence on differentiability etc. \n", "\n", "For a full example, see the `demo.yml` file.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read yaml and build pipeline \n", "\n", "`read_yaml` is available from `rubix.utils.py` and is very simple\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "read_cfg = read_yaml(\"./demo.yml\") # implemented in utils" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "read_cfg" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(read_cfg)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "read_cfg[\"Transformers\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(read_cfg[\"Transformers\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transformers need to be registered upon creation. If you have fixed ones or many of them, maybe it makes sense to write a factory function. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tp = ltp.LinearTransformerPipeline(read_cfg, [add, mult, div, sub])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tp.transformers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `transformers` member gives us a dict of `name: function` pairs for the transformers \n", "This currently has to be done before the assembly of the pipeline, or the pipeline will not know what to assemble it from" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tp.assemble()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tp.pipeline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have a list of jax `Partial`s to which we can apply, assuming the individual elements are well behaved, all jax transformations in principle. If this is true for the elements, then it is true for the composition as long as the function we use for composition is pure functional itself" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = jnp.array([3.0, 2.0, 1.0], dtype=jnp.float32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The expression that a pipeline builds is a partial object that is bound to the pipeline " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tp.expression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "... it has the same signature as the first function in the pipeline." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func = tp.compile_expression()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "func(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "... output's the same. yay :) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "expr = tp.get_jaxpr()(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "expr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(expr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def func_manual(x):\n", " return div(mult(add(sub(add(x, s=4), s=2), s=3), m=3), d=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "make_jaxpr(func_manual)(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "... expressions are too, because JAX is smart enough to trace across loops and we don't have to mess with expression composition ourselves. We hence should end up with something that's jax transformable if its elements are jax transformable. yay :) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "just for completeness, we can mess a bit more with the expression stuff" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jax.jacfwd(tp.compile_expression())(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jax.jacrev(tp.compile_expression())(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "jax.hessian(tp.compile_expression())(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Query individual elements \n", "\n", "this is mainly useful for debugging" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "compile or get expressions for individual elements" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tp.compile_element(\"A\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tp.get_jaxpr_for_element(\"A\", x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "when building an expression with no arguments, a function is returned that creates an expression once args are added" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = tp.get_jaxpr_for_element(\"A\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Alternative structures that allow for more complex systems\n", "\n", "- allow to inject new data at intermediate steps: multiple starting points: transforms the pipeline into an inverted tree. \n", "\n", "- allow for a step to depend on multiple other steps: transforms the pipeline into a directed acyclic graph. Common structure in more general data processing systems. \n", "\n", "=> if possible use something simple like `Partial` to accomplish this " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tentative best practices\n", "\n", "- think in small steps: a more granular pipeline is easier to write in a pure functional style, easier to reason about and probably also better to optimize. \n", "- A more granular system also is easier to test and extend \n", "- ideally write the pipeline such that it can be compiled all at once with `compile_expression`. \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary \n", "- pipeline produces same jax code as handwritten stuff. This seems encouraging.\n", "- at which points do we still need to ensure pure functional behavior?\n", "- how will we enforce transformer compatibility\n", "- this is a pathologically simple case, hence not representative for real-world scenarios\n", "- when does it break? \n", "- what use cases are not covered?\n", "- what else do you need? \n" ] } ], "metadata": { "kernelspec": { "display_name": "rubix", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }