DOC: re-organize notes section

This commit is contained in:
Jake VanderPlas 2023-01-30 11:57:06 -08:00
parent 65ef487a82
commit 01b5799239
3 changed files with 38 additions and 11 deletions

View File

@ -2,14 +2,33 @@
Notes
-----
This section contains shorter notes on topics relevant to using JAX; see also the
longer design discussions in :doc:`jep/index`.
Dependencies and version compatibility:
- :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases.
- :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy.
Migrations and deprecations:
- :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1
Memory and computation usage:
- :doc:`async_dispatch` describes JAX's asynchronous dispatch model.
- :doc:`concurrency` describes how JAX interacts with other Python concurrency.
- :doc:`gpu_memory_allocation` describes how JAX interacts with memory allocation on GPU.
Programmer guardrails:
- :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion.
.. toctree::
:hidden:
:maxdepth: 1
api_compatibility
deprecation
jax_array_migration
async_dispatch
concurrency
gpu_memory_allocation
rank_promotion_warning
jax_array_migration
type_promotion
rank_promotion_warning

View File

@ -23,19 +23,27 @@ expressions requiring rank promotion can lead to a warning, error, or can be
allowed just like regular NumPy. The configuration option is named
:code:`jax_numpy_rank_promotion` and it can take on string values
:code:`allow`, :code:`warn`, and :code:`raise`. The default setting is
:code:`warn`, which raises a warning on the first occurrence of rank promotion.
The :code:`raise` setting raises an error on rank promotion, and :code:`allow`
allows rank promotion without warning or error.
:code:`allow`, which allows rank promotion without warning or error.
The :code:`raise` setting raises an error on rank promotion, and :code:`warn`
raises a warning on the first occurrence of rank promotion.
As with most other JAX configuration options, you can set this option in
several ways. One is by using :code:`jax.config` in your code:
Rank promotion can be enabled or disabled locally with the :func:`jax.numpy_rank_promotion`
context manager:
.. code-block:: python
with jax.numpy_rank_promotion("warn"):
z = x + y
This configuration can also be set globally in several ways.
One is by using :code:`jax.config` in your code:
.. code-block:: python
from jax.config import config
config.update("jax_numpy_rank_promotion", "allow")
config.update("jax_numpy_rank_promotion", "warn")
You can also set the option using the environment variable
:code:`JAX_NUMPY_RANK_PROMOTION`, for example as
:code:`JAX_NUMPY_RANK_PROMOTION='raise'`. Finally, when using :code:`absl-py`
:code:`JAX_NUMPY_RANK_PROMOTION='warn'`. Finally, when using :code:`absl-py`
the option can be set with a command-line flag.

View File

@ -7,7 +7,6 @@ User Guides
.. toctree::
:maxdepth: 1
async_dispatch
aot
jaxpr
pytrees
@ -16,4 +15,5 @@ User Guides
profiling
device_memory_profiling
transfer_guard
type_promotion
notebooks/external_callbacks