Merge pull request #6919 from marcvanzee:patch-3

PiperOrigin-RevId: 378352041
This commit is contained in:
jax authors 2021-06-09 01:53:27 -07:00
commit 86d2da44c0

View File

@ -559,12 +559,12 @@ jax2tf.convert(jax.grad(f_jax, allow_int=True))(2))
### Different 64-bit precision in JAX and TensorFlow
JAX behaves somewhat differently than TensorFlow in the handling
of 32-bit vs. 64-bit values. However, always the `jax2tf.convert`
function behaves like the JAX function.
of 32-bit vs. 64-bit values. However, the `jax2tf.convert` function
always behaves like the JAX function.
JAX interprets the type of Python scalars differently based on
`JAX_ENABLE_X64` flag.
See the [JAX type promotion documentation](https://jax.readthedocs.io/en/latest/type_promotion.html).)
`JAX_ENABLE_X64` flag. (See
[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).)
In the default configuration, the
flag is unset, and JAX interprets Python constants as 32-bit,
e.g., the type of `3.14` is `float32`. This is also what
@ -573,7 +573,7 @@ all explicitly-specified 64-bit values to be interpreted as
32-bit:
```
# with JAX_ENABLE_x64=0
# with JAX_ENABLE_X64=0
jnp.sin(3.14) # Has type float32
tf.math.sin(3.14) # Has type float32
@ -588,11 +588,11 @@ jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64))
```
When the `JAX_ENABLE_x64` flas is set, JAX uses 64-bit types
When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types
for Python scalars and respects the explicit 64-bit types:
```
# with JAX_ENABLE_x64=1
# with JAX_ENABLE_X64=1
jnp.sin(3.14) # Has type float64
tf.math.sin(3.14) # Has type float32