diff --git a/docs/notes.rst b/docs/notes.rst index aead9e23e..082656380 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -2,14 +2,33 @@ Notes ----- +This section contains shorter notes on topics relevant to using JAX; see also the +longer design discussions in :doc:`jep/index`. + +Dependencies and version compatibility: + - :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases. + - :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy. + +Migrations and deprecations: + - :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1 + +Memory and computation usage: + - :doc:`async_dispatch` describes JAX's asynchronous dispatch model. + - :doc:`concurrency` describes how JAX interacts with other Python concurrency. + - :doc:`gpu_memory_allocation` describes how JAX interacts with memory allocation on GPU. + +Programmer guardrails: + - :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion. + .. toctree:: + :hidden: :maxdepth: 1 api_compatibility deprecation + jax_array_migration + async_dispatch concurrency gpu_memory_allocation - rank_promotion_warning - jax_array_migration - type_promotion \ No newline at end of file + rank_promotion_warning \ No newline at end of file diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst index 0c29af2f4..500cbbc42 100644 --- a/docs/rank_promotion_warning.rst +++ b/docs/rank_promotion_warning.rst @@ -23,19 +23,27 @@ expressions requiring rank promotion can lead to a warning, error, or can be allowed just like regular NumPy. The configuration option is named :code:`jax_numpy_rank_promotion` and it can take on string values :code:`allow`, :code:`warn`, and :code:`raise`. The default setting is -:code:`warn`, which raises a warning on the first occurrence of rank promotion. -The :code:`raise` setting raises an error on rank promotion, and :code:`allow` -allows rank promotion without warning or error. +:code:`allow`, which allows rank promotion without warning or error. +The :code:`raise` setting raises an error on rank promotion, and :code:`warn` +raises a warning on the first occurrence of rank promotion. -As with most other JAX configuration options, you can set this option in -several ways. One is by using :code:`jax.config` in your code: +Rank promotion can be enabled or disabled locally with the :func:`jax.numpy_rank_promotion` +context manager: + +.. code-block:: python + + with jax.numpy_rank_promotion("warn"): + z = x + y + +This configuration can also be set globally in several ways. +One is by using :code:`jax.config` in your code: .. code-block:: python from jax.config import config - config.update("jax_numpy_rank_promotion", "allow") + config.update("jax_numpy_rank_promotion", "warn") You can also set the option using the environment variable :code:`JAX_NUMPY_RANK_PROMOTION`, for example as -:code:`JAX_NUMPY_RANK_PROMOTION='raise'`. Finally, when using :code:`absl-py` +:code:`JAX_NUMPY_RANK_PROMOTION='warn'`. Finally, when using :code:`absl-py` the option can be set with a command-line flag. diff --git a/docs/user_guides.rst b/docs/user_guides.rst index dc5825614..82c2eea03 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -7,7 +7,6 @@ User Guides .. toctree:: :maxdepth: 1 - async_dispatch aot jaxpr pytrees @@ -16,4 +15,5 @@ User Guides profiling device_memory_profiling transfer_guard + type_promotion notebooks/external_callbacks