diff --git a/docs/jax.config.rst b/docs/jax.config.rst deleted file mode 100644 index 198086a8f..000000000 --- a/docs/jax.config.rst +++ /dev/null @@ -1,22 +0,0 @@ -.. currentmodule:: jax - -JAX configuration -================= - -.. autosummary:: - :toctree: _autosummary - - config - check_tracer_leaks - checking_leaks - debug_nans - debug_infs - default_device - default_matmul_precision - default_prng_impl - enable_checks - enable_custom_prng - enable_custom_vjp_by_custom_transpose - log_compiles - numpy_rank_promotion - transfer_guard diff --git a/docs/jax.debug.rst b/docs/jax.debug.rst index 6269ad5a7..e0d91fed9 100644 --- a/docs/jax.debug.rst +++ b/docs/jax.debug.rst @@ -1,6 +1,6 @@ -jax.debug package -================= +``jax.debug`` module +==================== .. currentmodule:: jax.debug diff --git a/docs/jax.distributed.rst b/docs/jax.distributed.rst index b3024e4cc..aa1938ee7 100644 --- a/docs/jax.distributed.rst +++ b/docs/jax.distributed.rst @@ -1,5 +1,5 @@ -jax.distributed module -====================== +``jax.distributed`` module +========================== .. currentmodule:: jax.distributed diff --git a/docs/jax.dlpack.rst b/docs/jax.dlpack.rst index 996ee3f0e..4a6790527 100644 --- a/docs/jax.dlpack.rst +++ b/docs/jax.dlpack.rst @@ -1,5 +1,5 @@ -jax.dlpack module -================= +``jax.dlpack`` module +===================== .. currentmodule:: jax.dlpack diff --git a/docs/jax.example_libraries.optimizers.rst b/docs/jax.example_libraries.optimizers.rst index 8d354566f..885891a91 100644 --- a/docs/jax.example_libraries.optimizers.rst +++ b/docs/jax.example_libraries.optimizers.rst @@ -1,5 +1,5 @@ -jax.example_libraries.optimizers module -======================================= +``jax.example_libraries.optimizers`` module +=========================================== .. automodule:: jax.example_libraries.optimizers :members: diff --git a/docs/jax.example_libraries.rst b/docs/jax.example_libraries.rst index cf75b8796..7314a717c 100644 --- a/docs/jax.example_libraries.rst +++ b/docs/jax.example_libraries.rst @@ -1,5 +1,5 @@ -jax.example_libraries package -============================= +``jax.example_libraries`` module +================================ JAX provides some small, experimental libraries for machine learning. These libraries are in part about providing tools and in part about serving as diff --git a/docs/jax.example_libraries.stax.rst b/docs/jax.example_libraries.stax.rst index ed1ff8cef..0352f16bc 100644 --- a/docs/jax.example_libraries.stax.rst +++ b/docs/jax.example_libraries.stax.rst @@ -1,5 +1,5 @@ -jax.example_libraries.stax module -================================= +``jax.example_libraries.stax`` module +===================================== .. automodule:: jax.example_libraries.stax :members: diff --git a/docs/jax.experimental.checkify.rst b/docs/jax.experimental.checkify.rst index ecf253aa6..5f1d60b06 100644 --- a/docs/jax.experimental.checkify.rst +++ b/docs/jax.experimental.checkify.rst @@ -1,5 +1,5 @@ -jax.experimental.checkify module -===================================== +``jax.experimental.checkify`` module +==================================== .. automodule:: jax.experimental.checkify diff --git a/docs/jax.experimental.global_device_array.rst b/docs/jax.experimental.global_device_array.rst index bbd1e251a..40fb1c95a 100644 --- a/docs/jax.experimental.global_device_array.rst +++ b/docs/jax.experimental.global_device_array.rst @@ -1,5 +1,5 @@ -jax.experimental.global_device_array module -=========================================== +``jax.experimental.global_device_array`` module +=============================================== .. automodule:: jax.experimental.global_device_array diff --git a/docs/jax.experimental.host_callback.rst b/docs/jax.experimental.host_callback.rst index 6cc168611..8ac26b2c3 100644 --- a/docs/jax.experimental.host_callback.rst +++ b/docs/jax.experimental.host_callback.rst @@ -1,5 +1,5 @@ -jax.experimental.host_callback module -===================================== +``jax.experimental.host_callback`` module +========================================= .. automodule:: jax.experimental.host_callback diff --git a/docs/jax.experimental.jet.rst b/docs/jax.experimental.jet.rst index c674f0709..3c5432f88 100644 --- a/docs/jax.experimental.jet.rst +++ b/docs/jax.experimental.jet.rst @@ -1,5 +1,5 @@ -jax.experimental.jet module -=========================== +``jax.experimental.jet`` module +=============================== .. automodule:: jax.experimental.jet diff --git a/docs/jax.experimental.maps.rst b/docs/jax.experimental.maps.rst index 2dd34e89e..938ffdf49 100644 --- a/docs/jax.experimental.maps.rst +++ b/docs/jax.experimental.maps.rst @@ -1,5 +1,5 @@ -jax.experimental.maps module -============================ +``jax.experimental.maps`` module +================================ .. automodule:: jax.experimental.maps diff --git a/docs/jax.experimental.pjit.rst b/docs/jax.experimental.pjit.rst index 64f6b2c02..34fe95ef0 100644 --- a/docs/jax.experimental.pjit.rst +++ b/docs/jax.experimental.pjit.rst @@ -1,5 +1,5 @@ -jax.experimental.pjit module -============================ +``jax.experimental.pjit`` module +================================ .. automodule:: jax.experimental.pjit diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index 8c4af5a94..da51168db 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -1,7 +1,7 @@ .. currentmodule:: jax.experimental -jax.experimental package -======================== +``jax.experimental`` module +=========================== ``jax.experimental.optix`` has been moved into its own Python package (https://github.com/deepmind/optax). diff --git a/docs/jax.experimental.sparse.rst b/docs/jax.experimental.sparse.rst index dcbe62b1c..916b2e636 100644 --- a/docs/jax.experimental.sparse.rst +++ b/docs/jax.experimental.sparse.rst @@ -1,5 +1,5 @@ -jax.experimental.sparse module -============================== +``jax.experimental.sparse`` module +================================== .. automodule:: jax.experimental.sparse diff --git a/docs/jax.flatten_util.rst b/docs/jax.flatten_util.rst index 734aac0d4..df6c09345 100644 --- a/docs/jax.flatten_util.rst +++ b/docs/jax.flatten_util.rst @@ -1,5 +1,5 @@ -jax.flatten_util package -======================== +``jax.flatten_util`` module +=========================== .. currentmodule:: jax.flatten_util diff --git a/docs/jax.image.rst b/docs/jax.image.rst index 8d4aad343..078d2ff1e 100644 --- a/docs/jax.image.rst +++ b/docs/jax.image.rst @@ -1,5 +1,5 @@ -jax.image package -================= +``jax.image`` module +==================== .. currentmodule:: jax.image diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index d69fcffdc..ed6d24b20 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -1,5 +1,5 @@ -jax.lax package -=============== +``jax.lax`` module +================== .. automodule:: jax.lax diff --git a/docs/jax.lib.rst b/docs/jax.lib.rst index 31509cd9f..cf6ddcc86 100644 --- a/docs/jax.lib.rst +++ b/docs/jax.lib.rst @@ -1,5 +1,5 @@ -jax.lib package -=============== +``jax.lib`` module +================== The `jax.lib` package is a set of internal tools and types for bridging between JAX's Python frontend and its XLA backend. diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst index e500ef7a8..9b864289f 100644 --- a/docs/jax.nn.initializers.rst +++ b/docs/jax.nn.initializers.rst @@ -1,6 +1,5 @@ - -jax.nn.initializers package -=========================== +``jax.nn.initializers`` module +============================== .. currentmodule:: jax.nn.initializers diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index 3c5d3f1a4..31ad050df 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -1,5 +1,5 @@ -jax.nn package +``jax.nn`` module ================= .. currentmodule:: jax.nn @@ -13,7 +13,7 @@ jax.nn package Activation functions ------------------------- +-------------------- .. autosummary:: :toctree: _autosummary diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index f1b98d892..9d5f42dad 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -1,5 +1,5 @@ -jax.numpy package -================= +``jax.numpy`` module +==================== .. currentmodule:: jax.numpy diff --git a/docs/jax.ops.rst b/docs/jax.ops.rst index 80f9064f0..7356fcf42 100644 --- a/docs/jax.ops.rst +++ b/docs/jax.ops.rst @@ -1,6 +1,5 @@ - -jax.ops package -=============== +``jax.ops`` module +================== .. currentmodule:: jax.ops diff --git a/docs/jax.profiler.rst b/docs/jax.profiler.rst index 7e99fc388..8e3c29ba8 100644 --- a/docs/jax.profiler.rst +++ b/docs/jax.profiler.rst @@ -1,7 +1,7 @@ .. currentmodule:: jax.profiler -jax.profiler module -=================== +``jax.profiler`` module +======================= .. automodule:: jax.profiler diff --git a/docs/jax.random.rst b/docs/jax.random.rst index e1e12193f..a263978e3 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -1,5 +1,5 @@ -jax.random package -================== +``jax.random`` module +===================== .. automodule:: jax.random diff --git a/docs/jax.rst b/docs/jax.rst index a309b3ca0..1d3eaaf89 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -7,32 +7,52 @@ Subpackages ----------- .. toctree:: - :maxdepth: 1 + :maxdepth: 1 - jax.numpy - jax.scipy - jax.sharding - jax.config - jax.debug - jax.dlpack - jax.distributed - jax.example_libraries - jax.experimental - jax.flatten_util - jax.image - jax.lax - jax.nn - jax.ops - jax.profiler - jax.random - jax.stages - jax.tree_util + jax.numpy + jax.scipy + jax.lax + jax.random + jax.sharding + jax.debug + jax.dlpack + jax.distributed + jax.flatten_util + jax.image + jax.nn + jax.ops + jax.profiler + jax.stages + jax.tree_util + jax.example_libraries + jax.experimental .. toctree:: :hidden: jax.lib +Configuration +------------- + +.. autosummary:: + :toctree: _autosummary + + config + check_tracer_leaks + checking_leaks + debug_nans + debug_infs + default_device + default_matmul_precision + default_prng_impl + enable_checks + enable_custom_prng + enable_custom_vjp_by_custom_transpose + log_compiles + numpy_rank_promotion + transfer_guard + .. _jax-jit: Just-in-time compilation (:code:`jit`) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index 58414c3f5..633547cac 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -1,5 +1,5 @@ -jax.scipy package -================= +``jax.scipy`` module +==================== jax.scipy.fft ------------- diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 9bf66cd90..b066639a6 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -1,5 +1,5 @@ -jax.sharding package -==================== +``jax.sharding`` module +======================= .. automodule:: jax.sharding diff --git a/docs/jax.stages.rst b/docs/jax.stages.rst index 892e1e12e..dddbc1135 100644 --- a/docs/jax.stages.rst +++ b/docs/jax.stages.rst @@ -1,5 +1,5 @@ -jax.stages package -================== +``jax.stages`` module +===================== .. automodule:: jax.stages diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index de1a3f3a1..e83c50fef 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -1,5 +1,5 @@ -jax.tree_util package -===================== +``jax.tree_util`` module +======================== .. currentmodule:: jax.tree_util