DOC: add section to Sharp Bits discussing implicit list conversions

This commit is contained in:
Jake VanderPlas 2021-05-26 09:03:42 -07:00
parent 62603fde67
commit d844609c6d
2 changed files with 365 additions and 118 deletions

View File

@ -568,6 +568,209 @@
"Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LwB07Kx5sgHu"
},
"source": [
"## 🔪 Non-array inputs: NumPy vs. JAX\n",
"\n",
"NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "sErQES14sjCG",
"outputId": "6bc29168-624a-4d51-eef1-220aeaf49985"
},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 16,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"np.sum([1, 2, 3])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZJ1Wt1bTtrSA"
},
"source": [
"JAX departs from this, generally returning a helpful error:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"id": "DFEGcENSsmEc",
"outputId": "86105261-0aec-41e0-c8a6-16eec437e2a8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.\n"
]
}
],
"source": [
"try:\n",
" jnp.sum([1, 2, 3])\n",
"except TypeError as e:\n",
" print(f\"TypeError: {e}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QPliLUZztxJt"
},
"source": [
"This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.\n",
"\n",
"For example, consider the following permissive version of `jnp.sum` that allows list inputs:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"id": "jhe-L_TwsvKd",
"outputId": "24ef84d4-79e5-42de-f8d4-34e6701c2576"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(45, dtype=int32)"
]
},
"execution_count": 32,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def permissive_sum(x):\n",
" return jnp.sum(jnp.array(x))\n",
"\n",
"x = list(range(10))\n",
"permissive_sum(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m0XZLP7nuYdE"
},
"source": [
"The output is what we would expect, but this hides potential performance issues under the hood. In JAX's tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the ``permissive_sum`` function above:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "k81u6DQ7vAjQ",
"outputId": "52847378-ba8c-4e84-fb8b-dabbaded6a00"
},
"outputs": [
{
"data": {
"text/plain": [
"{ lambda ; a b c d e f g h i j.\n",
" let k = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] a\n",
" l = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] b\n",
" m = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] c\n",
" n = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] d\n",
" o = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] e\n",
" p = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] f\n",
" q = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] g\n",
" r = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] h\n",
" s = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] i\n",
" t = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] j\n",
" u = concatenate[ dimension=0 ] k l m n o p q r s t\n",
" v = convert_element_type[ new_dtype=int32\n",
" weak_type=False ] u\n",
" w = reduce_sum[ axes=(0,) ] v\n",
" in (w,) }"
]
},
"execution_count": 31,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"make_jaxpr(permissive_sum)(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C0_dpCfpvCts"
},
"source": [
"Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.\n",
"\n",
"If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "nFf_DydixG8v",
"outputId": "5e4392b6-37eb-4a24-ce4f-43518e61d9b1"
},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(45, dtype=int32)"
]
},
"execution_count": 33,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jnp.sum(jnp.array(x))"
]
},
{
"cell_type": "markdown",
"metadata": {

View File

@ -55,10 +55,9 @@ JAX transformation and compilation are designed to work only on Python functions
Here are some examples of functions that are not functially pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
```{code-cell} ipython3
---
id: A6R-pdcm4u3v
outputId: 389605df-a4d5-4d4b-8d74-64e9d5d39456
---
:id: A6R-pdcm4u3v
:outputId: 389605df-a4d5-4d4b-8d74-64e9d5d39456
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
@ -75,10 +74,9 @@ print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([
```
```{code-cell} ipython3
---
id: -N8GhitI2bhD
outputId: f16ce914-1387-43b4-9b8a-1d6e3b97b11d
---
:id: -N8GhitI2bhD
:outputId: f16ce914-1387-43b4-9b8a-1d6e3b97b11d
g = 0.
def impure_uses_globals(x):
return x + g
@ -96,10 +94,9 @@ print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.]))
```
```{code-cell} ipython3
---
id: RTB6iFgu4DL6
outputId: e93d2a70-1c18-477a-d69d-d09ed556305a
---
:id: RTB6iFgu4DL6
:outputId: e93d2a70-1c18-477a-d69d-d09ed556305a
g = 0.
def impure_saves_global(x):
global g
@ -116,10 +113,9 @@ print ("Saved global: ", g) # Saved global has an internal JAX value
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
```{code-cell} ipython3
---
id: TP-Mqf_862C0
outputId: 78df2d95-2c6f-41c9-84a9-feda6329e75e
---
:id: TP-Mqf_862C0
:outputId: 78df2d95-2c6f-41c9-84a9-feda6329e75e
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
@ -168,10 +164,9 @@ iter_operand = iter(range(10))
In Numpy you're used to doing this:
```{code-cell} ipython3
---
id: om4xV7_84N9j
outputId: 733f901e-d433-4dc8-b5bb-0c23bf2b1306
---
:id: om4xV7_84N9j
:outputId: 733f901e-d433-4dc8-b5bb-0c23bf2b1306
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
@ -187,11 +182,10 @@ print(numpy_array)
If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)
```{code-cell} ipython3
---
id: 2AxeCufq4wAp
outputId: d5d873db-cee0-49dc-981d-ec852347f7ca
tags: [raises-exception]
---
:id: 2AxeCufq4wAp
:outputId: d5d873db-cee0-49dc-981d-ec852347f7ca
:tags: [raises-exception]
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
@ -226,10 +220,9 @@ from jax.ops import index, index_add, index_update
If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_.
```{code-cell} ipython3
---
id: ygUJT49b7BBk
outputId: 1a3511c4-a480-472f-cccb-5e01620cbe99
---
:id: ygUJT49b7BBk
:outputId: 1a3511c4-a480-472f-cccb-5e01620cbe99
jax_array = jnp.zeros((3, 3))
print("original array:")
print(jax_array)
@ -252,10 +245,9 @@ print(new_jax_array)
If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_.
```{code-cell} ipython3
---
id: tsw2svao8FUp
outputId: 874acd15-a493-4d63-efe4-9f440d5d2a12
---
:id: tsw2svao8FUp
:outputId: 874acd15-a493-4d63-efe4-9f440d5d2a12
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
@ -274,11 +266,10 @@ print(new_jax_array)
In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:
```{code-cell} ipython3
---
id: 5_ZM-BJUypdO
outputId: 461f38cd-9452-4bcc-a44f-a07ddfa12f42
tags: [raises-exception]
---
:id: 5_ZM-BJUypdO
:outputId: 461f38cd-9452-4bcc-a44f-a07ddfa12f42
:tags: [raises-exception]
try:
np.arange(10)[11]
except Exception as e:
@ -290,10 +281,9 @@ except Exception as e:
However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in `NaN`). When the indexing operation is an array index update (e.g. `index_add` or `scatter`-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or `gather`-like primitives) the index is clamped to the bounds of the array since __something__ must be returned. For example, the last value of the array will be returned from this indexing operation:
```{code-cell} ipython3
---
id: cusaAD0NypdR
outputId: 48428ad6-6cde-43ad-c12d-2eb9b9fe59cf
---
:id: cusaAD0NypdR
:outputId: 48428ad6-6cde-43ad-c12d-2eb9b9fe59cf
jnp.arange(10)[11]
```
@ -301,6 +291,79 @@ Note that due to this behavior for index retrieval, functions like `jnp.nanargmi
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
+++ {"id": "LwB07Kx5sgHu"}
## 🔪 Non-array inputs: NumPy vs. JAX
NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:
```{code-cell} ipython3
---
id: sErQES14sjCG
outputId: 6bc29168-624a-4d51-eef1-220aeaf49985
---
np.sum([1, 2, 3])
```
+++ {"id": "ZJ1Wt1bTtrSA"}
JAX departs from this, generally returning a helpful error:
```{code-cell} ipython3
---
id: DFEGcENSsmEc
outputId: 86105261-0aec-41e0-c8a6-16eec437e2a8
---
try:
jnp.sum([1, 2, 3])
except TypeError as e:
print(f"TypeError: {e}")
```
+++ {"id": "QPliLUZztxJt"}
This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.
For example, consider the following permissive version of `jnp.sum` that allows list inputs:
```{code-cell} ipython3
---
id: jhe-L_TwsvKd
outputId: 24ef84d4-79e5-42de-f8d4-34e6701c2576
---
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
```
+++ {"id": "m0XZLP7nuYdE"}
The output is what we would expect, but this hides potential performance issues under the hood. In JAX's tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the ``permissive_sum`` function above:
```{code-cell} ipython3
---
id: k81u6DQ7vAjQ
outputId: 52847378-ba8c-4e84-fb8b-dabbaded6a00
---
make_jaxpr(permissive_sum)(x)
```
+++ {"id": "C0_dpCfpvCts"}
Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.
If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:
```{code-cell} ipython3
---
id: nFf_DydixG8v
outputId: 5e4392b6-37eb-4a24-ce4f-43518e61d9b1
---
jnp.sum(jnp.array(x))
```
+++ {"id": "MUycRNh6e50W"}
## 🔪 Random Numbers
@ -317,10 +380,9 @@ Note also that, as the two behaviors described above are not inverses of each ot
You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
```{code-cell} ipython3
---
id: rr9FeP41fynt
outputId: 849d84cf-04ad-4e8b-9505-a92f6c0d7a39
---
:id: rr9FeP41fynt
:outputId: 849d84cf-04ad-4e8b-9505-a92f6c0d7a39
print(np.random.random())
print(np.random.random())
print(np.random.random())
@ -387,10 +449,9 @@ JAX instead implements an _explicit_ PRNG where entropy production and consumpti
The random state is described by two unsigned-int32s that we call a __key__:
```{code-cell} ipython3
---
id: yPHE7KTWgAWs
outputId: 329e7757-2461-434c-a08c-fde80a2d10c9
---
:id: yPHE7KTWgAWs
:outputId: 329e7757-2461-434c-a08c-fde80a2d10c9
from jax import random
key = random.PRNGKey(0)
key
@ -403,10 +464,9 @@ JAX's random functions produce pseudorandom numbers from the PRNG state, but __d
Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:
```{code-cell} ipython3
---
id: 7zUdQMynoE5e
outputId: 50617324-b887-42f2-a7ff-2a10f92d876a
---
:id: 7zUdQMynoE5e
:outputId: 50617324-b887-42f2-a7ff-2a10f92d876a
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
@ -419,10 +479,9 @@ print(key)
Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:
```{code-cell} ipython3
---
id: ASj0_rSzqgGh
outputId: bcc2ed60-2e41-4ef8-e84f-c724654aa198
---
:id: ASj0_rSzqgGh
:outputId: bcc2ed60-2e41-4ef8-e84f-c724654aa198
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
@ -435,10 +494,9 @@ print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
We propagate the __key__ and make new __subkeys__ whenever we need a new random number:
```{code-cell} ipython3
---
id: jbC34XLor2Ek
outputId: 6834a812-7160-4646-ee19-a246f683905a
---
:id: jbC34XLor2Ek
:outputId: 6834a812-7160-4646-ee19-a246f683905a
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
@ -451,10 +509,9 @@ print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
We can generate more than one __subkey__ at a time:
```{code-cell} ipython3
---
id: lEi08PJ4tfkX
outputId: 3bb513de-8d14-4d37-ae57-51d6f5eaa762
---
:id: lEi08PJ4tfkX
:outputId: 3bb513de-8d14-4d37-ae57-51d6f5eaa762
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
@ -471,10 +528,9 @@ for subkey in subkeys:
If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
```{code-cell} ipython3
---
id: aAx0T3F8lLtu
outputId: 808cfa77-d924-4586-af19-35a8fd7d2238
---
:id: aAx0T3F8lLtu
:outputId: 808cfa77-d924-4586-af19-35a8fd7d2238
def f(x):
if x < 3:
return 3. * x ** 2
@ -494,10 +550,9 @@ Using control flow with `jit` is more complicated, and by default it has more co
This works:
```{code-cell} ipython3
---
id: OZ_BJX0CplNC
outputId: 48ce004c-536a-44f5-b020-9267825e7e4d
---
:id: OZ_BJX0CplNC
:outputId: 48ce004c-536a-44f5-b020-9267825e7e4d
@jit
def f(x):
for i in range(3):
@ -512,10 +567,9 @@ print(f(3))
So does this:
```{code-cell} ipython3
---
id: pinVnmRWp6w6
outputId: e3e6f2f7-ba59-4a98-cdfc-905c91b38ed1
---
:id: pinVnmRWp6w6
:outputId: e3e6f2f7-ba59-4a98-cdfc-905c91b38ed1
@jit
def g(x):
y = 0.
@ -531,10 +585,9 @@ print(g(jnp.array([1., 2., 3.])))
But this doesn't, at least by default:
```{code-cell} ipython3
---
id: 9z38AIKclRNM
outputId: 466730dd-df8b-4b80-ac5e-e55b5ea85ec7
---
:id: 9z38AIKclRNM
:outputId: 466730dd-df8b-4b80-ac5e-e55b5ea85ec7
@jit
def f(x):
if x < 3:
@ -566,10 +619,9 @@ But there's a tradeoff here: if we trace a Python function on a `ShapedArray((),
The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:
```{code-cell} ipython3
---
id: -Tzp0H7Bt1Sn
outputId: aba57a88-d8eb-40b0-ff22-7c266d892b13
---
:id: -Tzp0H7Bt1Sn
:outputId: aba57a88-d8eb-40b0-ff22-7c266d892b13
def f(x):
if x < 3:
return 3. * x ** 2
@ -586,10 +638,9 @@ print(f(2.))
Here's another example, this time involving a loop:
```{code-cell} ipython3
---
id: iwY86_JKvD6b
outputId: 1ec847ea-df2b-438d-c0a1-fabf7b93b73d
---
:id: iwY86_JKvD6b
:outputId: 1ec847ea-df2b-438d-c0a1-fabf7b93b73d
def f(x, n):
y = 0.
for i in range(n):
@ -612,10 +663,9 @@ In effect, the loop gets statically unrolled. JAX can also trace at _higher_ le
These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`.
```{code-cell} ipython3
---
id: Tqe9uLmUI_Gv
outputId: fe319758-9959-434c-ab9d-0926e599dbc0
---
:id: Tqe9uLmUI_Gv
:outputId: fe319758-9959-434c-ab9d-0926e599dbc0
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
@ -642,10 +692,9 @@ print(good_example_jit(5, 4))
Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
```{code-cell} ipython3
---
id: m2ABpRd8K094
outputId: 64da37a0-aa06-46a3-e975-88c676c5b9fa
---
:id: m2ABpRd8K094
:outputId: 64da37a0-aa06-46a3-e975-88c676c5b9fa
@jit
def f(x):
print(x)
@ -680,10 +729,9 @@ def cond(pred, true_fun, false_fun, operand):
```
```{code-cell} ipython3
---
id: SGxz9JOWeiyH
outputId: b29da06c-037f-4b05-dbd8-ba52ac35a8cf
---
:id: SGxz9JOWeiyH
:outputId: b29da06c-037f-4b05-dbd8-ba52ac35a8cf
from jax import lax
operand = jnp.array([0.])
@ -707,10 +755,9 @@ def while_loop(cond_fun, body_fun, init_val):
```
```{code-cell} ipython3
---
id: jM-D39a-c436
outputId: b9c97167-fecf-4559-9ca7-1cb0235d8ad2
---
:id: jM-D39a-c436
:outputId: b9c97167-fecf-4559-9ca7-1cb0235d8ad2
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
@ -731,10 +778,9 @@ def fori_loop(start, stop, body_fun, init_val):
```
```{code-cell} ipython3
---
id: dt3tUpOmeR8u
outputId: 864f2959-2429-4666-b364-4baf90a57482
---
:id: dt3tUpOmeR8u
:outputId: 864f2959-2429-4666-b364-4baf90a57482
init_val = 0
start = 0
stop = 10
@ -906,10 +952,9 @@ When this code sees a nan in the output of an `@jit` function, it calls into the
At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!
```{code-cell} ipython3
---
id: CNNGtzM3NDkO
outputId: d1384021-d9bf-450f-a9ae-82024fa5fc1a
---
:id: CNNGtzM3NDkO
:outputId: d1384021-d9bf-450f-a9ae-82024fa5fc1a
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype
```
@ -951,10 +996,9 @@ Note that #2-#4 work for _any_ of JAX's configuration options.
We can then confirm that `x64` mode is enabled:
```{code-cell} ipython3
---
id: HqGbBa9Rr-2g
outputId: cd241d63-3d00-4fd7-f9c0-afc6af01ecf4
---
:id: HqGbBa9Rr-2g
:outputId: cd241d63-3d00-4fd7-f9c0-afc6af01ecf4
import jax.numpy as jnp
from jax import random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)