mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add TPU precision details to README gotchas
This commit is contained in:
parent
b744a84fdc
commit
6ca775b10a
@ -368,6 +368,10 @@ Some standouts:
|
|||||||
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
|
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
|
||||||
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
|
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
|
||||||
startup (or set the environment variable `JAX_ENABLE_X64=True`).
|
startup (or set the environment variable `JAX_ENABLE_X64=True`).
|
||||||
|
On TPU, JAX uses 32-bit values by default for everything _except_ internal
|
||||||
|
temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`.
|
||||||
|
Those ops have a `precision` parameter which can be used to simulate
|
||||||
|
true 32-bit, with a cost of possibly slower runtime.
|
||||||
1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
|
1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
|
||||||
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
|
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
|
||||||
np.float32)).dtype` is `float64` rather than `float32`.
|
np.float32)).dtype` is `float64` rather than `float32`.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user