From 5f1e3f5644b6705b21b5e030d241a514c244c2c4 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Fri, 15 Nov 2024 11:26:52 -0800 Subject: [PATCH] Add an example on logical operators to the tutorial. --- docs/control-flow.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/control-flow.md b/docs/control-flow.md index 04eb3cac8..7cb959f3e 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -340,6 +340,39 @@ $\ast$ = argument-value-independent loop condition - unrolls the loop `jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`. +For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar. + +```{code-cell} +def python_check_positive_even(x): + is_even = x % 2 == 0 + # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated. + return is_even and (x > 0) + +@jit +def jax_check_positive_even(x): + is_even = x % 2 == 0 + # `logical_and` does not short circuit, so `x > 0` is always evaluated. + return jnp.logical_and(is_even, x > 0) + +print(python_check_positive_even(24)) +print(jax_check_positive_even(24)) +``` + +When the JAX version with `logical_and` is applied to an array, it returns elementwise values. + +```{code-cell} +x = jnp.array([-1, 2, 5]) +print(jax_check_positive_even(x)) +``` + +Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior. + +```{code-cell} +:tags: [raises-exception] + +print(python_check_positive_even(x)) +``` + +++ {"id": "izLTvT24dAq0"} ## Python control flow + autodiff