## The basics: interactive NumPy on GPU and TPU

---



In [0]:
import jax
import jax.numpy as jnp
from jax import random

key = random.key(0)

In [0]:
key, subkey = random.split(key)
x = random.normal(key, (5000, 5000))

print(x.shape)
print(x.dtype)

In [0]:
y = jnp.dot(x, x)
print(y[0, 0])

In [0]:
x

In [0]:
import matplotlib.pyplot as plt

plt.plot(x[0])

In [0]:
jnp.dot(x, x.T)

In [0]:
print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])

In [0]:
import numpy as np

x_cpu = np.array(x)
%timeit -n 1 -r 1 np.dot(x_cpu, x_cpu)

In [0]:
%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()

## Automatic differentiation

In [0]:
from jax import grad

In [0]:
def f(x):
 if x > 0:
 return 2 * x ** 3
 else:
 return 3 * x

In [0]:
key = random.key(0)
x = random.normal(key, ())

print(grad(f)(x))
print(grad(f)(-x))

In [0]:
print(grad(grad(f))(-x))
print(grad(grad(grad(f)))(-x))

In [0]:
def predict(params, inputs):
 for W, b in params:
 outputs = jnp.dot(inputs, W) + b
 inputs = jnp.tanh(outputs) # inputs to the next layer
 return outputs # no activation on last layer

def loss(params, batch):
 inputs, targets = batch
 predictions = predict(params, inputs)
 return jnp.sum((predictions - targets)**2)



def init_layer(key, n_in, n_out):
 k1, k2 = random.split(key)
 W = random.normal(k1, (n_in, n_out))
 b = random.normal(k2, (n_out,))
 return W, b

layer_sizes = [5, 2, 3]

key = random.key(0)
key, *keys = random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

key, *keys = random.split(key, 3)
inputs = random.normal(keys[0], (8, 5))
targets = random.normal(keys[1], (8, 3))
batch = (inputs, targets)

In [0]:
print(loss(params, batch))

In [0]:
step_size = 1e-2

for _ in range(20):
 grads = grad(loss)(params, batch)
 params = [(W - step_size * dW, b - step_size * db)
 for (W, b), (dW, db) in zip(params, grads)]

In [0]:
print(loss(params, batch))

Other JAX autodiff highlights:

* Forward- and reverse-mode, totally composable
* Fast Jacobians and Hessians
* Complex number support (holomorphic and non-holomorphic)
* Jacobian pre-accumulation for elementwise operations (like `gelu`)


For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html).

## End-to-end compilation with XLA using `jit`

In [0]:
from jax import jit

In [0]:
key = random.key(0)
x = random.normal(key, (5000, 5000))

In [0]:
def f(x):
 y = x
 for _ in range(10):
 y = y - 0.1 * y + 3.
 return y[:100, :100]

f(x)

In [0]:
g = jit(f)
g(x)

In [0]:
%timeit f(x).block_until_ready()

In [0]:
%timeit g(x).block_until_ready()

In [0]:
grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)

### Constraints that come with using `jit`

In [0]:
def f(x):
 if x > 0:
 return 2 * x ** 2
 else:
 return 3 * x

g = jit(f)

In [0]:
f(2)

In [0]:
try:
 g(2)
except Exception as e:
 print(e)
 pass

In [0]:
def f(x, n):
 i = 0
 while i < n:
 x = x * x
 i += 1
 return x

g = jit(f)

In [0]:
f(jnp.array([1., 2., 3.]), 5)

In [0]:
try:
 g(jnp.array([1., 2., 3.]), 5)
except Exception as e:
 print(e)
 pass

In [0]:
g = jit(f, static_argnums=(1,))

In [0]:
g(jnp.array([1., 2., 3.]), 5)

## Vectorization with `vmap`

In [0]:
from jax import vmap

In [0]:
print(vmap(lambda x: x**2)(jnp.arange(8)))

In [0]:
from jax import make_jaxpr

make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))

In [0]:
make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))

In [0]:
make_jaxpr(vmap(vmap(jnp.dot)))(jnp.ones((10, 10, 8)), jnp.ones((10, 10, 8)))

In [0]:
perex_grads = vmap(grad(loss), in_axes=(None, 0))
make_jaxpr(perex_grads)(params, batch)

## Parallel accelerators with pmap

In [0]:
jax.devices()

In [0]:
from jax import pmap

In [0]:
y = pmap(lambda x: x ** 2)(jnp.arange(8))
print(y)

In [0]:
y

In [0]:
z = y / 2
print(z)

In [0]:
import matplotlib.pyplot as plt
plt.plot(y)

In [0]:
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)
result = pmap(jnp.dot)(mats, mats)
print(pmap(jnp.mean)(result))

In [0]:
timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()

### Collective communication operations

In [0]:
from functools import partial
from jax.lax import psum

@partial(pmap, axis_name='i')
def normalize(x):
 return x / psum(x, 'i')

print(normalize(jnp.arange(8.)))

In [0]:
@partial(pmap, axis_name='rows')
@partial(pmap, axis_name='cols')
def f(x):
 row_sum = psum(x, 'rows')
 col_sum = psum(x, 'cols')
 total_sum = psum(x, ('rows', 'cols'))
 return row_sum, col_sum, total_sum

x = jnp.arange(8.).reshape((4, 2))
a, b, c = f(x)

print("input:\n", x)
print("row sum:\n", a)
print("col sum:\n", b)
print("total sum:\n", c)

<img src="https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/nested_pmap.png" width="70%"/>

For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb).

### Compose pmap with other transforms!

In [0]:
@pmap
def f(x):
 y = jnp.sin(x)
 @pmap
 def g(z):
 return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
 return grad(lambda w: jnp.sum(g(w)))(x)

f(x)

In [0]:
grad(lambda x: jnp.sum(f(x)))(x)

### Compose everything

In [0]:
from jax import jvp, vjp # forward and reverse-mode

curry = lambda f: partial(partial, f)

@curry
def jacfwd(fun, x):
 pushfwd = partial(jvp, fun, (x,)) # jvp!
 std_basis = jnp.eye(np.size(x)).reshape((-1,) + jnp.shape(x)),
 y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis) # vmap!
 return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))

@curry
def jacrev(fun, x):
 y, pullback = vjp(fun, x) # vjp!
 std_basis = jnp.eye(np.size(y)).reshape((-1,) + jnp.shape(y))
 jac_flat, = vmap(pullback)(std_basis) # vmap!
 return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))

def hessian(fun):
 return jit(jacfwd(jacrev(fun))) # jit!

In [0]:
input_hess = hessian(lambda inputs: loss(params, (inputs, targets)))
per_example_hess = pmap(input_hess) # pmap!
per_example_hess(inputs)