mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Move Control Flow text from Sharp Bits into its own tutorial.
This commit is contained in:
parent
3a5ac487a6
commit
e6f6a8af8d
@ -189,8 +189,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like.
|
||||
|
||||
Using `jit` puts constraints on the kind of Python control flow
|
||||
the function can use; see
|
||||
the [Gotchas
|
||||
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
|
||||
the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html)
|
||||
for more.
|
||||
|
||||
### Auto-vectorization with `vmap`
|
||||
@ -369,7 +368,7 @@ Some standouts:
|
||||
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
|
||||
np.float32)).dtype` is `float64` rather than `float32`.
|
||||
1. Some transformations, like `jit`, [constrain how you can use Python control
|
||||
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
|
||||
flow](https://jax.readthedocs.io/en/latest/control-flow.html).
|
||||
You'll always get loud errors if something goes wrong. You might have to use
|
||||
[`jit`'s `static_argnums`
|
||||
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
|
||||
|
361
docs/control-flow.md
Normal file
361
docs/control-flow.md
Normal file
@ -0,0 +1,361 @@
|
||||
---
|
||||
jupytext:
|
||||
formats: md:myst
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: myst
|
||||
format_version: 0.13
|
||||
jupytext_version: 1.16.4
|
||||
kernelspec:
|
||||
display_name: Python 3
|
||||
language: python
|
||||
name: python3
|
||||
---
|
||||
|
||||
+++ {"id": "rg4CpMZ8c3ri"}
|
||||
|
||||
(control-flow)=
|
||||
# Control flow and logical operators with JIT
|
||||
|
||||
<!--* freshness: { reviewed: '2024-11-11' } *-->
|
||||
|
||||
When executing eagerly (outside of `jit`), JAX code works with Python control flow and logical operators just like Numpy code. Using control flow and logical operators with `jit` is more complicated.
|
||||
|
||||
In a nutshell, Python control flow and logical operators are evaluated at JIT compile time, such that the compiled function represents a single path through the [control flow graph](https://en.wikipedia.org/wiki/Control-flow_graph) (logical operators affect the path via short-circuiting). If the path depends on the values of the inputs, the function (by default) cannot be JIT compiled. The path may depend on the shape or dtype of the inputs, and the function is re-compiled every time it is called on an input with a new shape or dtype.
|
||||
|
||||
```{code-cell}
|
||||
from jax import grad, jit
|
||||
import jax.numpy as jnp
|
||||
```
|
||||
|
||||
For example, this works:
|
||||
|
||||
```{code-cell}
|
||||
:id: OZ_BJX0CplNC
|
||||
:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
for i in range(3):
|
||||
x = 2 * x
|
||||
return x
|
||||
|
||||
print(f(3))
|
||||
```
|
||||
|
||||
+++ {"id": "22RzeJ4QqAuX"}
|
||||
|
||||
So does this:
|
||||
|
||||
```{code-cell}
|
||||
:id: pinVnmRWp6w6
|
||||
:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422
|
||||
|
||||
@jit
|
||||
def g(x):
|
||||
y = 0.
|
||||
for i in range(x.shape[0]):
|
||||
y = y + x[i]
|
||||
return y
|
||||
|
||||
print(g(jnp.array([1., 2., 3.])))
|
||||
```
|
||||
|
||||
+++ {"id": "TStltU2dqf8A"}
|
||||
|
||||
But this doesn't, at least by default:
|
||||
|
||||
```{code-cell}
|
||||
:id: 9z38AIKclRNM
|
||||
:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac
|
||||
:tags: [raises-exception]
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
else:
|
||||
return -4 * x
|
||||
|
||||
# This will fail!
|
||||
f(2)
|
||||
```
|
||||
|
||||
Neither does this:
|
||||
|
||||
```{code-cell}
|
||||
:tags: [raises-exception]
|
||||
|
||||
@jit
|
||||
def g(x):
|
||||
return (x > 0) and (x < 3)
|
||||
|
||||
# This will fail!
|
||||
g(2)
|
||||
```
|
||||
|
||||
+++ {"id": "pIbr4TVPqtDN"}
|
||||
|
||||
__What gives!?__
|
||||
|
||||
When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.
|
||||
|
||||
For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
|
||||
|
||||
To get a view of your Python code that is valid for many different argument values, JAX traces it with the `ShapedArray` abstraction as input, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
|
||||
|
||||
But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.
|
||||
|
||||
The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnames` (or `static_argnums`) argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:
|
||||
|
||||
```{code-cell}
|
||||
:id: -Tzp0H7Bt1Sn
|
||||
:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a
|
||||
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
else:
|
||||
return -4 * x
|
||||
|
||||
f = jit(f, static_argnames='x')
|
||||
|
||||
print(f(2.))
|
||||
```
|
||||
|
||||
+++ {"id": "MHm1hIQAvBVs"}
|
||||
|
||||
Here's another example, this time involving a loop:
|
||||
|
||||
```{code-cell}
|
||||
:id: iwY86_JKvD6b
|
||||
:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937
|
||||
|
||||
def f(x, n):
|
||||
y = 0.
|
||||
for i in range(n):
|
||||
y = y + x[i]
|
||||
return y
|
||||
|
||||
f = jit(f, static_argnames='n')
|
||||
|
||||
f(jnp.array([2., 3., 4.]), 2)
|
||||
```
|
||||
|
||||
+++ {"id": "nSPTOX8DvOeO"}
|
||||
|
||||
In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation
|
||||
|
||||
+++ {"id": "wWdg8LTYwCW3"}
|
||||
|
||||
️⚠️ **functions with argument-__value__ dependent shapes**
|
||||
|
||||
These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`.
|
||||
|
||||
```{code-cell}
|
||||
:id: Tqe9uLmUI_Gv
|
||||
:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883
|
||||
|
||||
def example_fun(length, val):
|
||||
return jnp.ones((length,)) * val
|
||||
# un-jit'd works fine
|
||||
print(example_fun(5, 4))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: fOlR54XRgHpd
|
||||
:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71
|
||||
:tags: [raises-exception]
|
||||
|
||||
bad_example_jit = jit(example_fun)
|
||||
# this will fail:
|
||||
bad_example_jit(10, 4)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: kH0lOD4GgFyI
|
||||
:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade
|
||||
|
||||
# static_argnames tells JAX to recompile on changes at these argument positions:
|
||||
good_example_jit = jit(example_fun, static_argnames='length')
|
||||
# first compile
|
||||
print(good_example_jit(10, 4))
|
||||
# recompiles
|
||||
print(good_example_jit(5, 4))
|
||||
```
|
||||
|
||||
+++ {"id": "MStx_r2oKxpp"}
|
||||
|
||||
`static_argnames` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!
|
||||
|
||||
Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
|
||||
|
||||
```{code-cell}
|
||||
:id: m2ABpRd8K094
|
||||
:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
print(x)
|
||||
y = 2 * x
|
||||
print(y)
|
||||
return y
|
||||
f(2)
|
||||
```
|
||||
|
||||
+++ {"id": "uCDcWG4MnVn-"}
|
||||
|
||||
## Structured control flow primitives
|
||||
|
||||
There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:
|
||||
|
||||
- `lax.cond` _differentiable_
|
||||
- `lax.while_loop` __fwd-mode-differentiable__
|
||||
- `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.
|
||||
- `lax.scan` _differentiable_
|
||||
|
||||
+++ {"id": "Sd9xrLMXeK3A"}
|
||||
|
||||
### `cond`
|
||||
python equivalent:
|
||||
|
||||
```python
|
||||
def cond(pred, true_fun, false_fun, operand):
|
||||
if pred:
|
||||
return true_fun(operand)
|
||||
else:
|
||||
return false_fun(operand)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: SGxz9JOWeiyH
|
||||
:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3
|
||||
|
||||
from jax import lax
|
||||
|
||||
operand = jnp.array([0.])
|
||||
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
|
||||
# --> array([1.], dtype=float32)
|
||||
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
|
||||
# --> array([-1.], dtype=float32)
|
||||
```
|
||||
|
||||
+++ {"id": "lIYdn1woOS1n"}
|
||||
|
||||
`jax.lax` provides two other functions that allow branching on dynamic predicates:
|
||||
|
||||
- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is
|
||||
like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays
|
||||
rather than as functions.
|
||||
- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is
|
||||
like `lax.cond`, but allows switching between any number of callable choices.
|
||||
|
||||
In addition, `jax.numpy` provides several numpy-style interfaces to these functions:
|
||||
|
||||
- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with
|
||||
three arguments is the numpy-style wrapper of `lax.select`.
|
||||
- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)
|
||||
is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.
|
||||
- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has
|
||||
an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather
|
||||
than as functions. It is implemented in terms of multiple calls to `lax.select`.
|
||||
|
||||
+++ {"id": "xkOFAw24eOMg"}
|
||||
|
||||
### `while_loop`
|
||||
|
||||
python equivalent:
|
||||
```
|
||||
def while_loop(cond_fun, body_fun, init_val):
|
||||
val = init_val
|
||||
while cond_fun(val):
|
||||
val = body_fun(val)
|
||||
return val
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: jM-D39a-c436
|
||||
:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
|
||||
|
||||
init_val = 0
|
||||
cond_fun = lambda x: x < 10
|
||||
body_fun = lambda x: x+1
|
||||
lax.while_loop(cond_fun, body_fun, init_val)
|
||||
# --> array(10, dtype=int32)
|
||||
```
|
||||
|
||||
+++ {"id": "apo3n3HAeQY_"}
|
||||
|
||||
### `fori_loop`
|
||||
python equivalent:
|
||||
```
|
||||
def fori_loop(start, stop, body_fun, init_val):
|
||||
val = init_val
|
||||
for i in range(start, stop):
|
||||
val = body_fun(i, val)
|
||||
return val
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:id: dt3tUpOmeR8u
|
||||
:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380
|
||||
|
||||
init_val = 0
|
||||
start = 0
|
||||
stop = 10
|
||||
body_fun = lambda i,x: x+i
|
||||
lax.fori_loop(start, stop, body_fun, init_val)
|
||||
# --> array(45, dtype=int32)
|
||||
```
|
||||
|
||||
+++ {"id": "SipXS5qiqk8e"}
|
||||
|
||||
### Summary
|
||||
|
||||
$$
|
||||
\begin{array} {r|rr}
|
||||
\hline \
|
||||
\textrm{construct}
|
||||
& \textrm{jit}
|
||||
& \textrm{grad} \\
|
||||
\hline \
|
||||
\textrm{if} & ❌ & ✔ \\
|
||||
\textrm{for} & ✔* & ✔\\
|
||||
\textrm{while} & ✔* & ✔\\
|
||||
\textrm{lax.cond} & ✔ & ✔\\
|
||||
\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\
|
||||
\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\
|
||||
\textrm{lax.scan} & ✔ & ✔\\
|
||||
\hline
|
||||
\end{array}
|
||||
$$
|
||||
|
||||
<center>
|
||||
|
||||
$\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop
|
||||
|
||||
</center>
|
||||
|
||||
## Logical operators
|
||||
|
||||
`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`.
|
||||
|
||||
+++ {"id": "izLTvT24dAq0"}
|
||||
|
||||
## Python control flow + autodiff
|
||||
|
||||
Remember that the above constraints on control flow and logical operators are relevant only with `jit`. If you just want to apply `grad` to your python functions, without `jit`, you can use regular Python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
|
||||
|
||||
```{code-cell}
|
||||
:id: aAx0T3F8lLtu
|
||||
:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52
|
||||
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
else:
|
||||
return -4 * x
|
||||
|
||||
print(grad(f)(2.)) # ok!
|
||||
print(grad(f)(4.)) # ok!
|
||||
```
|
@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of
|
||||
Python control flow such as ``for`` loops. For a handful of loop iterations,
|
||||
Python is OK, but if you need *many* loop iterations, you should rewrite your
|
||||
code to make use of JAX's
|
||||
`structured control flow primitives <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Structured-control-flow-primitives>`_
|
||||
`structured control flow primitives <https://jax.readthedocs.io/en/latest/control-flow.html#Structured-control-flow-primitives>`_
|
||||
(such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can
|
||||
still use ``jit`` decorated functions *inside* the loop).
|
||||
|
||||
|
@ -170,7 +170,7 @@ jax.jit(g)(10, 20) # Raises an error
|
||||
The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.
|
||||
Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as
|
||||
`shape` or `dtype`, and not via their values.
|
||||
For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
|
||||
For more detail on the interaction between Python control flow and JAX, see {ref}`control-flow`.
|
||||
|
||||
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical.
|
||||
In that case, you can consider JIT-compiling only part of the function.
|
||||
|
@ -34,7 +34,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from jax import grad, jit\n",
|
||||
"from jax import jit\n",
|
||||
"from jax import lax\n",
|
||||
"from jax import random\n",
|
||||
"import jax\n",
|
||||
@ -1175,610 +1175,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1dc0e6b2",
|
||||
"metadata": {
|
||||
"id": "rg4CpMZ8c3ri"
|
||||
},
|
||||
"source": [
|
||||
"## 🔪 Control flow"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "izLTvT24dAq0"
|
||||
},
|
||||
"source": [
|
||||
"### ✔ Python control_flow + autodiff ✔\n",
|
||||
"## 🔪 Control flow\n",
|
||||
"\n",
|
||||
"If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"id": "aAx0T3F8lLtu",
|
||||
"outputId": "383b7bfa-1634-4d23-8497-49cb9452ca52"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"12.0\n",
|
||||
"-4.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" if x < 3:\n",
|
||||
" return 3. * x ** 2\n",
|
||||
" else:\n",
|
||||
" return -4 * x\n",
|
||||
"\n",
|
||||
"print(grad(f)(2.)) # ok!\n",
|
||||
"print(grad(f)(4.)) # ok!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hIfPT7WMmZ2H"
|
||||
},
|
||||
"source": [
|
||||
"### Python control flow + JIT\n",
|
||||
"\n",
|
||||
"Using control flow with `jit` is more complicated, and by default it has more constraints.\n",
|
||||
"\n",
|
||||
"This works:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"id": "OZ_BJX0CplNC",
|
||||
"outputId": "60c902a2-eba1-49d7-c8c8-2f68616d660c"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"24\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def f(x):\n",
|
||||
" for i in range(3):\n",
|
||||
" x = 2 * x\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"print(f(3))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "22RzeJ4QqAuX"
|
||||
},
|
||||
"source": [
|
||||
"So does this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"id": "pinVnmRWp6w6",
|
||||
"outputId": "25e06cf2-474f-4782-af7c-4f5514b64422"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"6.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def g(x):\n",
|
||||
" y = 0.\n",
|
||||
" for i in range(x.shape[0]):\n",
|
||||
" y = y + x[i]\n",
|
||||
" return y\n",
|
||||
"\n",
|
||||
"print(g(jnp.array([1., 2., 3.])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "TStltU2dqf8A"
|
||||
},
|
||||
"source": [
|
||||
"But this doesn't, at least by default:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"id": "9z38AIKclRNM",
|
||||
"outputId": "38dd2075-92fc-4b81-fee0-b9dff8da1fac",
|
||||
"tags": [
|
||||
"raises-exception"
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ConcretizationTypeError",
|
||||
"evalue": "ignored",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at <ipython-input-31-fe5ae3470df9>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def f(x):\n",
|
||||
" if x < 3:\n",
|
||||
" return 3. * x ** 2\n",
|
||||
" else:\n",
|
||||
" return -4 * x\n",
|
||||
"\n",
|
||||
"# This will fail!\n",
|
||||
"f(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "pIbr4TVPqtDN"
|
||||
},
|
||||
"source": [
|
||||
"__What gives!?__\n",
|
||||
"\n",
|
||||
"When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n",
|
||||
"\n",
|
||||
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
|
||||
"\n",
|
||||
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
|
||||
"\n",
|
||||
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
|
||||
"\n",
|
||||
"But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n",
|
||||
"\n",
|
||||
"The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"id": "-Tzp0H7Bt1Sn",
|
||||
"outputId": "f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"12.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def f(x):\n",
|
||||
" if x < 3:\n",
|
||||
" return 3. * x ** 2\n",
|
||||
" else:\n",
|
||||
" return -4 * x\n",
|
||||
"\n",
|
||||
"f = jit(f, static_argnums=(0,))\n",
|
||||
"\n",
|
||||
"print(f(2.))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "MHm1hIQAvBVs"
|
||||
},
|
||||
"source": [
|
||||
"Here's another example, this time involving a loop:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"id": "iwY86_JKvD6b",
|
||||
"outputId": "48f9b51f-bd32-466f-eac1-cd23444ce937"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array(5., dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def f(x, n):\n",
|
||||
" y = 0.\n",
|
||||
" for i in range(n):\n",
|
||||
" y = y + x[i]\n",
|
||||
" return y\n",
|
||||
"\n",
|
||||
"f = jit(f, static_argnums=(1,))\n",
|
||||
"\n",
|
||||
"f(jnp.array([2., 3., 4.]), 2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nSPTOX8DvOeO"
|
||||
},
|
||||
"source": [
|
||||
"In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "wWdg8LTYwCW3"
|
||||
},
|
||||
"source": [
|
||||
"️⚠️ **functions with argument-__value__ dependent shapes**\n",
|
||||
"\n",
|
||||
"These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"id": "Tqe9uLmUI_Gv",
|
||||
"outputId": "989be121-dfce-4bb3-c78e-a10829c5f883"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[4. 4. 4. 4. 4.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def example_fun(length, val):\n",
|
||||
" return jnp.ones((length,)) * val\n",
|
||||
"# un-jit'd works fine\n",
|
||||
"print(example_fun(5, 4))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"id": "fOlR54XRgHpd",
|
||||
"outputId": "cf31d798-a4ce-4069-8e3e-8f9631ff4b71",
|
||||
"tags": [
|
||||
"raises-exception"
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "TypeError",
|
||||
"evalue": "ignored",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>,).\nIf using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"bad_example_jit = jit(example_fun)\n",
|
||||
"# this will fail:\n",
|
||||
"bad_example_jit(10, 4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"id": "kH0lOD4GgFyI",
|
||||
"outputId": "d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n",
|
||||
"[4. 4. 4. 4. 4.]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# static_argnums tells JAX to recompile on changes at these argument positions:\n",
|
||||
"good_example_jit = jit(example_fun, static_argnums=(0,))\n",
|
||||
"# first compile\n",
|
||||
"print(good_example_jit(10, 4))\n",
|
||||
"# recompiles\n",
|
||||
"print(good_example_jit(5, 4))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "MStx_r2oKxpp"
|
||||
},
|
||||
"source": [
|
||||
"`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!\n",
|
||||
"\n",
|
||||
"Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"id": "m2ABpRd8K094",
|
||||
"outputId": "4f7ebe17-ade4-4e18-bd8c-4b24087c33c3"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\n",
|
||||
"Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array(4, dtype=int32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jit\n",
|
||||
"def f(x):\n",
|
||||
" print(x)\n",
|
||||
" y = 2 * x\n",
|
||||
" print(y)\n",
|
||||
" return y\n",
|
||||
"f(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "uCDcWG4MnVn-"
|
||||
},
|
||||
"source": [
|
||||
"### Structured control flow primitives\n",
|
||||
"\n",
|
||||
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
|
||||
"\n",
|
||||
" - `lax.cond` _differentiable_\n",
|
||||
" - `lax.while_loop` __fwd-mode-differentiable__\n",
|
||||
" - `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.\n",
|
||||
" - `lax.scan` _differentiable_"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Sd9xrLMXeK3A"
|
||||
},
|
||||
"source": [
|
||||
"#### `cond`\n",
|
||||
"python equivalent:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"def cond(pred, true_fun, false_fun, operand):\n",
|
||||
" if pred:\n",
|
||||
" return true_fun(operand)\n",
|
||||
" else:\n",
|
||||
" return false_fun(operand)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"id": "SGxz9JOWeiyH",
|
||||
"outputId": "942a8d0e-5ff6-4702-c499-b3941f529ca3"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([-1.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from jax import lax\n",
|
||||
"\n",
|
||||
"operand = jnp.array([0.])\n",
|
||||
"lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n",
|
||||
"# --> array([1.], dtype=float32)\n",
|
||||
"lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n",
|
||||
"# --> array([-1.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e6622244",
|
||||
"metadata": {
|
||||
"id": "lIYdn1woOS1n"
|
||||
},
|
||||
"source": [
|
||||
"`jax.lax` provides two other functions that allow branching on dynamic predicates:\n",
|
||||
"\n",
|
||||
"- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is\n",
|
||||
" like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays\n",
|
||||
" rather than as functions.\n",
|
||||
"- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is\n",
|
||||
" like `lax.cond`, but allows switching between any number of callable choices.\n",
|
||||
"\n",
|
||||
"In addition, `jax.numpy` provides several numpy-style interfaces to these functions:\n",
|
||||
"\n",
|
||||
"- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with\n",
|
||||
" three arguments is the numpy-style wrapper of `lax.select`.\n",
|
||||
"- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)\n",
|
||||
" is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.\n",
|
||||
"- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has\n",
|
||||
" an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather\n",
|
||||
" than as functions. It is implemented in terms of multiple calls to `lax.select`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "xkOFAw24eOMg"
|
||||
},
|
||||
"source": [
|
||||
"#### `while_loop`\n",
|
||||
"\n",
|
||||
"python equivalent:\n",
|
||||
"```\n",
|
||||
"def while_loop(cond_fun, body_fun, init_val):\n",
|
||||
" val = init_val\n",
|
||||
" while cond_fun(val):\n",
|
||||
" val = body_fun(val)\n",
|
||||
" return val\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"id": "jM-D39a-c436",
|
||||
"outputId": "552fe42f-4d32-4e25-c8c2-b951160a3f4e"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array(10, dtype=int32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"init_val = 0\n",
|
||||
"cond_fun = lambda x: x < 10\n",
|
||||
"body_fun = lambda x: x+1\n",
|
||||
"lax.while_loop(cond_fun, body_fun, init_val)\n",
|
||||
"# --> array(10, dtype=int32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "apo3n3HAeQY_"
|
||||
},
|
||||
"source": [
|
||||
"#### `fori_loop`\n",
|
||||
"python equivalent:\n",
|
||||
"```\n",
|
||||
"def fori_loop(start, stop, body_fun, init_val):\n",
|
||||
" val = init_val\n",
|
||||
" for i in range(start, stop):\n",
|
||||
" val = body_fun(i, val)\n",
|
||||
" return val\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"id": "dt3tUpOmeR8u",
|
||||
"outputId": "7819ca7c-1433-4d85-b542-f6159b0e8380"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array(45, dtype=int32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"init_val = 0\n",
|
||||
"start = 0\n",
|
||||
"stop = 10\n",
|
||||
"body_fun = lambda i,x: x+i\n",
|
||||
"lax.fori_loop(start, stop, body_fun, init_val)\n",
|
||||
"# --> array(45, dtype=int32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SipXS5qiqk8e"
|
||||
},
|
||||
"source": [
|
||||
"#### Summary\n",
|
||||
"\n",
|
||||
"$$\n",
|
||||
"\\begin{array} {r|rr}\n",
|
||||
"\\hline \\\n",
|
||||
"\\textrm{construct}\n",
|
||||
"& \\textrm{jit}\n",
|
||||
"& \\textrm{grad} \\\\\n",
|
||||
"\\hline \\\n",
|
||||
"\\textrm{if} & ❌ & ✔ \\\\\n",
|
||||
"\\textrm{for} & ✔* & ✔\\\\\n",
|
||||
"\\textrm{while} & ✔* & ✔\\\\\n",
|
||||
"\\textrm{lax.cond} & ✔ & ✔\\\\\n",
|
||||
"\\textrm{lax.while_loop} & ✔ & \\textrm{fwd}\\\\\n",
|
||||
"\\textrm{lax.fori_loop} & ✔ & \\textrm{fwd}\\\\\n",
|
||||
"\\textrm{lax.scan} & ✔ & ✔\\\\\n",
|
||||
"\\hline\n",
|
||||
"\\end{array}\n",
|
||||
"$$\n",
|
||||
"\n",
|
||||
"<center>\n",
|
||||
"\n",
|
||||
"$\\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop\n",
|
||||
"\n",
|
||||
"</center>"
|
||||
"Moved to {ref}`control-flow`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -2209,6 +1613,9 @@
|
||||
" ```\n",
|
||||
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
|
||||
"\n",
|
||||
"## 🔪 Sharp bits covered in tutorials\n",
|
||||
"- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.\n",
|
||||
"- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions.\n",
|
||||
"\n",
|
||||
"## Fin.\n",
|
||||
"\n",
|
||||
|
@ -31,7 +31,7 @@ JAX works great for many numerical and scientific programs, but __only if they a
|
||||
:id: GoK_PCxPeYcy
|
||||
|
||||
import numpy as np
|
||||
from jax import grad, jit
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import random
|
||||
import jax
|
||||
@ -536,328 +536,7 @@ for subkey in subkeys:
|
||||
|
||||
## 🔪 Control flow
|
||||
|
||||
+++ {"id": "izLTvT24dAq0"}
|
||||
|
||||
### ✔ Python control_flow + autodiff ✔
|
||||
|
||||
If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: aAx0T3F8lLtu
|
||||
:outputId: 383b7bfa-1634-4d23-8497-49cb9452ca52
|
||||
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
else:
|
||||
return -4 * x
|
||||
|
||||
print(grad(f)(2.)) # ok!
|
||||
print(grad(f)(4.)) # ok!
|
||||
```
|
||||
|
||||
+++ {"id": "hIfPT7WMmZ2H"}
|
||||
|
||||
### Python control flow + JIT
|
||||
|
||||
Using control flow with `jit` is more complicated, and by default it has more constraints.
|
||||
|
||||
This works:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: OZ_BJX0CplNC
|
||||
:outputId: 60c902a2-eba1-49d7-c8c8-2f68616d660c
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
for i in range(3):
|
||||
x = 2 * x
|
||||
return x
|
||||
|
||||
print(f(3))
|
||||
```
|
||||
|
||||
+++ {"id": "22RzeJ4QqAuX"}
|
||||
|
||||
So does this:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: pinVnmRWp6w6
|
||||
:outputId: 25e06cf2-474f-4782-af7c-4f5514b64422
|
||||
|
||||
@jit
|
||||
def g(x):
|
||||
y = 0.
|
||||
for i in range(x.shape[0]):
|
||||
y = y + x[i]
|
||||
return y
|
||||
|
||||
print(g(jnp.array([1., 2., 3.])))
|
||||
```
|
||||
|
||||
+++ {"id": "TStltU2dqf8A"}
|
||||
|
||||
But this doesn't, at least by default:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 9z38AIKclRNM
|
||||
:outputId: 38dd2075-92fc-4b81-fee0-b9dff8da1fac
|
||||
:tags: [raises-exception]
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
else:
|
||||
return -4 * x
|
||||
|
||||
# This will fail!
|
||||
f(2)
|
||||
```
|
||||
|
||||
+++ {"id": "pIbr4TVPqtDN"}
|
||||
|
||||
__What gives!?__
|
||||
|
||||
When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.
|
||||
|
||||
For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
|
||||
|
||||
To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
|
||||
|
||||
By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
|
||||
|
||||
But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), jnp.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.
|
||||
|
||||
The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: -Tzp0H7Bt1Sn
|
||||
:outputId: f7f664cb-2cd0-4fd7-c685-4ec6ba1c4b7a
|
||||
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
else:
|
||||
return -4 * x
|
||||
|
||||
f = jit(f, static_argnums=(0,))
|
||||
|
||||
print(f(2.))
|
||||
```
|
||||
|
||||
+++ {"id": "MHm1hIQAvBVs"}
|
||||
|
||||
Here's another example, this time involving a loop:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: iwY86_JKvD6b
|
||||
:outputId: 48f9b51f-bd32-466f-eac1-cd23444ce937
|
||||
|
||||
def f(x, n):
|
||||
y = 0.
|
||||
for i in range(n):
|
||||
y = y + x[i]
|
||||
return y
|
||||
|
||||
f = jit(f, static_argnums=(1,))
|
||||
|
||||
f(jnp.array([2., 3., 4.]), 2)
|
||||
```
|
||||
|
||||
+++ {"id": "nSPTOX8DvOeO"}
|
||||
|
||||
In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation
|
||||
|
||||
+++ {"id": "wWdg8LTYwCW3"}
|
||||
|
||||
️⚠️ **functions with argument-__value__ dependent shapes**
|
||||
|
||||
These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: Tqe9uLmUI_Gv
|
||||
:outputId: 989be121-dfce-4bb3-c78e-a10829c5f883
|
||||
|
||||
def example_fun(length, val):
|
||||
return jnp.ones((length,)) * val
|
||||
# un-jit'd works fine
|
||||
print(example_fun(5, 4))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: fOlR54XRgHpd
|
||||
:outputId: cf31d798-a4ce-4069-8e3e-8f9631ff4b71
|
||||
:tags: [raises-exception]
|
||||
|
||||
bad_example_jit = jit(example_fun)
|
||||
# this will fail:
|
||||
bad_example_jit(10, 4)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: kH0lOD4GgFyI
|
||||
:outputId: d009fcf5-c9f9-4ce6-fc60-22dc2cf21ade
|
||||
|
||||
# static_argnums tells JAX to recompile on changes at these argument positions:
|
||||
good_example_jit = jit(example_fun, static_argnums=(0,))
|
||||
# first compile
|
||||
print(good_example_jit(10, 4))
|
||||
# recompiles
|
||||
print(good_example_jit(5, 4))
|
||||
```
|
||||
|
||||
+++ {"id": "MStx_r2oKxpp"}
|
||||
|
||||
`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot!
|
||||
|
||||
Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: m2ABpRd8K094
|
||||
:outputId: 4f7ebe17-ade4-4e18-bd8c-4b24087c33c3
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
print(x)
|
||||
y = 2 * x
|
||||
print(y)
|
||||
return y
|
||||
f(2)
|
||||
```
|
||||
|
||||
+++ {"id": "uCDcWG4MnVn-"}
|
||||
|
||||
### Structured control flow primitives
|
||||
|
||||
There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:
|
||||
|
||||
- `lax.cond` _differentiable_
|
||||
- `lax.while_loop` __fwd-mode-differentiable__
|
||||
- `lax.fori_loop` __fwd-mode-differentiable__ in general; __fwd and rev-mode differentiable__ if endpoints are static.
|
||||
- `lax.scan` _differentiable_
|
||||
|
||||
+++ {"id": "Sd9xrLMXeK3A"}
|
||||
|
||||
#### `cond`
|
||||
python equivalent:
|
||||
|
||||
```python
|
||||
def cond(pred, true_fun, false_fun, operand):
|
||||
if pred:
|
||||
return true_fun(operand)
|
||||
else:
|
||||
return false_fun(operand)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: SGxz9JOWeiyH
|
||||
:outputId: 942a8d0e-5ff6-4702-c499-b3941f529ca3
|
||||
|
||||
from jax import lax
|
||||
|
||||
operand = jnp.array([0.])
|
||||
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
|
||||
# --> array([1.], dtype=float32)
|
||||
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
|
||||
# --> array([-1.], dtype=float32)
|
||||
```
|
||||
|
||||
+++ {"id": "lIYdn1woOS1n"}
|
||||
|
||||
`jax.lax` provides two other functions that allow branching on dynamic predicates:
|
||||
|
||||
- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is
|
||||
like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays
|
||||
rather than as functions.
|
||||
- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is
|
||||
like `lax.cond`, but allows switching between any number of callable choices.
|
||||
|
||||
In addition, `jax.numpy` provides several numpy-style interfaces to these functions:
|
||||
|
||||
- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with
|
||||
three arguments is the numpy-style wrapper of `lax.select`.
|
||||
- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html)
|
||||
is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index.
|
||||
- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has
|
||||
an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather
|
||||
than as functions. It is implemented in terms of multiple calls to `lax.select`.
|
||||
|
||||
+++ {"id": "xkOFAw24eOMg"}
|
||||
|
||||
#### `while_loop`
|
||||
|
||||
python equivalent:
|
||||
```
|
||||
def while_loop(cond_fun, body_fun, init_val):
|
||||
val = init_val
|
||||
while cond_fun(val):
|
||||
val = body_fun(val)
|
||||
return val
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: jM-D39a-c436
|
||||
:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
|
||||
|
||||
init_val = 0
|
||||
cond_fun = lambda x: x < 10
|
||||
body_fun = lambda x: x+1
|
||||
lax.while_loop(cond_fun, body_fun, init_val)
|
||||
# --> array(10, dtype=int32)
|
||||
```
|
||||
|
||||
+++ {"id": "apo3n3HAeQY_"}
|
||||
|
||||
#### `fori_loop`
|
||||
python equivalent:
|
||||
```
|
||||
def fori_loop(start, stop, body_fun, init_val):
|
||||
val = init_val
|
||||
for i in range(start, stop):
|
||||
val = body_fun(i, val)
|
||||
return val
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: dt3tUpOmeR8u
|
||||
:outputId: 7819ca7c-1433-4d85-b542-f6159b0e8380
|
||||
|
||||
init_val = 0
|
||||
start = 0
|
||||
stop = 10
|
||||
body_fun = lambda i,x: x+i
|
||||
lax.fori_loop(start, stop, body_fun, init_val)
|
||||
# --> array(45, dtype=int32)
|
||||
```
|
||||
|
||||
+++ {"id": "SipXS5qiqk8e"}
|
||||
|
||||
#### Summary
|
||||
|
||||
$$
|
||||
\begin{array} {r|rr}
|
||||
\hline \
|
||||
\textrm{construct}
|
||||
& \textrm{jit}
|
||||
& \textrm{grad} \\
|
||||
\hline \
|
||||
\textrm{if} & ❌ & ✔ \\
|
||||
\textrm{for} & ✔* & ✔\\
|
||||
\textrm{while} & ✔* & ✔\\
|
||||
\textrm{lax.cond} & ✔ & ✔\\
|
||||
\textrm{lax.while_loop} & ✔ & \textrm{fwd}\\
|
||||
\textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\
|
||||
\textrm{lax.scan} & ✔ & ✔\\
|
||||
\hline
|
||||
\end{array}
|
||||
$$
|
||||
|
||||
<center>
|
||||
|
||||
$\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop
|
||||
|
||||
</center>
|
||||
Moved to {ref}`control-flow`.
|
||||
|
||||
+++ {"id": "OxLsZUyRt_kF"}
|
||||
|
||||
@ -1145,6 +824,9 @@ Many such cases are discussed in detail in the sections above; here we list seve
|
||||
```
|
||||
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
|
||||
|
||||
## 🔪 Sharp bits covered in tutorials
|
||||
- {ref}`control-flow` discusses how to work with the constraints that `jit` imposes on the use of Python control flow and logical operators.
|
||||
- {ref}`stateful-computations` gives some advice on how to properly handle state in a JAX program, given that JAX transformations can be applied only to pure functions.
|
||||
|
||||
## Fin.
|
||||
|
||||
|
@ -12,6 +12,7 @@ kernelspec:
|
||||
name: python3
|
||||
---
|
||||
|
||||
(stateful-computations)=
|
||||
# Stateful computations
|
||||
|
||||
<!--* freshness: { reviewed: '2024-05-03' } *-->
|
||||
|
@ -16,6 +16,7 @@ Tutorials
|
||||
working-with-pytrees
|
||||
sharded-computation
|
||||
stateful-computations
|
||||
control-flow
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user