2023-08-30 15:14:47 -07:00
|
|
|
|
(jax-extend-jep)=
|
2023-05-05 10:24:06 -07:00
|
|
|
|
# `jax.extend`: a module for extensions
|
|
|
|
|
|
|
|
|
|
[@froystig](https://github.com/froystig),
|
|
|
|
|
[@sharadmv](https://github.com/sharadmv),
|
|
|
|
|
[@jakevdp](https://github.com/jakevdp),
|
|
|
|
|
[@yashk2810](https://github.com/yashk2810)
|
|
|
|
|
|
|
|
|
|
May 2023
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
import jax.extend as jex
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
Several projects depend on JAX's codebase internals, often to use its
|
|
|
|
|
core machinery (e.g. to write a
|
|
|
|
|
[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html))
|
|
|
|
|
or to extend it (e.g. to
|
|
|
|
|
[define new primitives](https://github.com/dfm/extending-jax)).
|
|
|
|
|
Two challenges for these dependencies are (a) that our internals
|
|
|
|
|
aren't all solidly designed for external use, and (b) that
|
|
|
|
|
circumventing JAX's public API is
|
|
|
|
|
[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html).
|
|
|
|
|
In other words, our internals are often used like a library, but are
|
|
|
|
|
neither structured nor updated like one.
|
|
|
|
|
|
|
|
|
|
This proposal considers **introducing a `jax.extend` module that
|
|
|
|
|
defines a library view of some of JAX's internal components**. We would
|
|
|
|
|
treat this as a second-tier API, still guaranteeing essentially [no
|
|
|
|
|
compatibility policy](#no-compatibility-policy), but hopefully making
|
|
|
|
|
it easier to spot changes when they happen.
|
|
|
|
|
|
|
|
|
|
The audience for `jax.extend` includes JAX-adjacent Python libraries
|
|
|
|
|
like [Oryx](https://github.com/jax-ml/oryx),
|
|
|
|
|
[jax-triton](https://github.com/jax-ml/jax-triton), and many others,
|
|
|
|
|
as well as projects experimenting with function transformations,
|
|
|
|
|
autodiff systems, compiler frontends for numerical programming, etc.
|
|
|
|
|
|
|
|
|
|
This note gives an overview of how `jax.extend` might look, now and
|
|
|
|
|
eventually. It doesn't lay things out in great detail, instead
|
|
|
|
|
proposing that we begin [iteratively developing](#iterative-development)
|
|
|
|
|
the module.
|
|
|
|
|
|
|
|
|
|
Note that `jax.extend` differs from `jax.experimental`, which is a
|
|
|
|
|
staging ground for new features and ideas in progress. Typically, work
|
|
|
|
|
in `jax.experimental` eventually makes into another JAX module or is
|
|
|
|
|
removed altogether.
|
|
|
|
|
|
|
|
|
|
## No compatibility policy
|
|
|
|
|
|
|
|
|
|
To keep development overhead low, `jax.extend` would not follow the
|
|
|
|
|
public
|
|
|
|
|
[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html)
|
|
|
|
|
policy. It would promise no deprecation windows nor backwards
|
|
|
|
|
compatibility between releases. Every release may break existing
|
|
|
|
|
callers without simple recourse (e.g. without a flag reintroducing
|
|
|
|
|
prior behavior). We would rely on the
|
|
|
|
|
[changelog](https://jax.readthedocs.io/en/latest/changelog.html)
|
|
|
|
|
to call out such changes.
|
|
|
|
|
|
|
|
|
|
Callers of `jax.extend` that need to upgrade their code regularly
|
|
|
|
|
alongside JAX releases might find it useful to pin JAX versions as an
|
|
|
|
|
intermediate step between releases. This is a common habit among
|
|
|
|
|
projects that rely on JAX's internals today. The difference is that it
|
|
|
|
|
would now come with the help of changelog announcements and better
|
|
|
|
|
intentions regarding library design and naming.
|
|
|
|
|
|
|
|
|
|
## Iterative development
|
|
|
|
|
|
|
|
|
|
Having no compatibility policy makes it easier to get started on
|
|
|
|
|
implementation: on day one, we can move a handful of symbols over from
|
|
|
|
|
internal packages such as `jax._src` and today's `jax.core` and
|
|
|
|
|
`jax.interpreters`. Then we can iterate to improve things from there.
|
|
|
|
|
|
|
|
|
|
## Possible module overview
|
|
|
|
|
|
|
|
|
|
We can imagine that eventually `jax.extend` would include the
|
|
|
|
|
following modules:
|
|
|
|
|
|
|
|
|
|
* `core` – primitives, the Jaxpr IR, etc.
|
|
|
|
|
* `interpreters` – core transformations (e.g. autodiff, batching)
|
|
|
|
|
and lowerings.
|
|
|
|
|
* `random` – random bit generation, key splitting and folding, key
|
|
|
|
|
arrays.
|
|
|
|
|
* `sharding` – extra functionality around distributed arrays.
|
|
|
|
|
|
|
|
|
|
We might also have other symbols in the module at first, such as
|
|
|
|
|
`jex.api_util`, as we work to remove or replace them. Others will be
|
|
|
|
|
decided in time. For instance, `jex.lib` could offer an entry point to
|
|
|
|
|
jaxlib (and would do so in the immediate term), but it's not clear
|
|
|
|
|
whether we want to keep it for long.
|
|
|
|
|
|
|
|
|
|
Some preliminary thoughts on what each of these might comprise follow.
|
|
|
|
|
|
|
|
|
|
### `jax.extend.core`
|
|
|
|
|
|
|
|
|
|
This should enable callers at least to define new JAX primitives and
|
|
|
|
|
to process the Jaxpr IR (the output of
|
|
|
|
|
`jax.make_jaxpr(...)`). Supporting this might involve providing:
|
|
|
|
|
|
|
|
|
|
* Access to existing core system primitives, such as today's
|
|
|
|
|
`jax._src.lax.add_p`.
|
|
|
|
|
* Access to IR types, such as the current `jax._src.core.ShapedArray`.
|
|
|
|
|
* Functions for checking and pretty-printing jaxprs.
|
|
|
|
|
* Functions for building jaxprs explicitly, rather than by staging
|
|
|
|
|
Python functions via `jax.make_jaxpr` (or not!).
|
|
|
|
|
|
|
|
|
|
At initialization, this module will contain many more symbols than
|
|
|
|
|
what's needed to define primitives and rules, including various names
|
|
|
|
|
used in setting up
|
|
|
|
|
["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing),
|
|
|
|
|
such as the current `jax._src.core.Trace` and `Tracer` classes. We can
|
|
|
|
|
revisit whether `jex.core` should also support final-style extensions
|
|
|
|
|
alongside initial style approaches, and whether it can do so by a more
|
|
|
|
|
narrow API than exposing `Trace` and `Tracer` entirely.
|
|
|
|
|
[Oryx](https://github.com/jax-ml/oryx) might help guide these decisions.
|
|
|
|
|
|
|
|
|
|
We can also consider relocating `make_jaxpr` itself to `jex.core`.
|
|
|
|
|
|
|
|
|
|
### `jax.extend.interpreters`
|
|
|
|
|
|
|
|
|
|
This module would provide a means of registering various
|
|
|
|
|
transformation rules for primitives---defining their behavior
|
|
|
|
|
under AD, batching, lowering, etc.
|
|
|
|
|
|
|
|
|
|
It would initially reflect `jax._src.interpreters` in providing
|
|
|
|
|
the modules `ad`, `batching`, `partial_eval` (for staging Python to
|
|
|
|
|
Jaxpr, and for linearization in AD), `mlir`, `pxla`, and `xla`. The
|
|
|
|
|
first three might be replaceable by a single primitive extension API
|
|
|
|
|
in `jex.core`. The latter three, used for lowering, could be
|
|
|
|
|
simplified into one module, maybe.
|
|
|
|
|
|
|
|
|
|
Today, to write transformation rules, e.g. for AD and batching,
|
|
|
|
|
callers may need symbols relating to tracers, e.g. `JVPTracer` and
|
|
|
|
|
`BatchTracer`. This may be avoidable later on, and allow us to remove
|
|
|
|
|
tracer types from `jex`.
|
|
|
|
|
|
|
|
|
|
This module plus `jex.core` ought to suffice for replicating today's
|
|
|
|
|
custom primitive tutorials (e.g.
|
|
|
|
|
[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html)
|
|
|
|
|
and
|
|
|
|
|
[dfm's](https://github.com/dfm/extending-jax)).
|
|
|
|
|
For instance, defining a primitive and its behavior under `jax.jit`
|
|
|
|
|
would be possible as follows (in the immediate term):
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
from jax.extend import core # Previously: from jax import core
|
|
|
|
|
from jax.extend.interpreters import mlir # ... and similarly
|
|
|
|
|
|
|
|
|
|
mul_add_p = core.Primitive('mul_add')
|
|
|
|
|
mul_add_p.def_impl(lambda x, y, z: x * y + z)
|
|
|
|
|
|
|
|
|
|
@mul_add_p.def_abstract_eval
|
|
|
|
|
def mul_add_abstract(x_sa, y_sa, z_sa):
|
|
|
|
|
return core.ShapedArray(x_sa.shape, x_sa.dtype)
|
|
|
|
|
|
|
|
|
|
def mul_add_mlir(ctx, xc, yc, zc):
|
|
|
|
|
add = mlir.hlo.AddOp
|
|
|
|
|
mul = mlir.hlo.MulOp
|
|
|
|
|
return add(mul(xc, yc), zc).results
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(mul_add_p, mul_add_mlir)
|
|
|
|
|
|
|
|
|
|
import jax
|
|
|
|
|
print(mul_add_p.bind(2, 3, 4)) # -> 10
|
|
|
|
|
print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
## `jax.extend.random`
|
|
|
|
|
|
|
|
|
|
This module could expose our mechanism for defining new RNG
|
|
|
|
|
implementations, and functions for working with PRNG key internals
|
|
|
|
|
(see issue [#9263](https://github.com/google/jax/issues/9263)),
|
|
|
|
|
such as the current `jax._src.prng.random_wrap` and
|
|
|
|
|
`random_unwrap`.
|
|
|
|
|
|
|
|
|
|
It could also expose the keyed hash functions that underlie the
|
|
|
|
|
built-in RNG implementations, such as `jax._src.prng.threefry_2x32`.
|
|
|
|
|
|
|
|
|
|
## `jax.extend.sharding`
|
|
|
|
|
|
|
|
|
|
This module could expose low-level utilities for sharding distributed
|
|
|
|
|
arrays.
|
|
|
|
|
|
|
|
|
|
We have only one item in mind for now. The XLA compiler's
|
|
|
|
|
array sharding format is more expressive than [those provided by
|
|
|
|
|
JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could
|
|
|
|
|
provide this as `jex.sharding.XlaOpShardingProto`, corresponding to
|
|
|
|
|
today's `jax._src.lib.xla_client.OpSharding` internally.
|