Sharp Bits: add section on Dynamic shapes

This commit is contained in:
Jake VanderPlas 2023-01-19 11:37:03 -08:00
parent 7085699832
commit 9e355a6606
2 changed files with 238 additions and 11 deletions

View File

@ -1352,7 +1352,7 @@
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at <ipython-input-39-fe5ae3470df9>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
"\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at <ipython-input-31-fe5ae3470df9>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
]
}
],
@ -1640,7 +1640,7 @@
"id": "Sd9xrLMXeK3A"
},
"source": [
"#### cond\n",
"#### `cond`\n",
"python equivalent:\n",
"\n",
"```python\n",
@ -1690,7 +1690,7 @@
"id": "xkOFAw24eOMg"
},
"source": [
"#### while_loop\n",
"#### `while_loop`\n",
"\n",
"python equivalent:\n",
"```\n",
@ -1738,7 +1738,7 @@
"id": "apo3n3HAeQY_"
},
"source": [
"#### fori_loop\n",
"#### `fori_loop`\n",
"python equivalent:\n",
"```\n",
"def fori_loop(start, stop, body_fun, init_val):\n",
@ -1813,6 +1813,158 @@
"</center>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxLsZUyRt_kF"
},
"source": [
"## 🔪 Dynamic Shapes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1tKXcAMduDR1"
},
"source": [
"JAX code used within transforms like `jax.jit`, `jax.vmap`, `jax.grad`, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.\n",
"\n",
"For example, if you were implementing your own version of `jnp.nansum`, you might start with something like this:"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"id": "9GIwgvfLujiD"
},
"outputs": [],
"source": [
"def nansum(x):\n",
" mask = ~jnp.isnan(x) # boolean mask selecting non-nan values\n",
" x_without_nans = x[mask]\n",
" return x_without_nans.sum()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "43S7wYAiupGe"
},
"source": [
"Outside JIT and other transforms, this works as expected:"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ITYoNQEZur4s",
"outputId": "a9a03d25-9c54-43b6-d35e-aea6c448d680"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.0\n"
]
}
],
"source": [
"x = jnp.array([1, 2, jnp.nan, 3, 4])\n",
"print(nansum(x))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guup5n8xvGI-"
},
"source": [
"If you attempt to apply `jax.jit` or another transform to this function, it will error:"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 114
},
"id": "nms9KjQEvNTz",
"outputId": "d8ae982f-111d-45b6-99f8-37715e2eaab3",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "NonConcreteBooleanIndexError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n"
]
}
],
"source": [
"jax.jit(nansum)(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r2aGyHDkvauu"
},
"source": [
"The problem is that the size of `x_without_nans` is dependent on the values within `x`, which is another way of saying its size is *dynamic*.\n",
"Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means.\n",
"For example, here it is possible to use the three-argument form of `jnp.where` to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zbuj7Dg1wnSg",
"outputId": "81a5e356-cd28-4709-b307-07c6254c82de"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10.0\n"
]
}
],
"source": [
"@jax.jit\n",
"def nansum_2(x):\n",
" mask = ~jnp.isnan(x) # boolean mask selecting non-nan values\n",
" return jnp.where(mask, x, 0).sum()\n",
"\n",
"print(nansum_2(x))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uGH-jqK7wxTl"
},
"source": [
"Similar tricks can be played in other situations where dynamically-shaped arrays occur."
]
},
{
"cell_type": "markdown",
"metadata": {
@ -1991,7 +2143,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 45,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -2006,7 +2158,7 @@
"dtype('float32')"
]
},
"execution_count": 41,
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
@ -2059,7 +2211,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 46,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
@ -2074,7 +2226,7 @@
"dtype('float32')"
]
},
"execution_count": 42,
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}

View File

@ -825,7 +825,7 @@ There are more options for control flow in JAX. Say you want to avoid re-compila
+++ {"id": "Sd9xrLMXeK3A"}
#### cond
#### `cond`
python equivalent:
```python
@ -854,7 +854,7 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
+++ {"id": "xkOFAw24eOMg"}
#### while_loop
#### `while_loop`
python equivalent:
```
@ -881,7 +881,7 @@ lax.while_loop(cond_fun, body_fun, init_val)
+++ {"id": "apo3n3HAeQY_"}
#### fori_loop
#### `fori_loop`
python equivalent:
```
def fori_loop(start, stop, body_fun, init_val):
@ -934,6 +934,81 @@ $\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop
</center>
+++ {"id": "OxLsZUyRt_kF"}
## 🔪 Dynamic Shapes
+++ {"id": "1tKXcAMduDR1"}
JAX code used within transforms like `jax.jit`, `jax.vmap`, `jax.grad`, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.
For example, if you were implementing your own version of `jnp.nansum`, you might start with something like this:
```{code-cell} ipython3
:id: 9GIwgvfLujiD
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
```
+++ {"id": "43S7wYAiupGe"}
Outside JIT and other transforms, this works as expected:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: ITYoNQEZur4s
outputId: a9a03d25-9c54-43b6-d35e-aea6c448d680
---
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
```
+++ {"id": "guup5n8xvGI-"}
If you attempt to apply `jax.jit` or another transform to this function, it will error:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
height: 114
id: nms9KjQEvNTz
outputId: d8ae982f-111d-45b6-99f8-37715e2eaab3
tags: [raises-exception]
---
jax.jit(nansum)(x)
```
+++ {"id": "r2aGyHDkvauu"}
The problem is that the size of `x_without_nans` is dependent on the values within `x`, which is another way of saying its size is *dynamic*.
Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means.
For example, here it is possible to use the three-argument form of `jnp.where` to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:
```{code-cell} ipython3
---
colab:
base_uri: https://localhost:8080/
id: Zbuj7Dg1wnSg
outputId: 81a5e356-cd28-4709-b307-07c6254c82de
---
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
```
+++ {"id": "uGH-jqK7wxTl"}
Similar tricks can be played in other situations where dynamically-shaped arrays occur.
+++ {"id": "DKTMw6tRZyK2"}
## 🔪 NaNs