"Open

# JAX Colab GPU Test

This notebook is meant to be run in a [Colab](http://colab.research.google.com) GPU runtime as a basic check for JAX updates.

In [1]:
import jax
import jaxlib

!cat /var/colab/hostname
print(jax.__version__)
print(jaxlib.__version__)

gpu-t4-s-kbefivsjoreh
0.1.64
0.1.45


## Confirm Device

In [2]:
import jax
key = jax.random.PRNGKey(1701)
arr = jax.random.normal(key, (1000,))
device = list(arr.devices())[0]
print(f"JAX device type: {device}")
assert device.platform == "gpu", "unexpected JAX device type"

JAX device type: gpu:0


## Matrix Multiplication

In [3]:
import jax
import numpy as np

# matrix multiplication on GPU
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (3000, 3000))
result = jax.numpy.dot(x, x.T).mean()
print(result)

1.0216676


## Linear Algebra

In [4]:
import jax.numpy as jnp
import jax.random as rand

N = 10
M = 20
key = rand.PRNGKey(1701)

X = rand.normal(key, (N, M))
u, s, vt = jnp.linalg.svd(X)
assert u.shape == (N, N)
assert vt.shape == (M, M)
print(s)

[6.9178247 5.9580336 5.5811076 4.5069666 4.1115823 3.9735446 3.3307252
 2.866489 1.8229384 1.5478926]


## XLA Compilation

In [5]:
@jax.jit
def selu(x, alpha=1.67, lmbda=1.05):
 return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
x = jax.random.normal(key, (5000,))
result = selu(x).block_until_ready()
print(result)

[ 0.34676838 -0.7532232 1.7060698 ... 2.1208055 -0.42621925
 0.13093245]
