Add an example on logical operators to the tutorial.

This commit is contained in:
Emily Fertig 2024-11-15 11:26:52 -08:00
parent d8085008b7
commit 5f1e3f5644

View File

@ -340,6 +340,39 @@ $\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop
`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`.
For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar.
```{code-cell}
def python_check_positive_even(x):
is_even = x % 2 == 0
# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
return is_even and (x > 0)
@jit
def jax_check_positive_even(x):
is_even = x % 2 == 0
# `logical_and` does not short circuit, so `x > 0` is always evaluated.
return jnp.logical_and(is_even, x > 0)
print(python_check_positive_even(24))
print(jax_check_positive_even(24))
```
When the JAX version with `logical_and` is applied to an array, it returns elementwise values.
```{code-cell}
x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
```
Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior.
```{code-cell}
:tags: [raises-exception]
print(python_check_positive_even(x))
```
+++ {"id": "izLTvT24dAq0"}
## Python control flow + autodiff