## The basics: interactive NumPy on GPU and TPU

---



In [None]:
import jax
import jax.numpy as jnp
from jax import random

In [None]:
key = random.key(0)
key, subkey = random.split(key)
x = random.normal(key, (5000, 5000))

print(x.shape)
print(x.dtype)

In [None]:
y = jnp.dot(x, x)
print(y[0, 0])

In [None]:
x

In [None]:
import matplotlib.pyplot as plt

plt.plot(x[0])

In [None]:
print(jnp.dot(x, x.T))

In [None]:
print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])

In [None]:
import numpy as np

x_cpu = np.array(x)
%timeit -n 5 -r 2 np.dot(x_cpu, x_cpu)

In [None]:
%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()

## Automatic differentiation

In [None]:
from jax import grad

In [None]:
def f(x):
 if x > 0:
 return 2 * x ** 3
 else:
 return 3 * x

In [None]:
key = random.key(0)
x = random.normal(key, ())

print(grad(f)(x))
print(grad(f)(-x))

In [None]:
print(grad(grad(f))(-x))
print(grad(grad(grad(f)))(-x))

Other JAX autodiff highlights:

* Forward- and reverse-mode, totally composable
* Fast Jacobians and Hessians
* Complex number support (holomorphic and non-holomorphic)
* Jacobian pre-accumulation for elementwise operations (like `gelu`)


For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html).

## End-to-end compilation with XLA with `jit`

In [None]:
from jax import jit

In [None]:
key = random.key(0)
x = random.normal(key, (5000, 5000))

In [None]:
def f(x):
 y = x
 for _ in range(10):
 y = y - 0.1 * y + 3.
 return y[:100, :100]

f(x)

In [None]:
g = jit(f)
g(x)

In [None]:
%timeit f(x).block_until_ready()

In [None]:
%timeit -n 100 g(x).block_until_ready()

In [None]:
grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)

## Parallelization over multiple accelerators with pmap

In [None]:
jax.device_count()

In [None]:
from jax import pmap

In [None]:
y = pmap(lambda x: x ** 2)(jnp.arange(8))
print(y)

In [None]:
y

In [None]:
import matplotlib.pyplot as plt
plt.plot(y)

### Collective communication operations

In [None]:
from functools import partial
from jax.lax import psum

@partial(pmap, axis_name='i')
def f(x):
 total = psum(x, 'i')
 return x / total, total

normalized, total = f(jnp.arange(8.))

print(f"normalized:\n{normalized}\n")
print("total:", total)

For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb).

## Automatic parallelization with sharded_jit (new!)

In [None]:
from jax.experimental import sharded_jit, PartitionSpec as P

In [None]:
from jax import lax

conv = lambda image, kernel: lax.conv(image, kernel, (1, 1), 'SAME')

In [None]:
image = jnp.ones((1, 8, 2000, 1000)).astype(np.float32)
kernel = jnp.array(np.random.random((8, 8, 5, 5)).astype(np.float32))

np.set_printoptions(edgeitems=1)
conv(image, kernel)

In [None]:
%timeit conv(image, kernel).block_until_ready()

In [None]:
image_partitions = P(1, 1, 4, 2)
sharded_conv = sharded_jit(conv,
 in_parts=(image_partitions, None),
 out_parts=image_partitions)

sharded_conv(image, kernel)

In [None]:
%timeit -n 10 sharded_conv(image, kernel).block_until_ready()