rocm_jax/docs/jax.rst

58 lines
1.0 KiB
ReStructuredText
Raw Normal View History

.. currentmodule:: jax
jax package
===========
Subpackages
-----------
.. toctree::
:maxdepth: 1
jax.numpy
jax.scipy
jax.experimental
jax.lax
jax.ops
2019-02-13 19:31:41 -08:00
jax.random
2019-05-14 21:00:27 -04:00
jax.tree_util
2019-07-20 14:40:31 +01:00
Just-in-time compilation (:code:`jit`)
--------------------------------------
.. autofunction:: jit
.. autofunction:: disable_jit
.. autofunction:: xla_computation
.. autofunction:: make_jaxpr
.. autofunction:: eval_shape
2019-07-20 14:40:31 +01:00
Automatic differentiation
-------------------------
.. autofunction:: grad
.. autofunction:: value_and_grad
.. autofunction:: jacfwd
.. autofunction:: jacrev
.. autofunction:: hessian
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: vjp
.. autofunction:: custom_transforms
.. autofunction:: defjvp
.. autofunction:: defjvp_all
.. autofunction:: defvjp
.. autofunction:: defvjp_all
.. autofunction:: custom_gradient
2019-07-20 14:40:31 +01:00
Vectorization (:code:`vmap`)
----------------------------
.. autofunction:: vmap
2019-07-20 14:40:31 +01:00
Parallelization (:code:`pmap`)
------------------------------
2019-07-20 14:40:31 +01:00
.. autofunction:: pmap