DOC: add documentation note about default dtypes

This commit is contained in:
Jake VanderPlas 2025-03-28 15:20:58 -07:00
parent b3a2c5341d
commit dafebd0d7f
4 changed files with 95 additions and 5 deletions

82
docs/default_dtypes.md Normal file
View File

@ -0,0 +1,82 @@
(default-dtypes)=
# Default dtypes and the X64 flag
JAX strives to meet the needs of a range of numerical computing practitioners, who
sometimes have conflicting preferences. When it comes to default dtypes, there are
two different camps:
- Classic scientific computing practitioners (i.e. users of tools like {mod}`numpy` or
{mod}`scipy`) tend to value accuracy of computations foremost: such users would
prefer that computations default to the **widest available representation**: e.g.
floating point values should default to `float64`, integers to `int64`, etc.
- AI researchers (i.e. folks implementing and training neural networks) tend to value
speed over accuracy, to the point where they have developed special data types like
[bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) and others
which deliberately discard the least significant bits in order to speed up computation.
For these users, the mere presence of a float64 value in their computation can lead
to programs that are slow at best, and incompatible with their hardware at worst!
These users would prefer that computations default to `float32` or `int32`.
The main mechanism JAX offers for this is the `jax_enable_x64` flag, which controls
whether 64-bit values can be created at all. By default this flag is set to `False`
(serving the needs of AI researchers and practitioners), but can be set to `True`
by users who value accuracy over computational speed.
## Default setting: 32-bits everywhere
By default `jax_enable_x64` is set to False, and so {mod}`jax.numpy` array creation
functions will default to returning 32-bit values.
For example:
```python
>>> import jax.numpy as jnp
>>> jnp.arange(5)
Array([0, 1, 2, 3, 4], dtype=int32)
>>> jnp.zeros(5)
Array([0., 0., 0., 0., 0.], dtype=float32)
>>> jnp.ones(5, dtype=int)
Array([1, 1, 1, 1, 1], dtype=int32)
```
Beyond defaults, because 64-bit values can be so poisonous to AI workflows, having
this flag set to False prevents you from creating 64-bit arrays at all! For example:
```
>>> jnp.arange(5, dtype='float64') # doctest: +SKIP
UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be
truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the
JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
Array([0., 1., 2., 3., 4.], dtype=float32)
```
## The X64 flag: enabling 64-bit values
To work in the "other mode" where functions default to producing 64-bit values, you can set the
`jax_enable_x64` flag to `True`:
```python
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
print(repr(jnp.arange(5)))
print(repr(jnp.zeros(5)))
print(repr(jnp.ones(5, dtype=int)))
```
```
Array([0, 1, 2, 3, 4], dtype=int64)
Array([0., 0., 0., 0., 0.], dtype=float64)
Array([1, 1, 1, 1, 1], dtype=int64)
```
The X64 configuration can also be set via the `JAX_ENABLE_X64` shell environment variable,
for example:
```bash
$ JAX_ENABLE_X64=1 python main.py
```
The X64 flag is intended as a **global setting** that should have one value for your whole
program, set at the top of your main file. A common feature request is for the flag to
be contextually configurable (e.g. enabling X64 just for one section of a long program):
this turns out to be difficult to implement within JAX's programming model, where code
execution may happen in a different context than code compilation. There is ongoing work
exploring the feasibility of relaxing this constraint, so stay tuned!

View File

@ -17,6 +17,10 @@ Memory and computation usage:
Programmer guardrails:
- :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion.
Arrays and data types:
- :doc:`type_promotion` describes JAX's implicit type promotion for functions of two or more values.
- :doc:`default_dtypes` describes how JAX determines the default dtype for array creation functions.
.. toctree::
:hidden:
@ -27,4 +31,6 @@ Programmer guardrails:
async_dispatch
concurrency
gpu_memory_allocation
rank_promotion_warning
rank_promotion_warning
type_promotion
default_dtypes

View File

@ -26,7 +26,6 @@ or deployed codebases.
errors
aot
export/index
type_promotion
transfer_guard
.. toctree::

View File

@ -50,7 +50,8 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *,
Args:
shape: int or sequence of ints specifying the shape of the created array.
dtype: optional dtype for the created array; defaults to floating point.
dtype: optional dtype for the created array; defaults to float32 or float64
depending on the X64 configuration (see :ref:`default-dtypes`).
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
@ -87,7 +88,8 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *,
Args:
shape: int or sequence of ints specifying the shape of the created array.
dtype: optional dtype for the created array; defaults to floating point.
dtype: optional dtype for the created array; defaults to float32 or float64
depending on the X64 configuration (see :ref:`default-dtypes`).
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
@ -126,7 +128,8 @@ def empty(shape: Any, dtype: DTypeLike | None = None, *,
Args:
shape: int or sequence of ints specifying the shape of the created array.
dtype: optional dtype for the created array; defaults to floating point.
dtype: optional dtype for the created array; defaults to float32 or float64
depending on the X64 configuration (see :ref:`default-dtypes`).
device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.