Add TPU precision details to README gotchas

This commit is contained in:
Jamie Townsend 2021-07-16 12:24:19 +02:00 committed by GitHub
parent b744a84fdc
commit 6ca775b10a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -368,6 +368,10 @@ Some standouts:
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
startup (or set the environment variable `JAX_ENABLE_X64=True`).
On TPU, JAX uses 32-bit values by default for everything _except_ internal
temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`.
Those ops have a `precision` parameter which can be used to simulate
true 32-bit, with a cost of possibly slower runtime.
1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
np.float32)).dtype` is `float64` rather than `float32`.