Merge pull request #14038 from jakevdp:sharp-bits-exceptions

PiperOrigin-RevId: 503196077
This commit is contained in:
jax authors 2023-01-19 10:08:54 -08:00
commit 7085699832
2 changed files with 261 additions and 199 deletions

View File

@ -21,7 +21,7 @@
"\n",
"When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_\"una anima di pura programmazione funzionale\"_](https://www.sscardapane.it/iaml-backup/jax-intro/).\n",
"\n",
"__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU). \n",
"__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).\n",
"JAX works great for many numerical and scientific programs, but __only if they are written with certain constraints__ that we describe below."
]
},
@ -38,13 +38,7 @@
"from jax import lax\n",
"from jax import random\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib as mpl\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import rcParams\n",
"rcParams['image.interpolation'] = 'nearest'\n",
"rcParams['image.cmap'] = 'viridis'\n",
"rcParams['axes.grid'] = False"
"import jax.numpy as jnp"
]
},
{
@ -62,7 +56,7 @@
"id": "2oHigBkW2dPT"
},
"source": [
"JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs. \n",
"JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.\n",
"\n",
"Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions."
]
@ -92,10 +86,10 @@
],
"source": [
"def impure_print_side_effect(x):\n",
" print(\"Executing function\") # This is a side-effect \n",
" print(\"Executing function\") # This is a side-effect\n",
" return x\n",
"\n",
"# The side-effects appear during the first run \n",
"# The side-effects appear during the first run\n",
"print (\"First call: \", jit(impure_print_side_effect)(4.))\n",
"\n",
"# Subsequent runs with parameters of same type and shape may not show the side-effect\n",
@ -256,11 +250,11 @@
"\n",
"# lax.scan\n",
"def func11(arr, extra):\n",
" ones = jnp.ones(arr.shape) \n",
" ones = jnp.ones(arr.shape)\n",
" def body(carry, aelems):\n",
" ae1, ae2 = aelems\n",
" return (carry + ae1 * ae2 + extra, carry)\n",
" return lax.scan(body, 0., (arr, ones)) \n",
" return lax.scan(body, 0., (arr, ones))\n",
"make_jaxpr(func11)(jnp.arange(16), 5.)\n",
"# make_jaxpr(func11)(iter(range(16)), 5.) # throws error\n",
"\n",
@ -338,6 +332,29 @@
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iOscaa_GecEK",
"outputId": "26fdb703-a476-4b7f-97ba-d28997ef750c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exception reporting mode: Minimal\n"
]
}
],
"source": [
"%xmode Minimal"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -350,10 +367,11 @@
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exception '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/jax.ops.html\n"
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n"
]
}
],
@ -361,10 +379,7 @@
"jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n",
"\n",
"# In place update of JAX's array will yield an error!\n",
"try:\n",
" jax_array[1, :] = 1.0\n",
"except Exception as e:\n",
" print(\"Exception {}\".format(e))"
"jax_array[1, :] = 1.0"
]
},
{
@ -407,7 +422,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -443,7 +458,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -496,7 +511,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -563,7 +578,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -576,18 +591,16 @@
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exception index 11 is out of bounds for axis 0 with size 10\n"
"ename": "IndexError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mIndexError\u001b[0m\u001b[0;31m:\u001b[0m index 11 is out of bounds for axis 0 with size 10\n"
]
}
],
"source": [
"try:\n",
" np.arange(10)[11]\n",
"except Exception as e:\n",
" print(\"Exception {}\".format(e))"
"np.arange(10)[11]"
]
},
{
@ -601,7 +614,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -616,7 +629,7 @@
"DeviceArray(9, dtype=int32)"
]
},
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@ -649,7 +662,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -664,7 +677,7 @@
"6"
]
},
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@ -684,28 +697,29 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DFEGcENSsmEc",
"outputId": "08535679-6c1f-4dd9-a414-d8b59310d1ee"
"outputId": "08535679-6c1f-4dd9-a414-d8b59310d1ee",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.\n"
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m sum requires ndarray or scalar arguments, got <class 'list'> at position 0.\n"
]
}
],
"source": [
"try:\n",
" jnp.sum([1, 2, 3])\n",
"except TypeError as e:\n",
" print(f\"TypeError: {e}\")"
"jnp.sum([1, 2, 3])"
]
},
{
@ -721,7 +735,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -736,7 +750,7 @@
"DeviceArray(45, dtype=int32)"
]
},
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@ -760,7 +774,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -772,35 +786,34 @@
{
"data": {
"text/plain": [
"{ lambda ; a b c d e f g h i j.\n",
" let k = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] a\n",
" l = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] b\n",
" m = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] c\n",
" n = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] d\n",
" o = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] e\n",
" p = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] f\n",
" q = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] g\n",
" r = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] h\n",
" s = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] i\n",
" t = broadcast_in_dim[ broadcast_dimensions=( )\n",
" shape=(1,) ] j\n",
" u = concatenate[ dimension=0 ] k l m n o p q r s t\n",
" v = convert_element_type[ new_dtype=int32\n",
" weak_type=False ] u\n",
" w = reduce_sum[ axes=(0,) ] v\n",
" in (w,) }"
"{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]\n",
" j:i32[]. let\n",
" k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a\n",
" l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b\n",
" m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c\n",
" n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d\n",
" o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e\n",
" p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f\n",
" q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g\n",
" r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h\n",
" s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i\n",
" t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j\n",
" u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k\n",
" v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l\n",
" w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m\n",
" x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n\n",
" y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o\n",
" z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p\n",
" ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q\n",
" bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r\n",
" bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s\n",
" bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t\n",
" be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd\n",
" bf:i32[] = reduce_sum[axes=(0,)] be\n",
" in (bf,) }"
]
},
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@ -822,7 +835,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -837,7 +850,7 @@
"DeviceArray(45, dtype=int32)"
]
},
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@ -861,8 +874,8 @@
"id": "O8vvaVt3MRG2"
},
"source": [
"> _If all scientific papers whose results are in doubt because of bad \n",
"> `rand()`s were to disappear from library shelves, there would be a \n",
"> _If all scientific papers whose results are in doubt because of bad\n",
"> `rand()`s were to disappear from library shelves, there would be a\n",
"> gap on each shelf about as big as your fist._ - Numerical Recipes"
]
},
@ -878,7 +891,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -891,9 +904,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0.07022903604194575\n",
"0.11575983097278075\n",
"0.15620432311959775\n"
"0.2726690048900553\n",
"0.6304191979771206\n",
"0.6933648856441533\n"
]
}
],
@ -914,7 +927,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"metadata": {
"id": "7Pyp2ajzfPO2"
},
@ -922,9 +935,9 @@
"source": [
"np.random.seed(0)\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state)\n",
"# print(rng_state)\n",
"# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n",
"# 2481403966, 4042607538, 337614300, ... 614 more numbers..., \n",
"# 2481403966, 4042607538, 337614300, ... 614 more numbers...,\n",
"# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)"
]
},
@ -939,7 +952,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"metadata": {
"id": "GAHaDCYafpAF"
},
@ -947,7 +960,7 @@
"source": [
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state) \n",
"#print(rng_state)\n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n",
"\n",
@ -955,15 +968,15 @@
"for i in range(311):\n",
" _ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"#print(rng_state) \n",
"#print(rng_state)\n",
"# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n",
"# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n",
"\n",
"# Next call iterates the RNG state for a new batch of fake \"entropy\".\n",
"_ = np.random.uniform()\n",
"rng_state = np.random.get_state()\n",
"# print(rng_state) \n",
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484, \n",
"# print(rng_state)\n",
"# --> ('MT19937', array([1499117434, 2949980591, 2242547484,\n",
"# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)"
]
},
@ -1000,7 +1013,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1015,7 +1028,7 @@
"DeviceArray([0, 0], dtype=uint32)"
]
},
"execution_count": 22,
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
@ -1032,14 +1045,14 @@
"id": "XjYyWYNfq0hW"
},
"source": [
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! \n",
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!\n",
"\n",
"Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1078,7 +1091,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1116,7 +1129,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1154,7 +1167,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1201,7 +1214,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1245,7 +1258,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1283,7 +1296,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1322,24 +1335,24 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9z38AIKclRNM",
"outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac"
"outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exception Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\n",
"The problem arose with the `bool` function. \n",
"While tracing the function f at <ipython-input-30-b42e45c0293f>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.\n",
"\n",
"See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
"ename": "ConcretizationTypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at <ipython-input-39-fe5ae3470df9>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
]
}
],
@ -1352,10 +1365,7 @@
" return -4 * x\n",
"\n",
"# This will fail!\n",
"try:\n",
" f(2)\n",
"except Exception as e:\n",
" print(\"Exception {}\".format(e))"
"f(2)"
]
},
{
@ -1381,7 +1391,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1421,7 +1431,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 33,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1436,7 +1446,7 @@
"DeviceArray(5., dtype=float32)"
]
},
"execution_count": 32,
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
@ -1475,7 +1485,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1488,10 +1498,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[4. 4. 4. 4. 4.]\n",
"Exception Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).\n",
"If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n",
"[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n",
"[4. 4. 4. 4. 4.]\n"
]
}
@ -1500,14 +1506,59 @@
"def example_fun(length, val):\n",
" return jnp.ones((length,)) * val\n",
"# un-jit'd works fine\n",
"print(example_fun(5, 4))\n",
"\n",
"print(example_fun(5, 4))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fOlR54XRgHpd",
"outputId": "cf31d798-a4ce-4069-8e3e-8f9631ff4b71",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "TypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n"
]
}
],
"source": [
"bad_example_jit = jit(example_fun)\n",
"# this will fail:\n",
"try:\n",
" print(bad_example_jit(10, 4))\n",
"except Exception as e:\n",
" print(\"Exception {}\".format(e))\n",
"bad_example_jit(10, 4)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kH0lOD4GgFyI",
"outputId": "d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n",
"[4. 4. 4. 4. 4.]\n"
]
}
],
"source": [
"# static_argnums tells JAX to recompile on changes at these argument positions:\n",
"good_example_jit = jit(example_fun, static_argnums=(0,))\n",
"# first compile\n",
@ -1522,14 +1573,14 @@
"id": "MStx_r2oKxpp"
},
"source": [
"`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! \n",
"`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!\n",
"\n",
"Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 37,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1552,7 +1603,7 @@
"DeviceArray(4, dtype=int32, weak_type=True)"
]
},
"execution_count": 34,
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
@ -1603,7 +1654,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 38,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1618,7 +1669,7 @@
"DeviceArray([-1.], dtype=float32)"
]
},
"execution_count": 35,
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
@ -1653,7 +1704,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 39,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1668,7 +1719,7 @@
"DeviceArray(10, dtype=int32, weak_type=True)"
]
},
"execution_count": 36,
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
@ -1700,7 +1751,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 40,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1715,7 +1766,7 @@
"DeviceArray(45, dtype=int32, weak_type=True)"
]
},
"execution_count": 37,
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
@ -1738,10 +1789,10 @@
"#### Summary\n",
"\n",
"$$\n",
"\\begin{array} {r|rr} \n",
"\\begin{array} {r|rr}\n",
"\\hline \\\n",
"\\textrm{construct} \n",
"& \\textrm{jit} \n",
"\\textrm{construct}\n",
"& \\textrm{jit}\n",
"& \\textrm{grad} \\\\\n",
"\\hline \\\n",
"\\textrm{if} & ❌ & ✔ \\\\\n",
@ -1940,7 +1991,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 41,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -1955,7 +2006,7 @@
"dtype('float32')"
]
},
"execution_count": 38,
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
@ -1971,7 +2022,7 @@
"id": "VcvqzobxNPbd"
},
"source": [
"To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__. \n",
"To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__.\n",
"\n",
"There are a few ways to do this:\n",
"\n",
@ -2008,7 +2059,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 42,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -2023,7 +2074,7 @@
"dtype('float32')"
]
},
"execution_count": 39,
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
@ -2061,10 +2112,10 @@
"\n",
" Here is an example of an unsafe cast with differing results between NumPy and JAX:\n",
" ```python\n",
" >>> np.arange(254.0, 258.0).astype('uint8') \n",
" >>> np.arange(254.0, 258.0).astype('uint8')\n",
" array([254, 255, 0, 1], dtype=uint8)\n",
"\n",
" >>> jnp.arange(254.0, 258.0).astype('uint8') \n",
" >>> jnp.arange(254.0, 258.0).astype('uint8')\n",
" DeviceArray([254, 255, 255, 255], dtype=uint8)\n",
" ```\n",
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
@ -2079,7 +2130,6 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Common Gotchas in JAX",
"provenance": [],
"toc_visible": true
@ -2103,6 +2153,9 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"mystnb": {
"render_error_lexer": "none"
}
},
"nbformat": 4,

