rocm_jax/docs/jax.rst
Matthew Johnson 1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00

142 lines
2.3 KiB
ReStructuredText

.. currentmodule:: jax
Public API: jax package
=======================
Subpackages
-----------
.. toctree::
:maxdepth: 1
jax.numpy
jax.scipy
jax.example_libraries
jax.experimental
jax.image
jax.lax
jax.nn
jax.ops
jax.random
jax.tree_util
jax.flatten_util
jax.dlpack
jax.profiler
jax.config
.. toctree::
:hidden:
jax.lib
.. _jax-jit:
Just-in-time compilation (:code:`jit`)
--------------------------------------
.. autosummary::
jit
disable_jit
ensure_compile_time_eval
xla_computation
make_jaxpr
eval_shape
device_put
device_put_replicated
device_put_sharded
device_get
default_backend
named_call
.. _jax-grad:
Automatic differentiation
-------------------------
.. autosummary::
grad
value_and_grad
jacfwd
jacrev
hessian
jvp
linearize
linear_transpose
vjp
custom_jvp
custom_vjp
closure_convert
checkpoint
Vectorization (:code:`vmap`)
----------------------------
.. autosummary::
vmap
jax.numpy.vectorize
Parallelization (:code:`pmap`)
------------------------------
.. autosummary::
pmap
devices
local_devices
process_index
device_count
local_device_count
process_count
.. autofunction:: jit
.. autofunction:: disable_jit
.. autofunction:: xla_computation
.. autofunction:: make_jaxpr
.. autofunction:: eval_shape
.. autofunction:: device_put
.. autofunction:: device_put_replicated
.. autofunction:: device_put_sharded
.. autofunction:: device_get
.. autofunction:: block_until_ready
.. autofunction:: default_backend
.. autofunction:: named_call
.. autofunction:: grad
.. autofunction:: value_and_grad
.. autofunction:: jacfwd
.. autofunction:: jacrev
.. autofunction:: hessian
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: linear_transpose
.. autofunction:: vjp
.. autoclass:: custom_jvp
.. automethod:: defjvp
.. automethod:: defjvps
.. autoclass:: custom_vjp
.. automethod:: defvjp
.. autofunction:: closure_convert
.. autofunction:: checkpoint
.. autofunction:: vmap
.. autofunction:: jax.numpy.vectorize
:noindex:
.. autofunction:: pmap
.. autofunction:: devices
.. autofunction:: local_devices
.. autofunction:: process_index
.. autofunction:: device_count
.. autofunction:: local_device_count
.. autofunction:: process_count