2019-08-20 23:13:15 -07:00
|
|
|
.. currentmodule:: jax
|
|
|
|
|
2020-02-10 11:40:05 +01:00
|
|
|
Public API: jax package
|
|
|
|
=======================
|
2019-01-15 20:14:19 -05:00
|
|
|
|
|
|
|
Subpackages
|
|
|
|
-----------
|
|
|
|
|
|
|
|
.. toctree::
|
|
|
|
:maxdepth: 1
|
|
|
|
|
|
|
|
jax.numpy
|
|
|
|
jax.scipy
|
2022-02-17 16:05:33 -05:00
|
|
|
jax.config
|
2022-07-26 14:47:36 -07:00
|
|
|
jax.debug
|
2022-02-17 16:05:33 -05:00
|
|
|
jax.dlpack
|
|
|
|
jax.distributed
|
2021-10-19 17:30:16 -07:00
|
|
|
jax.example_libraries
|
2019-01-15 20:14:19 -05:00
|
|
|
jax.experimental
|
2022-02-17 16:05:33 -05:00
|
|
|
jax.flatten_util
|
2020-07-10 09:57:59 -04:00
|
|
|
jax.image
|
2019-01-15 20:14:19 -05:00
|
|
|
jax.lax
|
2019-08-29 17:51:15 -07:00
|
|
|
jax.nn
|
2019-02-22 07:55:36 -05:00
|
|
|
jax.ops
|
2022-02-17 16:05:33 -05:00
|
|
|
jax.profiler
|
2019-02-13 19:31:41 -08:00
|
|
|
jax.random
|
2019-05-14 21:00:27 -04:00
|
|
|
jax.tree_util
|
2019-01-15 20:14:19 -05:00
|
|
|
|
2021-09-27 09:48:27 -07:00
|
|
|
.. toctree::
|
|
|
|
:hidden:
|
|
|
|
|
|
|
|
jax.lib
|
|
|
|
|
2021-03-08 16:25:04 -08:00
|
|
|
.. _jax-jit:
|
|
|
|
|
2019-07-20 14:40:31 +01:00
|
|
|
Just-in-time compilation (:code:`jit`)
|
|
|
|
--------------------------------------
|
2019-01-15 20:14:19 -05:00
|
|
|
|
2020-05-04 12:37:29 -07:00
|
|
|
.. autosummary::
|
2022-02-04 13:09:24 -08:00
|
|
|
:toctree: _autosummary
|
2020-05-04 12:37:29 -07:00
|
|
|
|
|
|
|
jit
|
|
|
|
disable_jit
|
2022-01-10 20:57:56 -08:00
|
|
|
ensure_compile_time_eval
|
2020-05-04 12:37:29 -07:00
|
|
|
xla_computation
|
|
|
|
make_jaxpr
|
|
|
|
eval_shape
|
|
|
|
device_put
|
2021-02-23 10:31:44 -08:00
|
|
|
device_put_replicated
|
|
|
|
device_put_sharded
|
2021-07-04 19:55:45 +08:00
|
|
|
device_get
|
2021-02-04 11:56:41 +00:00
|
|
|
default_backend
|
2020-11-04 21:01:42 -08:00
|
|
|
named_call
|
2022-06-09 10:34:25 -07:00
|
|
|
named_scope
|
2022-02-04 13:09:24 -08:00
|
|
|
block_until_ready
|
2020-05-04 12:37:29 -07:00
|
|
|
|
2021-03-08 16:25:04 -08:00
|
|
|
.. _jax-grad:
|
|
|
|
|
2020-05-04 12:37:29 -07:00
|
|
|
Automatic differentiation
|
|
|
|
-------------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
2022-02-04 13:09:24 -08:00
|
|
|
:toctree: _autosummary
|
2020-05-04 12:37:29 -07:00
|
|
|
|
|
|
|
grad
|
|
|
|
value_and_grad
|
|
|
|
jacfwd
|
|
|
|
jacrev
|
|
|
|
hessian
|
|
|
|
jvp
|
|
|
|
linearize
|
2020-09-16 20:29:19 -07:00
|
|
|
linear_transpose
|
2020-05-04 12:37:29 -07:00
|
|
|
vjp
|
|
|
|
custom_jvp
|
|
|
|
custom_vjp
|
2021-01-25 17:42:46 -08:00
|
|
|
closure_convert
|
2020-06-25 07:26:26 -07:00
|
|
|
checkpoint
|
2020-05-04 12:37:29 -07:00
|
|
|
|
|
|
|
|
|
|
|
Vectorization (:code:`vmap`)
|
|
|
|
----------------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
2022-02-04 13:09:24 -08:00
|
|
|
:toctree: _autosummary
|
2020-05-04 12:37:29 -07:00
|
|
|
|
|
|
|
vmap
|
2022-02-04 13:09:24 -08:00
|
|
|
numpy.vectorize
|
2020-05-04 12:37:29 -07:00
|
|
|
|
|
|
|
Parallelization (:code:`pmap`)
|
|
|
|
------------------------------
|
|
|
|
|
|
|
|
.. autosummary::
|
2022-02-04 13:09:24 -08:00
|
|
|
:toctree: _autosummary
|
2020-05-04 12:37:29 -07:00
|
|
|
|
|
|
|
pmap
|
|
|
|
devices
|
|
|
|
local_devices
|
2021-04-20 17:56:41 -07:00
|
|
|
process_index
|
2020-05-04 12:37:29 -07:00
|
|
|
device_count
|
|
|
|
local_device_count
|
2021-04-20 17:56:41 -07:00
|
|
|
process_count
|
2022-08-24 14:04:47 -07:00
|
|
|
|
|
|
|
Callbacks
|
|
|
|
---------
|
|
|
|
|
|
|
|
.. autosummary::
|
|
|
|
:toctree: _autosummary
|
|
|
|
|
|
|
|
pure_callback
|
|
|
|
debug.callback
|