DOC: document strict dtype promotion mode

This commit is contained in:
Jake VanderPlas 2024-04-12 14:05:05 -07:00
parent 386be2d307
commit 2be17dc778

View File

@ -209,4 +209,45 @@ strongly-typed array value:
.. code-block:: python
>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)
Array(2, dtype=int32)
.. _strict-dtype-promotion:
Strict dtype promotion
----------------------
In some contexts it can be useful to disable implicit type promotion behavior, and
instead require all promotions to be explicit. This can be done in JAX by setting the
``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\
context manager:
.. code-block:: python
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + y # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.
For convenience, strict promotion mode will still allow safe weakly-typed promotions,
so you can still write code code that mixes JAX arrays and Python scalars:
.. code-block:: python
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + 1
>>> print(z)
2.0
If you would prefer to set the configuration globally, you can do so using the standard
configuration update::
jax.config.update('jax_numpy_dtype_promotion', 'strict')
To restore the default standard type promotion, set this configuration to ``'standard'``::
jax.config.update('jax_numpy_dtype_promotion', 'standard')