such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc.
```{code-cell}
import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array)
```
If you use Python type annotations in your code, {class}`jax.Array` is the appropriate
annotation for jax array objects (see {mod}`jax.typing` for more discussion).
### Array devices and sharding
JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:
```{code-cell}
x.devices()
```
In general, an array may be *sharded* across multiple devices, in a manner that can be inspected via the `sharding` attribute:
```{code-cell}
x.sharding
```
Here the array is on a single device, but in general a JAX array can be
sharded across multiple devices, or even multiple hosts.
Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global `state`, which can be set using {func}`numpy.random.seed`. Global random state interacts poorly with JAX's compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random `key`:
```{code-cell}
from jax import random
key = random.key(43)
print(key)
```
The key is effectively a stand-in for NumPy's hidden state object, but we pass it explicitly to {func}`jax.random` functions.
Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.
```{code-cell}
print(random.normal(key))
print(random.normal(key))
```
**The rule of thumb is: never reuse keys (unless you want identical outputs).**
In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function:
```{code-cell}
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
```
Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. {func}`jax.random.split` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys.
For more on pseudo random numbers in JAX, see the {ref}`pseudorandom-numbers` tutorial.