View File

@ -24,7 +24,7 @@ kernelspec:
When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_"una anima di pura programmazione funzionale"_](https://www.sscardapane.it/iaml-backup/jax-intro/).
__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).
__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).
JAX works great for many numerical and scientific programs, but __only if they are written with certain constraints__ that we describe below.
```{code-cell} ipython3
@ -36,12 +36,6 @@ from jax import lax
from jax import random
import jax
import jax.numpy as jnp
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False
```
+++ {"id": "gX8CZU1g2agP"}
@ -50,7 +44,7 @@ rcParams['axes.grid'] = False
+++ {"id": "2oHigBkW2dPT"}
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.
JAX transformation and compilation are designed to work only on Python functions that are functionally pure: all the input data is passed through the function parameters, all the results are output through the function results. A pure function will always return the same result if invoked with the same inputs.
Here are some examples of functions that are not functionally pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
@ -62,10 +56,10 @@ id: A6R-pdcm4u3v
outputId: 25dcb191-14d4-4620-bcb2-00492d2f24e1
---
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
@ -160,11 +154,11 @@ print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
@ -204,6 +198,16 @@ print(numpy_array)
If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: iOscaa_GecEK
outputId: 26fdb703-a476-4b7f-97ba-d28997ef750c
---
%xmode Minimal
```
```{code-cell} ipython3
---
colab:
@ -215,10 +219,7 @@ tags: [raises-exception]
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
try:
jax_array[1, :] = 1.0
except Exception as e:
print("Exception {}".format(e))
jax_array[1, :] = 1.0
```
+++ {"id": "7mo76sS25Wco"}
@ -312,10 +313,7 @@ id: 5_ZM-BJUypdO
outputId: c9c41ae8-2653-4219-e6dc-09b03faa3b95
tags: [raises-exception]
---
try:
np.arange(10)[11]
except Exception as e:
print("Exception {}".format(e))
np.arange(10)[11]
```
+++ {"id": "eoXrGARWypdR"}
@ -364,11 +362,9 @@ colab:
base_uri: https://localhost:8080/
id: DFEGcENSsmEc
outputId: 08535679-6c1f-4dd9-a414-d8b59310d1ee
tags: [raises-exception]
---
try:
jnp.sum([1, 2, 3])
except TypeError as e:
print(f"TypeError: {e}")
jnp.sum([1, 2, 3])
```
+++ {"id": "QPliLUZztxJt"}
@ -427,8 +423,8 @@ jnp.sum(jnp.array(x))
+++ {"id": "O8vvaVt3MRG2"}
> _If all scientific papers whose results are in doubt because of bad
> `rand()`s were to disappear from library shelves, there would be a
> _If all scientific papers whose results are in doubt because of bad
> `rand()`s were to disappear from library shelves, there would be a
> gap on each shelf about as big as your fist._ - Numerical Recipes
+++ {"id": "Qikt9pPW9L5K"}
@ -457,9 +453,9 @@ Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/
np.random.seed(0)
rng_state = np.random.get_state()
#print(rng_state)
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
```
@ -472,7 +468,7 @@ This pseudorandom state vector is automagically updated behind the scenes every
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
@ -480,15 +476,15 @@ rng_state = np.random.get_state()
for i in range(311):
_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
```
@ -522,7 +518,7 @@ key
+++ {"id": "XjYyWYNfq0hW"}
JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!
JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!
Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:
@ -674,6 +670,7 @@ colab:
base_uri: https://localhost:8080/
id: 9z38AIKclRNM
outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac
tags: [raises-exception]
---
@jit
def f(x):
@ -683,10 +680,7 @@ def f(x):
return -4 * x
# This will fail!
try:
f(2)
except Exception as e:
print("Exception {}".format(e))
f(2)
```
+++ {"id": "pIbr4TVPqtDN"}
@ -766,13 +760,28 @@ def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
```
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: fOlR54XRgHpd
outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71
tags: [raises-exception]
---
bad_example_jit = jit(example_fun)
# this will fail:
try:
print(bad_example_jit(10, 4))
except Exception as e:
print("Exception {}".format(e))
bad_example_jit(10, 4)
```
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: kH0lOD4GgFyI
outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade
---
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
@ -783,7 +792,7 @@ print(good_example_jit(5, 4))
+++ {"id": "MStx_r2oKxpp"}
`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!
`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!
Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
@ -902,10 +911,10 @@ lax.fori_loop(start, stop, body_fun, init_val)
#### Summary
$$
\begin{array} {r|rr}
\begin{array} {r|rr}
\hline \
\textrm{construct}
& \textrm{jit}
\textrm{construct}
& \textrm{jit}
& \textrm{grad} \\
\hline \
\textrm{if} && ✔ \\
@ -1079,7 +1088,7 @@ x.dtype
+++ {"id": "VcvqzobxNPbd"}
To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__.
To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__.
There are a few ways to do this:
@ -1143,10 +1152,10 @@ Many such cases are discussed in detail in the sections above; here we list seve
Here is an example of an unsafe cast with differing results between NumPy and JAX:
```python
>>> np.arange(254.0, 258.0).astype('uint8')
>>> np.arange(254.0, 258.0).astype('uint8')
array([254, 255, 0, 1], dtype=uint8)
>>> jnp.arange(254.0, 258.0).astype('uint8')
>>> jnp.arange(254.0, 258.0).astype('uint8')
DeviceArray([254, 255, 255, 255], dtype=uint8)
```
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.