rocm_jax/docs/notebooks/Common_Gotchas_in_JAX.ipynb
2024-09-20 07:52:33 -07:00

2258 lines
65 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "hjM_sV_AepYf"
},
"source": [
"# 🔪 JAX - The Sharp Bits 🔪\n",
"\n",
"<!--* freshness: { reviewed: '2024-06-03' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4k5PVzEo2uJO"
},
"source": [
"When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_\"una anima di pura programmazione funzionale\"_](https://www.sscardapane.it/iaml-backup/jax-intro/).\n",
"\n",
"__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).\n",
"JAX works great for many numerical and scientific programs, but __only if they are written with certain constraints__ that we describe below."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "GoK_PCxPeYcy"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from jax import grad, jit\n",
"from jax import lax\n",
"from jax import random\n",
"import jax\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gX8CZU1g2agP"
},
"source": [
"## 🔪 Pure functions"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2oHigBkW2dPT"
},
"source": [
"JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.\n",
"\n",
"Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "A6R-pdcm4u3v",
"outputId": "25dcb191-14d4-4620-bcb2-00492d2f24e1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Executing function\n",
"First call: 4.0\n",
"Second call: 5.0\n",
"Executing function\n",
"Third call, different type: [5.]\n"
]
}
],
"source": [
"def impure_print_side_effect(x):\n",
" print(\"Executing function\") # This is a side-effect\n",
" return x\n",
"\n",
"# The side-effects appear during the first run\n",
"print (\"First call: \", jit(impure_print_side_effect)(4.))\n",
"\n",
"# Subsequent runs with parameters of same type and shape may not show the side-effect\n",
"# This is because JAX now invokes a cached compilation of the function\n",
"print (\"Second call: \", jit(impure_print_side_effect)(5.))\n",
"\n",
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
"print (\"Third call, different type: \", jit(impure_print_side_effect)(jnp.array([5.])))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "-N8GhitI2bhD",
"outputId": "fd3624c9-197d-42cb-d97f-c5e0ef885467"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First call: 4.0\n",
"Second call: 5.0\n",
"Third call, different type: [14.]\n"
]
}
],
"source": [
"g = 0.\n",
"def impure_uses_globals(x):\n",
" return x + g\n",
"\n",
"# JAX captures the value of the global during the first run\n",
"print (\"First call: \", jit(impure_uses_globals)(4.))\n",
"g = 10. # Update the global\n",
"\n",
"# Subsequent runs may silently use the cached value of the globals\n",
"print (\"Second call: \", jit(impure_uses_globals)(5.))\n",
"\n",
"# JAX re-runs the Python function when the type or shape of the argument changes\n",
"# This will end up reading the latest value of the global\n",
"print (\"Third call, different type: \", jit(impure_uses_globals)(jnp.array([4.])))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "RTB6iFgu4DL6",
"outputId": "16697bcd-3623-49b1-aabb-c54614aeadea"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First call: 4.0\n",
"Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\n"
]
}
],
"source": [
"g = 0.\n",
"def impure_saves_global(x):\n",
" global g\n",
" g = x\n",
" return x\n",
"\n",
"# JAX runs once the transformed function with special Traced values for arguments\n",
"print (\"First call: \", jit(impure_saves_global)(4.))\n",
"print (\"Saved global: \", g) # Saved global has an internal JAX value"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mlc2pQlp6v-9"
},
"source": [
"A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "TP-Mqf_862C0",
"outputId": "78d55886-54de-483c-e7c4-bafd1d2c7219"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"50.0\n"
]
}
],
"source": [
"def pure_uses_internal_state(x):\n",
" state = dict(even=0, odd=0)\n",
" for i in range(10):\n",
" state['even' if i % 2 == 0 else 'odd'] += x\n",
" return state['even'] + state['odd']\n",
"\n",
"print(jit(pure_uses_internal_state)(5.))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cDpQ5u63Ba_H"
},
"source": [
"It is not recommended to use iterators in any JAX function you want to `jit` or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "w99WXa6bBa_H",
"outputId": "52d885fd-0239-4a08-f5ce-0c38cc008903"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"45\n",
"0\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"from jax import make_jaxpr\n",
"\n",
"# lax.fori_loop\n",
"array = jnp.arange(10)\n",
"print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45\n",
"iterator = iter(range(10))\n",
"print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0\n",
"\n",
"# lax.scan\n",
"def func11(arr, extra):\n",
" ones = jnp.ones(arr.shape)\n",
" def body(carry, aelems):\n",
" ae1, ae2 = aelems\n",
" return (carry + ae1 * ae2 + extra, carry)\n",
" return lax.scan(body, 0., (arr, ones))\n",
"make_jaxpr(func11)(jnp.arange(16), 5.)\n",
"# make_jaxpr(func11)(iter(range(16)), 5.) # throws error\n",
"\n",
"# lax.cond\n",
"array_operand = jnp.array([0.])\n",
"lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)\n",
"iter_operand = iter(range(10))\n",
"# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oBdKtkVW8Lha"
},
"source": [
"## 🔪 In-place updates"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JffAqnEW4JEb"
},
"source": [
"In Numpy you're used to doing this:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "om4xV7_84N9j",
"outputId": "88b0074a-4440-41f6-caa7-031ac2d1a96f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"original array:\n",
"[[0. 0. 0.]\n",
" [0. 0. 0.]\n",
" [0. 0. 0.]]\n",
"updated array:\n",
"[[0. 0. 0.]\n",
" [1. 1. 1.]\n",
" [0. 0. 0.]]\n"
]
}
],
"source": [
"numpy_array = np.zeros((3,3), dtype=np.float32)\n",
"print(\"original array:\")\n",
"print(numpy_array)\n",
"\n",
"# In place, mutating update\n",
"numpy_array[1, :] = 1.0\n",
"print(\"updated array:\")\n",
"print(numpy_array)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "go3L4x3w4-9p"
},
"source": [
"If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "iOscaa_GecEK",
"outputId": "26fdb703-a476-4b7f-97ba-d28997ef750c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exception reporting mode: Minimal\n"
]
}
],
"source": [
"%xmode Minimal"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "2AxeCufq4wAp",
"outputId": "fa4a87ad-1a84-471a-a3c5-a1396c432c85",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n"
]
}
],
"source": [
"jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n",
"\n",
"# In place update of JAX's array will yield an error!\n",
"jax_array[1, :] = 1.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7mo76sS25Wco"
},
"source": [
"Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.\n",
"\n",
"Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hfloZ1QXCS_J"
},
"source": [
"️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X2Xjjvd-l8NL"
},
"source": [
"### Array updates: `x.at[idx].set(y)`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SHLY52KQEiuX"
},
"source": [
"For example, the update above can be written as:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "PBGI-HIeCP_s",
"outputId": "de13f19a-2066-4df1-d503-764c34585529"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"updated array:\n",
" [[0. 0. 0.]\n",
" [1. 1. 1.]\n",
" [0. 0. 0.]]\n"
]
}
],
"source": [
"updated_array = jax_array.at[1, :].set(1.0)\n",
"print(\"updated array:\\n\", updated_array)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zUANAw9sCmgu"
},
"source": [
"JAX's array update functions, unlike their NumPy versions, operate out-of-place. That is, the updated array is returned as a new array and the original array is not modified by the update."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "dbB0UmMhCe8f",
"outputId": "55d46fa1-d0de-4c43-996c-f3bbc87b7175"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"original array unchanged:\n",
" [[0. 0. 0.]\n",
" [0. 0. 0.]\n",
" [0. 0. 0.]]\n"
]
}
],
"source": [
"print(\"original array unchanged:\\n\", jax_array)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eM6MyndXL2NY"
},
"source": [
"However, inside __jit__-compiled code, if the __input value__ `x` of `x.at[idx].set(y)` is not reused, the compiler will optimize the array update to occur _in-place_."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7to-sF8EmC_y"
},
"source": [
"### Array updates with other operations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY5l3tAdDmsJ"
},
"source": [
"Indexed array updates are not limited simply to overwriting values. For example, we can perform indexed addition as follows:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "tsw2svao8FUp",
"outputId": "3c62a3b1-c12d-46f0-da74-791ec4b61e0b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"original array:\n",
"[[1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 1. 1. 1.]]\n",
"new array post-addition:\n",
"[[1. 1. 1. 8. 8. 8.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 8. 8. 8.]\n",
" [1. 1. 1. 1. 1. 1.]\n",
" [1. 1. 1. 8. 8. 8.]]\n"
]
}
],
"source": [
"print(\"original array:\")\n",
"jax_array = jnp.ones((5, 6))\n",
"print(jax_array)\n",
"\n",
"new_jax_array = jax_array.at[::2, 3:].add(7.)\n",
"print(\"new array post-addition:\")\n",
"print(new_jax_array)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sTjJ3WuaDyqU"
},
"source": [
"For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oZ_jE2WAypdL"
},
"source": [
"## 🔪 Out-of-bounds indexing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "btRFwEVzypdN"
},
"source": [
"In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "5_ZM-BJUypdO",
"outputId": "c9c41ae8-2653-4219-e6dc-09b03faa3b95",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "IndexError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mIndexError\u001b[0m\u001b[0;31m:\u001b[0m index 11 is out of bounds for axis 0 with size 10\n"
]
}
],
"source": [
"np.arange(10)[11]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eoXrGARWypdR"
},
"source": [
"However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in `NaN`). When the indexing operation is an array index update (e.g. `index_add` or `scatter`-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or `gather`-like primitives) the index is clamped to the bounds of the array since __something__ must be returned. For example, the last value of the array will be returned from this indexing operation:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "cusaAD0NypdR",
"outputId": "af1708aa-b50b-4da8-f022-7f2fa67030a8"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(9, dtype=int32)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.arange(10)[11]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NAcXJNAcDi_v"
},
"source": [
"If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "-0-MaFddO-xy",
"outputId": "746c4b2b-a90e-4ec9-de56-ed6682d451e5"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(9., dtype=float32)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.arange(10.0).at[11].get()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "g5JEJtIUPBXi",
"outputId": "4a0f6854-1165-47f2-e1ac-5a21fa2b8516"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(nan, dtype=float32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J8uO8yevBa_M"
},
"source": [
"Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n",
"\n",
"Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LwB07Kx5sgHu"
},
"source": [
"## 🔪 Non-array inputs: NumPy vs. JAX\n",
"\n",
"NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "sErQES14sjCG",
"outputId": "601485ff-4cda-48c5-f76c-2789073c4591"
},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.sum([1, 2, 3])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZJ1Wt1bTtrSA"
},
"source": [
"JAX departs from this, generally returning a helpful error:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "DFEGcENSsmEc",
"outputId": "08535679-6c1f-4dd9-a414-d8b59310d1ee",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m sum requires ndarray or scalar arguments, got <class 'list'> at position 0.\n"
]
}
],
"source": [
"jnp.sum([1, 2, 3])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QPliLUZztxJt"
},
"source": [
"This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.\n",
"\n",
"For example, consider the following permissive version of `jnp.sum` that allows list inputs:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "jhe-L_TwsvKd",
"outputId": "ab2ee183-d9ec-45cc-d6be-5009347e1bc5"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(45, dtype=int32)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def permissive_sum(x):\n",
" return jnp.sum(jnp.array(x))\n",
"\n",
"x = list(range(10))\n",
"permissive_sum(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m0XZLP7nuYdE"
},
"source": [
"The output is what we would expect, but this hides potential performance issues under the hood. In JAX's tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the ``permissive_sum`` function above:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "k81u6DQ7vAjQ",
"outputId": "869fc3b9-feda-4aa9-d2e5-5b5107de102d"
},
"outputs": [
{
"data": {
"text/plain": [
"{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]\n",
" j:i32[]. let\n",
" k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a\n",
" l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b\n",
" m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c\n",
" n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d\n",
" o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e\n",
" p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f\n",
" q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g\n",
" r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h\n",
" s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i\n",
" t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j\n",
" u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k\n",
" v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l\n",
" w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m\n",
" x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n\n",
" y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o\n",
" z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p\n",
" ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q\n",
" bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r\n",
" bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s\n",
" bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t\n",
" be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd\n",
" bf:i32[] = reduce_sum[axes=(0,)] be\n",
" in (bf,) }"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"make_jaxpr(permissive_sum)(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C0_dpCfpvCts"
},
"source": [
"Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.\n",
"\n",
"If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "nFf_DydixG8v",
"outputId": "e31b43b3-05f7-4300-fdd2-40e3896f6f8f"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(45, dtype=int32)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.sum(jnp.array(x))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MUycRNh6e50W"
},
"source": [
"## 🔪 Random numbers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O8vvaVt3MRG2"
},
"source": [
"> _If all scientific papers whose results are in doubt because of bad\n",
"> `rand()`s were to disappear from library shelves, there would be a\n",
"> gap on each shelf about as big as your fist._ - Numerical Recipes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qikt9pPW9L5K"
},
"source": [
"### RNGs and state\n",
"You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "rr9FeP41fynt",
"outputId": "df0ceb15-96ec-4a78-e327-c77f7ea3a745"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.2726690048900553\n",
"0.6304191979771206\n",
"0.6933648856441533\n"
]
}
],
"source": [
"print(np.random.random())\n",
"print(np.random.random())\n",
"print(np.random.random())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ORMVVGZJgSVi"
},
"source": [
"Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937}-1$ and at any point can be described by __624 32-bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "7Pyp2ajzfPO2"
},
"outputs": [],
"source": [
"np.random.seed(0)\n",
"rng_state = np.random.get_state()\n",
"# print(rng_state)\n",
"# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
"# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n",
"# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aJIxHVXCiM6m"
},
"source": [
"This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "GAHaDCYafpAF"
},
"outputs": [],
"source": [
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state)\n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
"\n",
"# Let's exhaust the entropy in this PRNG statevector\n",
"for i in range(311):\n",
" _ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state)\n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
"\n",
"# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"# print(rng_state)\n",
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n",
"# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N_mWnleNogps"
},
"source": [
"The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n",
"\n",
"The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5kB state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uvq7nV-j4vKK"
},
"source": [
"### JAX PRNG"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "COjzGBpO4tzL"
},
"source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by a special array element that we call a __key__:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "yPHE7KTWgAWs",
"outputId": "ae8af0ee-f19e-474e-81b6-45e894eb2fc3"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([0, 0], dtype=uint32)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"key = random.key(0)\n",
"key"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XjYyWYNfq0hW"
},
"source": [
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n",
"\n",
"Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "7zUdQMynoE5e",
"outputId": "23a86b72-dfb9-410a-8e68-22b48dc10805"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.20584226]\n",
"[0 0]\n",
"[-0.20584226]\n",
"[0 0]\n"
]
}
],
"source": [
"print(random.normal(key, shape=(1,)))\n",
"print(key)\n",
"# No no no!\n",
"print(random.normal(key, shape=(1,)))\n",
"print(key)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hQN9van8rJgd"
},
"source": [
"Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "ASj0_rSzqgGh",
"outputId": "2f13f249-85d1-47bb-d503-823eca6961aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old key [0 0]\n",
" \\---SPLIT --> new key [4146024105 967050713]\n",
" \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n"
]
}
],
"source": [
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(r\" \\---SPLIT --> new key \", key)\n",
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tqtFVE4MthO3"
},
"source": [
"We propagate the __key__ and make new __subkeys__ whenever we need a new random number:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "jbC34XLor2Ek",
"outputId": "4059a2e2-0205-40bc-ad55-17709d538871"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old key [4146024105 967050713]\n",
" \\---SPLIT --> new key [2384771982 3928867769]\n",
" \\--> new subkey [1278412471 2182328957] --> normal [-0.58665055]\n"
]
}
],
"source": [
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
"print(r\" \\---SPLIT --> new key \", key)\n",
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0KLYUluz3lN3"
},
"source": [
"We can generate more than one __subkey__ at a time:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "lEi08PJ4tfkX",
"outputId": "1f280560-155d-4c04-98e8-c41d72ee5b01"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.37533438]\n",
"[0.98645043]\n",
"[0.14553197]\n"
]
}
],
"source": [
"key, *subkeys = random.split(key, 4)\n",
"for subkey in subkeys:\n",
" print(random.normal(subkey, shape=(1,)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rg4CpMZ8c3ri"
},
"source": [
"## 🔪 Control flow"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "izLTvT24dAq0"
},
"source": [
"### ✔ Python control_flow + autodiff ✔\n",
"\n",
"If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "aAx0T3F8lLtu",
"outputId": "383b7bfa-1634-4d23-8497-49cb9452ca52"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12.0\n",
"-4.0\n"
]
}
],
"source": [
"def f(x):\n",
" if x < 3:\n",
" return 3. * x ** 2\n",
" else:\n",
" return -4 * x\n",
"\n",
"print(grad(f)(2.)) # ok!\n",
"print(grad(f)(4.)) # ok!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hIfPT7WMmZ2H"
},
"source": [
"### Python control flow + JIT\n",
"\n",
"Using control flow with `jit` is more complicated, and by default it has more constraints.\n",
"\n",
"This works:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"id": "OZ_BJX0CplNC",
"outputId": "60c902a2-eba1-49d7-c8c8-2f68616d660c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"24\n"
]
}
],
"source": [
"@jit\n",
"def f(x):\n",
" for i in range(3):\n",
" x = 2 * x\n",
" return x\n",
"\n",
"print(f(3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "22RzeJ4QqAuX"
},
"source": [
"So does this:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"id": "pinVnmRWp6w6",
"outputId": "25e06cf2-474f-4782-af7c-4f5514b64422"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6.0\n"
]
}
],
"source": [
"@jit\n",
"def g(x):\n",
" y = 0.\n",
" for i in range(x.shape[0]):\n",
" y = y + x[i]\n",
" return y\n",
"\n",
"print(g(jnp.array([1., 2., 3.])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TStltU2dqf8A"
},
"source": [
"But this doesn't, at least by default:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "9z38AIKclRNM",
"outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "ConcretizationTypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at <ipython-input-31-fe5ae3470df9>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
]
}
],
"source": [
"@jit\n",
"def f(x):\n",
" if x < 3:\n",
" return 3. * x ** 2\n",
" else:\n",
" return -4 * x\n",
"\n",
"# This will fail!\n",
"f(2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pIbr4TVPqtDN"
},
"source": [
"__What gives!?__\n",
"\n",
"When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n",
"\n",
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
"\n",
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"\n",
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
"\n",
"The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"id": "-Tzp0H7Bt1Sn",
"outputId": "f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12.0\n"
]
}
],
"source": [
"def f(x):\n",
" if x < 3:\n",
" return 3. * x ** 2\n",
" else:\n",
" return -4 * x\n",
"\n",
"f = jit(f, static_argnums=(0,))\n",
"\n",
"print(f(2.))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MHm1hIQAvBVs"
},
"source": [
"Here's another example, this time involving a loop:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "iwY86_JKvD6b",
"outputId": "48f9b51f-bd32-466f-eac1-cd23444ce937"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(5., dtype=float32)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def f(x, n):\n",
" y = 0.\n",
" for i in range(n):\n",
" y = y + x[i]\n",
" return y\n",
"\n",
"f = jit(f, static_argnums=(1,))\n",
"\n",
"f(jnp.array([2., 3., 4.]), 2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSPTOX8DvOeO"
},
"source": [
"In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wWdg8LTYwCW3"
},
"source": [
"️⚠️ **functions with argument-__value__ dependent shapes**\n",
"\n",
"These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"id": "Tqe9uLmUI_Gv",
"outputId": "989be121-dfce-4bb3-c78e-a10829c5f883"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4. 4. 4. 4. 4.]\n"
]
}
],
"source": [
"def example_fun(length, val):\n",
" return jnp.ones((length,)) * val\n",
"# un-jit'd works fine\n",
"print(example_fun(5, 4))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "fOlR54XRgHpd",
"outputId": "cf31d798-a4ce-4069-8e3e-8f9631ff4b71",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n"
]
}
],
"source": [
"bad_example_jit = jit(example_fun)\n",
"# this will fail:\n",
"bad_example_jit(10, 4)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"id": "kH0lOD4GgFyI",
"outputId": "d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n",
"[4. 4. 4. 4. 4.]\n"
]
}
],
"source": [
"# static_argnums tells JAX to recompile on changes at these argument positions:\n",
"good_example_jit = jit(example_fun, static_argnums=(0,))\n",
"# first compile\n",
"print(good_example_jit(10, 4))\n",
"# recompiles\n",
"print(good_example_jit(5, 4))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MStx_r2oKxpp"
},
"source": [
"`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!\n",
"\n",
"Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"id": "m2ABpRd8K094",
"outputId": "4f7ebe17-ade4-4e18-bd8c-4b24087c33c3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\n",
"Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\n"
]
},
{
"data": {
"text/plain": [
"Array(4, dtype=int32, weak_type=True)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@jit\n",
"def f(x):\n",
" print(x)\n",
" y = 2 * x\n",
" print(y)\n",
" return y\n",
"f(2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCDcWG4MnVn-"
},
"source": [
"### Structured control flow primitives\n",
"\n",
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
"\n",
" - `lax.cond` _differentiable_\n",
" - `lax.while_loop` __fwd-mode-differentiable__\n",
" - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.\n",
" - `lax.scan` _differentiable_"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sd9xrLMXeK3A"
},
"source": [
"#### `cond`\n",
"python equivalent:\n",
"\n",
"```python\n",
"def cond(pred, true_fun, false_fun, operand):\n",
" if pred:\n",
" return true_fun(operand)\n",
" else:\n",
" return false_fun(operand)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"id": "SGxz9JOWeiyH",
"outputId": "942a8d0e-5ff6-4702-c499-b3941f529ca3"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([-1.], dtype=float32)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from jax import lax\n",
"\n",
"operand = jnp.array([0.])\n",
"lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n",
"# --> array([1.], dtype=float32)\n",
"lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n",
"# --> array([-1.], dtype=float32)"
]
},
{
"cell_type": "markdown",
"id": "e6622244",
"metadata": {
"id": "lIYdn1woOS1n"
},
"source": [
"`jax.lax` provides two other functions that allow branching on dynamic predicates:\n",
"\n",
"- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is\n",
" like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays\n",
" rather than as functions.\n",
"- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is\n",
" like `lax.cond`, but allows switching between any number of callable choices.\n",
"\n",
"In addition, `jax.numpy` provides several numpy-style interfaces to these functions:\n",
"\n",
"- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with\n",
" three arguments is the numpy-style wrapper of `lax.select`.\n",
"- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)\n",
" is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.\n",
"- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has\n",
" an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather\n",
" than as functions. It is implemented in terms of multiple calls to `lax.select`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xkOFAw24eOMg"
},
"source": [
"#### `while_loop`\n",
"\n",
"python equivalent:\n",
"```\n",
"def while_loop(cond_fun, body_fun, init_val):\n",
" val = init_val\n",
" while cond_fun(val):\n",
" val = body_fun(val)\n",
" return val\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"id": "jM-D39a-c436",
"outputId": "552fe42f-4d32-4e25-c8c2-b951160a3f4e"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(10, dtype=int32, weak_type=True)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"init_val = 0\n",
"cond_fun = lambda x: x < 10\n",
"body_fun = lambda x: x+1\n",
"lax.while_loop(cond_fun, body_fun, init_val)\n",
"# --> array(10, dtype=int32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "apo3n3HAeQY_"
},
"source": [
"#### `fori_loop`\n",
"python equivalent:\n",
"```\n",
"def fori_loop(start, stop, body_fun, init_val):\n",
" val = init_val\n",
" for i in range(start, stop):\n",
" val = body_fun(i, val)\n",
" return val\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"id": "dt3tUpOmeR8u",
"outputId": "7819ca7c-1433-4d85-b542-f6159b0e8380"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(45, dtype=int32, weak_type=True)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"init_val = 0\n",
"start = 0\n",
"stop = 10\n",
"body_fun = lambda i,x: x+i\n",
"lax.fori_loop(start, stop, body_fun, init_val)\n",
"# --> array(45, dtype=int32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SipXS5qiqk8e"
},
"source": [
"#### Summary\n",
"\n",
"$$\n",
"\\begin{array} {r|rr}\n",
"\\hline \\\n",
"\\textrm{construct}\n",
"& \\textrm{jit}\n",
"& \\textrm{grad} \\\\\n",
"\\hline \\\n",
"\\textrm{if} & ❌ & ✔ \\\\\n",
"\\textrm{for} & ✔* & ✔\\\\\n",
"\\textrm{while} & ✔* & ✔\\\\\n",
"\\textrm{lax.cond} & ✔ & ✔\\\\\n",
"\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n",
"\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n",
"\\textrm{lax.scan} & ✔ & ✔\\\\\n",
"\\hline\n",
"\\end{array}\n",
"$$\n",
"\n",
"<center>\n",
"\n",
"$\\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop\n",
"\n",
"</center>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxLsZUyRt_kF"
},
"source": [
"## 🔪 Dynamic shapes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1tKXcAMduDR1"
},
"source": [
"JAX code used within transforms like `jax.jit`, `jax.vmap`, `jax.grad`, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.\n",
"\n",
"For example, if you were implementing your own version of `jnp.nansum`, you might start with something like this:"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"id": "9GIwgvfLujiD"
},
"outputs": [],
"source": [
"def nansum(x):\n",
" mask = ~jnp.isnan(x) # boolean mask selecting non-nan values\n",
" x_without_nans = x[mask]\n",
" return x_without_nans.sum()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "43S7wYAiupGe"
},
"source": [
"Outside JIT and other transforms, this works as expected:"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"id": "ITYoNQEZur4s",
"outputId": "a9a03d25-9c54-43b6-d35e-aea6c448d680"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.0\n"
]
}
],
"source": [
"x = jnp.array([1, 2, jnp.nan, 3, 4])\n",
"print(nansum(x))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guup5n8xvGI-"
},
"source": [
"If you attempt to apply `jax.jit` or another transform to this function, it will error:"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"id": "nms9KjQEvNTz",
"outputId": "d8ae982f-111d-45b6-99f8-37715e2eaab3",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "NonConcreteBooleanIndexError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n"
]
}
],
"source": [
"jax.jit(nansum)(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r2aGyHDkvauu"
},
"source": [
"The problem is that the size of `x_without_nans` is dependent on the values within `x`, which is another way of saying its size is *dynamic*.\n",
"Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means.\n",
"For example, here it is possible to use the three-argument form of `jnp.where` to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"id": "Zbuj7Dg1wnSg",
"outputId": "81a5e356-cd28-4709-b307-07c6254c82de"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.0\n"
]
}
],
"source": [
"@jax.jit\n",
"def nansum_2(x):\n",
" mask = ~jnp.isnan(x) # boolean mask selecting non-nan values\n",
" return jnp.where(mask, x, 0).sum()\n",
"\n",
"print(nansum_2(x))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uGH-jqK7wxTl"
},
"source": [
"Similar tricks can be played in other situations where dynamically-shaped arrays occur."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DKTMw6tRZyK2"
},
"source": [
"## 🔪 NaNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ncS0NI4jZrwy"
},
"source": [
"### Debugging NaNs\n",
"\n",
"If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:\n",
"\n",
"* setting the `JAX_DEBUG_NANS=True` environment variable;\n",
"\n",
"* adding `jax.config.update(\"jax_debug_nans\", True)` near the top of your main file;\n",
"\n",
"* adding `jax.config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`;\n",
"\n",
"This will cause computations to error-out immediately on production of a NaN. Switching this option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jit`. For code under an `@jit`, the output of every `@jit` function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jit` at a time.\n",
"\n",
"There could be tricky situations that arise, like nans that only occur under a `@jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute.\n",
"\n",
"If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. In the example below, we started an ipython repl with the command line `env JAX_DEBUG_NANS=True ipython`, then ran this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p6ZtDHPbBa_W"
},
"source": [
"```\n",
"In [1]: import jax.numpy as jnp\n",
"\n",
"In [2]: jnp.divide(0., 0.)\n",
"---------------------------------------------------------------------------\n",
"FloatingPointError Traceback (most recent call last)\n",
"<ipython-input-2-f2e2c413b437> in <module>()\n",
"----> 1 jnp.divide(0., 0.)\n",
"\n",
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
" 343 return floor_divide(x1, x2)\n",
" 344 else:\n",
"--> 345 return true_divide(x1, x2)\n",
" 346\n",
" 347\n",
"\n",
".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n",
" 332 x1, x2 = _promote_shapes(x1, x2)\n",
" 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n",
"--> 334 lax.convert_element_type(x2, result_dtype))\n",
" 335\n",
" 336\n",
"\n",
".../jax/jax/lax.pyc in div(x, y)\n",
" 244 def div(x, y):\n",
" 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n",
"--> 246 return div_p.bind(x, y)\n",
" 247\n",
" 248 def rem(x, y):\n",
"\n",
"... stack trace ...\n",
"\n",
".../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)\n",
" 103 py_val = device_buffer.to_py()\n",
" 104 if np.any(np.isnan(py_val)):\n",
"--> 105 raise FloatingPointError(\"invalid value\")\n",
" 106 else:\n",
" 107 return Array(device_buffer, *result_shape)\n",
"\n",
"FloatingPointError: invalid value\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_NCnVt_GBa_W"
},
"source": [
"The nan generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jit`, as the example below shows."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pf8RF6eiBa_W"
},
"source": [
"```\n",
"In [4]: from jax import jit\n",
"\n",
"In [5]: @jit\n",
" ...: def f(x, y):\n",
" ...: a = x * y\n",
" ...: b = (x + y) / (x - y)\n",
" ...: c = a + 2\n",
" ...: return a + b * c\n",
" ...:\n",
"\n",
"In [6]: x = jnp.array([2., 0.])\n",
"\n",
"In [7]: y = jnp.array([3., 0.])\n",
"\n",
"In [8]: f(x, y)\n",
"Invalid value encountered in the output of a jit function. Calling the de-optimized version.\n",
"---------------------------------------------------------------------------\n",
"FloatingPointError Traceback (most recent call last)\n",
"<ipython-input-8-811b7ddb3300> in <module>()\n",
"----> 1 f(x, y)\n",
"\n",
" ... stack trace ...\n",
"\n",
"<ipython-input-5-619b39acbaac> in f(x, y)\n",
" 2 def f(x, y):\n",
" 3 a = x * y\n",
"----> 4 b = (x + y) / (x - y)\n",
" 5 c = a + 2\n",
" 6 return a + b * c\n",
"\n",
".../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)\n",
" 343 return floor_divide(x1, x2)\n",
" 344 else:\n",
"--> 345 return true_divide(x1, x2)\n",
" 346\n",
" 347\n",
"\n",
".../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)\n",
" 332 x1, x2 = _promote_shapes(x1, x2)\n",
" 333 return lax.div(lax.convert_element_type(x1, result_dtype),\n",
"--> 334 lax.convert_element_type(x2, result_dtype))\n",
" 335\n",
" 336\n",
"\n",
".../jax/jax/lax.pyc in div(x, y)\n",
" 244 def div(x, y):\n",
" 245 r\"\"\"Elementwise division: :math:`x \\over y`.\"\"\"\n",
"--> 246 return div_p.bind(x, y)\n",
" 247\n",
" 248 def rem(x, y):\n",
"\n",
" ... stack trace ...\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6ur2yArDBa_W"
},
"source": [
"When this code sees a nan in the output of an `@jit` function, it calls into the de-optimized code, so we still get a clear stack trace. And we can run a post-mortem debugger with `%debug` to inspect all the values to figure out the error.\n",
"\n",
"⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!\n",
"\n",
"⚠️ The NaN-checker doesn't work with `pmap`. To debug nans in `pmap` code, one thing to try is replacing `pmap` with `vmap`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YTktlwTTMgFl"
},
"source": [
"## 🔪 Double (64bit) precision\n",
"\n",
"At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"id": "CNNGtzM3NDkO",
"outputId": "b422bb23-a784-44dc-f8c9-57f3b6c861b8"
},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float32')"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcvqzobxNPbd"
},
"source": [
"To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__.\n",
"\n",
"There are a few ways to do this:\n",
"\n",
"1. You can enable 64-bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n",
"\n",
"2. You can manually set the `jax_enable_x64` configuration flag at startup:\n",
"\n",
" ```python\n",
" # again, this only works on startup!\n",
" import jax\n",
" jax.config.update(\"jax_enable_x64\", True)\n",
" ```\n",
"\n",
"3. You can parse command-line flags with `absl.app.run(main)`\n",
"\n",
" ```python\n",
" import jax\n",
" jax.config.config_with_absl()\n",
" ```\n",
"\n",
"4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n",
"\n",
" ```python\n",
" import jax\n",
" if __name__ == '__main__':\n",
" # calls jax.config.config_with_absl() *and* runs absl parsing\n",
" jax.config.parse_flags_with_absl()\n",
" ```\n",
"\n",
"Note that #2-#4 work for _any_ of JAX's configuration options.\n",
"\n",
"We can then confirm that `x64` mode is enabled, for example:\n",
"\n",
"```python\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)\n",
"x.dtype # --> dtype('float64')\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Cks2_gKsXaW"
},
"source": [
"### Caveats\n",
"⚠️ XLA doesn't support 64-bit convolutions on all backends!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WAHjmL0E2XwO"
},
"source": [
"## 🔪 Miscellaneous divergences from NumPy\n",
"\n",
"While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n",
"Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n",
"\n",
"- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n",
"- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n",
"\n",
" Here is an example of an unsafe cast with differing results between NumPy and JAX:\n",
" ```python\n",
" >>> np.arange(254.0, 258.0).astype('uint8')\n",
" array([254, 255, 0, 1], dtype=uint8)\n",
"\n",
" >>> jnp.arange(254.0, 258.0).astype('uint8')\n",
" Array([254, 255, 255, 255], dtype=uint8)\n",
"\n",
" ```\n",
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
"\n",
"\n",
"## Fin.\n",
"\n",
"If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory _advisos_!"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "Common Gotchas in JAX",
"provenance": [],
"toc_visible": true
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2 (v3.8.2:7b3ab5921f, Feb 24 2020, 17:52:18) \n[Clang 6.0 (clang-600.0.57)]"
},
"mystnb": {
"render_error_lexer": "none"
},
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}