mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
1121 lines
49 KiB
Plaintext
1121 lines
49 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7XNMxdTwURqI"
|
|
},
|
|
"source": [
|
|
"# External Callbacks in JAX"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "h6lXo6bSUYGq"
|
|
},
|
|
"source": [
|
|
"This guide is a work-in-progress outlining the uses of various callback functions, which allow JAX code to execute certain commands on the host, even while running under `jit`, `vmap`, `grad`, or another transformation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Xi_nhfpnlmbm"
|
|
},
|
|
"source": [
|
|
"## Why callbacks?\n",
|
|
"\n",
|
|
"A callback routine is a way to perform **host-side** execution of code at runtime.\n",
|
|
"As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.\n",
|
|
"Using a simple Python `print` statement, it looks like this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"id": "lz8rEL1Amb4r",
|
|
"outputId": "bbd37102-19f2-46d2-b794-3d4952c6fe97"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import jax\n",
|
|
"\n",
|
|
"@jax.jit\n",
|
|
"def f(x):\n",
|
|
" y = x + 1\n",
|
|
" print(\"intermediate value: {}\".format(y))\n",
|
|
" return y * 2\n",
|
|
"\n",
|
|
"result = f(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "yEy41sFAmxOp"
|
|
},
|
|
"source": [
|
|
"What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).\n",
|
|
"\n",
|
|
"To print the value at runtime we need a callback, for example `jax.debug.print`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"id": "wFfHmoQxnKDF",
|
|
"outputId": "6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"intermediate value: 3\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def f(x):\n",
|
|
" y = x + 1\n",
|
|
" jax.debug.print(\"intermediate value: {}\", y)\n",
|
|
" return y * 2\n",
|
|
"\n",
|
|
"result = f(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "CvWv3pudn9X5"
|
|
},
|
|
"source": [
|
|
"This works by passing the runtime value represented by `y` back to the host process, where the host can print the value."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "X0vR078znuT-"
|
|
},
|
|
"source": [
|
|
"## Flavors of Callback\n",
|
|
"\n",
|
|
"In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:\n",
|
|
"\n",
|
|
"- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.\n",
|
|
"- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.\n",
|
|
"- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.\n",
|
|
"\n",
|
|
"(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).\n",
|
|
"\n",
|
|
"From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.\n",
|
|
"\n",
|
|
"|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |\n",
|
|
"|-------------------------------------|----|----|----|----|----|----|\n",
|
|
"|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |\n",
|
|
"|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |\n",
|
|
"|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |\n",
|
|
"\n",
|
|
"¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff\n",
|
|
"\n",
|
|
"² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.\n",
|
|
"\n",
|
|
"³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "hE_M8DaPvoym"
|
|
},
|
|
"source": [
|
|
"### Exploring `jax.pure_callback`\n",
|
|
"\n",
|
|
"`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).\n",
|
|
"\n",
|
|
"The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"id": "4lQDzXy6t_-k",
|
|
"outputId": "279e4daf-0540-4eab-f535-d3bcbac74c44"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import jax\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"def f_host(x):\n",
|
|
" # call a numpy (not jax.numpy) operation:\n",
|
|
" return np.sin(x).astype(x.dtype)\n",
|
|
"\n",
|
|
"def f(x):\n",
|
|
" result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
|
|
" return jax.pure_callback(f_host, result_shape, x)\n",
|
|
"\n",
|
|
"x = jnp.arange(5.0)\n",
|
|
"f(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "q7YCIr8qMrDs"
|
|
},
|
|
"source": [
|
|
"Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"id": "bgoZ0fxsuoWV",
|
|
"outputId": "901443bd-5cb4-4923-ce53-6f832ac22ca9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.jit(f)(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"id": "ajBRGWGfupu2",
|
|
"outputId": "b28e31ee-7457-4b92-872b-52d819f53ddf"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.vmap(f)(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"id": "xe7AOGexvC13",
|
|
"outputId": "8fa77977-1f2b-41c5-cc5e-11993ee5aa3e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def body_fun(_, x):\n",
|
|
" return _, f(x)\n",
|
|
"jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tMzAVs2VNj5G"
|
|
},
|
|
"source": [
|
|
"However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"id": "4QAF4VhUu5bb",
|
|
"outputId": "f8a06d02-47e9-4240-8077-d7be81e5a480"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Exception reporting mode: Minimal\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%xmode minimal"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"id": "qUpKPxlOurfY",
|
|
"outputId": "11a665e8-40eb-4b0e-dc2e-a544a25fc57e",
|
|
"tags": [
|
|
"raises-exception"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "ignored",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.grad(f)(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "y9DAibV4Nwpo"
|
|
},
|
|
"source": [
|
|
"For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "LrvdAloMZbIe"
|
|
},
|
|
"source": [
|
|
"By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may elliminate the callback entirely:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"id": "mmFc_zawZrBq",
|
|
"outputId": "a4df7568-3f64-4b2f-9a2c-7adb2e0815e0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"printing something\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def print_something():\n",
|
|
" print('printing something')\n",
|
|
" return np.int32(0)\n",
|
|
"\n",
|
|
"@jax.jit\n",
|
|
"def f1():\n",
|
|
" return jax.pure_callback(print_something, np.int32(0))\n",
|
|
"f1();"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"id": "tTwE4kpmaNei"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def f2():\n",
|
|
" jax.pure_callback(print_something, np.int32(0))\n",
|
|
" return 1.0\n",
|
|
"f2();"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "qfyGYbw4Z5U3"
|
|
},
|
|
"source": [
|
|
"In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.\n",
|
|
"In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "JHcJybr7OEBM"
|
|
},
|
|
"source": [
|
|
"### Exploring `jax.experimental.io_callback`\n",
|
|
"\n",
|
|
"In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n",
|
|
"\n",
|
|
"As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generaing a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"id": "eAg5xIhrOiWV",
|
|
"outputId": "e3cfec21-d843-4852-a49d-69a69fba9fc1"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"generating float32[5]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from jax.experimental import io_callback\n",
|
|
"from functools import partial\n",
|
|
"\n",
|
|
"global_rng = np.random.default_rng(0)\n",
|
|
"\n",
|
|
"def host_side_random_like(x):\n",
|
|
" \"\"\"Generate a random array like x using the global_rng state\"\"\"\n",
|
|
" # We have two side-effects here:\n",
|
|
" # - printing the shape and dtype\n",
|
|
" # - calling global_rng, thus updating its state\n",
|
|
" print(f'generating {x.dtype}{list(x.shape)}')\n",
|
|
" return global_rng.uniform(size=x.shape).astype(x.dtype)\n",
|
|
"\n",
|
|
"@jax.jit\n",
|
|
"def numpy_random_like(x):\n",
|
|
" return io_callback(host_side_random_like, x, x)\n",
|
|
"\n",
|
|
"x = jnp.zeros(5)\n",
|
|
"numpy_random_like(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "mAIF31MlXj33"
|
|
},
|
|
"source": [
|
|
"The `io_callback` is compatible with `vmap` by default:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"id": "NY3o5dG6Vg6u",
|
|
"outputId": "a67a8a98-214e-40ca-ad98-a930cd3db85e"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"generating float32[]\n",
|
|
"generating float32[]\n",
|
|
"generating float32[]\n",
|
|
"generating float32[]\n",
|
|
"generating float32[]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.vmap(numpy_random_like)(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "XXvSeeOXXquZ"
|
|
},
|
|
"source": [
|
|
"Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.\n",
|
|
"\n",
|
|
"If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {
|
|
"id": "3aNmRsDrX3-2",
|
|
"outputId": "a8ff4b77-f4cb-442f-8cfb-ea7251c66274",
|
|
"tags": [
|
|
"raises-exception"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "ignored",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: Cannot `vmap` ordered IO callback.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n",
|
|
"\nThe above exception was the direct cause of the following exception:\n",
|
|
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot `vmap` ordered IO callback.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def numpy_random_like_ordered(x):\n",
|
|
" return io_callback(host_side_random_like, x, x, ordered=True)\n",
|
|
"\n",
|
|
"jax.vmap(numpy_random_like_ordered)(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "fD2FTHlUYAZH"
|
|
},
|
|
"source": [
|
|
"On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {
|
|
"id": "lMVzZlIEWL7F",
|
|
"outputId": "f9741c18-a30d-4d46-b706-8102849286b5"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"generating float32[]\n",
|
|
"generating float32[]\n",
|
|
"generating float32[]\n",
|
|
"generating float32[]\n",
|
|
"generating float32[]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def body_fun(_, x):\n",
|
|
" return _, numpy_random_like_ordered(x)\n",
|
|
"jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "w_sf8mCbbo8K"
|
|
},
|
|
"source": [
|
|
"Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {
|
|
"id": "Cn6_RG4JcKZm",
|
|
"outputId": "336ae5d2-e35b-4fe5-cbfb-14a7aef28c07",
|
|
"tags": [
|
|
"raises-exception"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "ignored",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: IO callbacks do not support JVP.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n",
|
|
"\nThe above exception was the direct cause of the following exception:\n",
|
|
"\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m IO callbacks do not support JVP.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.grad(numpy_random_like)(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "plvfn9lWcKu4"
|
|
},
|
|
"source": [
|
|
"However, if the callback is not dependent on a differentiated variable, it will execute:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {
|
|
"id": "wxgfDmDfb5bx",
|
|
"outputId": "d8c0285c-cd04-4b4d-d15a-1b07f778882d"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"hello\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"@jax.jit\n",
|
|
"def f(x):\n",
|
|
" io_callback(lambda: print('hello'), None)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"jax.grad(f)(1.0);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "STLI40EZcVIY"
|
|
},
|
|
"source": [
|
|
"Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pkkM1ZmqclV-"
|
|
},
|
|
"source": [
|
|
"### Exploring `debug.callback`\n",
|
|
"\n",
|
|
"Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {
|
|
"id": "74TdWyu9eqBa",
|
|
"outputId": "d8551dab-2e61-492e-9ac3-dc3db51b2c18"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"log: 1.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from jax import debug\n",
|
|
"\n",
|
|
"def log_value(x):\n",
|
|
" # This could be an actual logging call; we'll use\n",
|
|
" # print() for demonstration\n",
|
|
" print(\"log:\", x)\n",
|
|
"\n",
|
|
"@jax.jit\n",
|
|
"def f(x):\n",
|
|
" debug.callback(log_value, x)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"f(1.0);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "P848STlsfzmW"
|
|
},
|
|
"source": [
|
|
"The debug callback is compatible with `vmap`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {
|
|
"id": "2sSNsPB-fGVI",
|
|
"outputId": "fff58575-d94c-48fb-b88a-c1c395595fd0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"log: 0.0\n",
|
|
"log: 1.0\n",
|
|
"log: 2.0\n",
|
|
"log: 3.0\n",
|
|
"log: 4.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"x = jnp.arange(5.0)\n",
|
|
"jax.vmap(f)(x);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "VDMacqpXf3La"
|
|
},
|
|
"source": [
|
|
"And is also compatible with `grad` and other autodiff transformations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {
|
|
"id": "wkFRle-tfTDe",
|
|
"outputId": "4e8a81d0-5012-4c51-d843-3fbdc498df31"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"log: 1.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.grad(f)(1.0);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "w8t-SDZ3gRzE"
|
|
},
|
|
"source": [
|
|
"This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "dF7hoWGQUneJ"
|
|
},
|
|
"source": [
|
|
"## Example: `pure_callback` with `custom_jvp`\n",
|
|
"\n",
|
|
"One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).\n",
|
|
"Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.\n",
|
|
"\n",
|
|
"Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.\n",
|
|
"We can start by defining a straightforward `pure_callback`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Ge4fNPZdVSJY"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax\n",
|
|
"import jax.numpy as jnp\n",
|
|
"import scipy.special\n",
|
|
"\n",
|
|
"def jv(v, z):\n",
|
|
" v, z = jnp.asarray(v), jnp.asarray(z)\n",
|
|
"\n",
|
|
" # Require the order v to be integer type: this simplifies\n",
|
|
" # the JVP rule below.\n",
|
|
" assert jnp.issubdtype(v.dtype, jnp.integer)\n",
|
|
"\n",
|
|
" # Promote the input to inexact (float/complex).\n",
|
|
" # Note that jnp.result_type() accounts for the enable_x64 flag.\n",
|
|
" z = z.astype(jnp.result_type(float, z.dtype))\n",
|
|
"\n",
|
|
" # Wrap scipy function to return the expected dtype.\n",
|
|
" _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)\n",
|
|
"\n",
|
|
" # Define the expected shape & dtype of output.\n",
|
|
" result_shape_dtype = jax.ShapeDtypeStruct(\n",
|
|
" shape=jnp.broadcast_shapes(v.shape, z.shape),\n",
|
|
" dtype=z.dtype)\n",
|
|
"\n",
|
|
" # We use vectorize=True because scipy.special.jv handles broadcasted inputs.\n",
|
|
" return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "vyjQj-0QVuoN"
|
|
},
|
|
"source": [
|
|
"This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "f4e46670f4e4"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from functools import partial\n",
|
|
"j1 = partial(jv, 1)\n",
|
|
"z = jnp.arange(5.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "6svImqFHWBwj",
|
|
"outputId": "bc8c778a-6c10-443b-9be2-c0f28e2ac1a9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(j1(z))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "d48eb4f2d48e"
|
|
},
|
|
"source": [
|
|
"Here is the same result with `jit`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "txvRqR9DWGdC",
|
|
"outputId": "d25f3476-23b1-48e4-dda1-3c06d32c3b87"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(jax.jit(j1)(z))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "d861a472d861"
|
|
},
|
|
"source": [
|
|
"And here is the same result again with `vmap`:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "BS-Ve5u_WU0C",
|
|
"outputId": "08cecd1f-6953-4853-e9db-25a03eb5b000"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(jax.vmap(j1)(z))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "SCH2ii_dWXP6"
|
|
},
|
|
"source": [
|
|
"However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "q3qh_4DrWxdQ",
|
|
"outputId": "c46b0bfa-96f3-4629-b9af-a4d4f3ccb870",
|
|
"tags": [
|
|
"raises-exception"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "ignored",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m<ipython-input-5-fde6421013cd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1090\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1091\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1092\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1166\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1167\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreduce_axes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, has_aux, reduce_axes, *primals)\u001b[0m\n\u001b[1;32m 2655\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2656\u001b[0;31m out_primal, out_vjp = ad.vjp(\n\u001b[0m\u001b[1;32m 2657\u001b[0m flat_fun, primals_flat, reduce_axes=reduce_axes)\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux, reduce_axes)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_nounits\u001b[0;34m(fun, pvals, instantiate)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 767\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m<ipython-input-1-9b5b54cddb29>\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3425\u001b[0m \"\"\"\n\u001b[0;32m-> 3426\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtypes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 328\u001b[0m all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfind_top_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfull_lower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mprocess_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 310\u001b[0;31m \u001b[0mprimal_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangent_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangents_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 311\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
|
|
"\nThe above exception was the direct cause of the following exception:\n",
|
|
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m<ipython-input-5-fde6421013cd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m<ipython-input-1-9b5b54cddb29>\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 129\u001b[0m lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m result_avals=tuple(flat_result_avals), vectorized=vectorized)\n",
|
|
"\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpure_callback_jvp_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \"Please use `jax.custom_jvp` to use callbacks while taking gradients.\")\n",
|
|
"\u001b[0;31mValueError\u001b[0m: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients."
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.grad(j1)(z)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "PtYeJ_xUW09v"
|
|
},
|
|
"source": [
|
|
"Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:\n",
|
|
"\n",
|
|
"$$\n",
|
|
"d J_\\nu(z) = \\left\\{\n",
|
|
"\\begin{eqnarray}\n",
|
|
"-J_1(z),\\ &\\nu=0\\\\\n",
|
|
"[J_{\\nu - 1}(z) - J_{\\nu + 1}(z)]/2,\\ &\\nu\\ne 0\n",
|
|
"\\end{eqnarray}\\right.\n",
|
|
"$$\n",
|
|
"\n",
|
|
"The gradient with respect to $\\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.\n",
|
|
"\n",
|
|
"We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "BOVQnt05XvLs"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"jv = jax.custom_jvp(jv)\n",
|
|
"\n",
|
|
"@jv.defjvp\n",
|
|
"def _jv_jvp(primals, tangents):\n",
|
|
" v, z = primals\n",
|
|
" _, z_dot = tangents # Note: v_dot is always 0 because v is integer.\n",
|
|
" jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)\n",
|
|
" djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))\n",
|
|
" return jv(v, z), z_dot * djv_dz"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "W1SxcvQSX44c"
|
|
},
|
|
"source": [
|
|
"Now computing the gradient of our function will work correctly:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "sCGceBs-X8nL",
|
|
"outputId": "71c5589f-f996-44a0-f09a-ca8bb40c167a"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-0.06447162\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"j1 = partial(jv, 1)\n",
|
|
"print(jax.grad(j1)(2.0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "gWQ4phN5YB26"
|
|
},
|
|
"source": [
|
|
"Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "QTe5mRAvYQBh",
|
|
"outputId": "d58ecff3-9419-422a-fd0e-14a7d9cf2cc3"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"DeviceArray(-0.4003078, dtype=float32, weak_type=True)"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"jax.hessian(j1)(2.0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "QEXGxU4uYZii"
|
|
},
|
|
"source": [
|
|
"Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.\n",
|
|
"When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.\n",
|
|
"However, if you are running JAX on a single CPU (where the \"host\" and \"device\" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"jupytext": {
|
|
"formats": "ipynb,md:myst"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"name": "python",
|
|
"version": "3.10.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|