Update README.md

This commit is contained in:
Marc van Zee 2021-06-08 07:21:40 +02:00 committed by GitHub
parent 7a3a160b26
commit 80e69d456e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -545,12 +545,12 @@ jax2tf.convert(jax.grad(f_jax, allow_int=True))(2))
### Different 64-bit precision in JAX and TensorFlow
JAX behaves somewhat differently than TensorFlow in the handling
of 32-bit vs. 64-bit values. However, always the `jax2tf.convert`
function behaves like the JAX function.
of 32-bit vs. 64-bit values. However, the `jax2tf.convert` function
always behaves like the JAX function.
JAX interprets the type of Python scalars differently based on
`JAX_ENABLE_X64` flag.
See the [JAX type promotion documentation](https://jax.readthedocs.io/en/latest/type_promotion.html).)
`JAX_ENABLE_X64` flag. (See
[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).)
In the default configuration, the
flag is unset, and JAX interprets Python constants as 32-bit,
e.g., the type of `3.14` is `float32`. This is also what
@ -559,7 +559,7 @@ all explicitly-specified 64-bit values to be interpreted as
32-bit:
```
# with JAX_ENABLE_x64=0
# with JAX_ENABLE_X64=0
jnp.sin(3.14) # Has type float32
tf.math.sin(3.14) # Has type float32
@ -574,11 +574,11 @@ jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64))
```
When the `JAX_ENABLE_x64` flas is set, JAX uses 64-bit types
When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types
for Python scalars and respects the explicit 64-bit types:
```
# with JAX_ENABLE_x64=1
# with JAX_ENABLE_X64=1
jnp.sin(3.14) # Has type float64
tf.math.sin(3.14) # Has type float32