Add DeprecationWarnings to jax.ops.index_... operators.

Remove uses of index_... in Common Gotchas notebook.
This commit is contained in:
Peter Hawkins 2021-10-05 16:12:52 -04:00
parent d8fe8bf598
commit 104a46594b
5 changed files with 611 additions and 364 deletions

File diff suppressed because it is too large Load Diff

View File

@ -55,9 +55,12 @@ JAX transformation and compilation are designed to work only on Python functions
Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
```{code-cell} ipython3
:id: A6R-pdcm4u3v
:outputId: 389605df-a4d5-4d4b-8d74-64e9d5d39456
---
colab:
base_uri: https://localhost:8080/
id: A6R-pdcm4u3v
outputId: 25dcb191-14d4-4620-bcb2-00492d2f24e1
---
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
@ -74,9 +77,12 @@ print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([
```
```{code-cell} ipython3
:id: -N8GhitI2bhD
:outputId: f16ce914-1387-43b4-9b8a-1d6e3b97b11d
---
colab:
base_uri: https://localhost:8080/
id: -N8GhitI2bhD
outputId: fd3624c9-197d-42cb-d97f-c5e0ef885467
---
g = 0.
def impure_uses_globals(x):
return x + g
@ -94,9 +100,12 @@ print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.]))
```
```{code-cell} ipython3
:id: RTB6iFgu4DL6
:outputId: e93d2a70-1c18-477a-d69d-d09ed556305a
---
colab:
base_uri: https://localhost:8080/
id: RTB6iFgu4DL6
outputId: 16697bcd-3623-49b1-aabb-c54614aeadea
---
g = 0.
def impure_saves_global(x):
global g
@ -113,9 +122,12 @@ print ("Saved global: ", g) # Saved global has an internal JAX value
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
```{code-cell} ipython3
:id: TP-Mqf_862C0
:outputId: 78df2d95-2c6f-41c9-84a9-feda6329e75e
---
colab:
base_uri: https://localhost:8080/
id: TP-Mqf_862C0
outputId: 78d55886-54de-483c-e7c4-bafd1d2c7219
---
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
@ -125,9 +137,17 @@ def pure_uses_internal_state(x):
print(jit(pure_uses_internal_state)(5.))
```
+++ {"id": "cDpQ5u63Ba_H"}
It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: w99WXa6bBa_H
outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903
---
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr
@ -164,9 +184,12 @@ iter_operand = iter(range(10))
In Numpy you're used to doing this:
```{code-cell} ipython3
:id: om4xV7_84N9j
:outputId: 733f901e-d433-4dc8-b5bb-0c23bf2b1306
---
colab:
base_uri: https://localhost:8080/
id: om4xV7_84N9j
outputId: 88b0074a-4440-41f6-caa7-031ac2d1a96f
---
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
@ -182,10 +205,13 @@ print(numpy_array)
If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)
```{code-cell} ipython3
:id: 2AxeCufq4wAp
:outputId: d5d873db-cee0-49dc-981d-ec852347f7ca
:tags: [raises-exception]
---
colab:
base_uri: https://localhost:8080/
id: 2AxeCufq4wAp
outputId: fa4a87ad-1a84-471a-a3c5-a1396c432c85
tags: [raises-exception]
---
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
@ -197,66 +223,79 @@ except Exception as e:
+++ {"id": "7mo76sS25Wco"}
__What gives?!__
Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.
Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.
Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at).
Instead, JAX offers the _functional_ update functions: [__index_update__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update), [__index_add__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add), [__index_min__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_min.html#jax.ops.index_min), [__index_max__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_max), and the [__index__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html#jax.ops.index) helper.
+++ {"id": "hfloZ1QXCS_J"}
️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation.
```{code-cell} ipython3
:id: m5lg1RYq5D9p
from jax.ops import index, index_add, index_update
```
+++ {"id": "X2Xjjvd-l8NL"}
### index_update
### Array updates: `x.at[idx].set(y)`
+++ {"id": "SHLY52KQEiuX"}
For example, the update above can be written as:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: PBGI-HIeCP_s
outputId: de13f19a-2066-4df1-d503-764c34585529
---
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
```
+++ {"id": "zUANAw9sCmgu"}
JAX's array update functions, unlike their NumPy versions, operate out-of-place. That is, the updated array is returned as a new array and the original array is not modified by the update.
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: dbB0UmMhCe8f
outputId: 55d46fa1-d0de-4c43-996c-f3bbc87b7175
---
print("original array unchanged:\n", jax_array)
```
+++ {"id": "eM6MyndXL2NY"}
If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_.
```{code-cell} ipython3
:id: ygUJT49b7BBk
:outputId: 1a3511c4-a480-472f-cccb-5e01620cbe99
jax_array = jnp.zeros((3, 3))
print("original array:")
print(jax_array)
new_jax_array = index_update(jax_array, index[1, :], 1.)
print("old array unchanged:")
print(jax_array)
print("new array:")
print(new_jax_array)
```
However, inside __jit__-compiled code, if the __input value__ `x` of `x.at[idx].set(y)` is not reused, the compiler will optimize the array update to occur _in-place_.
+++ {"id": "7to-sF8EmC_y"}
### index_add
### Array updates with other operations
+++ {"id": "iI5cLY1xMBLs"}
+++ {"id": "ZY5l3tAdDmsJ"}
If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_.
Indexed array updates are not limited simply to overwriting values. For example, we can perform indexed addition as follows:
```{code-cell} ipython3
:id: tsw2svao8FUp
:outputId: 874acd15-a493-4d63-efe4-9f440d5d2a12
---
colab:
base_uri: https://localhost:8080/
id: tsw2svao8FUp
outputId: 3c62a3b1-c12d-46f0-da74-791ec4b61e0b
---
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
```
+++ {"id": "sTjJ3WuaDyqU"}
For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at).
+++ {"id": "oZ_jE2WAypdL"}
## 🔪 Out-of-Bounds Indexing
@ -266,10 +305,13 @@ print(new_jax_array)
In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:
```{code-cell} ipython3
:id: 5_ZM-BJUypdO
:outputId: 461f38cd-9452-4bcc-a44f-a07ddfa12f42
:tags: [raises-exception]
---
colab:
base_uri: https://localhost:8080/
id: 5_ZM-BJUypdO
outputId: c9c41ae8-2653-4219-e6dc-09b03faa3b95
tags: [raises-exception]
---
try:
np.arange(10)[11]
except Exception as e:
@ -281,12 +323,17 @@ except Exception as e:
However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in `NaN`). When the indexing operation is an array index update (e.g. `index_add` or `scatter`-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or `gather`-like primitives) the index is clamped to the bounds of the array since __something__ must be returned. For example, the last value of the array will be returned from this indexing operation:
```{code-cell} ipython3
:id: cusaAD0NypdR
:outputId: 48428ad6-6cde-43ad-c12d-2eb9b9fe59cf
---
colab:
base_uri: https://localhost:8080/
id: cusaAD0NypdR
outputId: af1708aa-b50b-4da8-f022-7f2fa67030a8
---
jnp.arange(10)[11]
```
+++ {"id": "J8uO8yevBa_M"}
Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
@ -298,9 +345,12 @@ Note also that, as the two behaviors described above are not inverses of each ot
NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:
```{code-cell} ipython3
:id: sErQES14sjCG
:outputId: 6bc29168-624a-4d51-eef1-220aeaf49985
---
colab:
base_uri: https://localhost:8080/
id: sErQES14sjCG
outputId: 601485ff-4cda-48c5-f76c-2789073c4591
---
np.sum([1, 2, 3])
```
@ -309,9 +359,12 @@ np.sum([1, 2, 3])
JAX departs from this, generally returning a helpful error:
```{code-cell} ipython3
:id: DFEGcENSsmEc
:outputId: 86105261-0aec-41e0-c8a6-16eec437e2a8
---
colab:
base_uri: https://localhost:8080/
id: DFEGcENSsmEc
outputId: 08535679-6c1f-4dd9-a414-d8b59310d1ee
---
try:
jnp.sum([1, 2, 3])
except TypeError as e:
@ -325,9 +378,12 @@ This is a deliberate design choice, because passing lists or tuples to traced fu
For example, consider the following permissive version of `jnp.sum` that allows list inputs:
```{code-cell} ipython3
:id: jhe-L_TwsvKd
:outputId: 24ef84d4-79e5-42de-f8d4-34e6701c2576
---
colab:
base_uri: https://localhost:8080/
id: jhe-L_TwsvKd
outputId: ab2ee183-d9ec-45cc-d6be-5009347e1bc5
---
def permissive_sum(x):
return jnp.sum(jnp.array(x))
@ -340,9 +396,12 @@ permissive_sum(x)
The output is what we would expect, but this hides potential performance issues under the hood. In JAX's tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the ``permissive_sum`` function above:
```{code-cell} ipython3
:id: k81u6DQ7vAjQ
:outputId: 52847378-ba8c-4e84-fb8b-dabbaded6a00
---
colab:
base_uri: https://localhost:8080/
id: k81u6DQ7vAjQ
outputId: 869fc3b9-feda-4aa9-d2e5-5b5107de102d
---
make_jaxpr(permissive_sum)(x)
```
@ -353,9 +412,12 @@ Each entry of the list is handled as a separate input, resulting in a tracing &
If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:
```{code-cell} ipython3
:id: nFf_DydixG8v
:outputId: 5e4392b6-37eb-4a24-ce4f-43518e61d9b1
---
colab:
base_uri: https://localhost:8080/
id: nFf_DydixG8v
outputId: e31b43b3-05f7-4300-fdd2-40e3896f6f8f
---
jnp.sum(jnp.array(x))
```
@ -375,9 +437,12 @@ jnp.sum(jnp.array(x))
You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
```{code-cell} ipython3
:id: rr9FeP41fynt
:outputId: 849d84cf-04ad-4e8b-9505-a92f6c0d7a39
---
colab:
base_uri: https://localhost:8080/
id: rr9FeP41fynt
outputId: df0ceb15-96ec-4a78-e327-c77f7ea3a745
---
print(np.random.random())
print(np.random.random())
print(np.random.random())
@ -444,9 +509,12 @@ JAX instead implements an _explicit_ PRNG where entropy production and consumpti
The random state is described by two unsigned-int32s that we call a __key__:
```{code-cell} ipython3
:id: yPHE7KTWgAWs
:outputId: 329e7757-2461-434c-a08c-fde80a2d10c9
---
colab:
base_uri: https://localhost:8080/
id: yPHE7KTWgAWs
outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
---
from jax import random
key = random.PRNGKey(0)
key
@ -459,9 +527,12 @@ JAX's random functions produce pseudorandom numbers from the PRNG state, but __d
Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:
```{code-cell} ipython3
:id: 7zUdQMynoE5e
:outputId: 50617324-b887-42f2-a7ff-2a10f92d876a
---
colab:
base_uri: https://localhost:8080/
id: 7zUdQMynoE5e
outputId: 23a86b72-dfb9-410a-8e68-22b48dc10805
---
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
@ -474,9 +545,12 @@ print(key)
Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:
```{code-cell} ipython3
:id: ASj0_rSzqgGh
:outputId: bcc2ed60-2e41-4ef8-e84f-c724654aa198
---
colab:
base_uri: https://localhost:8080/
id: ASj0_rSzqgGh
outputId: 2f13f249-85d1-47bb-d503-823eca6961aa
---
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
@ -489,9 +563,12 @@ print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
We propagate the __key__ and make new __subkeys__ whenever we need a new random number:
```{code-cell} ipython3
:id: jbC34XLor2Ek
:outputId: 6834a812-7160-4646-ee19-a246f683905a
---
colab:
base_uri: https://localhost:8080/
id: jbC34XLor2Ek
outputId: 4059a2e2-0205-40bc-ad55-17709d538871
---
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
@ -504,9 +581,12 @@ print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
We can generate more than one __subkey__ at a time:
```{code-cell} ipython3
:id: lEi08PJ4tfkX
:outputId: 3bb513de-8d14-4d37-ae57-51d6f5eaa762
---
colab:
base_uri: https://localhost:8080/
id: lEi08PJ4tfkX
outputId: 1f280560-155d-4c04-98e8-c41d72ee5b01
---
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
@ -523,9 +603,12 @@ for subkey in subkeys:
If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
```{code-cell} ipython3
:id: aAx0T3F8lLtu
:outputId: 808cfa77-d924-4586-af19-35a8fd7d2238
---
colab:
base_uri: https://localhost:8080/
id: aAx0T3F8lLtu
outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52
---
def f(x):
if x < 3:
return 3. * x ** 2
@ -545,9 +628,12 @@ Using control flow with `jit` is more complicated, and by default it has more co
This works:
```{code-cell} ipython3
:id: OZ_BJX0CplNC
:outputId: 48ce004c-536a-44f5-b020-9267825e7e4d
---
colab:
base_uri: https://localhost:8080/
id: OZ_BJX0CplNC
outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c
---
@jit
def f(x):
for i in range(3):
@ -562,9 +648,12 @@ print(f(3))
So does this:
```{code-cell} ipython3
:id: pinVnmRWp6w6
:outputId: e3e6f2f7-ba59-4a98-cdfc-905c91b38ed1
---
colab:
base_uri: https://localhost:8080/
id: pinVnmRWp6w6
outputId: 25e06cf2-474f-4782-af7c-4f5514b64422
---
@jit
def g(x):
y = 0.
@ -580,9 +669,12 @@ print(g(jnp.array([1., 2., 3.])))
But this doesn't, at least by default:
```{code-cell} ipython3
:id: 9z38AIKclRNM
:outputId: 466730dd-df8b-4b80-ac5e-e55b5ea85ec7
---
colab:
base_uri: https://localhost:8080/
id: 9z38AIKclRNM
outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac
---
@jit
def f(x):
if x < 3:
@ -614,9 +706,12 @@ But there's a tradeoff here: if we trace a Python function on a `ShapedArray((),
The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:
```{code-cell} ipython3
:id: -Tzp0H7Bt1Sn
:outputId: aba57a88-d8eb-40b0-ff22-7c266d892b13
---
colab:
base_uri: https://localhost:8080/
id: -Tzp0H7Bt1Sn
outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a
---
def f(x):
if x < 3:
return 3. * x ** 2
@ -633,9 +728,12 @@ print(f(2.))
Here's another example, this time involving a loop:
```{code-cell} ipython3
:id: iwY86_JKvD6b
:outputId: 1ec847ea-df2b-438d-c0a1-fabf7b93b73d
---
colab:
base_uri: https://localhost:8080/
id: iwY86_JKvD6b
outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937
---
def f(x, n):
y = 0.
for i in range(n):
@ -658,9 +756,12 @@ In effect, the loop gets statically unrolled. JAX can also trace at _higher_ le
These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`.
```{code-cell} ipython3
:id: Tqe9uLmUI_Gv
:outputId: fe319758-9959-434c-ab9d-0926e599dbc0
---
colab:
base_uri: https://localhost:8080/
id: Tqe9uLmUI_Gv
outputId: 989be121-dfce-4bb3-c78e-a10829c5f883
---
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
@ -687,9 +788,12 @@ print(good_example_jit(5, 4))
Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
```{code-cell} ipython3
:id: m2ABpRd8K094
:outputId: 64da37a0-aa06-46a3-e975-88c676c5b9fa
---
colab:
base_uri: https://localhost:8080/
id: m2ABpRd8K094
outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3
---
@jit
def f(x):
print(x)
@ -724,9 +828,12 @@ def cond(pred, true_fun, false_fun, operand):
```
```{code-cell} ipython3
:id: SGxz9JOWeiyH
:outputId: b29da06c-037f-4b05-dbd8-ba52ac35a8cf
---
colab:
base_uri: https://localhost:8080/
id: SGxz9JOWeiyH
outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3
---
from jax import lax
operand = jnp.array([0.])
@ -750,9 +857,12 @@ def while_loop(cond_fun, body_fun, init_val):
```
```{code-cell} ipython3
:id: jM-D39a-c436
:outputId: b9c97167-fecf-4559-9ca7-1cb0235d8ad2
---
colab:
base_uri: https://localhost:8080/
id: jM-D39a-c436
outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
---
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
@ -773,9 +883,12 @@ def fori_loop(start, stop, body_fun, init_val):
```
```{code-cell} ipython3
:id: dt3tUpOmeR8u
:outputId: 864f2959-2429-4666-b364-4baf90a57482
---
colab:
base_uri: https://localhost:8080/
id: dt3tUpOmeR8u
outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380
---
init_val = 0
start = 0
stop = 10
@ -834,7 +947,7 @@ There could be tricky situations that arise, like nans that only occur under a `
If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line `env JAX_DEBUG_NANS=True ipython`, then ran this:
+++
+++ {"id": "p6ZtDHPbBa_W"}
```
In [1]: import jax.numpy as jnp
@ -878,11 +991,11 @@ FloatingPointError Traceback (most recent call last)
FloatingPointError: invalid value
```
+++
+++ {"id": "_NCnVt_GBa_W"}
The nan generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jit`, as the example below shows.
+++
+++ {"id": "pf8RF6eiBa_W"}
```
In [4]: from jax import jit
@ -939,7 +1052,7 @@ FloatingPointError Traceback (most recent call last)
... stack trace ...
```
+++
+++ {"id": "6ur2yArDBa_W"}
When this code sees a nan in the output of an `@jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error.
@ -954,9 +1067,12 @@ When this code sees a nan in the output of an `@jit` function, it calls into the
At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!
```{code-cell} ipython3
:id: CNNGtzM3NDkO
:outputId: d1384021-d9bf-450f-a9ae-82024fa5fc1a
---
colab:
base_uri: https://localhost:8080/
id: CNNGtzM3NDkO
outputId: b422bb23-a784-44dc-f8c9-57f3b6c861b8
---
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype
```
@ -998,9 +1114,12 @@ Note that #2-#4 work for _any_ of JAX's configuration options.
We can then confirm that `x64` mode is enabled:
```{code-cell} ipython3
:id: HqGbBa9Rr-2g
:outputId: cd241d63-3d00-4fd7-f9c0-afc6af01ecf4
---
colab:
base_uri: https://localhost:8080/
id: HqGbBa9Rr-2g
outputId: 5aa72952-08cc-4569-9b51-a10311ae9e81
---
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)

View File

@ -14,7 +14,7 @@
# Helpers for indexed updates.
import warnings
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union
@ -178,6 +178,8 @@ def index_add(x: Array,
[1., 1., 1., 7., 7., 7.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_add is deprecated. Use x.at[idx].add(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_add, indices_are_sorted, unique_indices)
@ -228,6 +230,8 @@ def index_mul(x: Array,
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_mul is deprecated. Use x.at[idx].mul(y) instead.",
DeprecationWarning)
return _scatter_update(x, idx, y, lax.scatter_mul,
indices_are_sorted, unique_indices)
@ -276,6 +280,8 @@ def index_min(x: Array,
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_min is deprecated. Use x.at[idx].min(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_min, indices_are_sorted, unique_indices)
@ -323,6 +329,8 @@ def index_max(x: Array,
[1., 1., 1., 6., 6., 6.],
[1., 1., 1., 1., 1., 1.]], dtype=float32)
"""
warnings.warn("index_max is deprecated. Use x.at[idx].max(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter_max, indices_are_sorted, unique_indices)
@ -371,6 +379,8 @@ def index_update(x: Array,
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 6., 6., 6.]], dtype=float32)
"""
warnings.warn("index_update is deprecated. Use x.at[idx].set(y) instead.",
DeprecationWarning)
return _scatter_update(
x, idx, y, lax.scatter, indices_are_sorted, unique_indices)

View File

@ -12,5 +12,6 @@ filterwarnings =
# jax2tf tests due to mix of JAX and TF
ignore:numpy.ufunc size changed
ignore:.*experimental feature
ignore:index.*is deprecated.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"

View File

@ -912,17 +912,6 @@ class UpdateOps(enum.Enum):
return x
def jax_fn(op, indexer, x, y, indices_are_sorted=False,
unique_indices=False):
return {
UpdateOps.UPDATE: ops.index_update,
UpdateOps.ADD: ops.index_add,
UpdateOps.MUL: ops.index_mul,
UpdateOps.MIN: ops.index_min,
UpdateOps.MAX: ops.index_max,
}[op](x, indexer, y, indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices)
def sugar_fn(op, indexer, x, y, indices_are_sorted=False,
unique_indices=False, mode=None):
x = jnp.array(x)
return {
@ -977,7 +966,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, mode=mode)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@ -999,7 +988,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y,
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
unique_indices=True)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@ -1022,7 +1011,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.sugar_fn(
jax_fn = lambda x, y: UpdateOps.jax_fn(
op, indexer, x, y, indices_are_sorted=True, unique_indices=True)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True,
tol=_update_tol(op))
@ -1046,7 +1035,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op))
self._CompileAndCheck(jax_fn, args_maker)
@ -1071,7 +1060,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode):
rng = jtu.rand_default(self.rng())
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y, mode=mode)
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode)
x = rng(shape, dtype)
y = rng(update_shape, update_dtype)
check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)
@ -1092,7 +1081,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y,
jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y,
unique_indices=True)
x = rng(shape, dtype)
y = rng(update_shape, update_dtype)