mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add an example on logical operators to the tutorial.
This commit is contained in:
parent
d8085008b7
commit
5f1e3f5644
@ -340,6 +340,39 @@ $\ast$ = argument-<b>value</b>-independent loop condition - unrolls the loop
|
||||
|
||||
`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not`, which operate element-wise on arrays and can be evaluated under `jit` without recompiling. Like their Numpy counterparts, the binary operators do not short circuit. Bitwise operators (`&`, `|`, `~`) can also be used with `jit`.
|
||||
|
||||
For example, consider a function that checks if its input is a positive even integer. The pure Python and JAX versions give the same answer when the input is scalar.
|
||||
|
||||
```{code-cell}
|
||||
def python_check_positive_even(x):
|
||||
is_even = x % 2 == 0
|
||||
# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
|
||||
return is_even and (x > 0)
|
||||
|
||||
@jit
|
||||
def jax_check_positive_even(x):
|
||||
is_even = x % 2 == 0
|
||||
# `logical_and` does not short circuit, so `x > 0` is always evaluated.
|
||||
return jnp.logical_and(is_even, x > 0)
|
||||
|
||||
print(python_check_positive_even(24))
|
||||
print(jax_check_positive_even(24))
|
||||
```
|
||||
|
||||
When the JAX version with `logical_and` is applied to an array, it returns elementwise values.
|
||||
|
||||
```{code-cell}
|
||||
x = jnp.array([-1, 2, 5])
|
||||
print(jax_check_positive_even(x))
|
||||
```
|
||||
|
||||
Python logical operators error when applied to JAX arrays of more than one element, even without `jit`. This replicates NumPy's behavior.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [raises-exception]
|
||||
|
||||
print(python_check_positive_even(x))
|
||||
```
|
||||
|
||||
+++ {"id": "izLTvT24dAq0"}
|
||||
|
||||
## Python control flow + autodiff
|
||||
|
Loading…
x
Reference in New Issue
Block a user