7.7 KiB
(jax-extend-jep)=
jax.extend
: a module for extensions
@froystig, @sharadmv, @jakevdp, @yashk2810
May 2023
import jax.extend as jex
Several projects depend on JAX's codebase internals, often to use its core machinery (e.g. to write a transformation over its IR) or to extend it (e.g. to define new primitives). Two challenges for these dependencies are (a) that our internals aren't all solidly designed for external use, and (b) that circumventing JAX's public API is unsupported. In other words, our internals are often used like a library, but are neither structured nor updated like one.
This proposal considers introducing a jax.extend
module that
defines a library view of some of JAX's internal components. We would
treat this as a second-tier API, still guaranteeing essentially no
compatibility policy, but hopefully making
it easier to spot changes when they happen.
The audience for jax.extend
includes JAX-adjacent Python libraries
like Oryx,
jax-triton, and many others,
as well as projects experimenting with function transformations,
autodiff systems, compiler frontends for numerical programming, etc.
This note gives an overview of how jax.extend
might look, now and
eventually. It doesn't lay things out in great detail, instead
proposing that we begin iteratively developing
the module.
Note that jax.extend
differs from jax.experimental
, which is a
staging ground for new features and ideas in progress. Typically, work
in jax.experimental
eventually makes into another JAX module or is
removed altogether.
No compatibility policy
To keep development overhead low, jax.extend
would not follow the
public
API compatibility
policy. It would promise no deprecation windows nor backwards
compatibility between releases. Every release may break existing
callers without simple recourse (e.g. without a flag reintroducing
prior behavior). We would rely on the
changelog
to call out such changes.
Callers of jax.extend
that need to upgrade their code regularly
alongside JAX releases might find it useful to pin JAX versions as an
intermediate step between releases. This is a common habit among
projects that rely on JAX's internals today. The difference is that it
would now come with the help of changelog announcements and better
intentions regarding library design and naming.
Iterative development
Having no compatibility policy makes it easier to get started on
implementation: on day one, we can move a handful of symbols over from
internal packages such as jax._src
and today's jax.core
and
jax.interpreters
. Then we can iterate to improve things from there.
Possible module overview
We can imagine that eventually jax.extend
would include the
following modules:
core
– primitives, the Jaxpr IR, etc.interpreters
– core transformations (e.g. autodiff, batching) and lowerings.random
– random bit generation, key splitting and folding, key arrays.sharding
– extra functionality around distributed arrays.
We might also have other symbols in the module at first, such as
jex.api_util
, as we work to remove or replace them. Others will be
decided in time. For instance, jex.lib
could offer an entry point to
jaxlib (and would do so in the immediate term), but it's not clear
whether we want to keep it for long.
Some preliminary thoughts on what each of these might comprise follow.
jax.extend.core
This should enable callers at least to define new JAX primitives and
to process the Jaxpr IR (the output of
jax.make_jaxpr(...)
). Supporting this might involve providing:
- Access to existing core system primitives, such as today's
jax._src.lax.add_p
. - Access to IR types, such as the current
jax._src.core.ShapedArray
. - Functions for checking and pretty-printing jaxprs.
- Functions for building jaxprs explicitly, rather than by staging
Python functions via
jax.make_jaxpr
(or not!).
At initialization, this module will contain many more symbols than
what's needed to define primitives and rules, including various names
used in setting up
"final-style transformations",
such as the current jax._src.core.Trace
and Tracer
classes. We can
revisit whether jex.core
should also support final-style extensions
alongside initial style approaches, and whether it can do so by a more
narrow API than exposing Trace
and Tracer
entirely.
Oryx might help guide these decisions.
We can also consider relocating make_jaxpr
itself to jex.core
.
jax.extend.interpreters
This module would provide a means of registering various transformation rules for primitives---defining their behavior under AD, batching, lowering, etc.
It would initially reflect jax._src.interpreters
in providing
the modules ad
, batching
, partial_eval
(for staging Python to
Jaxpr, and for linearization in AD), mlir
, pxla
, and xla
. The
first three might be replaceable by a single primitive extension API
in jex.core
. The latter three, used for lowering, could be
simplified into one module, maybe.
Today, to write transformation rules, e.g. for AD and batching,
callers may need symbols relating to tracers, e.g. JVPTracer
and
BatchTracer
. This may be avoidable later on, and allow us to remove
tracer types from jex
.
This module plus jex.core
ought to suffice for replicating today's
custom primitive tutorials (e.g.
ours
and
dfm's).
For instance, defining a primitive and its behavior under jax.jit
would be possible as follows (in the immediate term):
from jax.extend import core # Previously: from jax import core
from jax.extend.interpreters import mlir # ... and similarly
mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)
@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
return core.ShapedArray(x_sa.shape, x_sa.dtype)
def mul_add_mlir(ctx, xc, yc, zc):
add = mlir.hlo.AddOp
mul = mlir.hlo.MulOp
return add(mul(xc, yc), zc).results
mlir.register_lowering(mul_add_p, mul_add_mlir)
import jax
print(mul_add_p.bind(2, 3, 4)) # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
jax.extend.random
This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals
(see issue #9263),
such as the current jax._src.prng.random_wrap
and
random_unwrap
.
It could also expose the keyed hash functions that underlie the
built-in RNG implementations, such as jax._src.prng.threefry_2x32
.
jax.extend.sharding
This module could expose low-level utilities for sharding distributed arrays.
We have only one item in mind for now. The XLA compiler's
array sharding format is more expressive than those provided by
JAX. We could
provide this as jex.sharding.XlaOpShardingProto
, corresponding to
today's jax._src.lib.xla_client.OpSharding
internally.