mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface. PiperOrigin-RevId: 662142636
164 lines
2.4 KiB
ReStructuredText
164 lines
2.4 KiB
ReStructuredText
.. currentmodule:: jax
|
|
|
|
Public API: jax package
|
|
=======================
|
|
|
|
Subpackages
|
|
-----------
|
|
|
|
.. toctree::
|
|
:maxdepth: 1
|
|
|
|
jax.numpy
|
|
jax.scipy
|
|
jax.lax
|
|
jax.random
|
|
jax.sharding
|
|
jax.debug
|
|
jax.dlpack
|
|
jax.distributed
|
|
jax.dtypes
|
|
jax.flatten_util
|
|
jax.image
|
|
jax.nn
|
|
jax.ops
|
|
jax.profiler
|
|
jax.stages
|
|
jax.tree
|
|
jax.tree_util
|
|
jax.typing
|
|
jax.export
|
|
jax.extend
|
|
jax.example_libraries
|
|
jax.experimental
|
|
|
|
.. toctree::
|
|
:hidden:
|
|
|
|
jax.lib
|
|
|
|
Configuration
|
|
-------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
config
|
|
check_tracer_leaks
|
|
checking_leaks
|
|
debug_nans
|
|
debug_infs
|
|
default_device
|
|
default_matmul_precision
|
|
default_prng_impl
|
|
enable_checks
|
|
enable_custom_prng
|
|
enable_custom_vjp_by_custom_transpose
|
|
log_compiles
|
|
numpy_rank_promotion
|
|
transfer_guard
|
|
|
|
.. _jax-jit:
|
|
|
|
Just-in-time compilation (:code:`jit`)
|
|
--------------------------------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
jit
|
|
disable_jit
|
|
ensure_compile_time_eval
|
|
xla_computation
|
|
make_jaxpr
|
|
eval_shape
|
|
ShapeDtypeStruct
|
|
device_put
|
|
device_put_replicated
|
|
device_put_sharded
|
|
device_get
|
|
default_backend
|
|
named_call
|
|
named_scope
|
|
block_until_ready
|
|
|
|
.. _jax-grad:
|
|
|
|
Automatic differentiation
|
|
-------------------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
grad
|
|
value_and_grad
|
|
jacfwd
|
|
jacrev
|
|
hessian
|
|
jvp
|
|
linearize
|
|
linear_transpose
|
|
vjp
|
|
custom_jvp
|
|
custom_vjp
|
|
custom_gradient
|
|
closure_convert
|
|
checkpoint
|
|
|
|
jax.Array (:code:`jax.Array`)
|
|
-----------------------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
Array
|
|
make_array_from_callback
|
|
make_array_from_single_device_arrays
|
|
make_array_from_process_local_data
|
|
|
|
Vectorization (:code:`vmap`)
|
|
----------------------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
vmap
|
|
numpy.vectorize
|
|
|
|
Parallelization (:code:`pmap`)
|
|
------------------------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
pmap
|
|
devices
|
|
local_devices
|
|
process_index
|
|
device_count
|
|
local_device_count
|
|
process_count
|
|
process_indices
|
|
|
|
Callbacks
|
|
---------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
pure_callback
|
|
experimental.io_callback
|
|
debug.callback
|
|
debug.print
|
|
|
|
Miscellaneous
|
|
-------------
|
|
|
|
.. autosummary::
|
|
:toctree: _autosummary
|
|
|
|
Device
|
|
print_environment_info
|
|
live_arrays
|
|
clear_caches
|