mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
DOC: fix deprecated cond() signature in Common_Gotchas notebook
This commit is contained in:
parent
02cf04b60b
commit
39fb1b7f00
@ -246,9 +246,9 @@
|
||||
"\n",
|
||||
"# lax.cond\n",
|
||||
"array_operand = jnp.array([0.])\n",
|
||||
"lax.cond(True, array_operand, lambda x: x+1, array_operand, lambda x: x-1)\n",
|
||||
"lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)\n",
|
||||
"iter_operand = iter(range(10))\n",
|
||||
"# lax.cond(True, iter_operand, lambda x: next(x)+1, iter_operand, lambda x: next(x)-1) # throws error"
|
||||
"# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1272,12 +1272,12 @@
|
||||
"#### cond\n",
|
||||
"python equivalent:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"def cond(pred, true_operand, true_fun, false_operand, false_fun):\n",
|
||||
"```python\n",
|
||||
"def cond(pred, true_fun, false_fun, operand):\n",
|
||||
" if pred:\n",
|
||||
" return true_fun(true_operand)\n",
|
||||
" return true_fun(operand)\n",
|
||||
" else:\n",
|
||||
" return false_fun(false_operand)\n",
|
||||
" return false_fun(operand)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
@ -1306,9 +1306,9 @@
|
||||
"from jax import lax\n",
|
||||
"\n",
|
||||
"operand = jnp.array([0.])\n",
|
||||
"lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
"lax.cond(True, lambda x: x+1, lambda x: x-1, operand)\n",
|
||||
"# --> array([1.], dtype=float32)\n",
|
||||
"lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n",
|
||||
"lax.cond(False, lambda x: x+1, lambda x: x-1, operand)\n",
|
||||
"# --> array([-1.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
|
@ -154,9 +154,9 @@ make_jaxpr(func11)(jnp.arange(16), 5.)
|
||||
|
||||
# lax.cond
|
||||
array_operand = jnp.array([0.])
|
||||
lax.cond(True, array_operand, lambda x: x+1, array_operand, lambda x: x-1)
|
||||
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
|
||||
iter_operand = iter(range(10))
|
||||
# lax.cond(True, iter_operand, lambda x: next(x)+1, iter_operand, lambda x: next(x)-1) # throws error
|
||||
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
|
||||
```
|
||||
|
||||
+++ {"id": "oBdKtkVW8Lha"}
|
||||
@ -669,12 +669,12 @@ There are more options for control flow in JAX. Say you want to avoid re-compila
|
||||
#### cond
|
||||
python equivalent:
|
||||
|
||||
```
|
||||
def cond(pred, true_operand, true_fun, false_operand, false_fun):
|
||||
```python
|
||||
def cond(pred, true_fun, false_fun, operand):
|
||||
if pred:
|
||||
return true_fun(true_operand)
|
||||
return true_fun(operand)
|
||||
else:
|
||||
return false_fun(false_operand)
|
||||
return false_fun(operand)
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
@ -685,9 +685,9 @@ outputId: b29da06c-037f-4b05-dbd8-ba52ac35a8cf
|
||||
from jax import lax
|
||||
|
||||
operand = jnp.array([0.])
|
||||
lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)
|
||||
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
|
||||
# --> array([1.], dtype=float32)
|
||||
lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)
|
||||
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
|
||||
# --> array([-1.], dtype=float32)
|
||||
```
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user