Merge pull request #24918 from emilyfertig:emilyaf-logical-op-example

PiperOrigin-RevId: 696989966
This commit is contained in:
jax authors 2024-11-15 13:48:14 -08:00
commit 605c605181

View File

@ -340,6 +340,39 @@ $\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop
`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`.
For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar.
```{code-cell}
def python_check_positive_even(x):
is_even = x % 2 == 0
# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
return is_even and (x > 0)
@jit
def jax_check_positive_even(x):
is_even = x % 2 == 0
# `logical_and` does not short circuit, so `x > 0` is always evaluated.
return jnp.logical_and(is_even, x > 0)
print(python_check_positive_even(24))
print(jax_check_positive_even(24))
```
When the JAX version with `logical_and` is applied to an array, it returns elementwise values.
```{code-cell}
x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
```
Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior.
```{code-cell}
:tags: [raises-exception]
print(python_check_positive_even(x))
```
+++ {"id": "izLTvT24dAq0"}
## Python control flow + autodiff