rocm_jax/docs/jep/15856-jex.md

190 lines
7.7 KiB
Markdown
Raw Normal View History

(jax-extend-jep)=
# `jax.extend`: a module for extensions
[@froystig](https://github.com/froystig),
[@sharadmv](https://github.com/sharadmv),
[@jakevdp](https://github.com/jakevdp),
[@yashk2810](https://github.com/yashk2810)
May 2023
```python
import jax.extend as jex
```
Several projects depend on JAX's codebase internals, often to use its
core machinery (e.g. to write a
[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html))
or to extend it (e.g. to
[define new primitives](https://github.com/dfm/extending-jax)).
Two challenges for these dependencies are (a) that our internals
aren't all solidly designed for external use, and (b) that
circumventing JAX's public API is
[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html).
In other words, our internals are often used like a library, but are
neither structured nor updated like one.
This proposal considers **introducing a `jax.extend` module that
defines a library view of some of JAX's internal components**. We would
treat this as a second-tier API, still guaranteeing essentially [no
compatibility policy](#no-compatibility-policy), but hopefully making
it easier to spot changes when they happen.
The audience for `jax.extend` includes JAX-adjacent Python libraries
like [Oryx](https://github.com/jax-ml/oryx),
[jax-triton](https://github.com/jax-ml/jax-triton), and many others,
as well as projects experimenting with function transformations,
autodiff systems, compiler frontends for numerical programming, etc.
This note gives an overview of how `jax.extend` might look, now and
eventually. It doesn't lay things out in great detail, instead
proposing that we begin [iteratively developing](#iterative-development)
the module.
Note that `jax.extend` differs from `jax.experimental`, which is a
staging ground for new features and ideas in progress. Typically, work
in `jax.experimental` eventually makes into another JAX module or is
removed altogether.
## No compatibility policy
To keep development overhead low, `jax.extend` would not follow the
public
[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html)
policy. It would promise no deprecation windows nor backwards
compatibility between releases. Every release may break existing
callers without simple recourse (e.g. without a flag reintroducing
prior behavior). We would rely on the
[changelog](https://jax.readthedocs.io/en/latest/changelog.html)
to call out such changes.
Callers of `jax.extend` that need to upgrade their code regularly
alongside JAX releases might find it useful to pin JAX versions as an
intermediate step between releases. This is a common habit among
projects that rely on JAX's internals today. The difference is that it
would now come with the help of changelog announcements and better
intentions regarding library design and naming.
## Iterative development
Having no compatibility policy makes it easier to get started on
implementation: on day one, we can move a handful of symbols over from
internal packages such as `jax._src` and today's `jax.core` and
`jax.interpreters`. Then we can iterate to improve things from there.
## Possible module overview
We can imagine that eventually `jax.extend` would include the
following modules:
* `core` primitives, the Jaxpr IR, etc.
* `interpreters` core transformations (e.g. autodiff, batching)
and lowerings.
* `random` random bit generation, key splitting and folding, key
arrays.
* `sharding` extra functionality around distributed arrays.
We might also have other symbols in the module at first, such as
`jex.api_util`, as we work to remove or replace them. Others will be
decided in time. For instance, `jex.lib` could offer an entry point to
jaxlib (and would do so in the immediate term), but it's not clear
whether we want to keep it for long.
Some preliminary thoughts on what each of these might comprise follow.
### `jax.extend.core`
This should enable callers at least to define new JAX primitives and
to process the Jaxpr IR (the output of
`jax.make_jaxpr(...)`). Supporting this might involve providing:
* Access to existing core system primitives, such as today's
`jax._src.lax.add_p`.
* Access to IR types, such as the current `jax._src.core.ShapedArray`.
* Functions for checking and pretty-printing jaxprs.
* Functions for building jaxprs explicitly, rather than by staging
Python functions via `jax.make_jaxpr` (or not!).
At initialization, this module will contain many more symbols than
what's needed to define primitives and rules, including various names
used in setting up
["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing),
such as the current `jax._src.core.Trace` and `Tracer` classes. We can
revisit whether `jex.core` should also support final-style extensions
alongside initial style approaches, and whether it can do so by a more
narrow API than exposing `Trace` and `Tracer` entirely.
[Oryx](https://github.com/jax-ml/oryx) might help guide these decisions.
We can also consider relocating `make_jaxpr` itself to `jex.core`.
### `jax.extend.interpreters`
This module would provide a means of registering various
transformation rules for primitives---defining their behavior
under AD, batching, lowering, etc.
It would initially reflect `jax._src.interpreters` in providing
the modules `ad`, `batching`, `partial_eval` (for staging Python to
Jaxpr, and for linearization in AD), `mlir`, `pxla`, and `xla`. The
first three might be replaceable by a single primitive extension API
in `jex.core`. The latter three, used for lowering, could be
simplified into one module, maybe.
Today, to write transformation rules, e.g. for AD and batching,
callers may need symbols relating to tracers, e.g. `JVPTracer` and
`BatchTracer`. This may be avoidable later on, and allow us to remove
tracer types from `jex`.
This module plus `jex.core` ought to suffice for replicating today's
custom primitive tutorials (e.g.
[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html)
and
[dfm's](https://github.com/dfm/extending-jax)).
For instance, defining a primitive and its behavior under `jax.jit`
would be possible as follows (in the immediate term):
```python
from jax.extend import core # Previously: from jax import core
from jax.extend.interpreters import mlir # ... and similarly
mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)
@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
return core.ShapedArray(x_sa.shape, x_sa.dtype)
def mul_add_mlir(ctx, xc, yc, zc):
add = mlir.hlo.AddOp
mul = mlir.hlo.MulOp
return add(mul(xc, yc), zc).results
mlir.register_lowering(mul_add_p, mul_add_mlir)
import jax
print(mul_add_p.bind(2, 3, 4)) # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
```
## `jax.extend.random`
This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals
(see issue [#9263](https://github.com/google/jax/issues/9263)),
such as the current `jax._src.prng.random_wrap` and
`random_unwrap`.
It could also expose the keyed hash functions that underlie the
built-in RNG implementations, such as `jax._src.prng.threefry_2x32`.
## `jax.extend.sharding`
This module could expose low-level utilities for sharding distributed
arrays.
We have only one item in mind for now. The XLA compiler's
array sharding format is more expressive than [those provided by
JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could
provide this as `jex.sharding.XlaOpShardingProto`, corresponding to
today's `jax._src.lib.xla_client.OpSharding` internally.