mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Sharp Bits: add section on Dynamic shapes
This commit is contained in:
parent
7085699832
commit
9e355a6606
@ -1352,7 +1352,7 @@
|
||||
"evalue": "ignored",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at <ipython-input-39-fe5ae3470df9>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
|
||||
"\u001b[0;31mConcretizationTypeError\u001b[0m\u001b[0;31m:\u001b[0m Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the `bool` function. \nThe error occurred while tracing the function f at <ipython-input-31-fe5ae3470df9>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1640,7 +1640,7 @@
|
||||
"id": "Sd9xrLMXeK3A"
|
||||
},
|
||||
"source": [
|
||||
"#### cond\n",
|
||||
"#### `cond`\n",
|
||||
"python equivalent:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
@ -1690,7 +1690,7 @@
|
||||
"id": "xkOFAw24eOMg"
|
||||
},
|
||||
"source": [
|
||||
"#### while_loop\n",
|
||||
"#### `while_loop`\n",
|
||||
"\n",
|
||||
"python equivalent:\n",
|
||||
"```\n",
|
||||
@ -1738,7 +1738,7 @@
|
||||
"id": "apo3n3HAeQY_"
|
||||
},
|
||||
"source": [
|
||||
"#### fori_loop\n",
|
||||
"#### `fori_loop`\n",
|
||||
"python equivalent:\n",
|
||||
"```\n",
|
||||
"def fori_loop(start, stop, body_fun, init_val):\n",
|
||||
@ -1813,6 +1813,158 @@
|
||||
"</center>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "OxLsZUyRt_kF"
|
||||
},
|
||||
"source": [
|
||||
"## 🔪 Dynamic Shapes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "1tKXcAMduDR1"
|
||||
},
|
||||
"source": [
|
||||
"JAX code used within transforms like `jax.jit`, `jax.vmap`, `jax.grad`, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.\n",
|
||||
"\n",
|
||||
"For example, if you were implementing your own version of `jnp.nansum`, you might start with something like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"id": "9GIwgvfLujiD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def nansum(x):\n",
|
||||
" mask = ~jnp.isnan(x) # boolean mask selecting non-nan values\n",
|
||||
" x_without_nans = x[mask]\n",
|
||||
" return x_without_nans.sum()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "43S7wYAiupGe"
|
||||
},
|
||||
"source": [
|
||||
"Outside JIT and other transforms, this works as expected:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "ITYoNQEZur4s",
|
||||
"outputId": "a9a03d25-9c54-43b6-d35e-aea6c448d680"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"10.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x = jnp.array([1, 2, jnp.nan, 3, 4])\n",
|
||||
"print(nansum(x))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "guup5n8xvGI-"
|
||||
},
|
||||
"source": [
|
||||
"If you attempt to apply `jax.jit` or another transform to this function, it will error:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 114
|
||||
},
|
||||
"id": "nms9KjQEvNTz",
|
||||
"outputId": "d8ae982f-111d-45b6-99f8-37715e2eaab3",
|
||||
"tags": [
|
||||
"raises-exception"
|
||||
]
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NonConcreteBooleanIndexError",
|
||||
"evalue": "ignored",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jax.jit(nansum)(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "r2aGyHDkvauu"
|
||||
},
|
||||
"source": [
|
||||
"The problem is that the size of `x_without_nans` is dependent on the values within `x`, which is another way of saying its size is *dynamic*.\n",
|
||||
"Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means.\n",
|
||||
"For example, here it is possible to use the three-argument form of `jnp.where` to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Zbuj7Dg1wnSg",
|
||||
"outputId": "81a5e356-cd28-4709-b307-07c6254c82de"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"10.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@jax.jit\n",
|
||||
"def nansum_2(x):\n",
|
||||
" mask = ~jnp.isnan(x) # boolean mask selecting non-nan values\n",
|
||||
" return jnp.where(mask, x, 0).sum()\n",
|
||||
"\n",
|
||||
"print(nansum_2(x))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "uGH-jqK7wxTl"
|
||||
},
|
||||
"source": [
|
||||
"Similar tricks can be played in other situations where dynamically-shaped arrays occur."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@ -1991,7 +2143,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 45,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@ -2006,7 +2158,7 @@
|
||||
"dtype('float32')"
|
||||
]
|
||||
},
|
||||
"execution_count": 41,
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -2059,7 +2211,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 46,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@ -2074,7 +2226,7 @@
|
||||
"dtype('float32')"
|
||||
]
|
||||
},
|
||||
"execution_count": 42,
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -825,7 +825,7 @@ There are more options for control flow in JAX. Say you want to avoid re-compila
|
||||
|
||||
+++ {"id": "Sd9xrLMXeK3A"}
|
||||
|
||||
#### cond
|
||||
#### `cond`
|
||||
python equivalent:
|
||||
|
||||
```python
|
||||
@ -854,7 +854,7 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
|
||||
|
||||
+++ {"id": "xkOFAw24eOMg"}
|
||||
|
||||
#### while_loop
|
||||
#### `while_loop`
|
||||
|
||||
python equivalent:
|
||||
```
|
||||
@ -881,7 +881,7 @@ lax.while_loop(cond_fun, body_fun, init_val)
|
||||
|
||||
+++ {"id": "apo3n3HAeQY_"}
|
||||
|
||||
#### fori_loop
|
||||
#### `fori_loop`
|
||||
python equivalent:
|
||||
```
|
||||
def fori_loop(start, stop, body_fun, init_val):
|
||||
@ -934,6 +934,81 @@ $\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop
|
||||
|
||||
</center>
|
||||
|
||||
+++ {"id": "OxLsZUyRt_kF"}
|
||||
|
||||
## 🔪 Dynamic Shapes
|
||||
|
||||
+++ {"id": "1tKXcAMduDR1"}
|
||||
|
||||
JAX code used within transforms like `jax.jit`, `jax.vmap`, `jax.grad`, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.
|
||||
|
||||
For example, if you were implementing your own version of `jnp.nansum`, you might start with something like this:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 9GIwgvfLujiD
|
||||
|
||||
def nansum(x):
|
||||
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
|
||||
x_without_nans = x[mask]
|
||||
return x_without_nans.sum()
|
||||
```
|
||||
|
||||
+++ {"id": "43S7wYAiupGe"}
|
||||
|
||||
Outside JIT and other transforms, this works as expected:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: ITYoNQEZur4s
|
||||
outputId: a9a03d25-9c54-43b6-d35e-aea6c448d680
|
||||
---
|
||||
x = jnp.array([1, 2, jnp.nan, 3, 4])
|
||||
print(nansum(x))
|
||||
```
|
||||
|
||||
+++ {"id": "guup5n8xvGI-"}
|
||||
|
||||
If you attempt to apply `jax.jit` or another transform to this function, it will error:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
height: 114
|
||||
id: nms9KjQEvNTz
|
||||
outputId: d8ae982f-111d-45b6-99f8-37715e2eaab3
|
||||
tags: [raises-exception]
|
||||
---
|
||||
jax.jit(nansum)(x)
|
||||
```
|
||||
|
||||
+++ {"id": "r2aGyHDkvauu"}
|
||||
|
||||
The problem is that the size of `x_without_nans` is dependent on the values within `x`, which is another way of saying its size is *dynamic*.
|
||||
Often in JAX it is possible to work-around the need for dynamically-sized arrays via other means.
|
||||
For example, here it is possible to use the three-argument form of `jnp.where` to replace the NaN values with zeros, thus computing the same result while avoiding dynamic shapes:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: Zbuj7Dg1wnSg
|
||||
outputId: 81a5e356-cd28-4709-b307-07c6254c82de
|
||||
---
|
||||
@jax.jit
|
||||
def nansum_2(x):
|
||||
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
|
||||
return jnp.where(mask, x, 0).sum()
|
||||
|
||||
print(nansum_2(x))
|
||||
```
|
||||
|
||||
+++ {"id": "uGH-jqK7wxTl"}
|
||||
|
||||
Similar tricks can be played in other situations where dynamically-shaped arrays occur.
|
||||
|
||||
+++ {"id": "DKTMw6tRZyK2"}
|
||||
|
||||
## 🔪 NaNs
|
||||
|
Loading…
x
Reference in New Issue
Block a user