From 4eb1820ae266a805a2bafda858b1ee4aaaeee929 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 20 Jul 2019 14:40:31 +0100 Subject: [PATCH] Add documentation to JAX modules. --- docs/jax.lax.rst | 12 +++++++++--- docs/jax.numpy.rst | 23 +++++++++++++++++++++++ docs/jax.ops.rst | 10 +++++++++- docs/jax.rst | 32 +++++++++++++++++++++++++++++--- jax/numpy/lax_numpy.py | 11 +++++++++++ 5 files changed, 81 insertions(+), 7 deletions(-) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 3f93ed5d9..3e6dc4506 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -3,13 +3,19 @@ jax.lax package .. automodule:: jax.lax -`lax` is a library of primitives that underpins libraries such as `jax.numpy`. +:mod:`jax.lax` is a library of primitives operations that underpins libraries +such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules, +are typically defined as transformations on :mod:`jax.lax` primitives. Many of the primitives are thin wrappers around equivalent XLA operations, described by the `XLA operation semantics -`_ documentation. +`_ documentation. In a few +cases JAX diverges from XLA, usually to ensure that the set of operations is +closed under the operation of JVP and transpose rules. -Where possible, prefer to use libraries such as `jax.numpy` instead of using `jax.lax` directly. +Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of +using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is +therefore more stable and less likely to change than the :mod:`jax.lax` API. Operators --------- diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index ddd9c6b80..ff6921ba5 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -6,6 +6,29 @@ jax.numpy package .. automodule:: jax.numpy +Implements the NumPy API, using the primitives in :mod:`jax.lax`. + +While JAX tries to follow the NumPy API as closely as possible, sometimes JAX +cannot follow NumPy exactly. + +* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays + in-place cannot be implemented in JAX. However, often JAX is able to provide a + alternative API that is purely functional. For example, instead of in-place + array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed + update function :func:`jax.ops.index_update`. + +* NumPy is very aggressive at promoting values to :code:`float64` type. JAX + sometimes is less aggressive about type promotion. + +A small number of NumPy operations that have data-dependent output shapes are +incompatible with :func:`jax.jit` compilation. The XLA compiler requires that +shapes of arrays be known at compile time. While it would be possible to provide +a JAX implementation of an API such as :func:`numpy.nonzero`, we would be unable +to JIT-compile it because the shape of its output depends on the contents of the +input data. + +Not every function in NumPy is implemented; contributions are welcome! + .. autosummary:: :toctree: _autosummary diff --git a/docs/jax.ops.rst b/docs/jax.ops.rst index 33b250b17..1b084708f 100644 --- a/docs/jax.ops.rst +++ b/docs/jax.ops.rst @@ -11,7 +11,8 @@ Indexed update operators ------------------------ JAX is intended to be used with a functional style of programming, and hence -does not support NumPy-style indexed assignment directly. +does not support NumPy-style indexed assignment directly. Instead, JAX provides +pure alternatives, namely :func:`jax.ops.index_update` and its relatives. .. autosummary:: :toctree: _autosummary @@ -21,4 +22,11 @@ does not support NumPy-style indexed assignment directly. index_add index_min index_max + +Other operators +--------------- + +.. autosummary:: + :toctree: _autosummary + segment_sum diff --git a/docs/jax.rst b/docs/jax.rst index 6221ab91a..c1ec9d9b1 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -15,10 +15,36 @@ Subpackages jax.random jax.tree_util -Module contents ---------------- +Just-in-time compilation (:code:`jit`) +-------------------------------------- .. automodule:: jax - :members: jit, disable_jit, grad, value_and_grad, vmap, pmap, jacfwd, jacrev, hessian, jvp, linearize, vjp, make_jaxpr, eval_shape, custom_transforms, defjvp, defjvp_all, defvjp, defvjp_all, custom_gradient, xla_computation + :members: jit, disable_jit, xla_computation, make_jaxpr, eval_shape + :undoc-members: + :show-inheritance: + +Automatic differentiation +------------------------- + +.. automodule:: jax + :members: grad, value_and_grad, jacfwd, jacrev, hessian, jvp, linearize, vjp, custom_transforms, defjvp, defjvp_all, defvjp, defvjp_all, custom_gradient + :undoc-members: + :show-inheritance: + + +Vectorization (:code:`vmap`) +---------------------------- + +.. automodule:: jax + :members: vmap + :undoc-members: + :show-inheritance: + + +Parallelization (:code:`pmap`) +---------------------------- + +.. automodule:: jax + :members: pmap :undoc-members: :show-inheritance: diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 169462085..d63e4ca4a 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Implements the NumPy API, using the primitives in :mod:`jax.lax`. + +NumPy operations are implemented in Python in terms of the primitive operations +in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are +implemented in terms of :mod:`jax.lax` operations, we do not need to define +transformation rules such as gradient or batching rules. Instead, +transformations for NumPy primitives can be derived from the transformation +rules for the underlying :code:`lax` primitives. +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function