mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6919 from marcvanzee:patch-3
PiperOrigin-RevId: 378352041
This commit is contained in:
commit
86d2da44c0
@ -559,12 +559,12 @@ jax2tf.convert(jax.grad(f_jax, allow_int=True))(2))
|
||||
### Different 64-bit precision in JAX and TensorFlow
|
||||
|
||||
JAX behaves somewhat differently than TensorFlow in the handling
|
||||
of 32-bit vs. 64-bit values. However, always the `jax2tf.convert`
|
||||
function behaves like the JAX function.
|
||||
of 32-bit vs. 64-bit values. However, the `jax2tf.convert` function
|
||||
always behaves like the JAX function.
|
||||
|
||||
JAX interprets the type of Python scalars differently based on
|
||||
`JAX_ENABLE_X64` flag.
|
||||
See the [JAX type promotion documentation](https://jax.readthedocs.io/en/latest/type_promotion.html).)
|
||||
`JAX_ENABLE_X64` flag. (See
|
||||
[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).)
|
||||
In the default configuration, the
|
||||
flag is unset, and JAX interprets Python constants as 32-bit,
|
||||
e.g., the type of `3.14` is `float32`. This is also what
|
||||
@ -573,7 +573,7 @@ all explicitly-specified 64-bit values to be interpreted as
|
||||
32-bit:
|
||||
|
||||
```
|
||||
# with JAX_ENABLE_x64=0
|
||||
# with JAX_ENABLE_X64=0
|
||||
jnp.sin(3.14) # Has type float32
|
||||
tf.math.sin(3.14) # Has type float32
|
||||
|
||||
@ -588,11 +588,11 @@ jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32
|
||||
tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64))
|
||||
```
|
||||
|
||||
When the `JAX_ENABLE_x64` flas is set, JAX uses 64-bit types
|
||||
When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types
|
||||
for Python scalars and respects the explicit 64-bit types:
|
||||
|
||||
```
|
||||
# with JAX_ENABLE_x64=1
|
||||
# with JAX_ENABLE_X64=1
|
||||
jnp.sin(3.14) # Has type float64
|
||||
tf.math.sin(3.14) # Has type float32
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user