mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
update readme gotchas about pure functions
This commit is contained in:
parent
6b5ef898dc
commit
a61bcff54d
11
README.md
11
README.md
@ -343,22 +343,23 @@ we highly recommend reading the [Gotchas
|
||||
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
|
||||
Some standouts:
|
||||
|
||||
1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`.
|
||||
1. [In-place mutating updates of
|
||||
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-In-Place-Updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
|
||||
2. [Random numbers are
|
||||
1. [Random numbers are
|
||||
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers), but for [good reasons](https://github.com/google/jax/blob/master/design_notes/prng.md).
|
||||
3. If you're looking for [convolution
|
||||
1. If you're looking for [convolution
|
||||
operators](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Convolutions),
|
||||
they're in the `jax.lax` package.
|
||||
4. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
|
||||
1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
|
||||
[to enable
|
||||
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
|
||||
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
|
||||
startup (or set the environment variable `JAX_ENABLE_X64=True`).
|
||||
5. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
|
||||
1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
|
||||
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
|
||||
np.float32)).dtype` is `float64` rather than `float32`.
|
||||
6. Some transformations, like `jit`, [constrain how you can use Python control
|
||||
1. Some transformations, like `jit`, [constrain how you can use Python control
|
||||
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow).
|
||||
You'll always get loud errors if something goes wrong. You might have to use
|
||||
[`jit`'s `static_argnums`
|
||||
|
Loading…
x
Reference in New Issue
Block a user