DOC: fix deprecated cond() signature in Common_Gotchas notebook

This commit is contained in:
Jake VanderPlas 2021-03-04 13:04:30 -08:00
parent 02cf04b60b
commit 39fb1b7f00
2 changed files with 16 additions and 16 deletions

View File

@ -246,9 +246,9 @@
"\n",
"# lax.cond\n",
"array_operand = jnp.array([0.])\n",
"lax.cond(True, array_operand, lambda x: x+1, array_operand, lambda x: x-1)\n",
"lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)\n",
"iter_operand = iter(range(10))\n",
"# lax.cond(True, iter_operand, lambda x: next(x)+1, iter_operand, lambda x: next(x)-1) # throws error"
"# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error"
]
},
{
@ -1272,12 +1272,12 @@
"#### cond\n",
"python equivalent:\n",
"\n",
"```\n",
"def cond(pred, true_operand, true_fun, false_operand, false_fun):\n",
"```python\n",
"def cond(pred, true_fun, false_fun, operand):\n",
" if pred:\n",
" return true_fun(true_operand)\n",
" return true_fun(operand)\n",
" else:\n",
" return false_fun(false_operand)\n",
" return false_fun(operand)\n",
"```"
]
},
@ -1306,9 +1306,9 @@
"from jax import lax\n",
"\n",
"operand = jnp.array([0.])\n",
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n",
"# --> array([1.], dtype=float32)\n",
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
"lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n",
"# --> array([-1.], dtype=float32)"
]
},

View File

@ -154,9 +154,9 @@ make_jaxpr(func11)(jnp.arange(16), 5.)
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, array_operand, lambda x: x+1, array_operand, lambda x: x-1)
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, iter_operand, lambda x: next(x)+1, iter_operand, lambda x: next(x)-1) # throws error
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
```
+++ {"id": "oBdKtkVW8Lha"}
@ -669,12 +669,12 @@ There are more options for control flow in JAX. Say you want to avoid re-compila
#### cond
python equivalent:
```
def cond(pred, true_operand, true_fun, false_operand, false_fun):
```python
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(true_operand)
return true_fun(operand)
else:
return false_fun(false_operand)
return false_fun(operand)
```
```{code-cell} ipython3
@ -685,9 +685,9 @@ outputId: b29da06c-037f-4b05-dbd8-ba52ac35a8cf
from jax import lax
operand = jnp.array([0.])
lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
```