Add documentation to JAX modules.

This commit is contained in:
Peter Hawkins 2019-07-20 14:40:31 +01:00
parent c9eb063c19
commit 4eb1820ae2
5 changed files with 81 additions and 7 deletions

View File

@ -3,13 +3,19 @@ jax.lax package
.. automodule:: jax.lax
`lax` is a library of primitives that underpins libraries such as `jax.numpy`.
:mod:`jax.lax` is a library of primitives operations that underpins libraries
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
are typically defined as transformations on :mod:`jax.lax` primitives.
Many of the primitives are thin wrappers around equivalent XLA operations,
described by the `XLA operation semantics
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation.
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
cases JAX diverges from XLA, usually to ensure that the set of operations is
closed under the operation of JVP and transpose rules.
Where possible, prefer to use libraries such as `jax.numpy` instead of using `jax.lax` directly.
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
therefore more stable and less likely to change than the :mod:`jax.lax` API.
Operators
---------

View File

@ -6,6 +6,29 @@ jax.numpy package
.. automodule:: jax.numpy
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX
cannot follow NumPy exactly.
* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
in-place cannot be implemented in JAX. However, often JAX is able to provide a
alternative API that is purely functional. For example, instead of in-place
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
update function :func:`jax.ops.index_update`.
* NumPy is very aggressive at promoting values to :code:`float64` type. JAX
sometimes is less aggressive about type promotion.
A small number of NumPy operations that have data-dependent output shapes are
incompatible with :func:`jax.jit` compilation. The XLA compiler requires that
shapes of arrays be known at compile time. While it would be possible to provide
a JAX implementation of an API such as :func:`numpy.nonzero`, we would be unable
to JIT-compile it because the shape of its output depends on the contents of the
input data.
Not every function in NumPy is implemented; contributions are welcome!
.. autosummary::
:toctree: _autosummary

View File

@ -11,7 +11,8 @@ Indexed update operators
------------------------
JAX is intended to be used with a functional style of programming, and hence
does not support NumPy-style indexed assignment directly.
does not support NumPy-style indexed assignment directly. Instead, JAX provides
pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
.. autosummary::
:toctree: _autosummary
@ -21,4 +22,11 @@ does not support NumPy-style indexed assignment directly.
index_add
index_min
index_max
Other operators
---------------
.. autosummary::
:toctree: _autosummary
segment_sum

View File

@ -15,10 +15,36 @@ Subpackages
jax.random
jax.tree_util
Module contents
---------------
Just-in-time compilation (:code:`jit`)
--------------------------------------
.. automodule:: jax
:members: jit, disable_jit, grad, value_and_grad, vmap, pmap, jacfwd, jacrev, hessian, jvp, linearize, vjp, make_jaxpr, eval_shape, custom_transforms, defjvp, defjvp_all, defvjp, defvjp_all, custom_gradient, xla_computation
:members: jit, disable_jit, xla_computation, make_jaxpr, eval_shape
:undoc-members:
:show-inheritance:
Automatic differentiation
-------------------------
.. automodule:: jax
:members: grad, value_and_grad, jacfwd, jacrev, hessian, jvp, linearize, vjp, custom_transforms, defjvp, defjvp_all, defvjp, defvjp_all, custom_gradient
:undoc-members:
:show-inheritance:
Vectorization (:code:`vmap`)
----------------------------
.. automodule:: jax
:members: vmap
:undoc-members:
:show-inheritance:
Parallelization (:code:`pmap`)
----------------------------
.. automodule:: jax
:members: pmap
:undoc-members:
:show-inheritance:

View File

@ -12,6 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
NumPy operations are implemented in Python in terms of the primitive operations
in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
implemented in terms of :mod:`jax.lax` operations, we do not need to define
transformation rules such as gradient or batching rules. Instead,
transformations for NumPy primitives can be derived from the transformation
rules for the underlying :code:`lax` primitives.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function