mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add TPU precision details to README gotchas
This commit is contained in:
parent
b744a84fdc
commit
6ca775b10a
@ -368,6 +368,10 @@ Some standouts:
|
||||
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
|
||||
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
|
||||
startup (or set the environment variable `JAX_ENABLE_X64=True`).
|
||||
On TPU, JAX uses 32-bit values by default for everything _except_ internal
|
||||
temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`.
|
||||
Those ops have a `precision` parameter which can be used to simulate
|
||||
true 32-bit, with a cost of possibly slower runtime.
|
||||
1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
|
||||
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
|
||||
np.float32)).dtype` is `float64` rather than `float32`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user