mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10092 from jakevdp:sharp-bits-divergences
PiperOrigin-RevId: 438644183
This commit is contained in:
commit
d9403f626b
@ -1933,7 +1933,7 @@
|
||||
"id": "YTktlwTTMgFl"
|
||||
},
|
||||
"source": [
|
||||
"## Double (64bit) precision\n",
|
||||
"## 🔪 Double (64bit) precision\n",
|
||||
"\n",
|
||||
"At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!"
|
||||
]
|
||||
@ -2051,6 +2051,25 @@
|
||||
"id": "WAHjmL0E2XwO"
|
||||
},
|
||||
"source": [
|
||||
"## 🔪 Miscellaneous Divergences from NumPy\n",
|
||||
"\n",
|
||||
"While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n",
|
||||
"Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n",
|
||||
"\n",
|
||||
"- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n",
|
||||
"- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n",
|
||||
"\n",
|
||||
" Here is an example of an unsafe cast with differing results between NumPy and JAX:\n",
|
||||
" ```python\n",
|
||||
" >>> np.arange(254.0, 258.0).astype('uint8') \n",
|
||||
" array([254, 255, 0, 1], dtype=uint8)\n",
|
||||
"\n",
|
||||
" >>> jnp.arange(254.0, 258.0).astype('uint8') \n",
|
||||
" DeviceArray([254, 255, 255, 255], dtype=uint8)\n",
|
||||
" ```\n",
|
||||
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Fin.\n",
|
||||
"\n",
|
||||
"If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory _advisos_!"
|
||||
|
@ -1062,7 +1062,7 @@ When this code sees a nan in the output of an `@jit` function, it calls into the
|
||||
|
||||
+++ {"id": "YTktlwTTMgFl"}
|
||||
|
||||
## Double (64bit) precision
|
||||
## 🔪 Double (64bit) precision
|
||||
|
||||
At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!
|
||||
|
||||
@ -1133,6 +1133,25 @@ x.dtype # --> dtype('float64')
|
||||
|
||||
+++ {"id": "WAHjmL0E2XwO"}
|
||||
|
||||
## 🔪 Miscellaneous Divergences from NumPy
|
||||
|
||||
While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.
|
||||
Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.
|
||||
|
||||
- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.
|
||||
- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).
|
||||
|
||||
Here is an example of an unsafe cast with differing results between NumPy and JAX:
|
||||
```python
|
||||
>>> np.arange(254.0, 258.0).astype('uint8')
|
||||
array([254, 255, 0, 1], dtype=uint8)
|
||||
|
||||
>>> jnp.arange(254.0, 258.0).astype('uint8')
|
||||
DeviceArray([254, 255, 255, 255], dtype=uint8)
|
||||
```
|
||||
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
|
||||
|
||||
|
||||
## Fin.
|
||||
|
||||
If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory _advisos_!
|
||||
|
Loading…
x
Reference in New Issue
Block a user