mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
DOC: re-organize notes section
This commit is contained in:
parent
65ef487a82
commit
01b5799239
@ -2,14 +2,33 @@
|
||||
|
||||
Notes
|
||||
-----
|
||||
This section contains shorter notes on topics relevant to using JAX; see also the
|
||||
longer design discussions in :doc:`jep/index`.
|
||||
|
||||
Dependencies and version compatibility:
|
||||
- :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases.
|
||||
- :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy.
|
||||
|
||||
Migrations and deprecations:
|
||||
- :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1
|
||||
|
||||
Memory and computation usage:
|
||||
- :doc:`async_dispatch` describes JAX's asynchronous dispatch model.
|
||||
- :doc:`concurrency` describes how JAX interacts with other Python concurrency.
|
||||
- :doc:`gpu_memory_allocation` describes how JAX interacts with memory allocation on GPU.
|
||||
|
||||
Programmer guardrails:
|
||||
- :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion.
|
||||
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
:maxdepth: 1
|
||||
|
||||
api_compatibility
|
||||
deprecation
|
||||
jax_array_migration
|
||||
async_dispatch
|
||||
concurrency
|
||||
gpu_memory_allocation
|
||||
rank_promotion_warning
|
||||
jax_array_migration
|
||||
type_promotion
|
||||
rank_promotion_warning
|
@ -23,19 +23,27 @@ expressions requiring rank promotion can lead to a warning, error, or can be
|
||||
allowed just like regular NumPy. The configuration option is named
|
||||
:code:`jax_numpy_rank_promotion` and it can take on string values
|
||||
:code:`allow`, :code:`warn`, and :code:`raise`. The default setting is
|
||||
:code:`warn`, which raises a warning on the first occurrence of rank promotion.
|
||||
The :code:`raise` setting raises an error on rank promotion, and :code:`allow`
|
||||
allows rank promotion without warning or error.
|
||||
:code:`allow`, which allows rank promotion without warning or error.
|
||||
The :code:`raise` setting raises an error on rank promotion, and :code:`warn`
|
||||
raises a warning on the first occurrence of rank promotion.
|
||||
|
||||
As with most other JAX configuration options, you can set this option in
|
||||
several ways. One is by using :code:`jax.config` in your code:
|
||||
Rank promotion can be enabled or disabled locally with the :func:`jax.numpy_rank_promotion`
|
||||
context manager:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
z = x + y
|
||||
|
||||
This configuration can also be set globally in several ways.
|
||||
One is by using :code:`jax.config` in your code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from jax.config import config
|
||||
config.update("jax_numpy_rank_promotion", "allow")
|
||||
config.update("jax_numpy_rank_promotion", "warn")
|
||||
|
||||
You can also set the option using the environment variable
|
||||
:code:`JAX_NUMPY_RANK_PROMOTION`, for example as
|
||||
:code:`JAX_NUMPY_RANK_PROMOTION='raise'`. Finally, when using :code:`absl-py`
|
||||
:code:`JAX_NUMPY_RANK_PROMOTION='warn'`. Finally, when using :code:`absl-py`
|
||||
the option can be set with a command-line flag.
|
||||
|
@ -7,7 +7,6 @@ User Guides
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
async_dispatch
|
||||
aot
|
||||
jaxpr
|
||||
pytrees
|
||||
@ -16,4 +15,5 @@ User Guides
|
||||
profiling
|
||||
device_memory_profiling
|
||||
transfer_guard
|
||||
type_promotion
|
||||
notebooks/external_callbacks
|
||||
|
Loading…
x
Reference in New Issue
Block a user