mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add documentation to JAX modules.
This commit is contained in:
parent
c9eb063c19
commit
4eb1820ae2
@ -3,13 +3,19 @@ jax.lax package
|
||||
|
||||
.. automodule:: jax.lax
|
||||
|
||||
`lax` is a library of primitives that underpins libraries such as `jax.numpy`.
|
||||
:mod:`jax.lax` is a library of primitives operations that underpins libraries
|
||||
such as :mod:`jax.numpy`. Transformation rules, such as JVP and batching rules,
|
||||
are typically defined as transformations on :mod:`jax.lax` primitives.
|
||||
|
||||
Many of the primitives are thin wrappers around equivalent XLA operations,
|
||||
described by the `XLA operation semantics
|
||||
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation.
|
||||
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation. In a few
|
||||
cases JAX diverges from XLA, usually to ensure that the set of operations is
|
||||
closed under the operation of JVP and transpose rules.
|
||||
|
||||
Where possible, prefer to use libraries such as `jax.numpy` instead of using `jax.lax` directly.
|
||||
Where possible, prefer to use libraries such as :mod:`jax.numpy` instead of
|
||||
using :mod:`jax.lax` directly. The :mod:`jax.numpy` API follows NumPy, and is
|
||||
therefore more stable and less likely to change than the :mod:`jax.lax` API.
|
||||
|
||||
Operators
|
||||
---------
|
||||
|
@ -6,6 +6,29 @@ jax.numpy package
|
||||
|
||||
.. automodule:: jax.numpy
|
||||
|
||||
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
|
||||
|
||||
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX
|
||||
cannot follow NumPy exactly.
|
||||
|
||||
* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
|
||||
in-place cannot be implemented in JAX. However, often JAX is able to provide a
|
||||
alternative API that is purely functional. For example, instead of in-place
|
||||
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
|
||||
update function :func:`jax.ops.index_update`.
|
||||
|
||||
* NumPy is very aggressive at promoting values to :code:`float64` type. JAX
|
||||
sometimes is less aggressive about type promotion.
|
||||
|
||||
A small number of NumPy operations that have data-dependent output shapes are
|
||||
incompatible with :func:`jax.jit` compilation. The XLA compiler requires that
|
||||
shapes of arrays be known at compile time. While it would be possible to provide
|
||||
a JAX implementation of an API such as :func:`numpy.nonzero`, we would be unable
|
||||
to JIT-compile it because the shape of its output depends on the contents of the
|
||||
input data.
|
||||
|
||||
Not every function in NumPy is implemented; contributions are welcome!
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
|
@ -11,7 +11,8 @@ Indexed update operators
|
||||
------------------------
|
||||
|
||||
JAX is intended to be used with a functional style of programming, and hence
|
||||
does not support NumPy-style indexed assignment directly.
|
||||
does not support NumPy-style indexed assignment directly. Instead, JAX provides
|
||||
pure alternatives, namely :func:`jax.ops.index_update` and its relatives.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
@ -21,4 +22,11 @@ does not support NumPy-style indexed assignment directly.
|
||||
index_add
|
||||
index_min
|
||||
index_max
|
||||
|
||||
Other operators
|
||||
---------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
segment_sum
|
||||
|
32
docs/jax.rst
32
docs/jax.rst
@ -15,10 +15,36 @@ Subpackages
|
||||
jax.random
|
||||
jax.tree_util
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
Just-in-time compilation (:code:`jit`)
|
||||
--------------------------------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: jit, disable_jit, grad, value_and_grad, vmap, pmap, jacfwd, jacrev, hessian, jvp, linearize, vjp, make_jaxpr, eval_shape, custom_transforms, defjvp, defjvp_all, defvjp, defvjp_all, custom_gradient, xla_computation
|
||||
:members: jit, disable_jit, xla_computation, make_jaxpr, eval_shape
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Automatic differentiation
|
||||
-------------------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: grad, value_and_grad, jacfwd, jacrev, hessian, jvp, linearize, vjp, custom_transforms, defjvp, defjvp_all, defvjp, defvjp_all, custom_gradient
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Vectorization (:code:`vmap`)
|
||||
----------------------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: vmap
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
Parallelization (:code:`pmap`)
|
||||
----------------------------
|
||||
|
||||
.. automodule:: jax
|
||||
:members: pmap
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
@ -12,6 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
|
||||
|
||||
NumPy operations are implemented in Python in terms of the primitive operations
|
||||
in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
|
||||
implemented in terms of :mod:`jax.lax` operations, we do not need to define
|
||||
transformation rules such as gradient or batching rules. Instead,
|
||||
transformations for NumPy primitives can be derived from the transformation
|
||||
rules for the underlying :code:`lax` primitives.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
Loading…
x
Reference in New Issue
Block a user